# Understanding PyTorch Module

· 4 min read

Have you ever wondered 🤔 how PyTorch nn.Module works? I was always curious to understand how the internals work too. Recently I was reading Fast.ai's Deep learning for coders book's 19th chapter, where we learn how to build minimal versions of PyTorch and FastAI modules like

• Modules
• FastAI Learner

This intrigued 🤔 me to take a look at the PyTorch source code for nn.Module. The code for nn.Module is 1000+ lines 😮. After a few cups of coffee ☕☕, I was able to make sense of what is happening inside. Hopefully, by end of this post, you would have an understanding of what goes insidenn.Module without those cups of coffee 😄.

## A simple Pytorch model​

All the models in the PyTorch subclasses the nn.Module class which implements various methods. Some of these methods we would usually use are -

• to
• cuda
• train
• eval

A simple example looks like this.

class SimpleModel(nn.Module):    def __init__(self):        super().__init__()        self.conv1 = nn.Conv2d(1, 20, 5)        self.conv2 = nn.Conv2d(20, 20, 5)    def forward(self, x):        x = F.relu(self.conv1(x))        return F.relu(self.conv2(x))

Let's look at the __dict__ method of our model which contains the attributes of our SimpleModel class.

model = SimpleModel()model.__dict__{'training': True, '_parameters': OrderedDict(), '_buffers': OrderedDict(), '_non_persistent_buffers_set': set(), '_backward_hooks': OrderedDict(), '_is_full_backward_hook': None, '_forward_hooks': OrderedDict(), '_forward_pre_hooks': OrderedDict(), '_state_dict_hooks': OrderedDict(), '_load_state_dict_pre_hooks': OrderedDict(), '_modules': OrderedDict([('conv1',               Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))),              ('conv2', Conv2d(20, 20, kernel_size=(5, 5), stride=(1, 1)))])}

If we observe closely, we can see that our modules/layers like conv1 end up in _modules dict internally. After observing the above output closely, I was wondering about

• How these modules ended up in _modules and not as a direct attribute.
• How are we able to access the conv1 like this model.conv1.
• How are we able to pass input to a model as if it is a function model(input).
• How the tab completion was able to recommend us the keys of a _modules when we did model.

Let's put on our detective 🕵️‍♀️ hat and search for answers.

## How did the conv1 attribute end up in _modules?​

__init__ and __setattr__ are responsible for this behaviour. Let's look at the minimal versions of these functions.

def __init__(self):    self._parameters = OrderedDict()    self._modules = OrderedDict()def __setattr__(self,name,value):    if isinstance(value,nn.Parameter):        self._parameters[name]=value    if isinstance(value,nn.Module):        self._modules[name] = value    else:        object.__setattr__(self,name,value)

__setattr__ is called by python when we initialize something like self.conv1 = ****. Based on the type of the value, PyTorch decides whether to store it in _modules, _parameters or assign it to the object itself (which is the default behavior).

Want to quickly try out this code?

## How are we able to access conv1 like model.conv1​

__getattr__ is called by python when it cannot find an attribute directly in the __dict__. So to access conv1 we need to check if it is present in either _modules or _parameters and return the value if present. The actual code contains a number of validations, a minimal version would look like this.

def __getattr__(self,name):    if name in self._modules.keys():        return self._modules[name]    if name in self._parameters.keys():        return self._parameters[name]

## How tab completion works?​

I was happy 🤗 with the progress till here, but I got disappointed when I tried building a model with our newly built module. If we try to use tab completion on the model. + tab, nothing happens, but it works with the PyTorch module.

After some digging, I figured out that the __dir__ is responsible for it. Let's look at the minimal version.

def __dir__(self):    module_attrs = dir(self.__class__)    attrs = list(self.__dict__.keys())    modules = list(self._modules.keys())    parameters = list(self._parameters.keys())    keys = module_attrs + attrs + modules + parameters    return keys

## How the forward method works​

Ever wondered how it works, when we pass inputs to the model object like model(x), PyTorch somehow 🤔 calls the forward method. __call__ is called when we use the object as a function.

Yes, our forward method gets called inside the __call__. A very minimal version will look like this.

def __call__(self,*input, **kwargs):    return self.forward(*input, **kwargs)

In the actual code, PyTorch does more than just call the forward method.

## Other interesting methods​

If reading the blog has sparked your interest to understand more, you can try understanding the below functions.

• get_submodule
You can find the source code for nn.Module here.