提交 1bf4b8ab 编写于 作者: M minqiyang

keep parameters in block

test=develop
上级 8fe0c0c5
......@@ -382,6 +382,8 @@ class Variable(object):
if not self._ivar:
self._ivar = core.VarBase(stop_gradient)
self._ivar.desc = self.desc
if persistable:
self.block.vars[name] = self
else:
self.block.vars[name] = self
self.op = None
......@@ -1188,11 +1190,11 @@ class Block(object):
raise ValueError("Var {0} is not found recursively".format(name))
def _clear_block(self):
# TODO(minqiyang): move this to backward_hooks
self.desc._clear_block()
for name in self.vars.keys():
if not self.vars[name].persistable:
del self.vars[name]
assert self.vars[name].persistable
del self.ops[:]
......@@ -1341,11 +1343,8 @@ class Block(object):
backward_refs = _imperative_tracer().trace(
op.iop, op.inputs, op.outputs, self.desc,
_imperative_current_expected_place_, stop_gradient)
print("backward_refs", backward_refs)
import sys
sys.stdout.flush()
# TODO(minqiyang): support backward hooks to eager remove backward_refs
# TODO(minqiyang): support backward_hooks to eager remove backward_refs
op.backward_refs = defaultdict(list)
for k, v in six.iteritems(op.inputs):
if k in backward_refs:
......
......@@ -225,9 +225,6 @@ class FC(layers.Layer):
act=act,
name=name)
def parameters(self):
return [self._w, self._b]
def _build_once(self, input):
input_shape = input.shape
param_shape = [
......
......@@ -131,8 +131,7 @@ class TestImperativeMnist(unittest.TestCase):
dy_out = avg_loss._numpy()
if epoch == 0 and batch_id == 0:
for param in fluid.default_main_program().global_block(
).all_parameters():
for param in mnist.parameters():
dy_param_init_value[param.name] = param._numpy()
avg_loss._backward()
......@@ -142,8 +141,7 @@ class TestImperativeMnist(unittest.TestCase):
fluid.default_main_program().global_block()._clear_block()
dy_param_value = {}
for param in fluid.default_main_program().global_block(
).all_parameters():
for param in mnist.parameters():
dy_param_value[param.name] = param._numpy()
with new_program_scope():
......@@ -169,8 +167,7 @@ class TestImperativeMnist(unittest.TestCase):
# initialize params and fetch them
static_param_init_value = {}
static_param_name_list = []
for param in fluid.default_startup_program().global_block(
).all_parameters():
for param in mnist.parameters():
static_param_name_list.append(param.name)
out = exe.run(fluid.default_startup_program(),
......@@ -204,16 +201,12 @@ class TestImperativeMnist(unittest.TestCase):
self.assertTrue(np.allclose(dy_x_data.all(), static_x_data.all()))
for key, value in six.iteritems(static_param_init_value):
if not np.allclose(value, dy_param_init_value[key]):
print(key, value, dy_param_value[key])
# self.assertTrue(np.allclose(value, dy_param_init_value[key]))
self.assertTrue(np.allclose(value, dy_param_init_value[key]))
self.assertTrue(np.allclose(static_out, dy_out))
for key, value in six.iteritems(static_param_value):
if not np.allclose(value, dy_param_value[key], atol=1e-6):
print(key, value, dy_param_value[key])
# self.assertTrue(np.allclose(value, dy_param_value[key], atol=1e-5))
self.assertTrue(np.allclose(value, dy_param_value[key], atol=1e-5))
if __name__ == '__main__':
......
......@@ -223,8 +223,7 @@ class TestImperativeResnet(unittest.TestCase):
batch_size=batch_size)
dy_param_init_value = {}
for param in fluid.default_main_program().global_block(
).all_parameters():
for param in resnet.parameters():
dy_param_init_value[param.name] = param._numpy()
for batch_id, data in enumerate(train_reader()):
......@@ -247,16 +246,14 @@ class TestImperativeResnet(unittest.TestCase):
dy_out = avg_loss._numpy()
if batch_id == 0:
for param in fluid.default_main_program().global_block(
).all_parameters():
for param in resnet.parameters():
if param.name not in dy_param_init_value:
dy_param_init_value[param.name] = param._numpy()
avg_loss._backward()
dy_grad_value = {}
for param in fluid.default_main_program().global_block(
).all_parameters():
for param in resnet.parameters():
if not param.stop_gradient:
np_array = np.array(param._ivar._grad_ivar().value()
.get_tensor())
......@@ -269,8 +266,7 @@ class TestImperativeResnet(unittest.TestCase):
fluid.default_main_program().global_block()._clear_block()
dy_param_value = {}
for param in fluid.default_main_program().global_block(
).all_parameters():
for param in resnet.parameters():
dy_param_value[param.name] = param._numpy()
with new_program_scope():
......@@ -302,11 +298,9 @@ class TestImperativeResnet(unittest.TestCase):
static_param_init_value = {}
static_param_name_list = []
static_grad_name_list = []
for param in fluid.default_startup_program().global_block(
).all_parameters():
for param in resnet.parameters():
static_param_name_list.append(param.name)
for param in fluid.default_main_program().global_block(
).all_parameters():
for param in resnet.parameters():
if not param.stop_gradient:
static_grad_name_list.append(param.name +
core.grad_var_suffix())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册