Have you ever wondered 🤔 how PyTorch
I was always curious to understand how the internals work too. Recently I was reading Fast.ai's Deep learning for coders book's 19th chapter, where we learn how to build minimal versions of
FastAI modules like
- Dataset, Dataloaders
- FastAI Learner
This intrigued 🤔 me to take a look at the
PyTorch source code for
nn.Module. The code for
1000+ lines 😮. After a few cups of coffee ☕☕, I was able to make sense of what is happening inside. Hopefully, by end of this post, you would have an understanding of what goes inside
nn.Module without those cups of coffee 😄.
A simple Pytorch model
All the models in the
PyTorch subclasses the
nn.Module class which implements various methods. Some of these methods we would usually use are -
A simple example looks like this.
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
Let's look at the
__dict__ method of our model which contains the attributes of our
model = SimpleModel()
Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))),
('conv2', Conv2d(20, 20, kernel_size=(5, 5), stride=(1, 1)))])}
If we observe closely, we can see that our modules/layers like
conv1 end up in
_modules dict internally. After observing the above output closely, I was wondering about
- How these modules ended up in
_modulesand not as a direct attribute.
- How are we able to access the
- How are we able to pass input to a model as if it is a function
- How the tab completion was able to recommend us the keys of a
_moduleswhen we did
Let's put on our detective 🕵️♀️ hat and search for answers.
How did the
conv1 attribute end up in
__setattr__ are responsible for this behaviour. Let's look at the minimal versions of these functions.
self._parameters = OrderedDict()
self._modules = OrderedDict()
self._modules[name] = value
__setattr__ is called by python when we initialize something like
self.conv1 = ****. Based on the type of the value, PyTorch
decides whether to store it in
_parameters or assign it to the object itself (which is the default behavior).
How are we able to access
__getattr__ is called by python when it cannot find an attribute directly in the
__dict__. So to access
conv1 we need to check if it is
present in either
_parameters and return the value if present. The actual code contains a number of validations, a minimal version would look like this.
if name in self._modules.keys():
if name in self._parameters.keys():
How tab completion works?
I was happy 🤗 with the progress till here, but I got disappointed when I tried building a model with our newly built module.
If we try to use tab completion on the
model. + tab, nothing happens, but it works with the
After some digging, I figured out that the
__dir__ is responsible for it. Let's look at the minimal version.
module_attrs = dir(self.__class__)
attrs = list(self.__dict__.keys())
modules = list(self._modules.keys())
parameters = list(self._parameters.keys())
keys = module_attrs + attrs + modules + parameters
How the forward method works
Ever wondered how it works, when we pass inputs to the model object like
model(x), PyTorch somehow 🤔
__call__ is called when we use the object as a function.
Yes, our forward method gets called inside the
__call__. A very minimal version will look like this.
def __call__(self,*input, **kwargs):
return self.forward(*input, **kwargs)
In the actual code, PyTorch does more than just call the
Other interesting methods
If reading the blog has sparked your interest to understand more, you can try understanding the below functions.
I hope you enjoyed reading the blog.
You can find the source code for