提交 cf7229d2 编写于 作者: M minqiyang

Polish code

test=develop
上级 cafbd62e
...@@ -135,7 +135,7 @@ class TestImperativeMnist(unittest.TestCase): ...@@ -135,7 +135,7 @@ class TestImperativeMnist(unittest.TestCase):
scope.find_var(param.name).get_tensor()) scope.find_var(param.name).get_tensor())
dy_params = dict() dy_params = dict()
with fluid.imperative.guard(): with fluid.imperative.guard(device=None):
fluid.default_startup_program().random_seed = seed fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed fluid.default_main_program().random_seed = seed
......
...@@ -147,7 +147,7 @@ class BottleneckBlock(fluid.imperative.Layer): ...@@ -147,7 +147,7 @@ class BottleneckBlock(fluid.imperative.Layer):
class ResNet(fluid.imperative.Layer): class ResNet(fluid.imperative.Layer):
def __init__(self, layers=50, class_dim=1000): def __init__(self, layers=50, class_dim=102):
super(ResNet, self).__init__() super(ResNet, self).__init__()
self.layers = layers self.layers = layers
...@@ -208,6 +208,7 @@ class TestImperativeResnet(unittest.TestCase): ...@@ -208,6 +208,7 @@ class TestImperativeResnet(unittest.TestCase):
seed = 90 seed = 90
batch_size = train_parameters["batch_size"] batch_size = train_parameters["batch_size"]
batch_num = 1
with fluid.imperative.guard(): with fluid.imperative.guard():
fluid.default_startup_program().random_seed = seed fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed fluid.default_main_program().random_seed = seed
...@@ -227,7 +228,7 @@ class TestImperativeResnet(unittest.TestCase): ...@@ -227,7 +228,7 @@ class TestImperativeResnet(unittest.TestCase):
dy_param_init_value[param.name] = param._numpy() dy_param_init_value[param.name] = param._numpy()
for batch_id, data in enumerate(train_reader()): for batch_id, data in enumerate(train_reader()):
if batch_id >= 1: if batch_id >= batch_num:
break break
dy_x_data = np.array( dy_x_data = np.array(
...@@ -313,7 +314,7 @@ class TestImperativeResnet(unittest.TestCase): ...@@ -313,7 +314,7 @@ class TestImperativeResnet(unittest.TestCase):
static_param_init_value[static_param_name_list[i]] = out[i] static_param_init_value[static_param_name_list[i]] = out[i]
for batch_id, data in enumerate(train_reader()): for batch_id, data in enumerate(train_reader()):
if batch_id >= 1: if batch_id >= batch_num:
break break
static_x_data = np.array( static_x_data = np.array(
...@@ -368,6 +369,7 @@ class TestImperativeResnet(unittest.TestCase): ...@@ -368,6 +369,7 @@ class TestImperativeResnet(unittest.TestCase):
seed = 90 seed = 90
batch_size = train_parameters["batch_size"] batch_size = train_parameters["batch_size"]
batch_num = 1
with fluid.imperative.guard(device=None): with fluid.imperative.guard(device=None):
fluid.default_startup_program().random_seed = seed fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed fluid.default_main_program().random_seed = seed
...@@ -387,7 +389,7 @@ class TestImperativeResnet(unittest.TestCase): ...@@ -387,7 +389,7 @@ class TestImperativeResnet(unittest.TestCase):
dy_param_init_value[param.name] = param._numpy() dy_param_init_value[param.name] = param._numpy()
for batch_id, data in enumerate(train_reader()): for batch_id, data in enumerate(train_reader()):
if batch_id >= 1: if batch_id >= batch_num:
break break
dy_x_data = np.array( dy_x_data = np.array(
...@@ -473,7 +475,7 @@ class TestImperativeResnet(unittest.TestCase): ...@@ -473,7 +475,7 @@ class TestImperativeResnet(unittest.TestCase):
static_param_init_value[static_param_name_list[i]] = out[i] static_param_init_value[static_param_name_list[i]] = out[i]
for batch_id, data in enumerate(train_reader()): for batch_id, data in enumerate(train_reader()):
if batch_id >= 1: if batch_id >= batch_num:
break break
static_x_data = np.array( static_x_data = np.array(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册