提交 3ce2d295 编写于 作者: M minqiyang

Refine stop_gradient

test=develop
上级 c8965dc1
......@@ -1307,6 +1307,17 @@ class Block(object):
outputs=kwargs.get("outputs", None),
attrs=kwargs.get("attrs", None))
self.ops.append(op)
# set stop_gradient in static mode
if kwargs.get("stop_gradient", False):
outputs = kwargs.get("outputs", None)
if outputs is not None:
for k, v in six.iteritems(outputs):
if isinstance(v, Variable):
v.stop_gradient = True
elif isinstance(v, list) or isinstance(v, tuple):
for var in v:
var.stop_gradient = True
self._trace_op(op, kwargs.get("stop_gradient", False))
return op
......
......@@ -332,21 +332,16 @@ class BatchNorm(layers.Layer):
shape=param_shape,
dtype=self._dtype,
default_initializer=Constant(1.0))
# TODO(minqiyang): change stop_gradient sign to trainable to align with static graph
# # setting stop_gradient=True to reduce computation
# if use_global_stats and self._helper.param_attr.learning_rate == 0.:
# self._scale.stop_gradient = True
if use_global_stats and self._helper.param_attr.learning_rate == 0.:
self._scale.stop_gradient = True
self._bias = self._helper.create_parameter(
attr=self._helper.bias_attr,
shape=param_shape,
dtype=self._dtype,
is_bias=True)
# TODO(minqiyang): change stop_gradient sign to trainable to align with static graph
# # setting stop_gradient=True to reduce computation
# if use_global_stats and self._helper.bias_attr.learning_rate == 0.:
# self._bias.stop_gradient = True
if use_global_stats and self._helper.bias_attr.learning_rate == 0.:
self._bias.stop_gradient = True
self._mean = self._helper.create_parameter(
attr=ParamAttr(
......
......@@ -387,7 +387,7 @@ class Optimizer(object):
params_grads = []
for param in parameters:
if param.stop_gradient:
if param.stop_gradient or not param.trainable:
continue
# create gradient variable
grad_var = Variable(
......
......@@ -98,7 +98,7 @@ class MNIST(fluid.imperative.Layer):
class TestImperativeMnist(unittest.TestCase):
def test_mnist_cpu_float32(self):
def test_mnist_float32(self):
seed = 90
with fluid.imperative.guard():
......@@ -196,11 +196,10 @@ class TestImperativeMnist(unittest.TestCase):
static_param_value[static_param_name_list[i - 1]] = out[i]
for key, value in six.iteritems(static_param_init_value):
self.assertTrue(
np.allclose(value.all(), dy_param_init_value[key].all()))
self.assertTrue(np.allclose(static_out.all(), dy_out.all()))
self.assertTrue(np.allclose(value, dy_param_init_value[key]))
self.assertTrue(np.allclose(static_out, dy_out))
for key, value in six.iteritems(static_param_value):
self.assertTrue(np.allclose(value.all(), dy_param_value[key].all()))
self.assertTrue(np.allclose(value, dy_param_value[key]))
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册