Skip to content

Commit 10b22dc

Browse files
boscotsangsoumith
authored andcommitted
Fix test_epoch typo (#183)
test_loss /= len(test_loader.dataset) should be test_loss /= len(data_loader.dataset)
1 parent 08be28e commit 10b22dc

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

mnist_hogwild/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def test_epoch(model, data_loader):
5555
pred = output.data.max(1)[1] # get the index of the max log-probability
5656
correct += pred.eq(target.data).cpu().sum()
5757

58-
test_loss /= len(test_loader.dataset)
58+
test_loss /= len(data_loader.dataset)
5959
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
6060
test_loss, correct, len(data_loader.dataset),
6161
100. * correct / len(data_loader.dataset)))

0 commit comments

Comments
 (0)