未验证 提交 62fd3209 编写于 作者: W WangXi 提交者: GitHub

Fix dgc param regularizer, test=develop (#22888)

上级 07e13b84
......@@ -1149,21 +1149,23 @@ class DGCMomentumOptimizer(Optimizer):
self._num_trainers = num_trainers
self._clip_norm = local_grad_clip_norm * (num_trainers**-0.5)
self._get_dgc_regularization_param()
self.regular_type, self.regular_coeff = self._get_regularization_param(
self.regularization)
def _get_dgc_regularization_param(self):
self.regular_coeff = 0.0
self.regular_type = 0
def _get_regularization_param(self, regularization):
regular_type = 0
regular_coeff = 0.0
if self.regularization is not None:
self.regular_coeff = self.regularization._regularization_coeff
if regularization is not None:
regular_coeff = regularization._regularization_coeff
from .regularizer import L1Decay, L2Decay
if isinstance(self.regularization, L1Decay):
self.regular_type = 1
elif isinstance(self.regularization, L2Decay):
self.regular_type = 2
if isinstance(regularization, L1Decay):
regular_type = 1
elif isinstance(regularization, L2Decay):
regular_type = 2
else:
assert False, 'regularization must be None|L1Decay|L2Deacy'
return regular_type, regular_coeff
def _is_use_dgc(self, param_var, grad_var):
var_numel = abs(reduce(lambda x, y: x * y, param_var.shape))
......@@ -1364,6 +1366,13 @@ class DGCMomentumOptimizer(Optimizer):
block = framework.default_main_program().global_block()
op_maker = core.op_proto_and_checker_maker
regular_type = self.regular_type
regular_coeff = self.regular_coeff
# The regularizer of the Parameters have higher priority
if param_var.regularizer is not None:
regular_type, regular_coeff = self._get_regularization_param(
param_var.regularizer)
dgc_op = block.append_op(
type="dgc",
inputs={
......@@ -1388,8 +1397,8 @@ class DGCMomentumOptimizer(Optimizer):
"use_nesterov": self._use_nesterov,
"rampup_begin_step": float(self._rampup_begin_step),
"rampup_step": float(self._rampup_step),
"regular_coeff": float(self.regular_coeff),
"regular_type": int(self.regular_type),
"regular_coeff": float(regular_coeff),
"regular_type": int(regular_type),
},
stop_gradient=True)
......
......@@ -44,7 +44,9 @@ class TestDGCMomentumOptimizer(unittest.TestCase):
shape=[dims[0], dims[1]],
lod_level=0,
name="mul.x",
optimize_attr={'learning_rate': 1.1})
optimize_attr={'learning_rate': 1.1},
regularizer=None if regularization is not None else
regularizer.L2DecayRegularizer(2e-4))
mul_y = block.create_var(
dtype="float32",
shape=[dims[1], dims[2]],
......@@ -102,6 +104,14 @@ class TestDGCMomentumOptimizer(unittest.TestCase):
self.assertEqual(init_ops[0].type, "fill_constant")
self.assertAlmostEqual(init_ops[0].attr('value'), learning_rate)
# check dgc op regularization coeff
train_ops = program.global_block().ops
for op in train_ops:
if op.type == "dgc":
coeff = 2e-4 if regularization is None else 1e-4
self.assertAlmostEqual(op.attr('regular_coeff'), coeff)
print("dgc regular_coeff=" + str(coeff))
with open("test_dgc_optimizer_" + name + ".log", "w") as f:
program_to_code(program, fout=f)
......@@ -116,6 +126,10 @@ class TestDGCMomentumOptimizer(unittest.TestCase):
name="dgc_momentum",
regularization=regularizer.L2Decay(1e-4))
# check param.regularizer in dgc
self.check_dgc_momentum_optimizer(
dims=[16, 1024, 8], name="dgc_momentum")
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册