@@ -71,14 +71,22 @@ def main():
71
71
else :
72
72
model = torch .nn .DataParallel (model ).cuda ()
73
73
74
- # optionally resume from a checkpoint
74
+ # define loss function (criterion) and optimizer
75
+ criterion = nn .CrossEntropyLoss ().cuda ()
76
+
77
+ optimizer = torch .optim .SGD (model .parameters (), args .lr ,
78
+ momentum = args .momentum ,
79
+ weight_decay = args .weight_decay )
80
+
81
+ # optionally resume from a checkpoint
75
82
if args .resume :
76
83
if os .path .isfile (args .resume ):
77
84
print ("=> loading checkpoint '{}'" .format (args .resume ))
78
85
checkpoint = torch .load (args .resume )
79
86
args .start_epoch = checkpoint ['epoch' ]
80
87
best_prec1 = checkpoint ['best_prec1' ]
81
88
model .load_state_dict (checkpoint ['state_dict' ])
89
+ optimizer .load_state_dict (checkpoint ['optimizer' ])
82
90
print ("=> loaded checkpoint '{}' (epoch {})"
83
91
.format (args .resume , checkpoint ['epoch' ]))
84
92
else :
@@ -112,13 +120,6 @@ def main():
112
120
batch_size = args .batch_size , shuffle = False ,
113
121
num_workers = args .workers , pin_memory = True )
114
122
115
- # define loss function (criterion) and optimizer
116
- criterion = nn .CrossEntropyLoss ().cuda ()
117
-
118
- optimizer = torch .optim .SGD (model .parameters (), args .lr ,
119
- momentum = args .momentum ,
120
- weight_decay = args .weight_decay )
121
-
122
123
if args .evaluate :
123
124
validate (val_loader , model , criterion )
124
125
return
@@ -140,6 +141,7 @@ def main():
140
141
'arch' : args .arch ,
141
142
'state_dict' : model .state_dict (),
142
143
'best_prec1' : best_prec1 ,
144
+ 'optimizer' : optimizer .state_dict (),
143
145
}, is_best )
144
146
145
147
0 commit comments