未验证 提交 e7c67e11 编写于 作者: Y Yu Yang 提交者: GitHub

Add stop_gradient in Variable (#5361)

上级 2be4c3cb
......@@ -19,8 +19,20 @@ def append_backward_ops(loss, parameter_list=None, no_grad_set=None):
:rtype: list[Variable]
"""
assert isinstance(loss, framework.Variable)
param_grad_map = loss.block.program.append_backward(loss, no_grad_set or
set())
if no_grad_set is None:
program = loss.block.program
assert isinstance(program, framework.Program)
no_grad_set = list()
for block in program.blocks:
assert isinstance(block, framework.Block)
for var in block.vars.itervalues():
assert isinstance(var, framework.Variable)
if var.stop_gradient:
no_grad_set.append(var.name)
no_grad_set = set(no_grad_set)
param_grad_map = loss.block.program.append_backward(loss, no_grad_set)
if parameter_list is not None:
parameters = parameter_list
else:
......
......@@ -21,6 +21,7 @@ class Variable(object):
dtype=None,
lod_level=None,
persistable=None,
stop_gradient=False,
**kwargs):
self.block = block
......@@ -89,6 +90,7 @@ class Variable(object):
self.block.vars[name] = self
self.op = None
self.stop_gradient = stop_gradient
def __str__(self):
protostr = self.desc.serialize_to_string()
......
......@@ -99,7 +99,7 @@ def data(name,
shape = [-1] + shape # append batch size as -1
return helper.create_global_variable(
name=name, shape=shape, dtype=data_type, type=type)
name=name, shape=shape, dtype=data_type, type=type, stop_gradient=True)
def _convert_(name):
......
......@@ -125,11 +125,13 @@ class RecurrentOpTest1(unittest.TestCase):
name='x',
append_batch_size=False,
**self.p_info)
x.stop_gradient = False
h_boot = data(
shape=[self.input_dim],
data_type='float32',
name='h_boot',
**self.p_info)
h_boot.stop_gradient = False
rnn = StaticRNN(main_program=self.main_program)
with rnn.step():
......@@ -256,11 +258,13 @@ class RecurrentOpTest2(RecurrentOpTest1):
name='x',
append_batch_size=False,
**self.p_info)
x.stop_gradient = False
h_boot = data(
shape=[self.input_dim],
data_type='float32',
name='h_boot',
**self.p_info)
h_boot.stop_gradient = False
rnn = StaticRNN(main_program=self.main_program)
with rnn.step():
......@@ -353,18 +357,21 @@ class RecurrentOpTest3(RecurrentOpTest1):
name='x',
append_batch_size=False,
**self.p_info)
x.stop_gradient = False
h_boot1 = data(
shape=[self.batch_size, self.input_dim],
data_type='float32',
name='h_boot1',
append_batch_size=False,
**self.p_info)
h_boot1.stop_gradient = False
h_boot2 = data(
shape=[self.batch_size, self.input_dim],
data_type='float32',
name='h_boot2',
append_batch_size=False,
**self.p_info)
h_boot2.stop_gradient = False
rnn = StaticRNN(main_program=self.main_program)
with rnn.step():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册