Skip to content

Commit 0984955

Browse files
aromnvidiasoumith
authored andcommitted
save/load optimizer state (#141)
1 parent 3f21078 commit 0984955

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

imagenet/main.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,14 +71,22 @@ def main():
7171
else:
7272
model = torch.nn.DataParallel(model).cuda()
7373

74-
# optionally resume from a checkpoint
74+
# define loss function (criterion) and optimizer
75+
criterion = nn.CrossEntropyLoss().cuda()
76+
77+
optimizer = torch.optim.SGD(model.parameters(), args.lr,
78+
momentum=args.momentum,
79+
weight_decay=args.weight_decay)
80+
81+
# optionally resume from a checkpoint
7582
if args.resume:
7683
if os.path.isfile(args.resume):
7784
print("=> loading checkpoint '{}'".format(args.resume))
7885
checkpoint = torch.load(args.resume)
7986
args.start_epoch = checkpoint['epoch']
8087
best_prec1 = checkpoint['best_prec1']
8188
model.load_state_dict(checkpoint['state_dict'])
89+
optimizer.load_state_dict(checkpoint['optimizer'])
8290
print("=> loaded checkpoint '{}' (epoch {})"
8391
.format(args.resume, checkpoint['epoch']))
8492
else:
@@ -112,13 +120,6 @@ def main():
112120
batch_size=args.batch_size, shuffle=False,
113121
num_workers=args.workers, pin_memory=True)
114122

115-
# define loss function (criterion) and optimizer
116-
criterion = nn.CrossEntropyLoss().cuda()
117-
118-
optimizer = torch.optim.SGD(model.parameters(), args.lr,
119-
momentum=args.momentum,
120-
weight_decay=args.weight_decay)
121-
122123
if args.evaluate:
123124
validate(val_loader, model, criterion)
124125
return
@@ -140,6 +141,7 @@ def main():
140141
'arch': args.arch,
141142
'state_dict': model.state_dict(),
142143
'best_prec1': best_prec1,
144+
'optimizer' : optimizer.state_dict(),
143145
}, is_best)
144146

145147

0 commit comments

Comments
 (0)