未验证 提交 f2265d9f 编写于 作者: W WangXi 提交者: GitHub

Fix problem use recompute and dgc same time (#23010)

上级 8fb8b952
...@@ -3926,6 +3926,9 @@ class RecomputeOptimizer(Optimizer): ...@@ -3926,6 +3926,9 @@ class RecomputeOptimizer(Optimizer):
parameter_list, parameter_list,
no_grad_set, no_grad_set,
checkpoints=self._checkpoints) checkpoints=self._checkpoints)
# Note: since we can't use all_reduce_op now,
# dgc_op should be the last op of one grad.
self._optimizer._append_dgc_ops(params_grads)
return params_grads return params_grads
def apply_optimize(self, loss, startup_program, params_grads): def apply_optimize(self, loss, startup_program, params_grads):
......
...@@ -35,7 +35,8 @@ class TestDGCMomentumOptimizer(unittest.TestCase): ...@@ -35,7 +35,8 @@ class TestDGCMomentumOptimizer(unittest.TestCase):
def check_dgc_momentum_optimizer(self, def check_dgc_momentum_optimizer(self,
dims=[5, 10, 8], dims=[5, 10, 8],
name="momentum", name="momentum",
regularization=None): regularization=None,
use_recompute=False):
init_program = framework.Program() init_program = framework.Program()
program = framework.Program() program = framework.Program()
block = program.global_block() block = program.global_block()
...@@ -72,6 +73,13 @@ class TestDGCMomentumOptimizer(unittest.TestCase): ...@@ -72,6 +73,13 @@ class TestDGCMomentumOptimizer(unittest.TestCase):
local_grad_clip_norm=1.0, local_grad_clip_norm=1.0,
num_trainers=2, num_trainers=2,
regularization=regularization) regularization=regularization)
if use_recompute:
dgc_momentum_optimizer = optimizer.RecomputeOptimizer(
dgc_momentum_optimizer)
dgc_momentum_optimizer.get_accumulators = dgc_momentum_optimizer._optimizer.get_accumulators
dgc_momentum_optimizer.get_velocity_str = dgc_momentum_optimizer._optimizer.get_velocity_str
mean_out = block.create_var( mean_out = block.create_var(
dtype="float32", shape=[1], lod_level=0, name="mean.out") dtype="float32", shape=[1], lod_level=0, name="mean.out")
block.append_op( block.append_op(
...@@ -112,8 +120,9 @@ class TestDGCMomentumOptimizer(unittest.TestCase): ...@@ -112,8 +120,9 @@ class TestDGCMomentumOptimizer(unittest.TestCase):
self.assertAlmostEqual(op.attr('regular_coeff'), coeff) self.assertAlmostEqual(op.attr('regular_coeff'), coeff)
print("dgc regular_coeff=" + str(coeff)) print("dgc regular_coeff=" + str(coeff))
with open("test_dgc_optimizer_" + name + ".log", "w") as f: # for local test debug
program_to_code(program, fout=f) #with open("test_dgc_optimizer_" + name + str(use_recompute) + ".log", "w") as f:
# program_to_code(program, fout=f)
def test_momentum_without_dgc(self): def test_momentum_without_dgc(self):
self.check_dgc_momentum_optimizer( self.check_dgc_momentum_optimizer(
...@@ -130,6 +139,14 @@ class TestDGCMomentumOptimizer(unittest.TestCase): ...@@ -130,6 +139,14 @@ class TestDGCMomentumOptimizer(unittest.TestCase):
self.check_dgc_momentum_optimizer( self.check_dgc_momentum_optimizer(
dims=[16, 1024, 8], name="dgc_momentum") dims=[16, 1024, 8], name="dgc_momentum")
def test_momentum_with_dgc_recompute(self):
# 16 * 1024 = 16384, use dgc momentum
self.check_dgc_momentum_optimizer(
dims=[16, 1024, 8],
name="dgc_momentum",
regularization=regularizer.L2Decay(1e-4),
use_recompute=True)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册