未验证 提交 ab049978 编写于 作者: W WangXi 提交者: GitHub

[fleet] combine amp and gradient merge, test=develop (#30086)

上级 88e6dc4a
......@@ -25,7 +25,6 @@ class AMPOptimizer(MetaOptimizerBase):
"LarsOptimizer",
"LambOptimizer",
"RecomputeOptimizer",
"GradientMergeOptimizer",
"GraphExecutionOptimizer",
]
self.meta_optimizers_black_list = ["DGCOptimizer"]
......
......@@ -21,6 +21,7 @@ class GradientMergeOptimizer(MetaOptimizerBase):
self.inner_opt = optimizer
self.wrapped_opt = None
self.meta_optimizers_white_list = [
"AMPOptimizer",
"LarsOptimizer",
"LambOptimizer",
"GraphExecutionOptimizer",
......
......@@ -159,9 +159,6 @@ class OptimizerWithMixedPrecision(object):
params_grads = self._optimizer.backward(
self._scaled_loss, startup_program, parameter_list, no_grad_set,
callbacks)
# Change the op_role_var attr for some ops, so that gradients
# transferred across GPUs can be FP16.
update_role_var_grad(train_program, params_grads)
return params_grads
def apply_gradients(self, params_grads):
......@@ -176,6 +173,10 @@ class OptimizerWithMixedPrecision(object):
A list of optimize operators.
"""
# Change the op_role_var attr for some ops, so that gradients
# transferred across GPUs can be FP16.
update_role_var_grad(self._train_program, params_grads)
grads = [g for _, g in params_grads]
if not self._is_distributed:
with self._train_program._optimized_guard(grads):
......
......@@ -46,6 +46,19 @@ class TestFleetGradientMergeMetaOptimizer(TestFleetMetaOptimizer):
self.assertIn('@GradientMerge', ''.join(vars))
self.assertIn('subprog', ''.join(vars))
def test_gm_amp_optimizer(self):
train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program(
)
avg_cost, strategy = self.net(train_prog, startup_prog)
self.set_strategy(strategy, 'gradient_merge')
self.set_strategy(strategy, 'amp')
self.optimizer(avg_cost, strategy, train_prog, startup_prog)
print(train_prog)
vars = [x.name for x in train_prog.list_vars()]
self.assertIn('@GradientMerge', ''.join(vars))
self.assertIn('cast', ''.join(vars))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册