未验证 提交 66dc8e30 编写于 作者: Z Zhou Wei 提交者: GitHub

move the initialize position of grad_clip to optimizer(__init__),and speed up clip (#23782)

上级 361c6ccc
......@@ -137,9 +137,6 @@ class GradientClipBase(object):
raise NotImplementedError
def __call__(self, params_grads):
assert len(
params_grads
) > 0, "The number of trainable parameters should be greater than 0."
if framework.in_dygraph_mode():
return self._dygraph_clip(params_grads)
else:
......@@ -147,7 +144,7 @@ class GradientClipBase(object):
if getattr(p, 'gradient_clip_attr', None) is not None:
warnings.warn(
"'set_gradient_clip' will be ineffective, because you have "
"pass 'grad_clip' into 'minimize'. So, 'set_gradient_clip' "
"set 'grad_clip' in 'optimizer'. So, 'set_gradient_clip' "
"is redundant and you can remove it.")
break
return self._static_clip(params_grads)
......@@ -170,7 +167,7 @@ class GradientClipByValue(GradientClipBase):
The multi-dimensional Tensor :math:`X` is not passed from this class, but the gradients of all parameters in ``Program`` . If ``need_clip``
is not None, then only part of gradients can be selected for gradient clipping.
Gradient clip will takes effect after being set in ``optimizer.minimize(grad_clip)`` , see the document ``optimizer``
Gradient clip will takes effect after being set in ``optimizer`` , see the document ``optimizer``
(for example: :ref:`api_fluid_optimizer_SGDOptimizer`).
Args:
......@@ -208,8 +205,8 @@ class GradientClipByValue(GradientClipBase):
# return Parameter.name=="fc_0.w_0"
# clip = fluid.clip.GradientClipByValue(min=-1, max=1, need_clip=fileter_func)
sgd_optimizer = fluid.optimizer.SGDOptimizer(learning_rate=0.1)
sgd_optimizer.minimize(loss, grad_clip=clip)
sgd_optimizer = fluid.optimizer.SGDOptimizer(learning_rate=0.1, grad_clip=clip)
sgd_optimizer.minimize(loss)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
......@@ -242,8 +239,8 @@ class GradientClipByValue(GradientClipBase):
# clip = fluid.clip.GradientClipByValue(min=-1, max=1, need_clip=fileter_func)
sgd_optimizer = fluid.optimizer.SGD(
learning_rate=0.1, parameter_list=linear.parameters())
sgd_optimizer.minimize(loss, grad_clip=clip)
learning_rate=0.1, parameter_list=linear.parameters(), grad_clip=clip)
sgd_optimizer.minimize(loss)
"""
def __init__(self, max, min=None, need_clip=None):
......@@ -272,6 +269,7 @@ class GradientClipByValue(GradientClipBase):
def _static_clip(self, params_grads):
params_and_grads = []
param_new_grad_name_dict = dict()
with framework.name_scope('gradient_clip'):
for p, g in params_grads:
if g is None:
......@@ -284,7 +282,8 @@ class GradientClipByValue(GradientClipBase):
with p.block.program._optimized_guard([p, g]):
new_grad = layers.clip(x=g, min=self.min, max=self.max)
params_and_grads.append((p, new_grad))
_correct_clip_op_role_var(params_and_grads)
param_new_grad_name_dict[p.name] = new_grad.name
_correct_clip_op_role_var(params_and_grads, param_new_grad_name_dict)
return params_and_grads
def _process_context(self, context, param, grad):
......@@ -306,7 +305,7 @@ class GradientClipByNorm(GradientClipBase):
The multidimensional Tensor :math:`X` is not passed from this class, but the gradients of all parameters in ``Program`` . If ``need_clip``
is not None, then only part of gradients can be selected for gradient clipping.
Gradient clip will takes effect after being set in ``optimizer.minimize(grad_clip)`` , see the document ``optimizer``
Gradient clip will takes effect after being set in ``optimizer`` , see the document ``optimizer``
(for example: :ref:`api_fluid_optimizer_SGDOptimizer`).
The clipping formula is:
......@@ -359,8 +358,8 @@ class GradientClipByNorm(GradientClipBase):
# return Parameter.name=="fc_0.w_0"
# clip = fluid.clip.GradientClipByNorm(clip_norm=1.0, need_clip=fileter_func)
sgd_optimizer = fluid.optimizer.SGDOptimizer(learning_rate=0.1)
sgd_optimizer.minimize(loss, grad_clip=clip)
sgd_optimizer = fluid.optimizer.SGDOptimizer(learning_rate=0.1, grad_clip=clip)
sgd_optimizer.minimize(loss)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
......@@ -394,8 +393,8 @@ class GradientClipByNorm(GradientClipBase):
# clip = fluid.clip.GradientClipByNorm(clip_norm=1.0, need_clip=fileter_func)
sgd_optimizer = fluid.optimizer.SGD(
learning_rate=0.1, parameter_list=linear.parameters())
sgd_optimizer.minimize(loss, grad_clip=clip)
learning_rate=0.1, parameter_list=linear.parameters(), grad_clip=clip)
sgd_optimizer.minimize(loss)
"""
......@@ -422,6 +421,7 @@ class GradientClipByNorm(GradientClipBase):
def _static_clip(self, params_grads):
params_and_grads = []
with framework.name_scope('gradient_clip'):
param_new_grad_name_dict = dict()
for p, g in params_grads:
if g is None:
continue
......@@ -432,8 +432,9 @@ class GradientClipByNorm(GradientClipBase):
with p.block.program._optimized_guard([p, g]):
new_grad = layers.clip_by_norm(x=g, max_norm=self.clip_norm)
param_new_grad_name_dict[p.name] = new_grad.name
params_and_grads.append((p, new_grad))
_correct_clip_op_role_var(params_and_grads)
_correct_clip_op_role_var(params_and_grads, param_new_grad_name_dict)
return params_and_grads
def _process_context(self, context, param, grad):
......@@ -456,7 +457,7 @@ class GradientClipByGlobalNorm(GradientClipBase):
The list of Tensor :math:`t\_list` is not passed from this class, but the gradients of all parameters in ``Program`` . If ``need_clip``
is not None, then only part of gradients can be selected for gradient clipping.
Gradient clip will takes effect after being set in ``optimizer.minimize(grad_clip)`` , see the document ``optimizer``
Gradient clip will takes effect after being set in ``optimizer`` , see the document ``optimizer``
(for example: :ref:`api_fluid_optimizer_SGDOptimizer`).
The clipping formula is:
......@@ -505,8 +506,8 @@ class GradientClipByGlobalNorm(GradientClipBase):
# return Parameter.name=="fc_0.w_0"
# clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0, need_clip=fileter_func)
sgd_optimizer = fluid.optimizer.SGDOptimizer(learning_rate=0.1)
sgd_optimizer.minimize(loss, grad_clip=clip)
sgd_optimizer = fluid.optimizer.SGDOptimizer(learning_rate=0.1, grad_clip=clip)
sgd_optimizer.minimize(loss)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
......@@ -539,8 +540,8 @@ class GradientClipByGlobalNorm(GradientClipBase):
# clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0, need_clip=fileter_func)
sgd_optimizer = fluid.optimizer.SGD(
learning_rate=0.1, parameter_list=linear.parameters())
sgd_optimizer.minimize(loss, grad_clip=clip)
learning_rate=0.1, parameter_list=linear.parameters(), grad_clip=clip)
sgd_optimizer.minimize(loss)
"""
......@@ -628,6 +629,7 @@ class GradientClipByGlobalNorm(GradientClipBase):
y=layers.elementwise_max(
x=max_global_norm, y=global_norm_var))
param_new_grad_name_dict = dict()
for p, g in params_grads:
if g is None:
continue
......@@ -638,9 +640,10 @@ class GradientClipByGlobalNorm(GradientClipBase):
with p.block.program._optimized_guard([p, g]):
new_grad = layers.elementwise_mul(x=g, y=scale_var)
param_new_grad_name_dict[p.name] = new_grad.name
params_and_grads.append((p, new_grad))
_correct_clip_op_role_var(params_and_grads)
_correct_clip_op_role_var(params_and_grads, param_new_grad_name_dict)
return params_and_grads
def _process_context(self, context, param, grad):
......@@ -692,9 +695,10 @@ def set_gradient_clip(clip, param_list=None, program=None):
This API must be used after building network, and before ``minimize`` ,
and it may be removed in future releases, so it is not recommended.
It is recommended to use ``minimize(loss, grad_clip=clip)`` to clip gradient.
There are three clipping strategies: :ref:`api_fluid_clip_GradientClipByGlobalNorm` ,
:ref:`api_fluid_clip_GradientClipByNorm` , :ref:`api_fluid_clip_GradientClipByValue` .
It is recommended to set ``grad_clip`` when initializing the ``optimizer`` ,
this is a better method to clip gradient. There are three clipping strategies:
:ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` ,
:ref:`api_fluid_clip_GradientClipByValue` .
To specify parameters that require gradient clip.
......@@ -757,7 +761,7 @@ def set_gradient_clip(clip, param_list=None, program=None):
sgd = fluid.optimizer.SGD(learning_rate=1e-3)
sgd.minimize(loss)
# network 4: use 'set_gradient_clip' and 'minimize(grad_clip=clip)' together
# network 4: use 'set_gradient_clip' and 'optimize(grad_clip=clip)' together
with fluid.program_guard(fluid.Program(), fluid.Program()):
loss = network()
clip1 = fluid.clip.GradientClipByValue(min=-1.0, max=1.0)
......@@ -765,8 +769,8 @@ def set_gradient_clip(clip, param_list=None, program=None):
# Set the gradient clipping strategy: clip1
fluid.clip.set_gradient_clip(clip1)
# Set the gradient clipping strategy: clip2
sgd = fluid.optimizer.SGD(learning_rate=1e-3)
sgd.minimize(loss, grad_clip=clip2)
sgd = fluid.optimizer.SGD(learning_rate=1e-3, grad_clip=clip2)
sgd.minimize(loss)
# 'set_gradient_clip' will not take effect when setting has a conflict,
# and the gradient clipping strategy will be 'clip2'
......@@ -774,10 +778,10 @@ def set_gradient_clip(clip, param_list=None, program=None):
"""
warnings.warn("Caution! 'set_gradient_clip' is not recommended "
"and may be deprecated in future! "
"We recommend a new strategy: clip gradient by "
"'optimizer.minimize(loss, grad_clip=clip)'. "
"We recommend a new strategy: set 'grad_clip' "
"when initializing the 'optimizer'. "
"This method can reduce the mistakes, please "
"see documention of 'optimzier.minimize'.")
"refer to documention of 'optimizer'.")
if not isinstance(clip, GradientClipBase):
raise TypeError(
......@@ -824,33 +828,40 @@ def append_gradient_clip_ops(param_grads):
clip_attr._process_context(context=context, param=p, grad=g)
res = []
param_new_grad_name_dict = dict()
for p, g in param_grads:
if g is None:
continue
with p.block.program._optimized_guard(
[p, g]), framework.name_scope('graident_clip_@CLIP'):
param, new_grad = clip_attr._create_operators(param=p, grad=g)
param_new_grad_name_dict[param.name] = new_grad.name
res.append([param, new_grad])
_correct_clip_op_role_var(res)
_correct_clip_op_role_var(res, param_new_grad_name_dict)
return res
# change wrong mapping relation between param & grad in clip op
def _correct_clip_op_role_var(params_grads):
def _correct_clip_op_role_var(params_grads, param_new_grad_name_dict):
block_id_list = []
if len(param_new_grad_name_dict) == 0:
return
for param, grad in params_grads:
if grad is None:
continue
block_id = param.block.idx
if block_id in block_id_list:
continue
block_id_list.append(block_id)
for op in param.block.program.global_block().ops:
if 'op_namescope' in op.all_attrs() and "gradient_clip" in op.attr(
"op_namescope"):
if op.attr('op_role_var'):
param_name = op.attr('op_role_var')[0]
index = 0
for i in range(len(params_grads)):
if params_grads[i][0].name == param_name:
index = i
correct_p_g = [param_name, params_grads[index][1].name]
"op_namescope") and op.attr('op_role_var'):
param_name = op.attr('op_role_var')[0]
if param_name in param_new_grad_name_dict:
correct_p_g = [
param_name, param_new_grad_name_dict[param_name]
]
op._set_attr('op_role_var', correct_p_g)
......
此差异已折叠。
......@@ -36,7 +36,7 @@ class ParamAttr(object):
Note:
``gradient_clip`` of ``ParamAttr`` HAS BEEN DEPRECATED since 2.0.
It is recommended to use ``minimize(loss, grad_clip=clip)`` to clip gradient.
It is recommended to set ``grad_clip`` in ``optimizer`` to clip gradient.
There are three clipping strategies: :ref:`api_fluid_clip_GradientClipByGlobalNorm` ,
:ref:`api_fluid_clip_GradientClipByNorm` , :ref:`api_fluid_clip_GradientClipByValue` .
......
......@@ -19,6 +19,7 @@ import unittest
import paddle.fluid.framework as framework
import paddle.fluid.optimizer as optimizer
import paddle.fluid.regularizer as regularizer
import paddle.fluid.clip as clip
import paddle.compat as cpt
from paddle.fluid.backward import append_backward
from paddle.fluid.transpiler.details import program_to_code
......@@ -70,9 +71,9 @@ class TestDGCMomentumOptimizer(unittest.TestCase):
learning_rate=learning_rate,
momentum=0.2,
rampup_begin_step=0,
local_grad_clip_norm=1.0,
num_trainers=2,
regularization=regularization)
regularization=regularization,
grad_clip=clip.GradientClipByNorm(1.0))
if use_recompute:
dgc_momentum_optimizer = optimizer.RecomputeOptimizer(
......@@ -124,6 +125,16 @@ class TestDGCMomentumOptimizer(unittest.TestCase):
#with open("test_dgc_optimizer_" + name + str(use_recompute) + ".log", "w") as f:
# program_to_code(program, fout=f)
def test_tpyeError(self):
# the type of DGCMomentumOptimizer(grad_clip=) must be 'GradientClipByNorm'
with self.assertRaises(TypeError):
dgc_momentum_optimizer = self.MockDGCMomentum(
learning_rate=0.01,
momentum=0.2,
rampup_begin_step=0,
num_trainers=2,
grad_clip=clip.GradientClipByGlobalNorm(1.0))
def test_momentum_without_dgc(self):
self.check_dgc_momentum_optimizer(
regularization=regularizer.L1Decay(1e-4))
......
......@@ -76,8 +76,8 @@ class TestGradientClip(unittest.TestCase):
startup_program = fluid.Program()
with fluid.program_guard(
main_program=prog, startup_program=startup_program):
image = fluid.data(name='x', shape=[-1, 784], dtype='float32')
label = fluid.data(name='y', shape=[-1, 1], dtype='int64')
image = fluid.data(name="a", shape=[-1, 784], dtype='float32')
label = fluid.data(name="b", shape=[-1, 1], dtype='int64')
hidden = fluid.layers.fc(input=image, size=32, act='relu')
predict = fluid.layers.fc(input=hidden, size=10, act='softmax')
......@@ -112,13 +112,13 @@ class TestGradientClip(unittest.TestCase):
self.check_clip_result(out, out_clip)
def check_sparse_gradient_clip(self, place):
prog = fluid.framework.Program()
startup_program = fluid.framework.Program()
prog = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(
main_program=prog, startup_program=startup_program):
data = fluid.layers.data(
name="words", shape=[1], dtype="int64", lod_level=1)
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
data = fluid.data(
name="words", shape=[-1, 1], dtype="int64", lod_level=1)
label = fluid.data(name="label", shape=[-1, 1], dtype="int64")
cost = bow_net(data, label, self.word_dict_len)
self.backward_and_optimize(cost)
......@@ -172,7 +172,7 @@ class TestGradientClipByGlobalNorm(TestGradientClip):
self.clip_gradient = func
self.check_gradient_clip(fluid.CPUPlace())
# test whether the ouput is right when use 'minimize(grad_clip)'
# test whether the ouput is right when use grad_clip
def test_new_gradient_clip(self):
def func(params_grads):
clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=self.clip_norm)
......@@ -192,9 +192,10 @@ class TestGradientClipByGlobalNorm(TestGradientClip):
clip = fluid.clip.GradientClipByGlobalNorm(
clip_norm=5.0, need_clip=fileter_func)
fluid.clip.set_gradient_clip(clip)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.01)
# if 'set_gradient_clip' and 'minimize(grad_clip)' together, 'set_gradient_clip' will be ineffective
sgd_optimizer.minimize(cost, grad_clip=clip)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.01,
grad_clip=clip)
# if 'set_gradient_clip' and 'optimize(grad_clip)' together, 'set_gradient_clip' will be ineffective
sgd_optimizer.minimize(cost)
# 'set_gradient_clip' must before 'minimize', otherwise, 'set_gradient_clip' will be ineffective
fluid.clip.set_gradient_clip(clip)
......@@ -232,24 +233,10 @@ class TestGradientClipByGlobalNorm(TestGradientClip):
clip = fluid.clip.GradientClipByGlobalNorm(
clip_norm=self.clip_norm, need_clip="test")
# the type of minimize(grad_clip=) must be an instance of GradientClipBase's derived class
with self.assertRaises(TypeError):
x = fluid.default_main_program().global_block().create_parameter(
name="x", shape=[2, 3], dtype="float32")
loss = fluid.layers.reduce_mean(x)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.1)
sgd_optimizer.minimize(loss, grad_clip="test")
# the type of RecomputeOptimizer.minimize(grad_clip=) must be an instance of GradientClipBase's derived class
# the type of optimizer(grad_clip=) must be an instance of GradientClipBase's derived class
with self.assertRaises(TypeError):
x = fluid.default_main_program().global_block().create_parameter(
name="x", shape=[2, 3], dtype="float32")
loss = fluid.layers.reduce_mean(x)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.1)
recompute_optimizer = fluid.optimizer.RecomputeOptimizer(
sgd_optimizer)
recompute_optimizer._set_checkpoints([x])
recompute_optimizer.minimize(loss, grad_clip="test")
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.1,
grad_clip="test")
class TestGradientClipByNorm(TestGradientClip):
......@@ -271,7 +258,7 @@ class TestGradientClipByNorm(TestGradientClip):
a=u, b=v, rtol=1e-5, atol=1e-8),
"gradient clip by norm has wrong results!")
# test whether the ouput is right when use 'minimize(grad_clip)'
# test whether the ouput is right when use grad_clip
def test_gradient_clip(self):
self.check_gradient_clip(fluid.CPUPlace())
......@@ -319,7 +306,7 @@ class TestGradientClipByValue(TestGradientClip):
a=u, b=v, rtol=1e-6, atol=1e-8),
"gradient clip by value has wrong results!")
# test whether the ouput is right when use 'minimize(grad_clip)'
# test whether the ouput is right when use grad_clip
def test_gradient_clip(self):
self.check_gradient_clip(fluid.CPUPlace())
......@@ -357,7 +344,9 @@ class TestDygraphGradientClip(unittest.TestCase):
loss = fluid.layers.reduce_mean(out)
loss.backward()
sgd_optimizer = fluid.optimizer.SGD(
learning_rate=0.0, parameter_list=linear.parameters())
learning_rate=0.0,
parameter_list=linear.parameters(),
grad_clip=fluid.clip.GradientClipByGlobalNorm(0.1))
self.check_clip_result(loss, sgd_optimizer)
def check_clip_result(self, loss, optimizer):
......@@ -384,7 +373,7 @@ class TestDygraphGradientClipByGlobalNorm(TestDygraphGradientClip):
np.array([3, 4]).astype("float32"), name="y")
assert len(self.clip1([(x, x), (x, y), (x, None)])) == 2
# get params and grads from network
opt, params_grads = optimizer.minimize(loss, grad_clip=self.clip2)
opt, params_grads = optimizer.minimize(loss)
_, grads = zip(*params_grads)
params_grads = self.clip2(params_grads)
_, grads_clip = zip(*params_grads)
......@@ -426,7 +415,7 @@ class TestDygraphGradientClipByNorm(TestDygraphGradientClip):
assert len(self.clip([(x, None)])) == 0
# get params and grads from network
self.clip([(fluid.dygraph.to_variable(np.array([2, 3])), None)])
params_grads = optimizer.backward(loss)
opt, params_grads = optimizer.minimize(loss)
_, grads = zip(*params_grads)
params_grads = self.clip(params_grads)
_, grads_clip = zip(*params_grads)
......@@ -460,7 +449,7 @@ class TestDygraphGradientClipByValue(TestDygraphGradientClip):
x = fluid.dygraph.to_variable(np.array([2, 3]).astype("float32"))
assert len(self.clip([(x, None)])) == 0
# get params and grads from network
params_grads = optimizer.backward(loss)
opt, params_grads = optimizer.minimize(loss)
_, grads = zip(*params_grads)
params_grads = self.clip(params_grads)
_, grads_clip = zip(*params_grads)
......
......@@ -329,9 +329,9 @@ class TestImperativeAutoPrune(unittest.TestCase):
place = fluid.CPUPlace()
with fluid.dygraph.guard(place):
model = MyLayer(size, vocab_size, size)
optimizer = fluid.optimizer.AdamOptimizer(
0.001, parameter_list=model.parameters())
grad_clip = fluid.clip.GradientClipByGlobalNorm(0.001)
optimizer = fluid.optimizer.AdamOptimizer(
0.001, parameter_list=model.parameters(), grad_clip=grad_clip)
indices = fluid.dygraph.to_variable(indices)
embed = fluid.dygraph.to_variable(embed)
......@@ -339,7 +339,7 @@ class TestImperativeAutoPrune(unittest.TestCase):
loss = model.embed_linear0(indices)
loss.backward()
_, params_grads = optimizer.minimize(loss, grad_clip=grad_clip)
_, params_grads = optimizer.minimize(loss)
for items in params_grads:
assert items[0].name is not model.embed1.weight.name
assert items[0].name is not model.linear_1.weight.name
......@@ -348,9 +348,9 @@ class TestImperativeAutoPrune(unittest.TestCase):
with fluid.dygraph.guard(place):
model = MyLayer2(size, vocab_size, size)
optimizer = fluid.optimizer.AdamOptimizer(
0.001, parameter_list=model.parameters())
grad_clip = fluid.clip.GradientClipByGlobalNorm(0.001)
optimizer = fluid.optimizer.AdamOptimizer(
0.001, parameter_list=model.parameters(), grad_clip=grad_clip)
indices = fluid.dygraph.to_variable(indices)
emebd = fluid.dygraph.to_variable(embed)
......@@ -358,7 +358,7 @@ class TestImperativeAutoPrune(unittest.TestCase):
loss = model.embed_linear0(indices)
loss.backward()
optimizer.minimize(loss, grad_clip=grad_clip)
optimizer.minimize(loss)
for items in params_grads:
assert items[0].name is not model.embed1.weight.name
assert items[0].name is not model.linear_1.weight.name
......
......@@ -58,14 +58,15 @@ class TestSimpleNet(unittest.TestCase):
simplenet = SimpleNet(20, 32, dtype)
adam = SGDOptimizer(
learning_rate=0.001,
parameter_list=simplenet.parameters())
parameter_list=simplenet.parameters(
)) # grad_clip=grad_clip
input_emb, emb = simplenet(input)
self.assertTrue(emb.weight.gradient() is None)
self.assertTrue(input_emb.gradient() is None)
input_emb.backward(backward_strategy)
adam.minimize(input_emb) # grad_clip=grad_clip
adam.minimize(input_emb)
self.assertTrue(emb.weight.gradient() is not None)
emb.clear_gradients()
......@@ -92,14 +93,15 @@ class TestSimpleNet(unittest.TestCase):
simplenet = SimpleNet(20, 32, "float32")
adam = SGDOptimizer(
learning_rate=0.001,
parameter_list=simplenet.parameters())
parameter_list=simplenet.parameters(),
grad_clip=grad_clip)
input_emb, emb = simplenet(input)
self.assertTrue(emb.weight.gradient() is None)
self.assertTrue(input_emb.gradient() is None)
input_emb.backward(backward_strategy)
adam.minimize(input_emb, grad_clip=grad_clip)
adam.minimize(input_emb)
self.assertTrue(emb.weight.gradient() is not None)
emb.clear_gradients()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册