未验证 提交 61e218d0 编写于 作者: L LielinJiang 提交者: GitHub

Merge pull request #63 from LielinJiang/fix-test-model

Fix unittest for predict and evaluate
...@@ -20,6 +20,8 @@ import unittest ...@@ -20,6 +20,8 @@ import unittest
import os import os
import cv2 import cv2
import numpy as np import numpy as np
import tempfile
import shutil
import paddle import paddle
from paddle import fluid from paddle import fluid
...@@ -36,14 +38,6 @@ from hapi.download import get_weights_path_from_url ...@@ -36,14 +38,6 @@ from hapi.download import get_weights_path_from_url
class LeNetDygraph(fluid.dygraph.Layer): class LeNetDygraph(fluid.dygraph.Layer):
"""LeNet model from
`"LeCun Y, Bottou L, Bengio Y, et al. Gradient-based learning applied to document recognition[J]. Proceedings of the IEEE, 1998, 86(11): 2278-2324.`_
Args:
num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer
will not be defined. Default: 10.
classifier_activation (str): activation for the last fc layer. Default: 'softmax'.
"""
def __init__(self, num_classes=10, classifier_activation='softmax'): def __init__(self, num_classes=10, classifier_activation='softmax'):
super(LeNetDygraph, self).__init__() super(LeNetDygraph, self).__init__()
...@@ -137,6 +131,15 @@ class TestEvaluatePredict(unittest.TestCase): ...@@ -137,6 +131,15 @@ class TestEvaluatePredict(unittest.TestCase):
low_level_lenet_dygraph_train(self.lenet_dygraph, train_dataloader) low_level_lenet_dygraph_train(self.lenet_dygraph, train_dataloader)
self.acc1 = low_level_dynamic_evaluate(self.lenet_dygraph, self.acc1 = low_level_dynamic_evaluate(self.lenet_dygraph,
val_dataloader) val_dataloader)
self.save_dir = tempfile.mkdtemp()
self.weight_path = os.path.join(self.save_dir, 'lenet')
fluid.dygraph.save_dygraph(self.lenet_dygraph.state_dict(), self.weight_path)
fluid.disable_dygraph()
def tearDown(self):
shutil.rmtree(self.save_dir)
def evaluate(self, dynamic): def evaluate(self, dynamic):
fluid.enable_dygraph(self.device) if dynamic else None fluid.enable_dygraph(self.device) if dynamic else None
...@@ -144,67 +147,44 @@ class TestEvaluatePredict(unittest.TestCase): ...@@ -144,67 +147,44 @@ class TestEvaluatePredict(unittest.TestCase):
inputs = [Input([-1, 1, 28, 28], 'float32', name='image')] inputs = [Input([-1, 1, 28, 28], 'float32', name='image')]
labels = [Input([None, 1], 'int64', name='label')] labels = [Input([None, 1], 'int64', name='label')]
if fluid.in_dygraph_mode(): val_dataloader = fluid.io.DataLoader(
feed_list = None
else:
feed_list = [x.forward() for x in inputs + labels]
self.train_dataloader = fluid.io.DataLoader(
self.train_dataset,
places=self.device,
batch_size=64,
feed_list=feed_list)
self.val_dataloader = fluid.io.DataLoader(
self.val_dataset, self.val_dataset,
places=self.device, places=self.device,
batch_size=64, batch_size=64,
feed_list=feed_list) return_list=True)
self.test_dataloader = fluid.io.DataLoader(
self.test_dataset,
places=self.device,
batch_size=64,
feed_list=feed_list)
model = LeNet() model = LeNet()
model.load_dict(self.lenet_dygraph.state_dict())
model.load(self.weight_path)
model.prepare(metrics=Accuracy(), inputs=inputs, labels=labels) model.prepare(metrics=Accuracy(), inputs=inputs, labels=labels)
result = model.evaluate(self.val_dataloader) result = model.evaluate(val_dataloader)
np.testing.assert_allclose(result['acc'], self.acc1) np.testing.assert_allclose(result['acc'], self.acc1)
if fluid.in_dygraph_mode():
fluid.disable_dygraph()
def predict(self, dynamic): def predict(self, dynamic):
fluid.enable_dygraph(self.device) if dynamic else None fluid.enable_dygraph(self.device) if dynamic else None
inputs = [Input([-1, 1, 28, 28], 'float32', name='image')] inputs = [Input([-1, 1, 28, 28], 'float32', name='image')]
labels = [Input([None, 1], 'int64', name='label')] labels = [Input([None, 1], 'int64', name='label')]
if fluid.in_dygraph_mode(): test_dataloader = fluid.io.DataLoader(
feed_list = None
else:
feed_list = [x.forward() for x in inputs + labels]
self.train_dataloader = fluid.io.DataLoader(
self.train_dataset,
places=self.device,
batch_size=64,
feed_list=feed_list)
self.val_dataloader = fluid.io.DataLoader(
self.val_dataset,
places=self.device,
batch_size=64,
feed_list=feed_list)
self.test_dataloader = fluid.io.DataLoader(
self.test_dataset, self.test_dataset,
places=self.device, places=self.device,
batch_size=64, batch_size=64,
feed_list=feed_list) return_list=True)
model = LeNet() model = LeNet()
model.load_dict(self.lenet_dygraph.state_dict())
model.load(self.weight_path)
model.prepare(metrics=Accuracy(), inputs=inputs, labels=labels) model.prepare(metrics=Accuracy(), inputs=inputs, labels=labels)
output = model.predict(self.test_dataloader, stack_outputs=True) output = model.predict(test_dataloader, stack_outputs=True)
np.testing.assert_equal(output[0].shape[0], len(self.test_dataset)) np.testing.assert_equal(output[0].shape[0], len(self.test_dataset))
...@@ -212,6 +192,9 @@ class TestEvaluatePredict(unittest.TestCase): ...@@ -212,6 +192,9 @@ class TestEvaluatePredict(unittest.TestCase):
np.testing.assert_allclose(acc, self.acc1) np.testing.assert_allclose(acc, self.acc1)
if fluid.in_dygraph_mode():
fluid.disable_dygraph()
def test_evaluate_dygraph(self): def test_evaluate_dygraph(self):
self.evaluate(True) self.evaluate(True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册