未验证 提交 c22e1123 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] fix gradient merge optimize parse (#43169)

* fix gradient merge

* bug fix

* update annotation
上级 398b96c6
......@@ -148,7 +148,7 @@ class Parallelizer:
config)
auto_parallel_recompute_pass.apply([main_program],
[startup_program],
self._dist_context)
self._pass_context)
def _apply_post_optimization(self, main_program, startup_program, rank,
params_grads):
......@@ -162,7 +162,7 @@ class Parallelizer:
auto_parallel_sharding_pass = new_pass("auto_parallel_sharding",
config)
auto_parallel_sharding_pass.apply([main_program], [startup_program],
self._dist_context)
self._pass_context)
if self._strategy.gradient_merge:
config = copy.deepcopy(self._strategy.gradient_merge_configs)
......@@ -172,4 +172,4 @@ class Parallelizer:
"auto_parallel_gradient_merge_pass", config)
auto_parallel_gradient_merge_pass.apply([main_program],
[startup_program],
self._dist_context)
self._pass_context)
......@@ -18,10 +18,10 @@ from typing import List, Tuple, Dict, Any
import paddle
from paddle.framework import core
from paddle.fluid import layers
from paddle.fluid.framework import program_guard, device_guard
from paddle.fluid import unique_name, layers
from paddle.fluid.clip import append_gradient_clip_ops
from .pass_base import PassBase, PassType, register_pass
from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.distributed.auto_parallel.utils import set_var_dist_attr
from paddle.distributed.auto_parallel.utils import naive_set_dist_op_attr_for_program_by_mesh_and_mapping
from paddle.distributed.auto_parallel.process_group import get_world_process_group
......@@ -29,16 +29,8 @@ from paddle.distributed.auto_parallel.process_group import get_world_process_gro
world_process_group = get_world_process_group()
def _is_the_backward_op(op):
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
OpRole = core.op_proto_and_checker_maker.OpRole
return OP_ROLE_KEY in op.attr_names and \
int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Backward)
def _is_the_optimizer_op(op):
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
OpRole = core.op_proto_and_checker_maker.OpRole
return OP_ROLE_KEY in op.attr_names and \
int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Optimize)
......@@ -47,13 +39,13 @@ def _remove_and_get_optimizer_op(main_program, dist_context):
# 1 create tmp block
# 2 mv optimizer op from global program to tmp block
# 3 del the op from dist_context
from paddle.distributed.fleet.meta_optimizers.common import OpRole
main_block = main_program.global_block()
temp_block = main_program._create_block()
removed_op_idx = []
optimize_ops_desc = []
skip_ops = ["increment", "elementwise_mod", "equal"]
for idx, op in enumerate(main_block.ops):
if _is_the_optimizer_op(op):
if _is_the_optimizer_op(op) and op.type not in skip_ops:
# append optimizer op to tmp block
new_op_desc = temp_block.desc.append_op()
new_op_desc.copy_from(op.desc)
......@@ -111,8 +103,17 @@ def _get_gm_cond_var(main_program, k_steps, dist_context):
set_var_dist_attr(dist_context, cond_var, [-1], world_process_group.ranks)
with device_guard("cpu"):
# step_var = (step_var + 1) % k_step
layers.increment(x=step_var, value=1.0, in_place=True)
# step_var += 1
increment_op = main_block.append_op(type='increment',
inputs={'X': [step_var]},
outputs={'Out': [step_var]},
attrs={
'step': float(1.0),
'op_role': OpRole.Optimize
})
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
increment_op, world_process_group.ranks, [-1], dist_context)
# step_var %= k_step
elementwise_mod_op = main_block.append_op(type='elementwise_mod',
inputs={
'X': step_var,
......@@ -121,18 +122,19 @@ def _get_gm_cond_var(main_program, k_steps, dist_context):
outputs={'Out': step_var},
attrs={
'axis': -1,
'use_mkldnn': False
'use_mkldnn': False,
'op_role': OpRole.Optimize
})
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
elementwise_mod_op, world_process_group.ranks, [-1], dist_context)
# cond_var = (step_var == 0)
equal_op = main_block.append_op(type='equal',
inputs={
'X': step_var,
'Y': zero_var
},
outputs={'Out': cond_var})
outputs={'Out': cond_var},
attrs={'op_role': OpRole.Optimize})
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
equal_op, world_process_group.ranks, [-1], dist_context)
......@@ -154,7 +156,9 @@ def _append_gradient_merge_backward_op(
_remove_op_role_var(param, grad)
param_to_gradient_merge = {}
# {grad.name: gradient_merge_var.name} to rename opt inputs
grad_to_gradient_merge = {}
# {param: gradient_merge_var} to insert scale op and fill_constant op
new_params_to_grads = []
# step2: create gradient_merge var and init with 0
for param, grad in params_grads:
......@@ -168,7 +172,6 @@ def _append_gradient_merge_backward_op(
shape=param_var.shape,
dtype=param_var.dtype,
persistable=True)
param_to_gradient_merge[param_name] = gradient_merge_var
ref_process_mesh = ref_dist_attr.process_mesh
ref_dims_mapping = ref_dist_attr.dims_mapping
......@@ -197,17 +200,19 @@ def _append_gradient_merge_backward_op(
outputs={'Out': gradient_merge_var},
attrs={
'axis': -1,
'use_mkldnn': False
'use_mkldnn': False,
'op_role': OpRole.Optimize
})
new_params_to_grads.append([param, gradient_merge_var])
grad_to_gradient_merge[grad.name] = gradient_merge_var.name
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
new_grad_op, ref_process_mesh, ref_dims_mapping, dist_context)
return new_params_to_grads, param_to_gradient_merge
return new_params_to_grads, grad_to_gradient_merge
def _create_cond_block_and_update_optimizer(
main_program, cond_var, new_params_to_grads: List[Tuple[Any, Any]],
param_to_gradient_merge: Dict[str, Any], optimize_ops_desc: List[Any],
grad_to_gradient_merge: Dict[str, str], optimize_ops_desc: List[Any],
k_steps, avg):
def true_apply_gradient():
......@@ -229,7 +234,7 @@ def _create_cond_block_and_update_optimizer(
'bias_after_scale': False
})
new_grad.op._set_attr(op_maker.kOpRoleAttrName(),
op_maker.OpRole.Optimize)
OpRole.Optimize)
# append optimizer ops
for op_desc in optimize_ops_desc:
......@@ -238,14 +243,14 @@ def _create_cond_block_and_update_optimizer(
#update input/output
for input_name in new_op_desc.input_arg_names():
if input_name in new_params_to_grads:
new_op_desc._rename_input(input_name,
new_params_to_grads[input_name])
if input_name in grad_to_gradient_merge:
new_op_desc._rename_input(
input_name, grad_to_gradient_merge[input_name])
for output_name in new_op_desc.output_arg_names():
if output_name in new_params_to_grads:
new_op_desc._rename_output(output_name,
new_params_to_grads[output_name])
if output_name in grad_to_gradient_merge:
new_op_desc._rename_output(
output_name, grad_to_gradient_merge[output_name])
# remove op_role_var
if new_op_desc.has_attr(op_maker.kOpRoleVarAttrName()):
......@@ -271,6 +276,8 @@ def _create_cond_block_and_update_optimizer(
op_maker.OpRole.Optimize)
layers.cond(cond_var, true_fn=true_apply_gradient, false_fn=None)
cond_op = main_program.global_block().ops[-1]
cond_op._set_attr('op_role', OpRole.Optimize)
def parse_program(main_program, startup_program, params_grads, k_steps, avg,
......@@ -285,14 +292,14 @@ def parse_program(main_program, startup_program, params_grads, k_steps, avg,
main_program._rollback()
# 3 append gradient merge backward op to main_program
new_params_to_grads, param_to_gradient_merge = _append_gradient_merge_backward_op(
new_params_to_grads, grad_to_gradient_merge = _append_gradient_merge_backward_op(
main_program, startup_program, params_grads, cond_var.name,
dist_context)
# 4 create ConditionalBlock and append gradient merge optimizer ops
_create_cond_block_and_update_optimizer(main_program, cond_var,
new_params_to_grads,
param_to_gradient_merge,
grad_to_gradient_merge,
optimize_ops_desc, k_steps, avg)
......@@ -303,7 +310,6 @@ class GradientMergePass(PassBase):
super(GradientMergePass, self).__init__()
self.set_attr("k_steps", -1)
self.set_attr("avg", True)
self.set_attr("inner_optimizer", None)
def _check_self(self):
if self.get_attr("k_steps") < 1:
......
......@@ -14,12 +14,12 @@ if((NOT WITH_GPU)
list(REMOVE_ITEM TEST_OPS "test_dist_fuse_momentum_pass")
list(REMOVE_ITEM TEST_OPS "test_dist_fuse_relu_depthwise_conv_pass")
list(REMOVE_ITEM TEST_OPS "test_dist_fuse_sgd_pass")
list(REMOVE_ITEM TEST_OPS "test_dist_gradient_merge_pass")
list(REMOVE_ITEM TEST_OPS "test_dist_inplace_addto_pass")
list(REMOVE_ITEM TEST_OPS "test_auto_parallel_amp_pass")
list(REMOVE_ITEM TEST_OPS "test_auto_parallel_recompute_pass")
list(REMOVE_ITEM TEST_OPS "test_auto_parallel_sharding_pass")
list(REMOVE_ITEM TEST_OPS "test_auto_parallel_fp16_pass")
list(REMOVE_ITEM TEST_OPS "test_auto_parallel_gradient_merge_pass")
endif()
foreach(TEST_OP ${TEST_OPS})
......
......@@ -25,20 +25,14 @@ import paddle.nn as nn
import paddle.utils as utils
import paddle.static as static
import paddle.nn.functional as F
import paddle.distributed.fleet as fleet
import paddle.distributed.auto_parallel as auto
from paddle.fluid.initializer import NumpyArrayInitializer
from paddle.distributed.passes import new_pass, PassManager, PassContext
import paddle.distributed.fleet as fleet
from dist_pass_test_base import DistPassTestBase
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.fluid.initializer import NumpyArrayInitializer
from auto_parallel_pass_test_base import AutoPallelPassTestBase
logging.getLogger().setLevel(logging.INFO)
paddle.enable_static()
_global_parallel_strategy = None
_global_process_mesh = None
#np.set_printoptions(suppress=True)
class MLPLayer(nn.Layer):
......@@ -103,13 +97,11 @@ class MLPLayer(nn.Layer):
def mlp_forward(input, label, hidden_size):
if _global_parallel_strategy == "dp":
auto.shard_tensor(input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
auto.shard_tensor(input,
dist_attr={
"process_mesh": [0],
"dims_mapping": [-1, -1]
})
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
initializer_range=0.02)
......@@ -119,40 +111,33 @@ def mlp_forward(input, label, hidden_size):
return loss
class TestGradientMergePass(DistPassTestBase):
class TestGradientMergePass(AutoPallelPassTestBase):
def init(self):
self._params_grads = None
self._config = {"k_steps": 4, "avg": True}
#self._config["dist_context"] = DistributedContext()
def apply_passes(self, main_prog, startup_prog):
#self._config["params_grads"] = self._params_grads
#pass_context = PassContext()
#auto_parallel_gradient_merge_pass = new_pass(
# "auto_parallel_gradient_merge_pass", self._config)
#auto_parallel_gradient_merge_pass.apply([main_prog], [startup_prog],
# pass_context)
paddle.seed(2022)
random.seed(2022)
np.random.seed(2022)
def apply_passes(self):
dist_strategy = fleet.DistributedStrategy()
dist_strategy.semi_auto = True
dist_strategy.gradient_merge = True
dist_strategy.gradient_merge_configs = {"k_steps": 4, "avg": True}
dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy)
def test_result(self):
no_pass_rets = self._distributed_launch(model=None,
apply_pass=False,
gpus=[0],
gradient_merge=False,
batch_size=32,
hidden_size=128,
max_step=2)
pass_rets = self._distributed_launch(model=None,
apply_pass=True,
gpus=[0],
gradient_merge=True,
batch_size=8,
hidden_size=128,
max_step=8)
"""
# avg loss for gradient_merge pass
avg_loss = 0
pass_avg_ret_list = []
......@@ -167,40 +152,16 @@ class TestGradientMergePass(DistPassTestBase):
for no_pass_ret, pass_ret in zip(no_pass_rets[0], pass_avg_ret_list):
print(f"no_pass_ret={no_pass_ret}, pass_ret={pass_ret}")
self.assertTrue(
np.isclose(
no_pass_ret,
pass_ret,
rtol=self.rtol,
atol=self.atol,
equal_nan=self.equal_nan))
"""
def get_model(self, place, gradient_merge, batch_size, max_step):
paddle.seed(2021)
random.seed(2021)
np.random.seed(2021)
np.isclose(no_pass_ret,
pass_ret,
rtol=self.rtol,
atol=self.atol,
equal_nan=self.equal_nan))
hidden_size = 128
global _global_parallel_strategy
global _global_process_mesh
world_size = paddle.distributed.get_world_size()
if world_size == 1:
_global_parallel_strategy = "dp"
_global_process_mesh = auto.ProcessMesh([0])
elif world_size == 2:
_global_parallel_strategy = "dp"
_global_process_mesh = auto.ProcessMesh([0, 1])
def get_model(self, place, batch_size, hidden_size, max_step):
train_program = static.Program()
startup_program = static.Program()
dist_strategy = fleet.DistributedStrategy()
dist_strategy.semi_auto = True
#if gradient_merge:
# dist_strategy.gradient_merge = True
# dist_strategy.gradient_merge_configs = {"k_steps": 4, "avg": True}
fleet.init(is_collective=True, strategy=dist_strategy)
with static.program_guard(train_program, startup_program), \
utils.unique_name.guard():
input = static.data(name="input",
......@@ -212,8 +173,7 @@ class TestGradientMergePass(DistPassTestBase):
input.stop_gradient = False
loss = mlp_forward(input, label, hidden_size)
optimizer = paddle.fluid.optimizer.SGDOptimizer(learning_rate=0.01)
#optimizer = paddle.fluid.optimizer.Adam(learning_rate=0.01)
optimizer = paddle.fluid.optimizer.AdamOptimizer(learning_rate=0.01)
optimizer = fleet.distributed_optimizer(optimizer)
_, self._params_grads, dist_startup_prog, dist_main_prog = optimizer.minimize(
loss, startup_program)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册