From cf7229d2c2827b0a0c84a047f09b1b464e6e5dc7 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Mon, 21 Jan 2019 22:05:13 +0800 Subject: [PATCH] Polish code test=develop --- .../fluid/tests/unittests/test_imperative_gan.py | 2 +- .../fluid/tests/unittests/test_imperative_resnet.py | 12 +++++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_gan.py b/python/paddle/fluid/tests/unittests/test_imperative_gan.py index 4fe286f85ec..991991ac6d0 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_gan.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_gan.py @@ -135,7 +135,7 @@ class TestImperativeMnist(unittest.TestCase): scope.find_var(param.name).get_tensor()) dy_params = dict() - with fluid.imperative.guard(): + with fluid.imperative.guard(device=None): fluid.default_startup_program().random_seed = seed fluid.default_main_program().random_seed = seed diff --git a/python/paddle/fluid/tests/unittests/test_imperative_resnet.py b/python/paddle/fluid/tests/unittests/test_imperative_resnet.py index fcf0f4a2d8a..7295b1de091 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_resnet.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_resnet.py @@ -147,7 +147,7 @@ class BottleneckBlock(fluid.imperative.Layer): class ResNet(fluid.imperative.Layer): - def __init__(self, layers=50, class_dim=1000): + def __init__(self, layers=50, class_dim=102): super(ResNet, self).__init__() self.layers = layers @@ -208,6 +208,7 @@ class TestImperativeResnet(unittest.TestCase): seed = 90 batch_size = train_parameters["batch_size"] + batch_num = 1 with fluid.imperative.guard(): fluid.default_startup_program().random_seed = seed fluid.default_main_program().random_seed = seed @@ -227,7 +228,7 @@ class TestImperativeResnet(unittest.TestCase): dy_param_init_value[param.name] = param._numpy() for batch_id, data in enumerate(train_reader()): - if batch_id >= 1: + if batch_id >= batch_num: break dy_x_data = np.array( @@ -313,7 +314,7 @@ class TestImperativeResnet(unittest.TestCase): static_param_init_value[static_param_name_list[i]] = out[i] for batch_id, data in enumerate(train_reader()): - if batch_id >= 1: + if batch_id >= batch_num: break static_x_data = np.array( @@ -368,6 +369,7 @@ class TestImperativeResnet(unittest.TestCase): seed = 90 batch_size = train_parameters["batch_size"] + batch_num = 1 with fluid.imperative.guard(device=None): fluid.default_startup_program().random_seed = seed fluid.default_main_program().random_seed = seed @@ -387,7 +389,7 @@ class TestImperativeResnet(unittest.TestCase): dy_param_init_value[param.name] = param._numpy() for batch_id, data in enumerate(train_reader()): - if batch_id >= 1: + if batch_id >= batch_num: break dy_x_data = np.array( @@ -473,7 +475,7 @@ class TestImperativeResnet(unittest.TestCase): static_param_init_value[static_param_name_list[i]] = out[i] for batch_id, data in enumerate(train_reader()): - if batch_id >= 1: + if batch_id >= batch_num: break static_x_data = np.array( -- GitLab