模型创建与nn.Module
1. 网络模型创建步骤
前几次已经学完了数据部分,现在开始学习模型部分。
模型又分为模型创建和权值初始化。其中,模型创建还可以进一步分为构建网络层和拼接网络层。这些工作都可以用nn.Module来完成。
本节我们用7层结构的LeNet来学习模型创建和nn.Module。
实验:
1 | import os |
-
设置断点,进去LeNet
-
跳转到lenet.py的
class LeNet(nn.Module)
的__init__(self, classes)
内,构建网络层(如卷积层)就是在模型的__init__
中完成的。 -
单步运行后会跳出
__init__
回到主函数,完成了模型的初始化。我们再设置断点,运行至断点处。 -
进入
net(inputs)
,跳转到module.py的class Module(object)
的__call__
函数。运行到self.forward
,进入。 -
跳转到lenet.py的
class LeNet(nn.Module)
的forward(self, x)
内。单步运行,最后会返回out。 -
然后会返回到module.py,得到result,最后返回result。从下图可知,前向传播计算得到result靠的就是
self.forward
,也就是你自定义网络中的forward
函数。 -
单步运行,返回result后跳回主程序。返回的result就是我们一次迭代计算的output。
于是我们知道了,首先定义一个模型需要继承nn.module
,然后重写__init__
和__forward__
函数。其中,__init__
构建网络层,在模型初始化时需要运行__init__
函数;在数据前向传播过程中,需要用到__forward__
函数,因此需要在__forward__
函数中拼接网络层。
2. nn.Module属性
所有的模型都是nn.Module
的子类,就像在数据部分中,我们自定义的数据集都得继承dataset这个类一样。在本节的例子中,LeNet就需要继承nn.Module
,因此LeNet也是nn.Module
类的,包括LeNet中的网络层也是nn.Module
类。
下图展示了pytorch的神经网络包torch.nn
,本节重点研究nn.Module
。
在nn.module
类中,有8个有序字典属性用于管理模型。这里我们只需要关注parameters和modules。
下面我们观察nn.module
的创建以及属性管理的机制。
实验:
-
设置断点,进去LeNet
-
可以看到,LeNet是继承于
nn.Module
的子类,进入super(LeNet, self).__init__()
。 -
跳转到module.py的
class Module(object)
的__init__(self)
中。可以看到,在这里进行了8个有序字典属性的初始化。 -
跳出函数,回到lenet.py,可以看到LeNet类的初始化完成,已经具有了nn.Module的8个属性,现在暂时还是空的。
-
进入
nn.Conv2d(3, 6, 5)
观察网络层的构建。 -
跳转到conv.py的
class Conv2d(_ConvNd)
的__init__
函数。Conv2d是继承于_ConvNd的。运行到super(Conv2d, self).__init__
,然后进入。 -
进入会跳转到utils.py,跳出再进入就能到达conv.py的
class _ConvNd
中,可以发现_ConvNd也是继承于nn.Module
的,所以网络层也都是nn.Module
的子类。 -
进入
super(_ConvNd, self).__init__
,就又跳转到module.py的class Module(object)
的__init__(self)
中,这里进行了8个有序字典属性的初始化。跳出函数,回到lenet.py的__init__
,此时已经完成了conv1网络层的初始化。 -
现在再看LeNet类的属性,发现
_modules
多了一个元素,就是我们刚刚加入的网络层conv1。键是conv1
,值是nn.Conv2d(3, 6, 5)
。 -
conv1同样也是
nn.Module
的子类,也有这8个属性,我们同样可以查看。发现_parameters
有元素,而_modules
没有。这是因为conv1没有嵌套的网络层了,并且conv1有许多的可训练参数。而_parameters
是继承于tensor类的,所以有tensor的属性。 -
当lenet.py的
__init__
全部完成以后,跳转回主函数可以发现,全部网络层都已添加到属性_modules
中。
3. nn.Module总结
-
一个module可以包含多个子module
-
一个module相当于一个运算,必须实现forward()函数
-
每个module都有8个有序字典管理它的属性