使用pytorch进行网络模型的搭建、保存与加载,是非常快速、方便的。
搭建 ConvNet
所有的网络都要继承 torch.nn.Module ,然后在构造函数中使用 torch.nn 中的提供的接口定义 layer 的属性,最后,在 forward 函数中将各个 layer 连接起来。
下面,以 LeNet 为例:
class LeNet(nn.Module): def __init__(self): super(LeNet, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16*5*5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = F.relu(self.conv1(x)) x = F.max_pool2d(x, 2) x = F.relu(self.conv2(x)) x = F.max_pool2d(x, 2) x = x.view(x.size(0), -1) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) out = self.fc3(x) return out
这样一来,我们就搭建好了网络模型,是不是很简洁明了呢?此外,还可以使用 torch.nn.Sequential ,更方便进行模块化的定义,如下:
class LeNetSeq(nn.Module): def __init__(self): super(LeNetSeq, self).__init__() self.conv = nn.Sequential( nn.Conv2d(3, 6, 5), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(6, 16, 5), nn.ReLU(), nn.MaxPool2d(2), ) self.fc = nn.Sequential( nn.Linear(16*5*5, 120), nn.ReLU(), nn.Linear(120, 84), nn.ReLU(), nn.Linear(84, 10) ) def forward(self, x): x = self.conv(x) x = out.view(x.size(0), -1) out = self.fc(x) return out
Module 有很多属性,可以查看权重、参数等等;如下:
net = lenet.LeNet() print(net) for param in net.parameters(): print(type(param.data), param.size()) print(list(param.data)) print(net.state_dict().keys()) #参数的keys for key in net.state_dict():#模型参数 print key, 'corresponds to', list(net.state_dict()[key])
那么,如何进行参数初始化呢?使用 torch.nn.init ,如下:
def initNetParams(net): '''Init net parameters.''' for m in net.modules(): if isinstance(m, nn.Conv2d): init.xavier_uniform(m.weight) if m.bias: init.constant(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): init.constant(m.weight, 1) init.constant(m.bias, 0) elif isinstance(m, nn.Linear): init.normal(m.weight, std=1e-3) if m.bias: init.constant(m.bias, 0) initNetParams(net)
保存 ConvNet
使用 torch.save()
对网络结构和模型参数的保存,有两种保存方式:
保存整个神经网络的的结构信息和模型参数信息,save 的对象是网络 net;
保存神经网络的训练模型参数,save 的对象是 net.state_dict()。
torch.save(net1, 'net.pkl') # 保存整个神经网络的结构和模型参数 torch.save(net1.state_dict(), 'net_params.pkl') # 只保存神经网络的模型参数
加载 ConvNet
对应上面两种保存方式,重载方式也有两种。
对应第一种完整网络结构信息,重载的时候通过 torch.load('.pth')
直接初始化新的神经网络对象即可。
对应第二种只保存模型参数信息,需要首先导入对应的网络,通过 net.load_state_dict(torch.load('.pth'))
完成模型参数的重载。
在网络比较大的时候,第一种方法会花费较多的时间,所占的存储空间也比较大。
# 保存和加载整个模型 torch.save(model_object, 'model.pth') model = torch.load('model.pth') # 仅保存和加载模型参数 torch.save(model_object.state_dict(), 'params.pth') model_object.load_state_dict(torch.load('params.pth'))