Optimizing some deep learning code may seem quite complicated. After all, PyTorch is already super optimized so why (and how) one could improve what is already great ?
For the why, there are many reasons:
Why ?
GPU time is expensive
As data scientist, we are used to telling our bosses or investors that we need more money to improve our models (oh look at that beautiful GPU), what if you could halve your training time for free ?
Less training time is more accuracy
The money argument apart, you can tests more models in the same time! In the end, it means that you are more likely to find a better model in a same amount of time.
Optimizing code is super satisfying
Honestly, there is no other way to put it. Instead of optimizing the accuracy of your model, try to optimize its training time. See how good this feels!
Save the planet
Well, it is actually a pity to think that in the example below (taken from a typical training loop for PyTorch), half of the time is simply wasted. We will not degrade the (validation) loss of the model, yet, we will spend less electricity and time to achieve it!
How ?
Well, we will not touch anything that is inside PyTorch, obviously. I just noted that, as data scientist, we may not be too aware of low hanging fruits in terms of performance in our scripts. Here, the tricks will mostly lie into data loading and transformations.
A concrete case
Without further due, let’s start! I will focus on the most common parts of a PyTorch training script, in the case of image recognition. Here the problem is a transfer learning from a pretrained model (resnet34, because it is fast to execute) to a binary classification problem.
The images are satellite data so they have more channels than a usual RGB image.
We will focus on the image loading function:
def load_and_convert_tiff(file_path):
image = tiff.imread(file_path)
R = image[:, :, 1]*255*2
G = image[:, :, 2]*255*2
B = image[:, :, 3]*255*2
rgb_image = np.stack((R, G, B), axis=2).astype(np.uint8)
rgb_image = Image.fromarray(rgb_image)
return rgb_image
And the Dataset
class:
class Sentinel2Dataset(Dataset):
def __init__(self, file_paths, labels, transform=None):
self.file_paths = file_paths
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx):
img_path = self.file_paths[idx]
image = load_and_convert_tiff(img_path)
if self.transform is not None:
image = self.transform(image)
label = self.labels[idx]
return image, label
Along with the composition of transformations:
augment_and_transform = transforms.Compose(
[transforms.Resize(334),
transforms.ToTensor(),
transforms.RandomRotation(degrees=90),
transforms.RandomVerticalFlip(p=0.5),
transforms.RandomHorizontalFlip(p=0.5),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),])
We have a usual training loop, similar to what PyTorch suggests to do.
In order to keep the article simple, I will not dive into other parts of the code until it becomes necessary.
Let’s just assume that the training loop is properly implemented, and we use tqdm to measure the training time of our model.
These results will be our baseline for what follows:
32it [00:11, 2.76it/s]
EPOCH: 0 | LOSS train: 0.511 | LOSS valid: 0.495
32it [00:11, 2.82it/s]
EPOCH: 1 | LOSS train: 0.389 | LOSS valid: 0.405
32it [00:10, 3.00it/s]
EPOCH: 2 | LOSS train: 0.352 | LOSS valid: 0.359
Where 2.76it/s means 2.76 batches pass per second (the dataset is split into 32 batches). This is the number we will focus on.
We will also keep an eye on the training and validation losses to make sure we do not break things.
Get rid of the useless
The part rgb_image = Image.fromarray(rgb_image)
is actually useless. Some people use it because transforms.Resize()
may behave slightly differently on PIL images than on tensors (depending on the parameters you feed to the transform).
We can simply remove it from the function:
def load_and_convert_tiff(file_path):
image = tiff.imread(file_path)
R = image[:, :, 1]*255*2
G = image[:, :, 2]*255*2
B = image[:, :, 3]*255*2
rgb_image = np.stack((R, G, B), axis=2).astype(np.uint8)
return rgb_image
And now we switch the transforms.ToTensor()
and transforms.Resize(...)
as transforms does not support numpy arrays.
augment_and_transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Resize(334, antialias=False, interpolation=InterpolationMode.NEAREST_EXACT),
transforms.RandomRotation(degrees=90),
transforms.RandomVerticalFlip(p=0.5),
transforms.RandomHorizontalFlip(p=0.5),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),])
We can rerun our script and…
32it [00:09, 3.45it/s]
EPOCH: 0 | LOSS train: 0.519 | LOSS valid: 0.506
32it [00:09, 3.43it/s]
EPOCH: 1 | LOSS train: 0.389 | LOSS valid: 0.407
32it [00:09, 3.38it/s]
EPOCH: 2 | LOSS train: 0.351 | LOSS valid: 0.357
TADA ! the speed up is already impressive!
Use your RAM if possible!
Note that this advice will not work if your dataset is too large!
Here, we have 1000 training images. This is quite low, my machine has 32GB of RAM so I might as well load them once and for all. Besides, it will save my SSD.
Instead of reading the image from the hard drive each time __getitem__
is called, we can make an array of image which will be stored in memory.
In my case, these images only represents 10% of my RAM.
class Sentinel2Dataset(Dataset):
def __init__(self, file_paths, labels, transform=None):
self.file_paths = file_paths
self.images = []
for file_path in tqdm(self.file_paths):
image = load_and_convert_tiff(file_path)
self.images.append(image)
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx):
img_path = self.file_paths[idx]
image = self.images[idx]
if self.transform is not None:
image = self.transform(image)
label = self.labels[idx]
return image, label
Note the overhead! Indeed, when creating the class, it takes 2 seconds to load all the images in memory.
100%|███████████████████████████████| 993/993 [00:02<00:00, 418.36it/s]
100%|███████████████████████████████| 249/249 [00:00<00:00, 380.70it/s]
But waow, the speedup in training is totally worth it! We are close to halving our initial training time.
32it [00:06, 5.01it/s]
EPOCH: 0 | LOSS train: 0.519 | LOSS valid: 0.506
32it [00:06, 5.31it/s]
EPOCH: 1 | LOSS train: 0.389 | LOSS valid: 0.407
32it [00:06, 5.30it/s]
EPOCH: 2 | LOSS train: 0.351 | LOSS valid: 0.357
The resize happens all the time for the images. So each pass on the whole training set resizes the same image again and again. Let’s get rid of it.
Let’s turn the augment transform:
augment_and_transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Resize(334, antialias=False, interpolation=InterpolationMode.NEAREST_EXACT),
transforms.RandomRotation(degrees=90),
transforms.RandomVerticalFlip(p=0.5),
transforms.RandomHorizontalFlip(p=0.5),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),])
To:
augment_and_transform = transforms.Compose(
[transforms.RandomRotation(degrees=90),
transforms.RandomVerticalFlip(p=0.5),
transforms.RandomHorizontalFlip(p=0.5),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),])
So that only the data augmentation happens here.
Now, when loading the images in memory, let’s perform the common transformations:
class Sentinel2Dataset(Dataset):
def __init__(self, file_paths, labels, transform=None):
factored_transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Resize(334, antialias=False, interpolation=InterpolationMode.NEAREST_EXACT)])
self.file_paths = file_paths
self.images = []
for file_path in tqdm(self.file_paths):
image = load_and_convert_tiff(file_path)
transformed_image = factored_transform(image)
self.images.append(transformed_image)
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx):
img_path = self.file_paths[idx]
image = self.images[idx]
if self.transform is not None:
image = self.transform(image)
label = self.labels[idx]
return image, label
And this is it, the training time decreased once more:
32it [00:06, 5.12it/s]
EPOCH: 0 | LOSS train: 0.519 | LOSS valid: 0.506
32it [00:05, 5.45it/s]
EPOCH: 1 | LOSS train: 0.389 | LOSS valid: 0.407
32it [00:05, 5.46it/s]
EPOCH: 2 | LOSS train: 0.351 | LOSS valid: 0.357
Play with deterministic / benchmark
Some of you may be familiar with the deterministic
and benchmark
flags. I usually see this useful function:
def seed_everything(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
If you inverse the booleans, your results may not be the same at every run (the difference should be low though), but you will gain some extra time.
def seed_everything(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
As the running output shows:
32it [00:07, 4.39it/s]
EPOCH: 0 | LOSS train: 0.519 | LOSS valid: 0.487
32it [00:05, 5.61it/s]
EPOCH: 1 | LOSS train: 0.392 | LOSS valid: 0.410
32it [00:05, 5.63it/s]
EPOCH: 2 | LOSS train: 0.358 | LOSS valid: 0.400
Note that the variation in the loss is worrysome…
Use (try) gradient accumulation
Gradient accumulation seemed promising to me. I read it on this blog and the heuristic seemed interesting.
Besides, as it reduces the number of gradient updates to the model, I expected to gain some performance (and this should be particulary true on larger models).
The recipe consists in turning the training loop:
def train_one_epoch(epoch_index):
total_loss = 0.
for i, data in tqdm(enumerate(training_loader)):
inputs, labels = data
inputs = inputs.to(torch.device(device))
labels = labels.to(torch.device(device))
optimizer.zero_grad()
outputs = model(inputs)
batch_loss = loss_fn(outputs, labels)
batch_loss.backward()
optimizer.step()
total_loss += batch_loss.item()
return total_loss / (i+1)
In this, where the gradient update is performed evervy accum_iter
step.
def train_one_epoch(epoch_index):
total_loss = 0.
accum_iter = 4
optimizer.zero_grad()
for i, data in tqdm(enumerate(training_loader)):
inputs, labels = data
inputs = inputs.to(torch.device(device))
labels = labels.to(torch.device(device))
outputs = model(inputs)
batch_loss = loss_fn(outputs, labels)
batch_loss = batch_loss / accum_iter
batch_loss.backward()
total_loss += batch_loss.item()
if ((i + 1) % accum_iter == 0) or (i + 1 == len(training_loader)):
optimizer.step()
optimizer.zero_grad()
return total_loss / (i+1)
But the decrease in performance is too important (maybe I am doing something wrong ?) for no gain in execution time.
32it [00:06, 5.05it/s]
EPOCH: 0 | LOSS train: 0.156 | LOSS valid: 0.612
32it [00:06, 5.33it/s]
EPOCH: 1 | LOSS train: 0.127 | LOSS valid: 0.532
32it [00:06, 5.30it/s]
EPOCH: 2 | LOSS train: 0.110 | LOSS valid: 0.482
By reducing the size of the image, we can achieve a massive speedup, just note the 228
instead of 334
:
factored_transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Resize(228, antialias=False, interpolation=InterpolationMode.NEAREST_EXACT)])
And we are 4 times faster than our initial benchmark!
32it [00:04, 7.78it/s]
EPOCH: 0 | LOSS train: 0.518 | LOSS valid: 0.500
32it [00:02, 10.76it/s]
EPOCH: 1 | LOSS train: 0.393 | LOSS valid: 0.436
32it [00:02, 10.76it/s]
EPOCH: 2 | LOSS train: 0.363 | LOSS valid: 0.393
However, these kind of optimization changes what we actually are doing and seem to harm the loss of the model…
Conclusion
This is it! We almost halved the training time of our model, without harming its performance :) If we allow ourselves to decrease the model performance, we saw that smaller images are actually a way to go much faster.
I do not have other tricks that can easily be used at the moment. I hope you liked this article, do not hesitate to share to your friends, colleagues and on social media !
Learning more
If you are new to machine learning, Deep learning by Ian Goodfellow, Yoshua Bengio, Aaron Courville is an excellent introduction to the topic. The algorithms and mathematics are presented without any code so it will not be outdated as soon as new breaking change is introduced in the main packages ;) *note that this is a sponsored link.