提交 7afd7d4b 编写于 作者: L LiuChiaChi

fix unittests bugs

上级 cc71c7b9
...@@ -66,34 +66,6 @@ class LeNetDygraph(paddle.nn.Layer): ...@@ -66,34 +66,6 @@ class LeNetDygraph(paddle.nn.Layer):
return x return x
class LeNetDeclarative(fluid.dygraph.Layer):
def __init__(self, num_classes=10):
super(LeNetDeclarative, self).__init__()
self.num_classes = num_classes
self.features = Sequential(
Conv2d(
1, 6, 3, stride=1, padding=1),
ReLU(),
Pool2D(2, 'max', 2),
Conv2d(
6, 16, 5, stride=1, padding=0),
ReLU(),
Pool2D(2, 'max', 2))
if num_classes > 0:
self.fc = Sequential(
Linear(400, 120), Linear(120, 84), Linear(84, 10))
@declarative
def forward(self, inputs):
x = self.features(inputs)
if self.num_classes > 0:
x = fluid.layers.flatten(x, 1)
x = self.fc(x)
return x
class MnistDataset(MNIST): class MnistDataset(MNIST):
def __init__(self, mode, return_label=True, sample_num=None): def __init__(self, mode, return_label=True, sample_num=None):
super(MnistDataset, self).__init__(mode=mode) super(MnistDataset, self).__init__(mode=mode)
...@@ -577,7 +549,7 @@ class TestModelFunction(unittest.TestCase): ...@@ -577,7 +549,7 @@ class TestModelFunction(unittest.TestCase):
shutil.rmtree(save_dir) shutil.rmtree(save_dir)
paddle.enable_static() paddle.enable_static()
def test_export_deploy_model_without_inputs_in_dygraph(self): def test_dygraph_export_deploy_model_without_inputs(self):
mnist_data = MnistDataset(mode='train') mnist_data = MnistDataset(mode='train')
paddle.disable_static() paddle.disable_static()
for initial in ["fit", "train_batch", "eval_batch", "test_batch"]: for initial in ["fit", "train_batch", "eval_batch", "test_batch"]:
...@@ -586,6 +558,8 @@ class TestModelFunction(unittest.TestCase): ...@@ -586,6 +558,8 @@ class TestModelFunction(unittest.TestCase):
os.makedirs(save_dir) os.makedirs(save_dir)
net = LeNet() net = LeNet()
model = Model(net) model = Model(net)
optim = fluid.optimizer.Adam(
learning_rate=0.001, parameter_list=model.parameters())
model.prepare( model.prepare(
optimizer=optim, loss=CrossEntropyLoss(reduction="sum")) optimizer=optim, loss=CrossEntropyLoss(reduction="sum"))
if initial == "fit": if initial == "fit":
...@@ -614,7 +588,7 @@ class TestRaiseError(unittest.TestCase): ...@@ -614,7 +588,7 @@ class TestRaiseError(unittest.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
model = Model(net, inputs, labels) model = Model(net, inputs, labels)
def test_export_deploy_model_without_inputs_and_run_in_dygraph(self): def test_save_infer_model_without_inputs_and_run_in_dygraph(self):
paddle.disable_static() paddle.disable_static()
net = MyModel() net = MyModel()
save_dir = tempfile.mkdtemp() save_dir = tempfile.mkdtemp()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册