未验证 提交 c22e1123 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] fix gradient merge optimize parse (#43169)

* fix gradient merge

* bug fix

* update annotation
上级 398b96c6
...@@ -148,7 +148,7 @@ class Parallelizer: ...@@ -148,7 +148,7 @@ class Parallelizer:
config) config)
auto_parallel_recompute_pass.apply([main_program], auto_parallel_recompute_pass.apply([main_program],
[startup_program], [startup_program],
self._dist_context) self._pass_context)
def _apply_post_optimization(self, main_program, startup_program, rank, def _apply_post_optimization(self, main_program, startup_program, rank,
params_grads): params_grads):
...@@ -162,7 +162,7 @@ class Parallelizer: ...@@ -162,7 +162,7 @@ class Parallelizer:
auto_parallel_sharding_pass = new_pass("auto_parallel_sharding", auto_parallel_sharding_pass = new_pass("auto_parallel_sharding",
config) config)
auto_parallel_sharding_pass.apply([main_program], [startup_program], auto_parallel_sharding_pass.apply([main_program], [startup_program],
self._dist_context) self._pass_context)
if self._strategy.gradient_merge: if self._strategy.gradient_merge:
config = copy.deepcopy(self._strategy.gradient_merge_configs) config = copy.deepcopy(self._strategy.gradient_merge_configs)
...@@ -172,4 +172,4 @@ class Parallelizer: ...@@ -172,4 +172,4 @@ class Parallelizer:
"auto_parallel_gradient_merge_pass", config) "auto_parallel_gradient_merge_pass", config)
auto_parallel_gradient_merge_pass.apply([main_program], auto_parallel_gradient_merge_pass.apply([main_program],
[startup_program], [startup_program],
self._dist_context) self._pass_context)
...@@ -18,10 +18,10 @@ from typing import List, Tuple, Dict, Any ...@@ -18,10 +18,10 @@ from typing import List, Tuple, Dict, Any
import paddle import paddle
from paddle.framework import core from paddle.framework import core
from paddle.fluid import layers
from paddle.fluid.framework import program_guard, device_guard from paddle.fluid.framework import program_guard, device_guard
from paddle.fluid import unique_name, layers
from paddle.fluid.clip import append_gradient_clip_ops
from .pass_base import PassBase, PassType, register_pass from .pass_base import PassBase, PassType, register_pass
from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.distributed.auto_parallel.utils import set_var_dist_attr from paddle.distributed.auto_parallel.utils import set_var_dist_attr
from paddle.distributed.auto_parallel.utils import naive_set_dist_op_attr_for_program_by_mesh_and_mapping from paddle.distributed.auto_parallel.utils import naive_set_dist_op_attr_for_program_by_mesh_and_mapping
from paddle.distributed.auto_parallel.process_group import get_world_process_group from paddle.distributed.auto_parallel.process_group import get_world_process_group
...@@ -29,16 +29,8 @@ from paddle.distributed.auto_parallel.process_group import get_world_process_gro ...@@ -29,16 +29,8 @@ from paddle.distributed.auto_parallel.process_group import get_world_process_gro
world_process_group = get_world_process_group() world_process_group = get_world_process_group()
def _is_the_backward_op(op):
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
OpRole = core.op_proto_and_checker_maker.OpRole
return OP_ROLE_KEY in op.attr_names and \
int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Backward)
def _is_the_optimizer_op(op): def _is_the_optimizer_op(op):
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
OpRole = core.op_proto_and_checker_maker.OpRole
return OP_ROLE_KEY in op.attr_names and \ return OP_ROLE_KEY in op.attr_names and \
int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Optimize) int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Optimize)
...@@ -47,13 +39,13 @@ def _remove_and_get_optimizer_op(main_program, dist_context): ...@@ -47,13 +39,13 @@ def _remove_and_get_optimizer_op(main_program, dist_context):
# 1 create tmp block # 1 create tmp block
# 2 mv optimizer op from global program to tmp block # 2 mv optimizer op from global program to tmp block
# 3 del the op from dist_context # 3 del the op from dist_context
from paddle.distributed.fleet.meta_optimizers.common import OpRole
main_block = main_program.global_block() main_block = main_program.global_block()
temp_block = main_program._create_block() temp_block = main_program._create_block()
removed_op_idx = [] removed_op_idx = []
optimize_ops_desc = [] optimize_ops_desc = []
skip_ops = ["increment", "elementwise_mod", "equal"]
for idx, op in enumerate(main_block.ops): for idx, op in enumerate(main_block.ops):
if _is_the_optimizer_op(op): if _is_the_optimizer_op(op) and op.type not in skip_ops:
# append optimizer op to tmp block # append optimizer op to tmp block
new_op_desc = temp_block.desc.append_op() new_op_desc = temp_block.desc.append_op()
new_op_desc.copy_from(op.desc) new_op_desc.copy_from(op.desc)
...@@ -111,8 +103,17 @@ def _get_gm_cond_var(main_program, k_steps, dist_context): ...@@ -111,8 +103,17 @@ def _get_gm_cond_var(main_program, k_steps, dist_context):
set_var_dist_attr(dist_context, cond_var, [-1], world_process_group.ranks) set_var_dist_attr(dist_context, cond_var, [-1], world_process_group.ranks)
with device_guard("cpu"): with device_guard("cpu"):
# step_var = (step_var + 1) % k_step # step_var += 1
layers.increment(x=step_var, value=1.0, in_place=True) increment_op = main_block.append_op(type='increment',
inputs={'X': [step_var]},
outputs={'Out': [step_var]},
attrs={
'step': float(1.0),
'op_role': OpRole.Optimize
})
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
increment_op, world_process_group.ranks, [-1], dist_context)
# step_var %= k_step
elementwise_mod_op = main_block.append_op(type='elementwise_mod', elementwise_mod_op = main_block.append_op(type='elementwise_mod',
inputs={ inputs={
'X': step_var, 'X': step_var,
...@@ -121,18 +122,19 @@ def _get_gm_cond_var(main_program, k_steps, dist_context): ...@@ -121,18 +122,19 @@ def _get_gm_cond_var(main_program, k_steps, dist_context):
outputs={'Out': step_var}, outputs={'Out': step_var},
attrs={ attrs={
'axis': -1, 'axis': -1,
'use_mkldnn': False 'use_mkldnn': False,
'op_role': OpRole.Optimize
}) })
naive_set_dist_op_attr_for_program_by_mesh_and_mapping( naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
elementwise_mod_op, world_process_group.ranks, [-1], dist_context) elementwise_mod_op, world_process_group.ranks, [-1], dist_context)
# cond_var = (step_var == 0) # cond_var = (step_var == 0)
equal_op = main_block.append_op(type='equal', equal_op = main_block.append_op(type='equal',
inputs={ inputs={
'X': step_var, 'X': step_var,
'Y': zero_var 'Y': zero_var
}, },
outputs={'Out': cond_var}) outputs={'Out': cond_var},
attrs={'op_role': OpRole.Optimize})
naive_set_dist_op_attr_for_program_by_mesh_and_mapping( naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
equal_op, world_process_group.ranks, [-1], dist_context) equal_op, world_process_group.ranks, [-1], dist_context)
...@@ -154,7 +156,9 @@ def _append_gradient_merge_backward_op( ...@@ -154,7 +156,9 @@ def _append_gradient_merge_backward_op(
_remove_op_role_var(param, grad) _remove_op_role_var(param, grad)
param_to_gradient_merge = {} # {grad.name: gradient_merge_var.name} to rename opt inputs
grad_to_gradient_merge = {}
# {param: gradient_merge_var} to insert scale op and fill_constant op
new_params_to_grads = [] new_params_to_grads = []
# step2: create gradient_merge var and init with 0 # step2: create gradient_merge var and init with 0
for param, grad in params_grads: for param, grad in params_grads:
...@@ -168,7 +172,6 @@ def _append_gradient_merge_backward_op( ...@@ -168,7 +172,6 @@ def _append_gradient_merge_backward_op(
shape=param_var.shape, shape=param_var.shape,
dtype=param_var.dtype, dtype=param_var.dtype,
persistable=True) persistable=True)
param_to_gradient_merge[param_name] = gradient_merge_var
ref_process_mesh = ref_dist_attr.process_mesh ref_process_mesh = ref_dist_attr.process_mesh
ref_dims_mapping = ref_dist_attr.dims_mapping ref_dims_mapping = ref_dist_attr.dims_mapping
...@@ -197,17 +200,19 @@ def _append_gradient_merge_backward_op( ...@@ -197,17 +200,19 @@ def _append_gradient_merge_backward_op(
outputs={'Out': gradient_merge_var}, outputs={'Out': gradient_merge_var},
attrs={ attrs={
'axis': -1, 'axis': -1,
'use_mkldnn': False 'use_mkldnn': False,
'op_role': OpRole.Optimize
}) })
new_params_to_grads.append([param, gradient_merge_var]) new_params_to_grads.append([param, gradient_merge_var])
grad_to_gradient_merge[grad.name] = gradient_merge_var.name
naive_set_dist_op_attr_for_program_by_mesh_and_mapping( naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
new_grad_op, ref_process_mesh, ref_dims_mapping, dist_context) new_grad_op, ref_process_mesh, ref_dims_mapping, dist_context)
return new_params_to_grads, param_to_gradient_merge return new_params_to_grads, grad_to_gradient_merge
def _create_cond_block_and_update_optimizer( def _create_cond_block_and_update_optimizer(
main_program, cond_var, new_params_to_grads: List[Tuple[Any, Any]], main_program, cond_var, new_params_to_grads: List[Tuple[Any, Any]],
param_to_gradient_merge: Dict[str, Any], optimize_ops_desc: List[Any], grad_to_gradient_merge: Dict[str, str], optimize_ops_desc: List[Any],
k_steps, avg): k_steps, avg):
def true_apply_gradient(): def true_apply_gradient():
...@@ -229,7 +234,7 @@ def _create_cond_block_and_update_optimizer( ...@@ -229,7 +234,7 @@ def _create_cond_block_and_update_optimizer(
'bias_after_scale': False 'bias_after_scale': False
}) })
new_grad.op._set_attr(op_maker.kOpRoleAttrName(), new_grad.op._set_attr(op_maker.kOpRoleAttrName(),
op_maker.OpRole.Optimize) OpRole.Optimize)
# append optimizer ops # append optimizer ops
for op_desc in optimize_ops_desc: for op_desc in optimize_ops_desc:
...@@ -238,14 +243,14 @@ def _create_cond_block_and_update_optimizer( ...@@ -238,14 +243,14 @@ def _create_cond_block_and_update_optimizer(
#update input/output #update input/output
for input_name in new_op_desc.input_arg_names(): for input_name in new_op_desc.input_arg_names():
if input_name in new_params_to_grads: if input_name in grad_to_gradient_merge:
new_op_desc._rename_input(input_name, new_op_desc._rename_input(
new_params_to_grads[input_name]) input_name, grad_to_gradient_merge[input_name])
for output_name in new_op_desc.output_arg_names(): for output_name in new_op_desc.output_arg_names():
if output_name in new_params_to_grads: if output_name in grad_to_gradient_merge:
new_op_desc._rename_output(output_name, new_op_desc._rename_output(
new_params_to_grads[output_name]) output_name, grad_to_gradient_merge[output_name])
# remove op_role_var # remove op_role_var
if new_op_desc.has_attr(op_maker.kOpRoleVarAttrName()): if new_op_desc.has_attr(op_maker.kOpRoleVarAttrName()):
...@@ -271,6 +276,8 @@ def _create_cond_block_and_update_optimizer( ...@@ -271,6 +276,8 @@ def _create_cond_block_and_update_optimizer(
op_maker.OpRole.Optimize) op_maker.OpRole.Optimize)
layers.cond(cond_var, true_fn=true_apply_gradient, false_fn=None) layers.cond(cond_var, true_fn=true_apply_gradient, false_fn=None)
cond_op = main_program.global_block().ops[-1]
cond_op._set_attr('op_role', OpRole.Optimize)
def parse_program(main_program, startup_program, params_grads, k_steps, avg, def parse_program(main_program, startup_program, params_grads, k_steps, avg,
...@@ -285,14 +292,14 @@ def parse_program(main_program, startup_program, params_grads, k_steps, avg, ...@@ -285,14 +292,14 @@ def parse_program(main_program, startup_program, params_grads, k_steps, avg,
main_program._rollback() main_program._rollback()
# 3 append gradient merge backward op to main_program # 3 append gradient merge backward op to main_program
new_params_to_grads, param_to_gradient_merge = _append_gradient_merge_backward_op( new_params_to_grads, grad_to_gradient_merge = _append_gradient_merge_backward_op(
main_program, startup_program, params_grads, cond_var.name, main_program, startup_program, params_grads, cond_var.name,
dist_context) dist_context)
# 4 create ConditionalBlock and append gradient merge optimizer ops # 4 create ConditionalBlock and append gradient merge optimizer ops
_create_cond_block_and_update_optimizer(main_program, cond_var, _create_cond_block_and_update_optimizer(main_program, cond_var,
new_params_to_grads, new_params_to_grads,
param_to_gradient_merge, grad_to_gradient_merge,
optimize_ops_desc, k_steps, avg) optimize_ops_desc, k_steps, avg)
...@@ -303,7 +310,6 @@ class GradientMergePass(PassBase): ...@@ -303,7 +310,6 @@ class GradientMergePass(PassBase):
super(GradientMergePass, self).__init__() super(GradientMergePass, self).__init__()
self.set_attr("k_steps", -1) self.set_attr("k_steps", -1)
self.set_attr("avg", True) self.set_attr("avg", True)
self.set_attr("inner_optimizer", None)
def _check_self(self): def _check_self(self):
if self.get_attr("k_steps") < 1: if self.get_attr("k_steps") < 1:
......
...@@ -14,12 +14,12 @@ if((NOT WITH_GPU) ...@@ -14,12 +14,12 @@ if((NOT WITH_GPU)
list(REMOVE_ITEM TEST_OPS "test_dist_fuse_momentum_pass") list(REMOVE_ITEM TEST_OPS "test_dist_fuse_momentum_pass")
list(REMOVE_ITEM TEST_OPS "test_dist_fuse_relu_depthwise_conv_pass") list(REMOVE_ITEM TEST_OPS "test_dist_fuse_relu_depthwise_conv_pass")
list(REMOVE_ITEM TEST_OPS "test_dist_fuse_sgd_pass") list(REMOVE_ITEM TEST_OPS "test_dist_fuse_sgd_pass")
list(REMOVE_ITEM TEST_OPS "test_dist_gradient_merge_pass")
list(REMOVE_ITEM TEST_OPS "test_dist_inplace_addto_pass") list(REMOVE_ITEM TEST_OPS "test_dist_inplace_addto_pass")
list(REMOVE_ITEM TEST_OPS "test_auto_parallel_amp_pass") list(REMOVE_ITEM TEST_OPS "test_auto_parallel_amp_pass")
list(REMOVE_ITEM TEST_OPS "test_auto_parallel_recompute_pass") list(REMOVE_ITEM TEST_OPS "test_auto_parallel_recompute_pass")
list(REMOVE_ITEM TEST_OPS "test_auto_parallel_sharding_pass") list(REMOVE_ITEM TEST_OPS "test_auto_parallel_sharding_pass")
list(REMOVE_ITEM TEST_OPS "test_auto_parallel_fp16_pass") list(REMOVE_ITEM TEST_OPS "test_auto_parallel_fp16_pass")
list(REMOVE_ITEM TEST_OPS "test_auto_parallel_gradient_merge_pass")
endif() endif()
foreach(TEST_OP ${TEST_OPS}) foreach(TEST_OP ${TEST_OPS})
......
...@@ -25,20 +25,14 @@ import paddle.nn as nn ...@@ -25,20 +25,14 @@ import paddle.nn as nn
import paddle.utils as utils import paddle.utils as utils
import paddle.static as static import paddle.static as static
import paddle.nn.functional as F import paddle.nn.functional as F
import paddle.distributed.fleet as fleet
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.fluid.initializer import NumpyArrayInitializer
from paddle.distributed.passes import new_pass, PassManager, PassContext from paddle.fluid.initializer import NumpyArrayInitializer
import paddle.distributed.fleet as fleet from auto_parallel_pass_test_base import AutoPallelPassTestBase
from dist_pass_test_base import DistPassTestBase
from paddle.distributed.auto_parallel.dist_context import DistributedContext
logging.getLogger().setLevel(logging.INFO) logging.getLogger().setLevel(logging.INFO)
paddle.enable_static() paddle.enable_static()
_global_parallel_strategy = None
_global_process_mesh = None
#np.set_printoptions(suppress=True)
class MLPLayer(nn.Layer): class MLPLayer(nn.Layer):
...@@ -103,13 +97,11 @@ class MLPLayer(nn.Layer): ...@@ -103,13 +97,11 @@ class MLPLayer(nn.Layer):
def mlp_forward(input, label, hidden_size): def mlp_forward(input, label, hidden_size):
if _global_parallel_strategy == "dp":
auto.shard_tensor(input, auto.shard_tensor(input,
dist_attr={ dist_attr={
"process_mesh": _global_process_mesh, "process_mesh": [0],
"dims_mapping": [0, -1] "dims_mapping": [-1, -1]
}) })
mlp = MLPLayer(hidden_size=hidden_size, mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size, intermediate_size=4 * hidden_size,
initializer_range=0.02) initializer_range=0.02)
...@@ -119,40 +111,33 @@ def mlp_forward(input, label, hidden_size): ...@@ -119,40 +111,33 @@ def mlp_forward(input, label, hidden_size):
return loss return loss
class TestGradientMergePass(DistPassTestBase): class TestGradientMergePass(AutoPallelPassTestBase):
def init(self): def init(self):
self._params_grads = None paddle.seed(2022)
self._config = {"k_steps": 4, "avg": True} random.seed(2022)
#self._config["dist_context"] = DistributedContext() np.random.seed(2022)
def apply_passes(self, main_prog, startup_prog): def apply_passes(self):
#self._config["params_grads"] = self._params_grads
#pass_context = PassContext()
#auto_parallel_gradient_merge_pass = new_pass(
# "auto_parallel_gradient_merge_pass", self._config)
#auto_parallel_gradient_merge_pass.apply([main_prog], [startup_prog],
# pass_context)
dist_strategy = fleet.DistributedStrategy() dist_strategy = fleet.DistributedStrategy()
dist_strategy.semi_auto = True
dist_strategy.gradient_merge = True dist_strategy.gradient_merge = True
dist_strategy.gradient_merge_configs = {"k_steps": 4, "avg": True} dist_strategy.gradient_merge_configs = {"k_steps": 4, "avg": True}
dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy) fleet.init(is_collective=True, strategy=dist_strategy)
def test_result(self): def test_result(self):
no_pass_rets = self._distributed_launch(model=None, no_pass_rets = self._distributed_launch(model=None,
apply_pass=False, apply_pass=False,
gpus=[0], gpus=[0],
gradient_merge=False,
batch_size=32, batch_size=32,
hidden_size=128,
max_step=2) max_step=2)
pass_rets = self._distributed_launch(model=None, pass_rets = self._distributed_launch(model=None,
apply_pass=True, apply_pass=True,
gpus=[0], gpus=[0],
gradient_merge=True,
batch_size=8, batch_size=8,
hidden_size=128,
max_step=8) max_step=8)
"""
# avg loss for gradient_merge pass # avg loss for gradient_merge pass
avg_loss = 0 avg_loss = 0
pass_avg_ret_list = [] pass_avg_ret_list = []
...@@ -167,40 +152,16 @@ class TestGradientMergePass(DistPassTestBase): ...@@ -167,40 +152,16 @@ class TestGradientMergePass(DistPassTestBase):
for no_pass_ret, pass_ret in zip(no_pass_rets[0], pass_avg_ret_list): for no_pass_ret, pass_ret in zip(no_pass_rets[0], pass_avg_ret_list):
print(f"no_pass_ret={no_pass_ret}, pass_ret={pass_ret}") print(f"no_pass_ret={no_pass_ret}, pass_ret={pass_ret}")
self.assertTrue( self.assertTrue(
np.isclose( np.isclose(no_pass_ret,
no_pass_ret,
pass_ret, pass_ret,
rtol=self.rtol, rtol=self.rtol,
atol=self.atol, atol=self.atol,
equal_nan=self.equal_nan)) equal_nan=self.equal_nan))
"""
def get_model(self, place, gradient_merge, batch_size, max_step):
paddle.seed(2021)
random.seed(2021)
np.random.seed(2021)
hidden_size = 128
global _global_parallel_strategy def get_model(self, place, batch_size, hidden_size, max_step):
global _global_process_mesh
world_size = paddle.distributed.get_world_size()
if world_size == 1:
_global_parallel_strategy = "dp"
_global_process_mesh = auto.ProcessMesh([0])
elif world_size == 2:
_global_parallel_strategy = "dp"
_global_process_mesh = auto.ProcessMesh([0, 1])
train_program = static.Program() train_program = static.Program()
startup_program = static.Program() startup_program = static.Program()
dist_strategy = fleet.DistributedStrategy()
dist_strategy.semi_auto = True
#if gradient_merge:
# dist_strategy.gradient_merge = True
# dist_strategy.gradient_merge_configs = {"k_steps": 4, "avg": True}
fleet.init(is_collective=True, strategy=dist_strategy)
with static.program_guard(train_program, startup_program), \ with static.program_guard(train_program, startup_program), \
utils.unique_name.guard(): utils.unique_name.guard():
input = static.data(name="input", input = static.data(name="input",
...@@ -212,8 +173,7 @@ class TestGradientMergePass(DistPassTestBase): ...@@ -212,8 +173,7 @@ class TestGradientMergePass(DistPassTestBase):
input.stop_gradient = False input.stop_gradient = False
loss = mlp_forward(input, label, hidden_size) loss = mlp_forward(input, label, hidden_size)
optimizer = paddle.fluid.optimizer.SGDOptimizer(learning_rate=0.01) optimizer = paddle.fluid.optimizer.AdamOptimizer(learning_rate=0.01)
#optimizer = paddle.fluid.optimizer.Adam(learning_rate=0.01)
optimizer = fleet.distributed_optimizer(optimizer) optimizer = fleet.distributed_optimizer(optimizer)
_, self._params_grads, dist_startup_prog, dist_main_prog = optimizer.minimize( _, self._params_grads, dist_startup_prog, dist_main_prog = optimizer.minimize(
loss, startup_program) loss, startup_program)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册