未验证 提交 0443c6f4 编写于 作者: X xiayanming 提交者: GitHub

[Auto Parallel] Gradient merge pass support dist attribute (#40737)

* [Auto Parallel] gradient merge pass support dist attribute
上级 a8f86600
...@@ -22,6 +22,10 @@ from paddle.fluid.framework import program_guard, device_guard ...@@ -22,6 +22,10 @@ from paddle.fluid.framework import program_guard, device_guard
from paddle.fluid import unique_name, layers from paddle.fluid import unique_name, layers
from paddle.fluid.clip import append_gradient_clip_ops from paddle.fluid.clip import append_gradient_clip_ops
from .pass_base import PassBase, PassType, register_pass from .pass_base import PassBase, PassType, register_pass
from paddle.distributed.auto_parallel.utils import set_var_dist_attr
from paddle.distributed.auto_parallel.utils import naive_set_dist_op_attr_for_program_by_mesh_and_mapping
from paddle.distributed.auto_parallel.process_group import get_world_process_group
world_process_group = get_world_process_group()
def _is_the_backward_op(op): def _is_the_backward_op(op):
...@@ -68,15 +72,11 @@ def _remove_and_get_optimizer_op(main_program, dist_context): ...@@ -68,15 +72,11 @@ def _remove_and_get_optimizer_op(main_program, dist_context):
def _remove_op_role_var(param, grad): def _remove_op_role_var(param, grad):
op_maker = core.op_proto_and_checker_maker op_maker = core.op_proto_and_checker_maker
op = grad.op op = grad.op
assert _is_the_backward_op(op), \
'grad.op={} is not the backward op which produces the grad={}' \
.format(op, grad.name)
if op.has_attr(op_maker.kOpRoleVarAttrName()): if op.has_attr(op_maker.kOpRoleVarAttrName()):
op._remove_attr(op_maker.kOpRoleVarAttrName()) op._remove_attr(op_maker.kOpRoleVarAttrName())
def _get_gm_cond_var(main_program, k_steps): def _get_gm_cond_var(main_program, k_steps, dist_context):
main_block = main_program.global_block() main_block = main_program.global_block()
# Add const var # Add const var
k_step_var = layers.create_global_var( k_step_var = layers.create_global_var(
...@@ -86,6 +86,7 @@ def _get_gm_cond_var(main_program, k_steps): ...@@ -86,6 +86,7 @@ def _get_gm_cond_var(main_program, k_steps):
dtype='int32', dtype='int32',
persistable=True, persistable=True,
force_cpu=True) force_cpu=True)
set_var_dist_attr(dist_context, k_step_var, [-1], world_process_group.ranks)
zero_var = layers.create_global_var( zero_var = layers.create_global_var(
name="gradient_merge_zero", name="gradient_merge_zero",
...@@ -94,6 +95,7 @@ def _get_gm_cond_var(main_program, k_steps): ...@@ -94,6 +95,7 @@ def _get_gm_cond_var(main_program, k_steps):
dtype='int32', dtype='int32',
persistable=True, persistable=True,
force_cpu=True) force_cpu=True)
set_var_dist_attr(dist_context, zero_var, [-1], world_process_group.ranks)
# Add step var & cond var # Add step var & cond var
step_var = layers.create_global_var( step_var = layers.create_global_var(
...@@ -103,6 +105,7 @@ def _get_gm_cond_var(main_program, k_steps): ...@@ -103,6 +105,7 @@ def _get_gm_cond_var(main_program, k_steps):
dtype='int32', dtype='int32',
persistable=True, persistable=True,
force_cpu=True) force_cpu=True)
set_var_dist_attr(dist_context, step_var, [-1], world_process_group.ranks)
cond_var = layers.create_global_var( cond_var = layers.create_global_var(
name="gradient_merge_cond", name="gradient_merge_cond",
...@@ -111,24 +114,29 @@ def _get_gm_cond_var(main_program, k_steps): ...@@ -111,24 +114,29 @@ def _get_gm_cond_var(main_program, k_steps):
dtype='bool', dtype='bool',
persistable=False, persistable=False,
force_cpu=True) force_cpu=True)
set_var_dist_attr(dist_context, cond_var, [-1], world_process_group.ranks)
with device_guard("cpu"): with device_guard("cpu"):
# step_var = (step_var + 1) % k_step # step_var = (step_var + 1) % k_step
layers.increment(x=step_var, value=1.0, in_place=True) layers.increment(x=step_var, value=1.0, in_place=True)
main_block.append_op( elementwise_mod_op = main_block.append_op(
type='elementwise_mod', type='elementwise_mod',
inputs={'X': step_var, inputs={'X': step_var,
'Y': k_step_var}, 'Y': k_step_var},
outputs={'Out': step_var}, outputs={'Out': step_var},
attrs={'axis': -1, attrs={'axis': -1,
'use_mkldnn': False}) 'use_mkldnn': False})
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
elementwise_mod_op, world_process_group.ranks, [-1], dist_context)
# cond_var = (step_var == 0) # cond_var = (step_var == 0)
main_block.append_op( equal_op = main_block.append_op(
type='equal', type='equal',
inputs={'X': step_var, inputs={'X': step_var,
'Y': zero_var}, 'Y': zero_var},
outputs={'Out': cond_var}) outputs={'Out': cond_var})
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
equal_op, world_process_group.ranks, [-1], dist_context)
return cond_var return cond_var
...@@ -137,7 +145,8 @@ def _append_gradient_merge_backward_op( ...@@ -137,7 +145,8 @@ def _append_gradient_merge_backward_op(
main_program, main_program,
startup_program, startup_program,
params_grads: List[Tuple[Any, Any]], params_grads: List[Tuple[Any, Any]],
cond_var_name: str) -> Tuple[List[Tuple[Any, Any]], Dict[str, Any]]: cond_var_name: str,
dist_context) -> Tuple[List[Tuple[Any, Any]], Dict[str, Any]]:
main_block = main_program.global_block() main_block = main_program.global_block()
startup_block = startup_program.global_block() startup_block = startup_program.global_block()
...@@ -156,12 +165,19 @@ def _append_gradient_merge_backward_op( ...@@ -156,12 +165,19 @@ def _append_gradient_merge_backward_op(
param_name = param.name param_name = param.name
param_var = main_block.var(param_name) param_var = main_block.var(param_name)
assert (param_var is not None) assert (param_var is not None)
ref_dist_attr = dist_context.get_tensor_dist_attr_for_program(param_var)
assert ref_dist_attr is not None
gradient_merge_var = main_block.create_var( gradient_merge_var = main_block.create_var(
name=param_name + "@GRAD@GradientMerge", name=param_name + "@GRAD@GradientMerge",
shape=param_var.shape, shape=param_var.shape,
dtype=param_var.dtype, dtype=param_var.dtype,
persistable=True) persistable=True)
param_to_gradient_merge[param_name] = gradient_merge_var param_to_gradient_merge[param_name] = gradient_merge_var
ref_process_mesh = ref_dist_attr.process_mesh
ref_dims_mapping = ref_dist_attr.dims_mapping
set_var_dist_attr(dist_context, gradient_merge_var, ref_dims_mapping,
ref_process_mesh)
startup_gradient_merge_var = startup_block.create_var( startup_gradient_merge_var = startup_block.create_var(
name=param_name + "@GRAD@GradientMerge", name=param_name + "@GRAD@GradientMerge",
...@@ -186,6 +202,8 @@ def _append_gradient_merge_backward_op( ...@@ -186,6 +202,8 @@ def _append_gradient_merge_backward_op(
attrs={'axis': -1, attrs={'axis': -1,
'use_mkldnn': False}) 'use_mkldnn': False})
new_params_to_grads.append([param, gradient_merge_var]) new_params_to_grads.append([param, gradient_merge_var])
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
new_grad_op, ref_process_mesh, ref_dims_mapping, dist_context)
return new_params_to_grads, param_to_gradient_merge return new_params_to_grads, param_to_gradient_merge
...@@ -240,7 +258,7 @@ def _create_cond_block_and_update_optimizer( ...@@ -240,7 +258,7 @@ def _create_cond_block_and_update_optimizer(
new_op_desc.remove_attr(op_maker.kOpRoleVarAttrName()) new_op_desc.remove_attr(op_maker.kOpRoleVarAttrName())
# op's update Grad # op's update Grad
if new_op_desc.input("Grad"): if core.grad_var_suffix() in new_op_desc.input_arg_names():
grad_value = new_op_desc.input("Grad")[0] grad_value = new_op_desc.input("Grad")[0]
# TODO FIXME(xym) support fp16 # TODO FIXME(xym) support fp16
grad_merge_value = grad_value + '@GradientMerge' grad_merge_value = grad_value + '@GradientMerge'
...@@ -265,7 +283,7 @@ def _create_cond_block_and_update_optimizer( ...@@ -265,7 +283,7 @@ def _create_cond_block_and_update_optimizer(
def parse_program(main_program, startup_program, params_grads, k_steps, avg, def parse_program(main_program, startup_program, params_grads, k_steps, avg,
dist_context): dist_context):
# 1 create gradient_merge_cond # 1 create gradient_merge_cond
cond_var = _get_gm_cond_var(main_program, k_steps) cond_var = _get_gm_cond_var(main_program, k_steps, dist_context)
# 2 remove optimizer_op from main_program # 2 remove optimizer_op from main_program
optimize_ops_desc = _remove_and_get_optimizer_op(main_program, dist_context) optimize_ops_desc = _remove_and_get_optimizer_op(main_program, dist_context)
...@@ -275,7 +293,8 @@ def parse_program(main_program, startup_program, params_grads, k_steps, avg, ...@@ -275,7 +293,8 @@ def parse_program(main_program, startup_program, params_grads, k_steps, avg,
# 3 append gradient merge backward op to main_program # 3 append gradient merge backward op to main_program
new_params_to_grads, param_to_gradient_merge = _append_gradient_merge_backward_op( new_params_to_grads, param_to_gradient_merge = _append_gradient_merge_backward_op(
main_program, startup_program, params_grads, cond_var.name) main_program, startup_program, params_grads, cond_var.name,
dist_context)
# 4 create ConditionalBlock and append gradient merge optimizer ops # 4 create ConditionalBlock and append gradient merge optimizer ops
_create_cond_block_and_update_optimizer( _create_cond_block_and_update_optimizer(
......
...@@ -31,6 +31,7 @@ from paddle.fluid.initializer import NumpyArrayInitializer ...@@ -31,6 +31,7 @@ from paddle.fluid.initializer import NumpyArrayInitializer
from paddle.distributed.passes import new_pass, PassManager, PassContext from paddle.distributed.passes import new_pass, PassManager, PassContext
import paddle.distributed.fleet as fleet import paddle.distributed.fleet as fleet
from dist_pass_test_base import DistPassTestBase from dist_pass_test_base import DistPassTestBase
from paddle.distributed.auto_parallel.dist_context import DistributedContext
logging.getLogger().setLevel(logging.INFO) logging.getLogger().setLevel(logging.INFO)
paddle.enable_static() paddle.enable_static()
...@@ -111,14 +112,20 @@ class TestGradientMergePass(DistPassTestBase): ...@@ -111,14 +112,20 @@ class TestGradientMergePass(DistPassTestBase):
def init(self): def init(self):
self._params_grads = None self._params_grads = None
self._config = {"k_steps": 4, "avg": True} self._config = {"k_steps": 4, "avg": True}
#self._config["dist_context"] = DistributedContext()
def apply_passes(self, main_prog, startup_prog): def apply_passes(self, main_prog, startup_prog):
self._config["params_grads"] = self._params_grads #self._config["params_grads"] = self._params_grads
pass_context = PassContext() #pass_context = PassContext()
auto_parallel_gradient_merge_pass = new_pass( #auto_parallel_gradient_merge_pass = new_pass(
"auto_parallel_gradient_merge_pass", self._config) # "auto_parallel_gradient_merge_pass", self._config)
auto_parallel_gradient_merge_pass.apply([main_prog], [startup_prog], #auto_parallel_gradient_merge_pass.apply([main_prog], [startup_prog],
pass_context) # pass_context)
dist_strategy = fleet.DistributedStrategy()
dist_strategy.gradient_merge = True
dist_strategy.gradient_merge_configs = {"k_steps": 4, "avg": True}
dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy)
def test_result(self): def test_result(self):
no_pass_rets = self._distributed_launch( no_pass_rets = self._distributed_launch(
...@@ -135,7 +142,7 @@ class TestGradientMergePass(DistPassTestBase): ...@@ -135,7 +142,7 @@ class TestGradientMergePass(DistPassTestBase):
gradient_merge=True, gradient_merge=True,
batch_size=8, batch_size=8,
max_step=8) max_step=8)
"""
# avg loss for gradient_merge pass # avg loss for gradient_merge pass
avg_loss = 0 avg_loss = 0
pass_avg_ret_list = [] pass_avg_ret_list = []
...@@ -156,6 +163,7 @@ class TestGradientMergePass(DistPassTestBase): ...@@ -156,6 +163,7 @@ class TestGradientMergePass(DistPassTestBase):
rtol=self.rtol, rtol=self.rtol,
atol=self.atol, atol=self.atol,
equal_nan=self.equal_nan)) equal_nan=self.equal_nan))
"""
def get_model(self, place, gradient_merge, batch_size, max_step): def get_model(self, place, gradient_merge, batch_size, max_step):
paddle.seed(2021) paddle.seed(2021)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册