提交 87a1b53b 编写于 作者: L LiuChiaChi

fix unittets for test_model.py

上级 f7f40635
...@@ -570,11 +570,11 @@ class TestModelFunction(unittest.TestCase): ...@@ -570,11 +570,11 @@ class TestModelFunction(unittest.TestCase):
np.random.random((1, 1, 28, 28)), dtype=np.float32) np.random.random((1, 1, 28, 28)), dtype=np.float32)
label = np.array(np.random.rand(1, 1), dtype=np.int64) label = np.array(np.random.rand(1, 1), dtype=np.int64)
if initial == "train_batch": if initial == "train_batch":
model.train_batch(img, label) model.train_batch([img], [label])
elif initial == "eval_batch": elif initial == "eval_batch":
model.eval_batch(img, label) model.eval_batch([img], [label])
else: else:
model.test_batch(img) model.test_batch([img])
model.save(save_dir, training=False) model.save(save_dir, training=False)
shutil.rmtree(save_dir) shutil.rmtree(save_dir)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册