From cdc4e6620d8eb91a98c3aa5b440c369a18752f8f Mon Sep 17 00:00:00 2001 From: LielinJiang <50691816+LielinJiang@users.noreply.github.com> Date: Tue, 17 Nov 2020 17:24:38 +0800 Subject: [PATCH] fix lenet num classes (#28642) --- python/paddle/vision/models/lenet.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/paddle/vision/models/lenet.py b/python/paddle/vision/models/lenet.py index 119be85db5..2fb50fc17b 100644 --- a/python/paddle/vision/models/lenet.py +++ b/python/paddle/vision/models/lenet.py @@ -49,7 +49,8 @@ class LeNet(nn.Layer): if num_classes > 0: self.fc = nn.Sequential( - nn.Linear(400, 120), nn.Linear(120, 84), nn.Linear(84, 10)) + nn.Linear(400, 120), + nn.Linear(120, 84), nn.Linear(84, num_classes)) def forward(self, inputs): x = self.features(inputs) -- GitLab