未验证 提交 f2265d9f 编写于 作者: W WangXi 提交者: GitHub

Fix problem use recompute and dgc same time (#23010)

上级 8fb8b952
......@@ -3926,6 +3926,9 @@ class RecomputeOptimizer(Optimizer):
parameter_list,
no_grad_set,
checkpoints=self._checkpoints)
# Note: since we can't use all_reduce_op now,
# dgc_op should be the last op of one grad.
self._optimizer._append_dgc_ops(params_grads)
return params_grads
def apply_optimize(self, loss, startup_program, params_grads):
......
......@@ -35,7 +35,8 @@ class TestDGCMomentumOptimizer(unittest.TestCase):
def check_dgc_momentum_optimizer(self,
dims=[5, 10, 8],
name="momentum",
regularization=None):
regularization=None,
use_recompute=False):
init_program = framework.Program()
program = framework.Program()
block = program.global_block()
......@@ -72,6 +73,13 @@ class TestDGCMomentumOptimizer(unittest.TestCase):
local_grad_clip_norm=1.0,
num_trainers=2,
regularization=regularization)
if use_recompute:
dgc_momentum_optimizer = optimizer.RecomputeOptimizer(
dgc_momentum_optimizer)
dgc_momentum_optimizer.get_accumulators = dgc_momentum_optimizer._optimizer.get_accumulators
dgc_momentum_optimizer.get_velocity_str = dgc_momentum_optimizer._optimizer.get_velocity_str
mean_out = block.create_var(
dtype="float32", shape=[1], lod_level=0, name="mean.out")
block.append_op(
......@@ -112,8 +120,9 @@ class TestDGCMomentumOptimizer(unittest.TestCase):
self.assertAlmostEqual(op.attr('regular_coeff'), coeff)
print("dgc regular_coeff=" + str(coeff))
with open("test_dgc_optimizer_" + name + ".log", "w") as f:
program_to_code(program, fout=f)
# for local test debug
#with open("test_dgc_optimizer_" + name + str(use_recompute) + ".log", "w") as f:
# program_to_code(program, fout=f)
def test_momentum_without_dgc(self):
self.check_dgc_momentum_optimizer(
......@@ -130,6 +139,14 @@ class TestDGCMomentumOptimizer(unittest.TestCase):
self.check_dgc_momentum_optimizer(
dims=[16, 1024, 8], name="dgc_momentum")
def test_momentum_with_dgc_recompute(self):
# 16 * 1024 = 16384, use dgc momentum
self.check_dgc_momentum_optimizer(
dims=[16, 1024, 8],
name="dgc_momentum",
regularization=regularizer.L2Decay(1e-4),
use_recompute=True)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册