模型保存与加载

1. PyTorch中对象的保存与加载

1.1 torch.save

1
torch.save(obj, f)

主要参数:

  • obj:对象
  • f:输出路径

1.2 torch.load

1
torch.load(f, map_location=None)

主要参数

  • f:文件路径
  • map_location:指定存放位置, cpu or gpu

2. PyTorch中模型的保存

以下实验要保存的网络为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
import torch.nn as nn


class LeNet2(nn.Module):
def __init__(self, classes):
super(LeNet2, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 6, 5),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Conv2d(6, 16, 5),
nn.ReLU(),
nn.MaxPool2d(2, 2)
)
self.classifier = nn.Sequential(
nn.Linear(16*5*5, 120),
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU(),
nn.Linear(84, classes)
)

def forward(self, x):
x = self.features(x)
x = x.view(x.size()[0], -1)
x = self.classifier(x)
return x

def initialize(self):
for p in self.parameters():
p.data.fill_(20191104)

2.1 方法一:保存整个Module

1
torch.save(net, path)

将整个网络打包保存下来,比较省事但是比较耗时耗内存。

2.2 方法二:保存模型参数

1
2
state_dict = net.state_dict()
torch.save(state_dict , path)

把网络的所有权值(用state_dict()获取)保存下来

我们看看state_dict函数

1
state_dict(destination=None, prefix='', keep_vars=False)

功能:返回包含模块整个状态的字典。包括参数和持久缓冲区(例如运行平均值)。键是相应的参数和缓冲区名称。

加载模型时需要重新构建网络然后加载state_dict。比较省内存,保存的速度也比较快。推荐这种方法。

2.3 保存模型实验

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
net = LeNet2(classes=2019)

# "训练"
print("训练前: ", net.features[0].weight[0, ...])
net.initialize()
print("训练后: ", net.features[0].weight[0, ...])

path_model = "./model.pkl"
path_state_dict = "./model_state_dict.pkl"

# 保存整个模型
torch.save(net, path_model)

# 保存模型参数
net_state_dict = net.state_dict()
torch.save(net_state_dict, path_state_dict)

输出:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
训练前:  tensor([[[ 0.0840,  0.1148,  0.0477,  0.0700, -0.0106],
[ 0.1000, -0.0980, -0.0225, -0.0794, -0.0793],
[ 0.0915, -0.0229, -0.0222, -0.0409, 0.0894],
[ 0.0444, 0.0796, 0.0176, -0.0989, 0.0571],
[ 0.0642, 0.0691, 0.0382, -0.0549, 0.0555]],

[[-0.1023, 0.0449, -0.0966, -0.0305, -0.0489],
[-0.0781, 0.0534, 0.0486, 0.0315, 0.0260],
[ 0.0583, -0.0244, 0.0346, -0.0623, -0.1030],
[ 0.1148, -0.1038, -0.0653, 0.0300, 0.0351],
[-0.1022, -0.0612, -0.0225, 0.0810, -0.0252]],

[[ 0.0222, -0.0831, -0.0930, -0.1112, 0.0207],
[ 0.0246, -0.0831, -0.0317, -0.0591, -0.0993],
[ 0.0870, -0.0170, 0.0817, 0.1024, -0.0824],
[-0.0761, 0.0392, 0.0600, -0.0597, 0.0248],
[ 0.0553, -0.0301, -0.0230, -0.0375, 0.0828]]],
grad_fn=<SelectBackward>)
训练后: tensor([[[20191104., 20191104., 20191104., 20191104., 20191104.],
[20191104., 20191104., 20191104., 20191104., 20191104.],
[20191104., 20191104., 20191104., 20191104., 20191104.],
[20191104., 20191104., 20191104., 20191104., 20191104.],
[20191104., 20191104., 20191104., 20191104., 20191104.]],

[[20191104., 20191104., 20191104., 20191104., 20191104.],
[20191104., 20191104., 20191104., 20191104., 20191104.],
[20191104., 20191104., 20191104., 20191104., 20191104.],
[20191104., 20191104., 20191104., 20191104., 20191104.],
[20191104., 20191104., 20191104., 20191104., 20191104.]],

[[20191104., 20191104., 20191104., 20191104., 20191104.],
[20191104., 20191104., 20191104., 20191104., 20191104.],
[20191104., 20191104., 20191104., 20191104., 20191104.],
[20191104., 20191104., 20191104., 20191104., 20191104.],
[20191104., 20191104., 20191104., 20191104., 20191104.]]],
grad_fn=<SelectBackward>)

保存为如下两个文件:

image-20200816001104430

image-20200816001120505

3. PyTorch中模型的加载

3.1 方法一:加载整个Module

实验:

1
2
3
4
path_model = "./model.pkl"
net_load = torch.load(path_model)

print(net_load.features[0].weight[0, ...])

输出:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
tensor([[[20191104., 20191104., 20191104., 20191104., 20191104.],
[20191104., 20191104., 20191104., 20191104., 20191104.],
[20191104., 20191104., 20191104., 20191104., 20191104.],
[20191104., 20191104., 20191104., 20191104., 20191104.],
[20191104., 20191104., 20191104., 20191104., 20191104.]],

[[20191104., 20191104., 20191104., 20191104., 20191104.],
[20191104., 20191104., 20191104., 20191104., 20191104.],
[20191104., 20191104., 20191104., 20191104., 20191104.],
[20191104., 20191104., 20191104., 20191104., 20191104.],
[20191104., 20191104., 20191104., 20191104., 20191104.]],

[[20191104., 20191104., 20191104., 20191104., 20191104.],
[20191104., 20191104., 20191104., 20191104., 20191104.],
[20191104., 20191104., 20191104., 20191104., 20191104.],
[20191104., 20191104., 20191104., 20191104., 20191104.],
[20191104., 20191104., 20191104., 20191104., 20191104.]]],
grad_fn=<SelectBackward>)

3.2 方法二:加载模型参数

使用load_state_dict函数加载模型参数

1
load_state_dict(state_dict: Dict[str, torch.Tensor], strict: bool = True)

功能:将参数和缓冲区从state_dict复制到此模块及其子模块中。如果strict为True,则state_dict的键必须与此模块的state_dict()函数返回的键完全匹配。

实验:

1
2
3
4
5
6
7
8
net_new = LeNet2(classes=2019)
path_state_dict = "./model_state_dict.pkl"
state_dict_load = torch.load(path_state_dict)

print("加载前: ", net_new.features[0].weight[0, ...])
net_new.load_state_dict(state_dict_load)
print("加载后: ", net_new.features[0].weight[0, ...])
print("state_dict的key: ", state_dict_load.keys())

输出:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
加载前:  tensor([[[ 0.0452, -0.0556,  0.0299,  0.1078,  0.0888],
[ 0.0699, -0.0578, 0.0267, 0.1028, 0.0795],
[ 0.0793, 0.0856, 0.0328, -0.0826, 0.0112],
[-0.0618, 0.1154, -0.0941, -0.0513, -0.0798],
[ 0.0535, 0.0547, -0.0327, -0.0607, 0.0165]],

[[-0.0204, -0.0709, 0.0989, -0.1087, -0.0880],
[-0.0235, -0.1067, -0.0354, -0.1050, -0.0106],
[ 0.0425, -0.0373, 0.0458, 0.1115, 0.0366],
[ 0.0450, -0.0708, -0.1020, 0.0442, 0.0183],
[ 0.0998, 0.0840, 0.0096, -0.0083, 0.0167]],

[[ 0.0956, 0.0375, -0.1030, -0.0864, -0.0391],
[-0.0204, 0.0717, 0.1090, -0.1011, 0.0208],
[-0.0459, 0.0744, 0.0349, -0.1036, 0.0664],
[ 0.1001, 0.0722, -0.0013, -0.0883, -0.1062],
[ 0.0272, 0.0279, -0.0223, -0.1040, 0.0453]]],
grad_fn=<SelectBackward>)
加载后: tensor([[[20191104., 20191104., 20191104., 20191104., 20191104.],
[20191104., 20191104., 20191104., 20191104., 20191104.],
[20191104., 20191104., 20191104., 20191104., 20191104.],
[20191104., 20191104., 20191104., 20191104., 20191104.],
[20191104., 20191104., 20191104., 20191104., 20191104.]],

[[20191104., 20191104., 20191104., 20191104., 20191104.],
[20191104., 20191104., 20191104., 20191104., 20191104.],
[20191104., 20191104., 20191104., 20191104., 20191104.],
[20191104., 20191104., 20191104., 20191104., 20191104.],
[20191104., 20191104., 20191104., 20191104., 20191104.]],

[[20191104., 20191104., 20191104., 20191104., 20191104.],
[20191104., 20191104., 20191104., 20191104., 20191104.],
[20191104., 20191104., 20191104., 20191104., 20191104.],
[20191104., 20191104., 20191104., 20191104., 20191104.],
[20191104., 20191104., 20191104., 20191104., 20191104.]]],
grad_fn=<SelectBackward>)
state_dict的key: odict_keys(['features.0.weight', 'features.0.bias', 'features.3.weight', 'features.3.bias', 'classifier.0.weight', 'classifier.0.bias', 'classifier.2.weight', 'classifier.2.bias', 'classifier.4.weight', 'classifier.4.bias'])

可以看到把所有的权值都加载进网络里了

4. 断点续训练

为了防止训练过程由于意外情况退出而丢失训练数据,可以设置周期checkpoints,保存当前训练进度,以便于下一次直接从checkpoints开始训练。

checkpoint应该包括下面信息(但不限于):

1
2
3
4
checkpoint = {
"model_state_dict": net.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"epoch": epoch}

4.1 断点保存

实验:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from matplotlib import pyplot as plt
from model.lenet import LeNet
from tools.my_dataset import RMBDataset
from tools.common_tools import set_seed


set_seed(1) # 设置随机种子
rmb_label = {"1": 0, "100": 1}

# 参数设置
checkpoint_interval = 5
MAX_EPOCH = 10
BATCH_SIZE = 16
LR = 0.01
log_interval = 10
val_interval = 1


# ============================ step 1/5 数据 ============================

split_dir = os.path.join("..", "data", "rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")

norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.RandomCrop(32, padding=4),
transforms.RandomGrayscale(p=0.8),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])

valid_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])

# 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)

# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)

# ============================ step 2/5 模型 ============================
net = LeNet(classes=2)
net.initialize_weights()

# ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss() # 选择损失函数

# ============================ step 4/5 优化器 ============================
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9) # 选择优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=6, gamma=0.1) # 设置学习率下降策略

# ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list()

start_epoch = -1
for epoch in range(start_epoch+1, MAX_EPOCH):

loss_mean = 0.
correct = 0.
total = 0.

net.train()
for i, data in enumerate(train_loader):

# forward
inputs, labels = data
outputs = net(inputs)

# backward
optimizer.zero_grad()
loss = criterion(outputs, labels)
loss.backward()

# update weights
optimizer.step()

# 统计分类情况
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).squeeze().sum().numpy()

# 打印训练信息
loss_mean += loss.item()
train_curve.append(loss.item())
if (i+1) % log_interval == 0:
loss_mean = loss_mean / log_interval
print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
loss_mean = 0.

scheduler.step() # 更新学习率

if (epoch+1) % checkpoint_interval == 0:

checkpoint = {"model_state_dict": net.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"epoch": epoch}
path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch)
torch.save(checkpoint, path_checkpoint)

if epoch > 5:
print("训练意外中断...")
break

# validate the model
if (epoch+1) % val_interval == 0:

correct_val = 0.
total_val = 0.
loss_val = 0.
net.eval()
with torch.no_grad():
for j, data in enumerate(valid_loader):
inputs, labels = data
outputs = net(inputs)
loss = criterion(outputs, labels)

_, predicted = torch.max(outputs.data, 1)
total_val += labels.size(0)
correct_val += (predicted == labels).squeeze().sum().numpy()

loss_val += loss.item()

valid_curve.append(loss.item())
print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val/len(valid_loader), correct / total))


train_x = range(len(train_curve))
train_y = train_curve

train_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
valid_y = valid_curve

plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')

plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show()

输出:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
Training:Epoch[000/010] Iteration[010/010] Loss: 0.6852 Acc:54.37%
Valid: Epoch[000/010] Iteration[002/002] Loss: 0.4895 Acc:54.37%
Training:Epoch[001/010] Iteration[010/010] Loss: 0.4095 Acc:85.00%
Valid: Epoch[001/010] Iteration[002/002] Loss: 0.0386 Acc:85.00%
Training:Epoch[002/010] Iteration[010/010] Loss: 0.1624 Acc:93.12%
Valid: Epoch[002/010] Iteration[002/002] Loss: 0.0018 Acc:93.12%
Training:Epoch[003/010] Iteration[010/010] Loss: 0.2654 Acc:91.25%
Valid: Epoch[003/010] Iteration[002/002] Loss: 0.0009 Acc:91.25%
Training:Epoch[004/010] Iteration[010/010] Loss: 0.1090 Acc:96.25%
Valid: Epoch[004/010] Iteration[002/002] Loss: 0.0108 Acc:96.25%
Training:Epoch[005/010] Iteration[010/010] Loss: 0.0920 Acc:95.62%
Valid: Epoch[005/010] Iteration[002/002] Loss: 0.0000 Acc:95.62%
Training:Epoch[006/010] Iteration[010/010] Loss: 0.0102 Acc:100.00%
训练意外中断...
image-20200816005401148

checkpoints的设计如下,checkpoint_interval = 5,所以我们保存了Epoch=5的数据

1
2
3
4
5
6
7
if (epoch+1) % checkpoint_interval == 0:

checkpoint = {"model_state_dict": net.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"epoch": epoch}
path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch)
torch.save(checkpoint, path_checkpoint)

4.2 断点加载

实验:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from matplotlib import pyplot as plt
from model.lenet import LeNet
from tools.my_dataset import RMBDataset
from tools.common_tools import set_seed

set_seed(1) # 设置随机种子
rmb_label = {"1": 0, "100": 1}

# 参数设置
checkpoint_interval = 5
MAX_EPOCH = 10
BATCH_SIZE = 16
LR = 0.01
log_interval = 10
val_interval = 1

# ============================ step 1/5 数据 ============================

split_dir = os.path.join("..", "data", "rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")

norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.RandomCrop(32, padding=4),
transforms.RandomGrayscale(p=0.8),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])

valid_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])

# 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)

# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)

# ============================ step 2/5 模型 ============================

net = LeNet(classes=2)
net.initialize_weights()

# ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss() # 选择损失函数

# ============================ step 4/5 优化器 ============================
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9) # 选择优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=6, gamma=0.1) # 设置学习率下降策略


# ============================ step 5+/5 断点恢复 ============================

path_checkpoint = "./checkpoint_4_epoch.pkl"
checkpoint = torch.load(path_checkpoint)

net.load_state_dict(checkpoint['model_state_dict'])

optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

start_epoch = checkpoint['epoch']

scheduler.last_epoch = start_epoch

# ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list()

for epoch in range(start_epoch + 1, MAX_EPOCH):

loss_mean = 0.
correct = 0.
total = 0.

net.train()
for i, data in enumerate(train_loader):

# forward
inputs, labels = data
outputs = net(inputs)

# backward
optimizer.zero_grad()
loss = criterion(outputs, labels)
loss.backward()

# update weights
optimizer.step()

# 统计分类情况
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).squeeze().sum().numpy()

# 打印训练信息
loss_mean += loss.item()
train_curve.append(loss.item())
if (i+1) % log_interval == 0:
loss_mean = loss_mean / log_interval
print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
loss_mean = 0.

scheduler.step() # 更新学习率

if (epoch+1) % checkpoint_interval == 0:

checkpoint = {"model_state_dict": net.state_dict(),
"optimizer_state_dic": optimizer.state_dict(),
"loss": loss,
"epoch": epoch}
path_checkpoint = "./checkpint_{}_epoch.pkl".format(epoch)
torch.save(checkpoint, path_checkpoint)

# validate the model
if (epoch+1) % val_interval == 0:

correct_val = 0.
total_val = 0.
loss_val = 0.
net.eval()
with torch.no_grad():
for j, data in enumerate(valid_loader):
inputs, labels = data
outputs = net(inputs)
loss = criterion(outputs, labels)

_, predicted = torch.max(outputs.data, 1)
total_val += labels.size(0)
correct_val += (predicted == labels).squeeze().sum().numpy()

loss_val += loss.item()

valid_curve.append(loss.item())
print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val/len(valid_loader), correct / total))


train_x = range(len(train_curve))
train_y = train_curve

train_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
valid_y = valid_curve

plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')

plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show()

输出:

1
2
3
4
5
6
7
8
9
10
Training:Epoch[005/010] Iteration[010/010] Loss: 0.0406 Acc:99.38%
Valid: Epoch[005/010] Iteration[002/002] Loss: 0.0000 Acc:99.38%
Training:Epoch[006/010] Iteration[010/010] Loss: 0.0276 Acc:98.75%
Valid: Epoch[006/010] Iteration[002/002] Loss: 0.0000 Acc:98.75%
Training:Epoch[007/010] Iteration[010/010] Loss: 0.0074 Acc:100.00%
Valid: Epoch[007/010] Iteration[002/002] Loss: 0.0000 Acc:100.00%
Training:Epoch[008/010] Iteration[010/010] Loss: 0.0029 Acc:100.00%
Valid: Epoch[008/010] Iteration[002/002] Loss: 0.0000 Acc:100.00%
Training:Epoch[009/010] Iteration[010/010] Loss: 0.0004 Acc:100.00%
Valid: Epoch[009/010] Iteration[002/002] Loss: 0.0000 Acc:100.00%
image-20200816010440776

断点加载设置如下

1
2
3
4
5
6
7
8
9
10
path_checkpoint = "./checkpoint_4_epoch.pkl"
checkpoint = torch.load(path_checkpoint)

net.load_state_dict(checkpoint['model_state_dict'])

optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

start_epoch = checkpoint['epoch']

scheduler.last_epoch = start_epoch # 这个别忘了