未验证 提交 68497e7b 编写于 作者: C Chen Weihang 提交者: GitHub

change trainable to stop_gradient in optimizer (#31823)

上级 270699e6
...@@ -351,7 +351,7 @@ class Adam(Optimizer): ...@@ -351,7 +351,7 @@ class Adam(Optimizer):
""" """
params_grads = [] params_grads = []
for param in self._parameter_list: for param in self._parameter_list:
if not param.trainable: if param.stop_gradient:
continue continue
if param._grad_ivar() is not None: if param._grad_ivar() is not None:
grad_var = param._grad_ivar() grad_var = param._grad_ivar()
......
...@@ -184,7 +184,7 @@ class Adamax(Optimizer): ...@@ -184,7 +184,7 @@ class Adamax(Optimizer):
""" """
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
for param, grad in parameters_and_grads: for param, grad in parameters_and_grads:
if grad is None or param.trainable is False: if grad is None or param.stop_gradient is True:
continue continue
with param.block.program._optimized_guard( with param.block.program._optimized_guard(
[param, grad]), name_scope('adamax'): [param, grad]), name_scope('adamax'):
......
...@@ -542,7 +542,7 @@ class Optimizer(object): ...@@ -542,7 +542,7 @@ class Optimizer(object):
def _update_param_device_map(self, parameters_and_grads, target_block): def _update_param_device_map(self, parameters_and_grads, target_block):
for param_and_grad in parameters_and_grads: for param_and_grad in parameters_and_grads:
if param_and_grad[0].trainable is True: if param_and_grad[0].stop_gradient is False:
param_name = param_and_grad[0].name param_name = param_and_grad[0].name
ops = target_block.ops ops = target_block.ops
device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName( device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName(
...@@ -598,14 +598,14 @@ class Optimizer(object): ...@@ -598,14 +598,14 @@ class Optimizer(object):
self._update_param_device_map(parameters_and_grads, target_block) self._update_param_device_map(parameters_and_grads, target_block)
self._create_accumulators( self._create_accumulators(
target_block, target_block,
[p[0] for p in parameters_and_grads if p[0].trainable]) [p[0] for p in parameters_and_grads if not p[0].stop_gradient])
self._create_global_learning_rate() self._create_global_learning_rate()
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
for param_and_grad in parameters_and_grads: for param_and_grad in parameters_and_grads:
if param_and_grad[1] is None: if param_and_grad[1] is None:
continue continue
if param_and_grad[0].trainable is True: if param_and_grad[0].stop_gradient is False:
self._append_optimize_op(target_block, param_and_grad) self._append_optimize_op(target_block, param_and_grad)
else: else:
for param_and_grad in parameters_and_grads: for param_and_grad in parameters_and_grads:
...@@ -613,7 +613,7 @@ class Optimizer(object): ...@@ -613,7 +613,7 @@ class Optimizer(object):
continue continue
with param_and_grad[0].block.program._optimized_guard( with param_and_grad[0].block.program._optimized_guard(
param_and_grad), name_scope("optimizer"): param_and_grad), name_scope("optimizer"):
if param_and_grad[0].trainable is True: if param_and_grad[0].stop_gradient is False:
device = self._get_device_for_param(param_and_grad[0] device = self._get_device_for_param(param_and_grad[0]
.name) .name)
with device_guard(device): with device_guard(device):
...@@ -689,7 +689,7 @@ class Optimizer(object): ...@@ -689,7 +689,7 @@ class Optimizer(object):
params_grads = [] params_grads = []
for param in parameter_list: for param in parameter_list:
if not param.trainable: if param.stop_gradient:
continue continue
if param._grad_ivar() is not None: if param._grad_ivar() is not None:
# create gradient tensor # create gradient tensor
...@@ -789,8 +789,9 @@ class Optimizer(object): ...@@ -789,8 +789,9 @@ class Optimizer(object):
def _get_no_grad_set(self, loss, no_grad_set=None): def _get_no_grad_set(self, loss, no_grad_set=None):
no_grad_set = _get_no_grad_set_name(no_grad_set) no_grad_set = _get_no_grad_set_name(no_grad_set)
parameters = loss.block.program.global_block().all_parameters() parameters = loss.block.program.global_block().all_parameters()
param_no_trainable = set( param_no_trainable = set([
[param.name for param in parameters if param.trainable is False]) param.name for param in parameters if param.stop_gradient is True
])
# If the parameter is no trainable, it should not have a gradient. # If the parameter is no trainable, it should not have a gradient.
no_grad_set.update(param_no_trainable) no_grad_set.update(param_no_trainable)
...@@ -825,7 +826,7 @@ class Optimizer(object): ...@@ -825,7 +826,7 @@ class Optimizer(object):
""" """
for p in self._parameter_list: for p in self._parameter_list:
if p.trainable: if not p.stop_gradient:
p.clear_gradient() p.clear_gradient()
@imperative_base.no_grad @imperative_base.no_grad
...@@ -920,7 +921,7 @@ class Optimizer(object): ...@@ -920,7 +921,7 @@ class Optimizer(object):
""" """
params_grads = [] params_grads = []
for param in self._parameter_list: for param in self._parameter_list:
if not param.trainable: if param.stop_gradient:
continue continue
if param._grad_ivar() is not None: if param._grad_ivar() is not None:
grad_var = param._grad_ivar() grad_var = param._grad_ivar()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册