Skip to main content

Understanding PyTorch Module

ยท 4 min read
Vishnu Subramanian

Have you ever wondered ๐Ÿค” how PyTorch nn.Module works? I was always curious to understand how the internals work too. Recently I was reading Fast.ai's Deep learning for coders book's 19th chapter, where we learn how to build minimal versions of PyTorch and FastAI modules like

  • Dataset, Dataloaders
  • Modules
  • FastAI Learner

This intrigued ๐Ÿค” me to take a look at the PyTorch source code for nn.Module. The code for nn.Module is 1000+ lines ๐Ÿ˜ฎ. After a few cups of coffee โ˜•โ˜•, I was able to make sense of what is happening inside. Hopefully, by end of this post, you would have an understanding of what goes insidenn.Module without those cups of coffee ๐Ÿ˜„.

A simple Pytorch modelโ€‹

All the models in the PyTorch subclasses the nn.Module class which implements various methods. Some of these methods we would usually use are -

  • to
  • cuda
  • zero_grad
  • train
  • eval
  • load_state_dict

A simple example looks like this.

class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)

def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))

Let's look at the __dict__ method of our model which contains the attributes of our SimpleModel class.

model = SimpleModel()
model.__dict__

{'training': True,
'_parameters': OrderedDict(),
'_buffers': OrderedDict(),
'_non_persistent_buffers_set': set(),
'_backward_hooks': OrderedDict(),
'_is_full_backward_hook': None,
'_forward_hooks': OrderedDict(),
'_forward_pre_hooks': OrderedDict(),
'_state_dict_hooks': OrderedDict(),
'_load_state_dict_pre_hooks': OrderedDict(),
'_modules': OrderedDict([('conv1',
Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))),
('conv2', Conv2d(20, 20, kernel_size=(5, 5), stride=(1, 1)))])}

If we observe closely, we can see that our modules/layers like conv1 end up in _modules dict internally. After observing the above output closely, I was wondering about

  • How these modules ended up in _modules and not as a direct attribute.
  • How are we able to access the conv1 like this model.conv1.
  • How are we able to pass input to a model as if it is a function model(input).
  • How the tab completion was able to recommend us the keys of a _modules when we did model.

Let's put on our detective ๐Ÿ•ต๏ธโ€โ™€๏ธ hat and search for answers.

How did the conv1 attribute end up in _modules?โ€‹

__init__ and __setattr__ are responsible for this behaviour. Let's look at the minimal versions of these functions.

def __init__(self):
self._parameters = OrderedDict()
self._modules = OrderedDict()

def __setattr__(self,name,value):
if isinstance(value,nn.Parameter):
self._parameters[name]=value
if isinstance(value,nn.Module):
self._modules[name] = value
else:
object.__setattr__(self,name,value)

__setattr__ is called by python when we initialize something like self.conv1 = ****. Based on the type of the value, PyTorch decides whether to store it in _modules, _parameters or assign it to the object itself (which is the default behavior).

How are we able to access conv1 like model.conv1โ€‹

__getattr__ is called by python when it cannot find an attribute directly in the __dict__. So to access conv1 we need to check if it is present in either _modules or _parameters and return the value if present. The actual code contains a number of validations, a minimal version would look like this.

def __getattr__(self,name):
if name in self._modules.keys():
return self._modules[name]
if name in self._parameters.keys():
return self._parameters[name]

How tab completion works?โ€‹

I was happy ๐Ÿค— with the progress till here, but I got disappointed when I tried building a model with our newly built module. If we try to use tab completion on the model. + tab, nothing happens, but it works with the PyTorch module.

After some digging, I figured out that the __dir__ is responsible for it. Let's look at the minimal version.

def __dir__(self):
module_attrs = dir(self.__class__)
attrs = list(self.__dict__.keys())
modules = list(self._modules.keys())
parameters = list(self._parameters.keys())
keys = module_attrs + attrs + modules + parameters
return keys

How the forward method worksโ€‹

Ever wondered how it works, when we pass inputs to the model object like model(x), PyTorch somehow ๐Ÿค” calls the forward method. __call__ is called when we use the object as a function.

Yes, our forward method gets called inside the __call__. A very minimal version will look like this.

def __call__(self,*input, **kwargs):
return self.forward(*input, **kwargs)

In the actual code, PyTorch does more than just call the forward method.

Other interesting methodsโ€‹

If reading the blog has sparked your interest to understand more, you can try understanding the below functions.

  • get_submodule
  • add_module
  • apply
  • float/double

I hope you enjoyed reading the blog.

You can find the source code for nn.Module here.