diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index f584f53e853f933b5d8ccc8089dbf682422a593a..07dd42b4041cebbe98b61a6d6457739c86c1e4fa 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -382,6 +382,8 @@ class Variable(object): if not self._ivar: self._ivar = core.VarBase(stop_gradient) self._ivar.desc = self.desc + if persistable: + self.block.vars[name] = self else: self.block.vars[name] = self self.op = None @@ -1188,11 +1190,11 @@ class Block(object): raise ValueError("Var {0} is not found recursively".format(name)) def _clear_block(self): + # TODO(minqiyang): move this to backward_hooks self.desc._clear_block() for name in self.vars.keys(): - if not self.vars[name].persistable: - del self.vars[name] + assert self.vars[name].persistable del self.ops[:] @@ -1341,11 +1343,8 @@ class Block(object): backward_refs = _imperative_tracer().trace( op.iop, op.inputs, op.outputs, self.desc, _imperative_current_expected_place_, stop_gradient) - print("backward_refs", backward_refs) - import sys - sys.stdout.flush() - # TODO(minqiyang): support backward hooks to eager remove backward_refs + # TODO(minqiyang): support backward_hooks to eager remove backward_refs op.backward_refs = defaultdict(list) for k, v in six.iteritems(op.inputs): if k in backward_refs: diff --git a/python/paddle/fluid/imperative/nn.py b/python/paddle/fluid/imperative/nn.py index 6c5961cc63d1c140e0a6f33aac054acdbbe8e8e0..1b0a60df8bc8c5020e7d295917ec957bf68cb5d5 100644 --- a/python/paddle/fluid/imperative/nn.py +++ b/python/paddle/fluid/imperative/nn.py @@ -225,9 +225,6 @@ class FC(layers.Layer): act=act, name=name) - def parameters(self): - return [self._w, self._b] - def _build_once(self, input): input_shape = input.shape param_shape = [ diff --git a/python/paddle/fluid/tests/unittests/test_imperative_optimizer.py b/python/paddle/fluid/tests/unittests/test_imperative_optimizer.py index a07dc2a71295fbbfdeddee2ca5c60b2467fd5f23..f666274690a6e5816f0716639b2c876dc9611b03 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_optimizer.py @@ -131,8 +131,7 @@ class TestImperativeMnist(unittest.TestCase): dy_out = avg_loss._numpy() if epoch == 0 and batch_id == 0: - for param in fluid.default_main_program().global_block( - ).all_parameters(): + for param in mnist.parameters(): dy_param_init_value[param.name] = param._numpy() avg_loss._backward() @@ -142,8 +141,7 @@ class TestImperativeMnist(unittest.TestCase): fluid.default_main_program().global_block()._clear_block() dy_param_value = {} - for param in fluid.default_main_program().global_block( - ).all_parameters(): + for param in mnist.parameters(): dy_param_value[param.name] = param._numpy() with new_program_scope(): @@ -169,8 +167,7 @@ class TestImperativeMnist(unittest.TestCase): # initialize params and fetch them static_param_init_value = {} static_param_name_list = [] - for param in fluid.default_startup_program().global_block( - ).all_parameters(): + for param in mnist.parameters(): static_param_name_list.append(param.name) out = exe.run(fluid.default_startup_program(), @@ -204,16 +201,12 @@ class TestImperativeMnist(unittest.TestCase): self.assertTrue(np.allclose(dy_x_data.all(), static_x_data.all())) for key, value in six.iteritems(static_param_init_value): - if not np.allclose(value, dy_param_init_value[key]): - print(key, value, dy_param_value[key]) - # self.assertTrue(np.allclose(value, dy_param_init_value[key])) + self.assertTrue(np.allclose(value, dy_param_init_value[key])) self.assertTrue(np.allclose(static_out, dy_out)) for key, value in six.iteritems(static_param_value): - if not np.allclose(value, dy_param_value[key], atol=1e-6): - print(key, value, dy_param_value[key]) - # self.assertTrue(np.allclose(value, dy_param_value[key], atol=1e-5)) + self.assertTrue(np.allclose(value, dy_param_value[key], atol=1e-5)) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/test_imperative_resnet.py b/python/paddle/fluid/tests/unittests/test_imperative_resnet.py index e32c84ebcf2bc7e8a3d41bb285d5b554dfe57d61..190e8e352b8859eadbdda9eb1856947d228aa6d9 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_resnet.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_resnet.py @@ -223,8 +223,7 @@ class TestImperativeResnet(unittest.TestCase): batch_size=batch_size) dy_param_init_value = {} - for param in fluid.default_main_program().global_block( - ).all_parameters(): + for param in resnet.parameters(): dy_param_init_value[param.name] = param._numpy() for batch_id, data in enumerate(train_reader()): @@ -247,16 +246,14 @@ class TestImperativeResnet(unittest.TestCase): dy_out = avg_loss._numpy() if batch_id == 0: - for param in fluid.default_main_program().global_block( - ).all_parameters(): + for param in resnet.parameters(): if param.name not in dy_param_init_value: dy_param_init_value[param.name] = param._numpy() avg_loss._backward() dy_grad_value = {} - for param in fluid.default_main_program().global_block( - ).all_parameters(): + for param in resnet.parameters(): if not param.stop_gradient: np_array = np.array(param._ivar._grad_ivar().value() .get_tensor()) @@ -269,8 +266,7 @@ class TestImperativeResnet(unittest.TestCase): fluid.default_main_program().global_block()._clear_block() dy_param_value = {} - for param in fluid.default_main_program().global_block( - ).all_parameters(): + for param in resnet.parameters(): dy_param_value[param.name] = param._numpy() with new_program_scope(): @@ -302,11 +298,9 @@ class TestImperativeResnet(unittest.TestCase): static_param_init_value = {} static_param_name_list = [] static_grad_name_list = [] - for param in fluid.default_startup_program().global_block( - ).all_parameters(): + for param in resnet.parameters(): static_param_name_list.append(param.name) - for param in fluid.default_main_program().global_block( - ).all_parameters(): + for param in resnet.parameters(): if not param.stop_gradient: static_grad_name_list.append(param.name + core.grad_var_suffix())