未验证 提交 e7c67e11 编写于 作者: Y Yu Yang 提交者: GitHub

Add stop_gradient in Variable (#5361)

上级 2be4c3cb
...@@ -19,8 +19,20 @@ def append_backward_ops(loss, parameter_list=None, no_grad_set=None): ...@@ -19,8 +19,20 @@ def append_backward_ops(loss, parameter_list=None, no_grad_set=None):
:rtype: list[Variable] :rtype: list[Variable]
""" """
assert isinstance(loss, framework.Variable) assert isinstance(loss, framework.Variable)
param_grad_map = loss.block.program.append_backward(loss, no_grad_set or
set()) if no_grad_set is None:
program = loss.block.program
assert isinstance(program, framework.Program)
no_grad_set = list()
for block in program.blocks:
assert isinstance(block, framework.Block)
for var in block.vars.itervalues():
assert isinstance(var, framework.Variable)
if var.stop_gradient:
no_grad_set.append(var.name)
no_grad_set = set(no_grad_set)
param_grad_map = loss.block.program.append_backward(loss, no_grad_set)
if parameter_list is not None: if parameter_list is not None:
parameters = parameter_list parameters = parameter_list
else: else:
......
...@@ -21,6 +21,7 @@ class Variable(object): ...@@ -21,6 +21,7 @@ class Variable(object):
dtype=None, dtype=None,
lod_level=None, lod_level=None,
persistable=None, persistable=None,
stop_gradient=False,
**kwargs): **kwargs):
self.block = block self.block = block
...@@ -89,6 +90,7 @@ class Variable(object): ...@@ -89,6 +90,7 @@ class Variable(object):
self.block.vars[name] = self self.block.vars[name] = self
self.op = None self.op = None
self.stop_gradient = stop_gradient
def __str__(self): def __str__(self):
protostr = self.desc.serialize_to_string() protostr = self.desc.serialize_to_string()
......
...@@ -99,7 +99,7 @@ def data(name, ...@@ -99,7 +99,7 @@ def data(name,
shape = [-1] + shape # append batch size as -1 shape = [-1] + shape # append batch size as -1
return helper.create_global_variable( return helper.create_global_variable(
name=name, shape=shape, dtype=data_type, type=type) name=name, shape=shape, dtype=data_type, type=type, stop_gradient=True)
def _convert_(name): def _convert_(name):
......
...@@ -125,11 +125,13 @@ class RecurrentOpTest1(unittest.TestCase): ...@@ -125,11 +125,13 @@ class RecurrentOpTest1(unittest.TestCase):
name='x', name='x',
append_batch_size=False, append_batch_size=False,
**self.p_info) **self.p_info)
x.stop_gradient = False
h_boot = data( h_boot = data(
shape=[self.input_dim], shape=[self.input_dim],
data_type='float32', data_type='float32',
name='h_boot', name='h_boot',
**self.p_info) **self.p_info)
h_boot.stop_gradient = False
rnn = StaticRNN(main_program=self.main_program) rnn = StaticRNN(main_program=self.main_program)
with rnn.step(): with rnn.step():
...@@ -256,11 +258,13 @@ class RecurrentOpTest2(RecurrentOpTest1): ...@@ -256,11 +258,13 @@ class RecurrentOpTest2(RecurrentOpTest1):
name='x', name='x',
append_batch_size=False, append_batch_size=False,
**self.p_info) **self.p_info)
x.stop_gradient = False
h_boot = data( h_boot = data(
shape=[self.input_dim], shape=[self.input_dim],
data_type='float32', data_type='float32',
name='h_boot', name='h_boot',
**self.p_info) **self.p_info)
h_boot.stop_gradient = False
rnn = StaticRNN(main_program=self.main_program) rnn = StaticRNN(main_program=self.main_program)
with rnn.step(): with rnn.step():
...@@ -353,18 +357,21 @@ class RecurrentOpTest3(RecurrentOpTest1): ...@@ -353,18 +357,21 @@ class RecurrentOpTest3(RecurrentOpTest1):
name='x', name='x',
append_batch_size=False, append_batch_size=False,
**self.p_info) **self.p_info)
x.stop_gradient = False
h_boot1 = data( h_boot1 = data(
shape=[self.batch_size, self.input_dim], shape=[self.batch_size, self.input_dim],
data_type='float32', data_type='float32',
name='h_boot1', name='h_boot1',
append_batch_size=False, append_batch_size=False,
**self.p_info) **self.p_info)
h_boot1.stop_gradient = False
h_boot2 = data( h_boot2 = data(
shape=[self.batch_size, self.input_dim], shape=[self.batch_size, self.input_dim],
data_type='float32', data_type='float32',
name='h_boot2', name='h_boot2',
append_batch_size=False, append_batch_size=False,
**self.p_info) **self.p_info)
h_boot2.stop_gradient = False
rnn = StaticRNN(main_program=self.main_program) rnn = StaticRNN(main_program=self.main_program)
with rnn.step(): with rnn.step():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册