提交 1bf4b8ab 编写于 作者: M minqiyang

keep parameters in block

test=develop
上级 8fe0c0c5
...@@ -382,6 +382,8 @@ class Variable(object): ...@@ -382,6 +382,8 @@ class Variable(object):
if not self._ivar: if not self._ivar:
self._ivar = core.VarBase(stop_gradient) self._ivar = core.VarBase(stop_gradient)
self._ivar.desc = self.desc self._ivar.desc = self.desc
if persistable:
self.block.vars[name] = self
else: else:
self.block.vars[name] = self self.block.vars[name] = self
self.op = None self.op = None
...@@ -1188,11 +1190,11 @@ class Block(object): ...@@ -1188,11 +1190,11 @@ class Block(object):
raise ValueError("Var {0} is not found recursively".format(name)) raise ValueError("Var {0} is not found recursively".format(name))
def _clear_block(self): def _clear_block(self):
# TODO(minqiyang): move this to backward_hooks
self.desc._clear_block() self.desc._clear_block()
for name in self.vars.keys(): for name in self.vars.keys():
if not self.vars[name].persistable: assert self.vars[name].persistable
del self.vars[name]
del self.ops[:] del self.ops[:]
...@@ -1341,11 +1343,8 @@ class Block(object): ...@@ -1341,11 +1343,8 @@ class Block(object):
backward_refs = _imperative_tracer().trace( backward_refs = _imperative_tracer().trace(
op.iop, op.inputs, op.outputs, self.desc, op.iop, op.inputs, op.outputs, self.desc,
_imperative_current_expected_place_, stop_gradient) _imperative_current_expected_place_, stop_gradient)
print("backward_refs", backward_refs)
import sys
sys.stdout.flush()
# TODO(minqiyang): support backward hooks to eager remove backward_refs # TODO(minqiyang): support backward_hooks to eager remove backward_refs
op.backward_refs = defaultdict(list) op.backward_refs = defaultdict(list)
for k, v in six.iteritems(op.inputs): for k, v in six.iteritems(op.inputs):
if k in backward_refs: if k in backward_refs:
......
...@@ -225,9 +225,6 @@ class FC(layers.Layer): ...@@ -225,9 +225,6 @@ class FC(layers.Layer):
act=act, act=act,
name=name) name=name)
def parameters(self):
return [self._w, self._b]
def _build_once(self, input): def _build_once(self, input):
input_shape = input.shape input_shape = input.shape
param_shape = [ param_shape = [
......
...@@ -131,8 +131,7 @@ class TestImperativeMnist(unittest.TestCase): ...@@ -131,8 +131,7 @@ class TestImperativeMnist(unittest.TestCase):
dy_out = avg_loss._numpy() dy_out = avg_loss._numpy()
if epoch == 0 and batch_id == 0: if epoch == 0 and batch_id == 0:
for param in fluid.default_main_program().global_block( for param in mnist.parameters():
).all_parameters():
dy_param_init_value[param.name] = param._numpy() dy_param_init_value[param.name] = param._numpy()
avg_loss._backward() avg_loss._backward()
...@@ -142,8 +141,7 @@ class TestImperativeMnist(unittest.TestCase): ...@@ -142,8 +141,7 @@ class TestImperativeMnist(unittest.TestCase):
fluid.default_main_program().global_block()._clear_block() fluid.default_main_program().global_block()._clear_block()
dy_param_value = {} dy_param_value = {}
for param in fluid.default_main_program().global_block( for param in mnist.parameters():
).all_parameters():
dy_param_value[param.name] = param._numpy() dy_param_value[param.name] = param._numpy()
with new_program_scope(): with new_program_scope():
...@@ -169,8 +167,7 @@ class TestImperativeMnist(unittest.TestCase): ...@@ -169,8 +167,7 @@ class TestImperativeMnist(unittest.TestCase):
# initialize params and fetch them # initialize params and fetch them
static_param_init_value = {} static_param_init_value = {}
static_param_name_list = [] static_param_name_list = []
for param in fluid.default_startup_program().global_block( for param in mnist.parameters():
).all_parameters():
static_param_name_list.append(param.name) static_param_name_list.append(param.name)
out = exe.run(fluid.default_startup_program(), out = exe.run(fluid.default_startup_program(),
...@@ -204,16 +201,12 @@ class TestImperativeMnist(unittest.TestCase): ...@@ -204,16 +201,12 @@ class TestImperativeMnist(unittest.TestCase):
self.assertTrue(np.allclose(dy_x_data.all(), static_x_data.all())) self.assertTrue(np.allclose(dy_x_data.all(), static_x_data.all()))
for key, value in six.iteritems(static_param_init_value): for key, value in six.iteritems(static_param_init_value):
if not np.allclose(value, dy_param_init_value[key]): self.assertTrue(np.allclose(value, dy_param_init_value[key]))
print(key, value, dy_param_value[key])
# self.assertTrue(np.allclose(value, dy_param_init_value[key]))
self.assertTrue(np.allclose(static_out, dy_out)) self.assertTrue(np.allclose(static_out, dy_out))
for key, value in six.iteritems(static_param_value): for key, value in six.iteritems(static_param_value):
if not np.allclose(value, dy_param_value[key], atol=1e-6): self.assertTrue(np.allclose(value, dy_param_value[key], atol=1e-5))
print(key, value, dy_param_value[key])
# self.assertTrue(np.allclose(value, dy_param_value[key], atol=1e-5))
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -223,8 +223,7 @@ class TestImperativeResnet(unittest.TestCase): ...@@ -223,8 +223,7 @@ class TestImperativeResnet(unittest.TestCase):
batch_size=batch_size) batch_size=batch_size)
dy_param_init_value = {} dy_param_init_value = {}
for param in fluid.default_main_program().global_block( for param in resnet.parameters():
).all_parameters():
dy_param_init_value[param.name] = param._numpy() dy_param_init_value[param.name] = param._numpy()
for batch_id, data in enumerate(train_reader()): for batch_id, data in enumerate(train_reader()):
...@@ -247,16 +246,14 @@ class TestImperativeResnet(unittest.TestCase): ...@@ -247,16 +246,14 @@ class TestImperativeResnet(unittest.TestCase):
dy_out = avg_loss._numpy() dy_out = avg_loss._numpy()
if batch_id == 0: if batch_id == 0:
for param in fluid.default_main_program().global_block( for param in resnet.parameters():
).all_parameters():
if param.name not in dy_param_init_value: if param.name not in dy_param_init_value:
dy_param_init_value[param.name] = param._numpy() dy_param_init_value[param.name] = param._numpy()
avg_loss._backward() avg_loss._backward()
dy_grad_value = {} dy_grad_value = {}
for param in fluid.default_main_program().global_block( for param in resnet.parameters():
).all_parameters():
if not param.stop_gradient: if not param.stop_gradient:
np_array = np.array(param._ivar._grad_ivar().value() np_array = np.array(param._ivar._grad_ivar().value()
.get_tensor()) .get_tensor())
...@@ -269,8 +266,7 @@ class TestImperativeResnet(unittest.TestCase): ...@@ -269,8 +266,7 @@ class TestImperativeResnet(unittest.TestCase):
fluid.default_main_program().global_block()._clear_block() fluid.default_main_program().global_block()._clear_block()
dy_param_value = {} dy_param_value = {}
for param in fluid.default_main_program().global_block( for param in resnet.parameters():
).all_parameters():
dy_param_value[param.name] = param._numpy() dy_param_value[param.name] = param._numpy()
with new_program_scope(): with new_program_scope():
...@@ -302,11 +298,9 @@ class TestImperativeResnet(unittest.TestCase): ...@@ -302,11 +298,9 @@ class TestImperativeResnet(unittest.TestCase):
static_param_init_value = {} static_param_init_value = {}
static_param_name_list = [] static_param_name_list = []
static_grad_name_list = [] static_grad_name_list = []
for param in fluid.default_startup_program().global_block( for param in resnet.parameters():
).all_parameters():
static_param_name_list.append(param.name) static_param_name_list.append(param.name)
for param in fluid.default_main_program().global_block( for param in resnet.parameters():
).all_parameters():
if not param.stop_gradient: if not param.stop_gradient:
static_grad_name_list.append(param.name + static_grad_name_list.append(param.name +
core.grad_var_suffix()) core.grad_var_suffix())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册