Skip to main content

PyTorch Lazy modules

· 5 min read
Vishnu Subramanian

While designing DL modules like a classification head, it is required to calculate the input features. PyTorch Lazy modules comes to the rescue by helping us automate it.

In this post, we will explore how we can use PyTorch Lazy modules to re-write PyTorch models used for

  • Image classifiers
  • Unet
info

When solving deep learning problems, it is common to try out different architectures like efficientnet and resnet.

You can find the code for the blog here

Lazy module with Image classification

Let’s look at an example code which can modify the head of 2 architectures.

resnet = timm.create_model(model_name='resnet50',pretrained=True)
effnet = timm.create_model(model_name='tf_efficientnetv2_b0',pretrained=True)

resnet.fc, effnet.classifier

Output

(Linear(in_features=2048, out_features=1000, bias=True),
Linear(in_features=1280, out_features=1000, bias=True))

Let’s change the architecture of the model head.

def create_head_old(in_features,out_features):
head = nn.Sequential(
nn.Linear(in_features=in_features, out_features=512,bias=False),
nn.ReLU(inplace=True),
nn.BatchNorm1d(512),
nn.Dropout(),
nn.Linear(in_features=512,out_features=out_features,bias=False))
return head

resnet.fc = create_head_old(2048,1)
effnet.classifier = create_head_old(1280,1)

In the above code, we show how to replace the Linear layer in each of these models with a slightly complicated network. The important thing is how we need to hardcode the input features required for each model. This will make it difficult to try different model families. The Lazy Module feature comes with Lazy variants of Linear, BatchNorm, Conv and ConvTranspose which will help in automatic initialisation. It is actually easy to understand this by looking at an example. Lets replace the linear in create_head_old with the Lazy variant.

def create_head_new(out_features):
head = nn.Sequential(
nn.LazyLinear(512,bias=False),
nn.ReLU(inplace=True),
nn.BatchNorm1d(512),
nn.Dropout(),
nn.Linear(in_features=512,
out_features=out_features,
bias=False))
return head

resnet.fc = create_head_new(1)
effnet.classifier = create_head_new(1)

Output:

Sequential(
(0): LazyLinear(in_features=0, out_features=512, bias=False)
(1): ReLU(inplace=True)
(2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): Dropout(p=0.5, inplace=False)
(4): Linear(in_features=512, out_features=1, bias=False)
)

We replace only the 1st Linear layer of our custom head. If we observe the output, we see the input feature is initialised to 0 and it changes to the actual values on first run. We can simply pass a dummy batch through the model, and it will get initialised and behave like a normal model.

dummy_tensor = torch.randn((2,3,224,224))
_ = resnet(dummy_tensor)
resnet.fc

Output:

Sequential(
(0): Linear(in_features=2048, out_features=512, bias=False)
(1): ReLU(inplace=True)
(2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): Dropout(p=0.5, inplace=False)
(4): Linear(in_features=512, out_features=1, bias=False)
)

Now, we can easily use our new model head on any model without hardcoding input features. We may not appreciate the advantage of Lazy modules, when working with a simple example. But, the lazy modules makes life a lot easier when you are trying to build architectures for segmentation and object detection. Lets try how we can use the Lazy modules on a Unet architecture and simplify the code and make it flexible to adopt different models as encoders.

Lazy module with Unet Architecture:

For this example, I use a trimmed version of the Unet shown in of my earlier videos, where I explained how to build Unet model using Timm and Fastai. Looking through the complete code could be overwhelming, so I would recommend to look at Unet and UnetDecoder class.

class Encoder(nn.Module):
def __init__(self, model_name='resnext50_32x4d'):
super().__init__()
self.encoder = timm.create_model(model_name, features_only=True, pretrained=False)

def forward(self, x):
return self.encoder(x)

def conv_block(in_feat,out_feat):
conv_block = nn.Sequential(nn.Conv2d(in_feat,out_feat,3,1,1,bias=False),
nn.BatchNorm2d(out_feat),
nn.ReLU())
return conv_block

class UnetBlock(nn.Module):
def __init__(self,in_channels,chanels,out_channels):
super().__init__()
self.conv1 = conv_block(in_channels,chanels)
self.conv2 = conv_block(chanels,out_channels)

def forward(self,x):
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
x = self.conv1(x)
x = self.conv2(x)
return x

class UnetDecoder(nn.Module):
def __init__(self, fs=32, expansion=4,n_out=1):
super().__init__()
center_ch = 512*expansion
decoder5_ch = center_ch + (256*expansion)
channels = 512
self.center = nn.Sequential(conv_block(center_ch,center_ch),conv_block(center_ch,center_ch//2))
self.decoder5 = UnetBlock(decoder5_ch,channels,fs)
self.decoder4 = UnetBlock(256*expansion+fs,256,fs)
self.decoder3 = UnetBlock(128*expansion+fs,128,fs)
self.decoder2 = UnetBlock(64*expansion+fs,64,fs)
self.decoder1 = UnetBlock(fs,fs,fs)
self.logit = nn.Sequential(conv_block(fs,fs//2),conv_block(fs//2,fs//2),nn.Conv2d(fs//2,n_out,kernel_size=1))

def forward(self, feats):
e1,e2,e3,e4,e5 = feats #'64 256 512 1024 2048'
f = self.center(e5)
d5 = self.decoder5(torch.cat([f, e5], 1))
d4 = self.decoder4(torch.cat([d5, e4], 1))
d3 = self.decoder3(torch.cat([d4, e3], 1))
d2 = self.decoder2(torch.cat([d3, e2], 1))
d1 = self.decoder1(d2)
return self.logit(d1)

class Unet(nn.Module):
def __init__(self, fs=32, expansion=4, model_name='resnext50_32x4d',n_out=1):
super().__init__()
self.encoder = Encoder(model_name)
self.decoder = UnetDecoder(fs=fs, expansion=expansion,n_out=n_out)

def forward(self, x):
feats = self.encoder(x)
out = self.decoder(feats)
return out

Let’s focus on the UnetBlock in the UnetDecoder class and observe how we have to hardcode the input features of the UnetBlock. We have to calculate the input features manually and they could change when we try different architectures and thus break our code. When someone else looks at the code, they often wonder how did we come up with the input features in the UnetBlock.

Let’s simplify the above code, by adopting the lazy modules. Most of the code from above remains unchanged. Only the conv_block and UnetDecoder changes. By simply replacing the Conv2d with LazyConv2d we do not need to do any hardcoding or manual calculations.

def lazy_conv_block(out_feat):
conv_block = nn.Sequential(nn.LazyConv2d(out_feat,3,1,1,bias=False),
nn.BatchNorm2d(out_feat),
nn.ReLU())
return conv_block

class LazyUnetDecoder(nn.Module):
def __init__(self, fs=32, expansion=4,n_out=1):
super().__init__()
channels = 512
center_ch = channels*expansion
self.center = nn.Sequential(lazy_conv_block(center_ch),lazy_conv_block(center_ch//2))
self.decoder5 = LazyUnetBlock(channels,fs)
self.decoder4 = LazyUnetBlock(channels//2,fs)
self.decoder3 = LazyUnetBlock(channels//4,fs)
self.decoder2 = LazyUnetBlock(channels//8,fs)
self.decoder1 = LazyUnetBlock(fs,fs)
self.logit = nn.Sequential(lazy_conv_block(fs//2),lazy_conv_block(fs//2),nn.Conv2d(fs//2,n_out,kernel_size=1))

def forward(self, feats):
e1,e2,e3,e4,e5 = feats #'64 256 512 1024 2048'
f = self.center(e5)
d5 = self.decoder5(torch.cat([f, e5], 1))
d4 = self.decoder4(torch.cat([d5, e4], 1))
d3 = self.decoder3(torch.cat([d4, e3], 1))
d2 = self.decoder2(torch.cat([d3, e2], 1))
d1 = self.decoder1(d2)
return self.logit(d1)

In addition to simplifying the code, we can also use most of the backbone/architectures. Lets try few different backbones and check what happens?

Fails ❌

Unet(model_name='resnet18')(dummy_batch).shape

Works ✅

LazyUnet(model_name='resnet18')(dummy_batch).shape

Conclusion:

The LazyModules of PyTorch is going to simplify writing a lot of future architectures. Since it is a new feature under heavy development, the API or functionality could change in the future.