未验证 提交 5592f8ad 编写于 作者: J JZ-LIANG 提交者: GitHub

[Auto Parallel-Performance] Sharding Comm Optimization (#48604)

* remove deps and prior comm

* grad comm fuse

* add deps for amp&global norm

* stage2 broadcast prior deps

* stage2 grad overlap

* stream_analyzer bugfix

* overlap enable

* dep op namescope

* depend support multiple inputs

* check finite deps

* stage2 param comm overlap

* Set kD2HStream

* grad comm hierarchical

* grad comm hierarchical

* new unitest
Co-authored-by: Nchenruibiao <chenruibiao@baidu.com>
上级 852c8db3
......@@ -90,8 +90,12 @@ SHARDING = "sharding"
set_field_default_config(SHARDING, "enable", False)
set_field_default_config(SHARDING, "stage", 1)
set_field_default_config(SHARDING, "degree", 8)
set_field_default_config(SHARDING, "overlap_grad_comm", False)
set_field_default_config(SHARDING, "bucket_size_numel", -1)
set_field_default_config(SHARDING, "enable_overlap", False)
set_field_default_config(SHARDING, "param_comm_stream_num", 1)
set_field_default_config(SHARDING, "grad_comm_stream_num", 1)
set_field_default_config(SHARDING, "param_bucket_size_numel", 1)
set_field_default_config(SHARDING, "grad_bucket_size_numel", 1)
set_field_default_config(SHARDING, "enable_hierarchical_comm", False)
set_field_default_config(SHARDING, "partition_algor", "greedy_even")
set_field_default_config(SHARDING, "enable_tuning", False)
set_field_default_config(SHARDING, "tuning_range", [])
......
......@@ -45,6 +45,15 @@ class ParallelMode:
MoEParallel = "auto_parallel/moe_parallel"
class SyncMode:
"""
the synchorization mode for communication or auxiliary operator
"""
AmpFlagSync = "auto_parallel/amp_flag_synchorization"
GlobalNormSync = "auto_parallel/global_norm_synchorization"
def is_elementwise_op(op_type):
if op_type in _g_elementwise_ops:
return True
......@@ -441,7 +450,7 @@ def sync_and_scale_gradients(dist_ctx, op, dp_group, allreduce_var_names):
dims_mapping = op_dist_attr.get_output_dims_mapping(grad_var.name)
assert (
dims_mapping is not None
), "Unexception: dims_mapping of output [{}] of op [{}] is None".format(
), "Unexpected: dims_mapping of output [{}] of op [{}] is None".format(
grad_var.name, op_dist_attr.op_type
)
# NOTE auxiliary op's dist attr should follow dist_op not dist_tensor
......@@ -502,6 +511,22 @@ def is_data_parallel_reduce_op(op):
)
def is_amp_flag_sync_op(op):
return (
op.type == "c_allreduce_max"
and op.desc.has_attr("op_namescope")
and SyncMode.AmpFlagSync in op.desc.attr("op_namescope")
)
def is_global_norm_sync_op(op):
return (
op.type == "c_allreduce_sum"
and op.desc.has_attr("op_namescope")
and SyncMode.GlobalNormSync in op.desc.attr("op_namescope")
)
def is_in_backward_phase(dist_ctx):
# NOTE currently high-order differential in Paddle dose NOT distinguish gradient computation operators
# in Forward phase and operators in Backward phase (both with op_role=1), which will mislead
......
......@@ -24,6 +24,7 @@ from ..utils import set_dist_op_desc_original_id, set_var_dist_attr
from .common import (
DistributedOperatorImpl,
DistributedOperatorImplContainer,
SyncMode,
register_distributed_operator_impl,
register_distributed_operator_impl_container,
)
......@@ -166,6 +167,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
OP_ROLE_KEY: OpRole.Optimize,
},
)
allreduce_op._set_attr('op_namescope', str('/') + SyncMode.AmpFlagSync)
cast_op2 = main_block.append_op(
type='cast',
inputs={'X': inf_var_int32},
......
......@@ -318,6 +318,16 @@ class Parallelizer:
[main_program], [startup_program], self._pass_context
)
# deps for newexe
config = {}
config["dist_context"] = self._dist_context
APSED_pass = new_pass(
"auto_parallel_supplement_explicit_dependencies", config
)
APSED_pass.apply(
[main_program], [startup_program], self._pass_context
)
# gradient_merge is then train-only optimization
if self._mode == "train" and self._strategy.gradient_merge.enable:
config = copy.deepcopy(self._strategy.gradient_merge.to_dict())
......
......@@ -48,8 +48,10 @@ def clear_all_process_groups():
_g_process_group_map[0] = ProcessGroup(0, [])
def new_process_group(ranks, group_id=None):
def new_process_group(ranks, group_id=None, force_new_group=False):
global _g_process_group_map
if not force_new_group:
# A key constructed from ranks is used for avoiding duplication
new_key = ''.join(map(str, sorted(ranks)))
for pg_id, pg in _g_process_group_map.items():
......@@ -137,7 +139,6 @@ class ProcessGroup:
]
strategy.current_endpoint = genv.current_endpoint
strategy.nrings = 1
if core.is_compiled_with_cuda():
place = core.CUDAPlace(genv.device_id)
core.NCCLParallelContext(strategy, place).init_with_ring_id(
......
......@@ -1184,6 +1184,8 @@ def _get_split_indices(
def set_grad_var_shape(program, dist_context):
from paddle.distributed.fleet.meta_optimizers.common import OpRole
from .operators.common import infer_shape
block = program.global_block()
......@@ -1955,6 +1957,9 @@ def set_recompute_segments(model, losses, strategy, program):
and hasattr(model.gpt, "checkpoints")
):
ckpts = model.gpt.checkpoints
# last recompute segment is not need to recompute
if len(ckpts) > 2:
ckpts.pop()
else:
ckpts = recompute.checkpoints
else:
......@@ -2189,6 +2194,7 @@ def insert_dependencies_for_two_ops(
dist_context,
is_recompute=False,
sync=False,
op_namescope=None,
):
"""
dependency: prior_op should be run before posterior_op
......@@ -2233,49 +2239,74 @@ def insert_dependencies_for_two_ops(
[block.var(name) for name in posterior_op.input_arg_names]
)
return insert_dependencies_for_two_vars(
return insert_dependencies_for_vars(
block,
idx,
first_var,
second_var,
dist_context,
OpRole.Backward,
prior_op_mesh,
is_recompute,
sync,
process_mesh=prior_op_mesh,
is_recompute=is_recompute,
sync=sync,
op_namescope=op_namescope,
use_nop=False,
)
def insert_dependencies_for_two_vars(
def insert_dependencies_for_vars(
block,
idx,
prior_var,
post_var,
prior_vars,
post_vars,
dist_context,
oprole,
process_mesh=None,
is_recompute=False,
sync=False,
op_namescope=None,
use_nop=False,
):
"""
dependency: op that generates prior_var should be run before op that generates post_var
dependency: op that generates prior_vars should be run before op that generates post_vars
"""
if isinstance(prior_vars, Variable):
prior_vars = [prior_vars]
if isinstance(post_vars, Variable):
post_vars = [post_vars]
for prior_var in prior_vars:
assert block.has_var(prior_var.name)
for post_var in post_vars:
assert block.has_var(post_var.name)
if process_mesh is None:
process_mesh = dist_context.get_tensor_dist_attr_for_program(
post_var
post_vars[0]
).process_mesh
assert process_mesh is not None
use_nop = True
if use_nop:
depend_op = block._insert_op_without_sync(
idx,
type='nop',
inputs={
"X": prior_var,
"X": prior_vars,
},
outputs={"Out": post_var},
outputs={"Out": post_vars},
)
else:
depend_op = block._insert_op_without_sync(
idx,
type='depend',
inputs={
"X": post_vars,
"Dep": prior_vars,
},
outputs={"Out": post_vars},
)
# depend_op.desc.set_type("depend")
depend_op._set_attr(OP_ROLE_KEY, oprole)
# depend_op.desc.set_input("Dep", [first_var.name])
......@@ -2284,6 +2315,8 @@ def insert_dependencies_for_two_vars(
naive_set_dist_op_attr_for_program_by_mesh(
depend_op, process_mesh, dist_context, is_recompute
)
if op_namescope is not None:
depend_op._set_attr('op_namescope', "/{}".format(op_namescope))
if sync:
block._sync_with_cpp()
......@@ -2291,6 +2324,13 @@ def insert_dependencies_for_two_vars(
return depend_op
def is_dep_skip_op(op):
if "c_" in op.type:
return True
return False
def use_standalone_executor():
return os.environ.get('FLAGS_CONVERT_GRAPH_TO_PROGRAM', None) in [
1,
......
......@@ -23,6 +23,7 @@ from .auto_parallel_recompute import * # noqa: F403
from .auto_parallel_quantization import * # noqa: F403
from .auto_parallel_data_parallel_optimization import * # noqa: F403
from .auto_parallel_grad_clip import * # noqa: F403
from .auto_parallel_supplement_explicit_dependencies import * # noqa: F403
from .cpp_pass import * # noqa: F403
from .ps_trainer_pass import * # noqa: F403
from .ps_server_pass import * # noqa: F403
......
......@@ -22,7 +22,7 @@ from paddle.distributed.auto_parallel.operators.common import (
from paddle.distributed.auto_parallel.utils import (
find_higher_order_backward_op,
get_var_numel,
insert_dependencies_for_two_vars,
insert_dependencies_for_vars,
is_forward_op,
is_loss_grad_op,
is_optimize_op,
......@@ -153,12 +153,12 @@ class DataParallelOptimizationPass(PassBase):
continue
assert op.has_attr(
"ring_id"
), "Unexception: comm op [{}] has NOT ring id.".format(str(op))
), "Unexpected: comm op [{}] has NOT ring id.".format(str(op))
group = ring_id_to_process_group(op.attr("ring_id"))
assert (
group is not None
), "Unexception: data parallel group of [{}] from op [{}] is None".format(
), "Unexpected: data parallel group of [{}] from op [{}] is None".format(
grad_name, str(op)
)
......@@ -187,7 +187,7 @@ class DataParallelOptimizationPass(PassBase):
not_synchronized_grads.append(grad_name)
assert (
len(not_synchronized_grads) == 0
), "Unexception: gradients [{}] is scaled BUT NOT synchronized.".format(
), "Unexpected: gradients [{}] is scaled BUT NOT synchronized.".format(
not_synchronized_grads
)
......@@ -251,12 +251,12 @@ class DataParallelOptimizationPass(PassBase):
):
assert op.has_attr(
'rescale_grad'
), "Unexception: op [{}] is supported to have [rescale_grad] attribute.".format(
), "Unexpected: op [{}] is supported to have [rescale_grad] attribute.".format(
str(op)
)
assert (
len(op.input("Grad")) == 1
), "Unexception: op [{}] is supported to have only one input grad var.".format(
), "Unexpected: op [{}] is supported to have only one input grad var.".format(
str(op)
)
......@@ -271,7 +271,7 @@ class DataParallelOptimizationPass(PassBase):
assert scaled_grads == set(
self._grad_name_to_group_map.keys()
), "Unexception: gradients [{}] are unscaled.".format(
), "Unexpected: gradients [{}] are unscaled.".format(
set(self._grad_name_to_group_map.keys()) - scaled_grads
)
......@@ -463,7 +463,7 @@ class DataParallelOptimizationPass(PassBase):
group.coalesce_var = group.gradients[0]
continue
# create coalecse tensor
# create coalesce tensor
group.coalesce_var = block.create_var(
name=unique_name.generate(
self.coalesce_prefix + '_{}'.format(i)
......@@ -508,12 +508,10 @@ class DataParallelOptimizationPass(PassBase):
for idx in sorted(remove_op_indices, reverse=True):
assert (
block.ops[idx].type in remove_op_types
), "Unexception: try to remove op {}".format(
str(block.ops[idx])
)
), "Unexpected: try to remove op {}".format(str(block.ops[idx]))
block._remove_op(idx, False)
# insert coalecse op
# insert coalesce op
concated_shapes = []
concated_ranks = []
for grad_ in group.gradients:
......@@ -596,7 +594,7 @@ class DataParallelOptimizationPass(PassBase):
not_sync_coalesces.remove(var_name)
assert (
len(not_sync_coalesces) == 0
), "Unexception: {} has NOT been add prior Dep before allreduce.".format(
), "Unexpected: {} has NOT been add prior Dep before allreduce.".format(
not_sync_coalesces
)
......@@ -628,7 +626,7 @@ class DataParallelOptimizationPass(PassBase):
assert (
len(not_sync_coalesces) == 0
), "Unexception: {} has NOT been add post Dep after allreduce.".format(
), "Unexpected: {} has NOT been add post Dep after allreduce.".format(
not_sync_coalesces
)
......@@ -642,7 +640,7 @@ class DataParallelOptimizationPass(PassBase):
for idx, prior_name, post_name in dep_var_pairs:
prior_var = block.var(prior_name)
post_var = block.var(post_name)
depend_op = insert_dependencies_for_two_vars(
depend_op = insert_dependencies_for_vars(
block,
idx,
prior_var,
......@@ -651,9 +649,10 @@ class DataParallelOptimizationPass(PassBase):
OpRole.Backward,
process_mesh=[
-1
], # hack to avoid initialize the dist attr for coalesc var
], # hack to avoid initialize the dist attr for coalesce var
is_recompute=False,
sync=False,
op_namescope="data_parallel_overlap_dep",
)
depend_op.dist_attr.execution_stream = self.gradient_sync_stream
block._sync_with_cpp()
......@@ -694,16 +693,17 @@ class DataParallelOptimizationPass(PassBase):
self._logger.addHandler(log_handler)
if len(grad_groups) > 0:
self._logger.info("Data Parallel Optimization: ")
self._logger.info(
"origin {} allreduce ops are fused into {} coalecse allreduce ops.".format(
" {} Allreduce ops are fused into {} coalesce allreduce ops.".format(
len(self._grad_name_to_group_map.keys()), len(grad_groups)
)
)
self._logger.info("gradient fusing group are following: ")
self._logger.debug("gradient fusing group are following: ")
fused_grads = set()
for i, group in enumerate(grad_groups):
self._logger.info(
"coalecse gradient [{}] is composed by: {}".format(
self._logger.debug(
"coalesce gradient [{}] is composed by: {}".format(
i, [grad.name for grad in group.gradients]
)
)
......@@ -711,12 +711,14 @@ class DataParallelOptimizationPass(PassBase):
individual_grads = set(self._grad_name_to_group_map.keys()) - set(
fused_grads
)
self._logger.info(
self._logger.debug(
"the following [{}] gradients are not fused: ".format(
len(individual_grads)
)
)
self._logger.info("individual gradient {}".format(individual_grads))
self._logger.debug(
"individual gradient {}".format(individual_grads)
)
class GradientsGroup:
......
......@@ -23,11 +23,12 @@ from ..auto_parallel.dist_attribute import (
OperatorDistributedAttribute,
TensorDistributedAttribute,
)
from ..auto_parallel.operators.common import SyncMode
from ..auto_parallel.process_group import get_world_process_group
from ..auto_parallel.reshard import Resharder
from ..auto_parallel.utils import (
_get_comm_group,
insert_dependencies_for_two_vars,
insert_dependencies_for_vars,
is_gradient_clip_op,
is_optimize_op,
use_standalone_executor,
......@@ -372,8 +373,9 @@ class ClipGradByGloblNormPass(PassBase):
OP_ROLE_KEY: OpRole.Optimize,
},
)
# TODO better regular the usage of op namescope
allreduce_op._set_attr(
'op_namescope', "/gradient_clip_pass"
'op_namescope', str('/') + SyncMode.GlobalNormSync
)
self.clip_helper._init_dist_attr(allreduce_op)
......@@ -394,15 +396,14 @@ class ClipGradByGloblNormPass(PassBase):
prior_op = block.ops[j]
break
j -= 1
print("here: ", block.ops[j])
assert (
prior_op is not None
), "Unexception: ClipByGlobalNorm could not find priory depend op"
), "Unexpected: ClipByGlobalNorm could not find priory depend op"
prior_var = block.vars[prior_op.output_arg_names[0]]
assert (
prior_var is not None
), "Unexception: ClipByGlobalNorm could not find priory depend var"
insert_dependencies_for_two_vars(
), "Unexpected: ClipByGlobalNorm could not find priory depend var"
insert_dependencies_for_vars(
block,
idx,
prior_var,
......@@ -414,6 +415,7 @@ class ClipGradByGloblNormPass(PassBase):
], # hack to avoid initialize the dist attr for coalesc var
is_recompute=False,
sync=False,
op_namescope="grad_clip_fill_constant_dep",
)
for varname in removed_tmp_var:
......
......@@ -474,6 +474,7 @@ class RecomputePass(PassBase):
self._dist_context,
is_recompute=True,
sync=False,
op_namescope="recompute_segment_dep",
)
main_program._sync_with_cpp()
......
......@@ -17,6 +17,7 @@ from functools import reduce
import paddle
from paddle.distributed.auto_parallel.operators.common import (
ParallelMode,
is_data_parallel_reduce_op,
is_parameter_related,
)
......@@ -25,12 +26,14 @@ from paddle.distributed.auto_parallel.utils import (
_get_comm_group,
get_logger,
get_var_numel,
insert_dependencies_for_vars,
is_backward_op,
is_dep_skip_op,
is_loss_grad_op,
is_optimize_op,
naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
set_var_dist_attr,
)
from paddle.distributed.fleet.meta_optimizers.common import (
is_backward_op,
is_optimizer_op,
use_standalone_executor,
)
from paddle.distributed.fleet.meta_optimizers.sharding.utils import get_var_size
from paddle.fluid import unique_name
......@@ -85,15 +88,19 @@ class ShardingPass(PassBase):
self.set_attr("stage", None)
self.set_attr("sharding_degree", None) # for parallelizer
self.set_attr("degree", None) # for parallelizer_v2
self.set_attr("overlap_grad_comm", None)
self.set_attr("bucket_size_numel", None)
self.set_attr("enable_overlap", None)
self.set_attr("param_comm_stream_num", None)
self.set_attr("grad_comm_stream_num", None)
self.set_attr("param_bucket_size_numel", None)
self.set_attr("grad_bucket_size_numel", None)
self.set_attr("partition_algor", None)
self.set_attr("enable_hierarchical_comm", None)
self.set_attr("params_grads", [])
self.set_attr("global_rank", -1)
self.dp_groups = set()
self.sharding_infos = []
self.varname_to_sharding_info = {}
self.partial_sharding = False
self.sharding_hybrid_dp = False
self.outer_dp_group = None
self.shared_params_grads = []
......@@ -121,13 +128,20 @@ class ShardingPass(PassBase):
"global_rank"
) < 0:
return False
if self.get_attr("overlap_grad_comm") is None:
if self.get_attr("enable_overlap") is None:
return False
if self.get_attr("param_comm_stream_num") is None:
return False
if self.get_attr("grad_comm_stream_num") is None:
return False
if self.get_attr("bucket_size_numel") is None:
if self.get_attr("param_bucket_size_numel") is None:
return False
if self.get_attr("grad_bucket_size_numel") is None:
return False
if self.get_attr("partition_algor") is None:
return False
if self.get_attr("enable_hierarchical_comm") is None:
return False
return True
def _check_conflict(self, other_pass):
......@@ -140,9 +154,24 @@ class ShardingPass(PassBase):
)
self.stage = int(self.get_attr("stage"))
self.global_rank = int(self.get_attr("global_rank"))
self.overlap_grad_comm = self.get_attr("overlap_grad_comm")
self.bucket_size_numel = int(self.get_attr("bucket_size_numel"))
self.enable_overlap = self.get_attr("enable_overlap")
self.param_comm_stream_num = int(self.get_attr("param_comm_stream_num"))
self.grad_comm_stream_num = int(self.get_attr("grad_comm_stream_num"))
self.enable_hierarchical_comm = self.get_attr(
"enable_hierarchical_comm"
)
if self.param_comm_stream_num > 1 or self.grad_comm_stream_num > 1:
assert (
self.enable_overlap
), "multiple comm stream need enable_overlap to be True"
self.param_bucket_size_numel = int(
self.get_attr("param_bucket_size_numel")
)
self.grad_bucket_size_numel = int(
self.get_attr("grad_bucket_size_numel")
)
self.partition_algor = self.get_attr("partition_algor")
params_grads = self.get_attr("params_grads")
main_block, startup_block = (
main_program.global_block(),
......@@ -226,7 +255,9 @@ class ShardingPass(PassBase):
# sharding hybrid data parallel: partial sharding param within
if dp_group.nranks > self.sharding_world_size:
self.partial_sharding = True
self.sharding_hybrid_dp = True
assert self.param_comm_stream_num < 2
assert self.grad_comm_stream_num < 2
assert (
len(self.dp_groups) == 1
), "hybrid sharding and data parallelism are supported only when there is excatly one data parallel group in the network"
......@@ -402,7 +433,7 @@ class ShardingPass(PassBase):
should_removed_optimizer_states = []
for idx, op in reversed(list(enumerate(main_block.ops))):
if not is_optimizer_op(op):
if not is_optimize_op(op):
break
if op.type in _supported_optimizer_type:
......@@ -441,7 +472,7 @@ class ShardingPass(PassBase):
def _insert_optimizer_broadcasts(self, main_block, startup_block):
if self.stage > 2 or self.bucket_size_numel > 1:
if self.stage > 2 or self.param_bucket_size_numel > 1:
return
for sharding_info in self.sharding_infos:
......@@ -460,6 +491,9 @@ class ShardingPass(PassBase):
OP_ROLE_KEY: OpRole.Optimize,
},
)
new_op._set_attr(
'op_namescope', str('/') + ParallelMode.DataParallel
)
param_dist_attr = (
self._dist_context.get_tensor_dist_attr_for_program(param)
)
......@@ -495,7 +529,7 @@ class ShardingPass(PassBase):
input_name = op.input_arg_names[0]
base_name = _get_base_name_from_grad_name(input_name)
sharding_info = self.varname_to_sharding_info[base_name]
_insert_reduce_op(
reduce_op = _insert_reduce_op(
main_block,
idx,
input_name,
......@@ -504,12 +538,15 @@ class ShardingPass(PassBase):
self._dist_context,
)
if (
not self.partial_sharding
not self.sharding_hybrid_dp
or not sharding_info.is_in_local_shard(base_name)
):
main_block._remove_op(idx + 1, sync=False)
else:
op._set_attr("ring_id", self.outer_dp_group.id)
op._set_attr(
'op_namescope', str('/') + ParallelMode.DataParallel
)
# NOTE:
# var@GRAD = sum(var@GRAD@RENAME@0, var@GRAD@RENAME@1)
......@@ -545,7 +582,7 @@ class ShardingPass(PassBase):
not_used_param_nane.append(param_name)
for idx, op in reversed(list(enumerate(main_block.ops))):
if is_optimizer_op(op):
if is_optimize_op(op):
continue
for input_name in op.input_arg_names:
......@@ -643,21 +680,718 @@ class ShardingPass(PassBase):
def _optimization_pass(self, main_program, startup_program):
with paddle.static.program_guard(main_program, startup_program):
if self.overlap_grad_comm:
_fuse_overlap_gradient_comm()
if self.stage <= 1:
return
self.grad_coalesce_prefix = 'sharding_coalesce_grad_'
self.param_coalesce_prefix = 'sharding_coalesce_param_'
# NOTE PR#49275 for detail
self.comm_op_scheduling_priority = -1
# TODO support multiple sub_blocks
if self.bucket_size_numel > 1:
assert (
len(self.sharding_infos) == 1
), "gradient synchronization optimization only support one sharding group right now, but got [{}].".format(
len(self.sharding_infos)
)
sharding_info = self.sharding_infos[0]
with paddle.static.program_guard(main_program, startup_program):
self._gradient_sync_optimization(sharding_info)
# TODO independent the logic of fuse and overlap
# support overlap when no fuse
if self.param_bucket_size_numel > 1:
if self.stage == 2:
_fuse_overlap_parameter_comm_stage_two(
self.sharding_infos,
self._fuse_overlap_parameter_comm_stage_two(sharding_info)
elif self.stage == 3:
self._fuse_overlap_parameter_comm_stage_three(sharding_info)
def _gradient_sync_optimization(self, sharding_info):
if self.grad_bucket_size_numel <= 1 and (not self.enable_overlap):
return
main_block = default_main_program().global_block()
startup_block = default_startup_program().global_block()
coalesce_to_group_map, grad_name_to_group_map = self._group_grads(
main_block,
sharding_info,
)
self._overlap_grad_comm(
main_block,
sharding_info,
coalesce_to_group_map,
grad_name_to_group_map,
)
def _fuse_overlap_parameter_comm_stage_two(self, sharding_info):
main_block = default_main_program().global_block()
startup_block = default_startup_program().global_block()
group_to_param_map, param_to_group_map = group_param(
sharding_info, self.param_bucket_size_numel
)
_logger.info("Sharding Stage2 Optimization:")
_logger.info(
"Param Bucket size is [{}], [{}] Parameters are fused into [{}] Buckets".format(
self.param_bucket_size_numel,
len(param_to_group_map.keys()),
len(group_to_param_map.keys()),
)
)
broadcast_var_to_group_map = {}
if self.enable_overlap:
# if the communication is cross node, comm will be slow and calc will therefore
# wait for comm. enable multi-comm-stream
# TODO revise me in future
# 1. manager the comm and corresponding stream
# 2. allow more than two streams and open to be config
self.param_comm_group_stream_pairs = []
ranks = sharding_info.group.ranks
for i in range(self.param_comm_stream_num):
if i == 0:
group = sharding_info.group
else:
group = new_process_group(ranks, force_new_group=True)
# NOTE here stream is just a presentation with different name,
# it is up to executor to create the exact streams given the name.
stream = "sharding_param_comm_stream{}".format(i)
self.param_comm_group_stream_pairs.append(
{
"comm_group": group,
"comm_stream": stream,
}
)
_logger.info(
"Parameter Communication would use [{}] streams.".format(
self.param_comm_stream_num
)
)
self.op_to_stream_idx = {}
for i, param_group in enumerate(group_to_param_map.keys()):
assert len(param_group) >= 1
if len(param_group) > 1:
coalesce_var_name = unique_name.generate(
self.param_coalesce_prefix + str(i)
)
startup_block.create_var(
name=coalesce_var_name,
dtype=param_group.dtype,
persistable=True,
stop_gradient=True,
)
param_group.coalesce_var = main_block.create_var(
name=coalesce_var_name,
dtype=param_group.dtype,
persistable=True,
stop_gradient=True,
)
startup_block.append_op(
type="coalesce_tensor",
inputs={"Input": param_group.vars},
outputs={
"Output": param_group.vars,
"FusedOutput": param_group.coalesce_var,
},
attrs={
"copy_data": True,
"use_align": True,
"dtype": param_group.dtype,
OP_ROLE_KEY: OpRole.Forward,
},
)
else:
param_group.coalesce_var = param_group.vars[0]
_logger.info(
"Bucket[{}] size [{}]MB.".format(
i,
sum([get_var_size(p) for p in param_group.vars]),
)
)
_logger.debug(
"Bucket[{}] parameters: {}.".format(
i,
[p.name for p in param_group.vars],
)
)
broadcast_var_to_group_map[
param_group.coalesce_var.name
] = param_group
# TODO revise me to manager stream and comm
comm_stream_idx = i % self.param_comm_stream_num
comm_group = self.param_comm_group_stream_pairs[comm_stream_idx][
'comm_group'
]
comm_stream = self.param_comm_group_stream_pairs[comm_stream_idx][
'comm_stream'
]
new_op = main_block.append_op(
type='c_broadcast',
inputs={'X': param_group.coalesce_var},
outputs={'Out': param_group.coalesce_var},
attrs={
'ring_id': comm_group.id,
'root': param_group.rank,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize,
},
)
self.op_to_stream_idx[new_op] = comm_stream_idx
new_op._set_attr(
'op_namescope', str('/') + ParallelMode.DataParallel
)
if self.enable_overlap:
new_op.dist_attr.execution_stream = comm_stream
new_op.dist_attr.scheduling_priority = (
self.comm_op_scheduling_priority
)
# NOTE the current dist context lack the presentation for bucket tensor which
# composes many tensor with different dims_mapping. we DO NOT assign dist attr
# for it currently.
# add dependencies:
# 1. all broadcast depend on its pre collective
# 2. coalesce broadcast add nop to resolute data flow dependencies
dep_map = {}
for i, op in enumerate(main_block.ops):
if is_sharding_param_broadcast_op(op):
broadcast_varname = op.output("Out")[0]
broadcast_var = main_block.vars[broadcast_varname]
param_group = broadcast_var_to_group_map[broadcast_varname]
comm_stream = None
if self.enable_overlap:
comm_stream = op.dist_attr.execution_stream
# FIXME remove me when upgrade to multi-comm version
if len(dep_map.keys()) < self.param_comm_stream_num:
op = _get_broadcast_first_depend_op(main_block)
prior_var = main_block.vars[op.output("ParamOut")[0]]
else:
pre_op = main_block.ops[i - self.param_comm_stream_num]
assert is_sharding_param_broadcast_op(
pre_op
), "Unexpected: sharding broadcast pre op should be broadcast."
prior_var = main_block.vars[pre_op.output("Out")[0]]
# broadcast order dependencies
dep_map[i] = [(i, [prior_var], [broadcast_var], comm_stream)]
if len(param_group.vars) > 1:
# in shard coalesce depend to optimizer
if param_group.is_in_local_shard:
last_grad = param_group.vars[-1]
dep_map[i].append(
(i, [last_grad], [broadcast_var], comm_stream)
)
# coalesce resolution post deps
dep_map[i].append(
(i + 1, [broadcast_var], param_group.vars, comm_stream)
)
# insert deps
indice = sorted(list(dep_map.keys()), reverse=True)
for i in indice:
for idx, prior_vars, post_vars, comm_stream in dep_map[i][::-1]:
depend_op = insert_dependencies_for_vars(
main_block,
idx,
prior_vars,
post_vars,
self._dist_context,
fuse_size=self.bucket_size_numel,
OpRole.Optimize,
process_mesh=[
-1
], # hack to avoid initialize the dist attr for coalesce var
is_recompute=False,
sync=False,
op_namescope="sharding_stage2_broadcast_dep",
)
elif self.stage == 3:
_fuse_overlap_parameter_comm_stage_three(
self.sharding_infos, fuse_size=self.bucket_size_numel
if self.enable_overlap:
depend_op.dist_attr.execution_stream = comm_stream
depend_op.dist_attr.scheduling_priority = (
self.comm_op_scheduling_priority
)
main_block._sync_with_cpp()
def _fuse_overlap_parameter_comm_stage_three(self, sharding_info):
pass
def _group_grads(
self,
block,
sharding_info,
):
"""
conditions for gradients to be grouped:
1. group size < grad_bucket_size_numel
2. same dp group (TODO)
3. same src rank
4. same dtype
5. dependency: grad would NOT be used by other ops within group segment
main logic:
1. record coalesce group
2. record all dp allreduce/reduce op idx
3. insert coalesce op
4. insert coalesce dependency (avoid allocate memory too early)
5. modify and remove allreduce/reduce op
6. ensure sharding-dp hybrid parallel logic
gradients inside same group would be fuse into one coalesce tensor
"""
ops = block.ops
if self.grad_bucket_size_numel < 1:
# numel for transformer layer
# h = 4096 + 1
# ffn_numel = 2 * (4 * h) * h
# mha_numel = 3 * h * h + h * h
# max_fuse_numel = ffn_numel + mha_numel
self.grad_bucket_size_numel = 1
first_backward_op = None
for op in ops:
if is_loss_grad_op(op):
first_backward_op = op
# not backward op, sharding for inference
if first_backward_op is None:
return
first_backward_varname = first_backward_op.output_arg_names[0]
cur_group = VarGroup(self.grad_bucket_size_numel)
grad_groups = []
grouped_grad_names = set()
def op_depend_on_group(op, group):
vars_ = set(op.input_arg_names + op.output_arg_names)
var_names = set([var.name for var in group.vars])
return len(vars_.intersection(var_names)) > 0
# analyze groups
i = 0
while i < len(ops):
op = ops[i]
if is_data_parallel_reduce_op(op):
assert (
op.type == "c_reduce_sum"
), "Sharding should reduce grad first and than allreduce if Hybrid Sharding with Data-Parallel"
grad_name = op.output_arg_names[0]
param_name = _get_base_name_from_grad_name(grad_name)
rank = sharding_info.get_var_rank(param_name)
grad_var = block.var(grad_name)
if cur_group.acceptable(grad_var, rank):
assert grad_name not in grouped_grad_names
cur_group.collect(grad_var, rank)
else:
grad_groups.append(cur_group)
cur_group = VarGroup(self.grad_bucket_size_numel)
cur_group.collect(grad_var, rank)
if len(cur_group.vars) == 1:
cur_group.coalesce_op_idx = i - 1
# NOTE coalesce dependency: control when allocate memory for gradients
# too early would increase the peak memory requirement, too later would hurt the performance
j = 2
while is_dep_skip_op(ops[i - j]):
j += 1
dep_op = ops[i - j]
dep_varname = dep_op.output_arg_names[0]
cur_group.coalesce_dep_varname = dep_varname
grouped_grad_names.add(grad_name)
cur_group.reduce_op_indices.append(i)
if self.sharding_hybrid_dp and sharding_info.is_in_local_shard(
param_name
):
cur_group.is_in_local_shard = True
assert (
ops[i + 1].type == "c_allreduce_sum"
), "Sharding should reduce grad first and than allreduce if Hybrid Sharding with Data-Parallel"
assert (
ops[i + 1].output_arg_names[0] == grad_name
), "Hybrid Sharding with Data-Parallel should sync same gradient var"
cur_group.allreduce_op_indices.append(i + 1)
i += 1
elif op_depend_on_group(op, cur_group):
grad_groups.append(cur_group)
cur_group = VarGroup(self.grad_bucket_size_numel)
i += 1
# some grad not in this rank may not be used after dp reduced
if len(cur_group.vars) >= 1:
grad_groups.append(cur_group)
_logger.info("Sharding Gradient Communication Optimization:")
_logger.info(
"Gradient Bucket size is [{}], [{}] Gradients are fused into [{}] Buckets.".format(
self.grad_bucket_size_numel,
len(grouped_grad_names),
len(grad_groups),
)
)
# create coalesce tesnor and record op idx
grad_name_to_group_map = {}
coalesce_to_group_map = {}
modify_reduce_op_map = {}
coalesce_op_map = {}
remove_reduce_op_indices = []
for i, group in enumerate(grad_groups):
if len(group.vars) > 1:
group.coalesce_var = block.create_var(
name=unique_name.generate(
self.grad_coalesce_prefix + str(i)
),
dtype=group.dtype,
persistable=False,
stop_gradient=True,
)
coalesce_op_map[group.coalesce_op_idx] = group
last_reduce_op_idx = group.reduce_op_indices.pop()
modify_reduce_op_map[last_reduce_op_idx] = group
remove_reduce_op_indices.extend(group.reduce_op_indices)
if group.is_in_local_shard:
last_allreduce_op_idx = group.allreduce_op_indices.pop()
modify_reduce_op_map[last_allreduce_op_idx] = group
remove_reduce_op_indices.extend(group.allreduce_op_indices)
else:
group.coalesce_var = group.vars[0]
for grad in group.vars:
grad_name_to_group_map[grad.name] = group
coalesce_to_group_map[group.coalesce_var.name] = group
coalesce_op_set = set(coalesce_op_map.keys())
modify_op_set = set(modify_reduce_op_map.keys())
remove_op_set = set(remove_reduce_op_indices)
confilct = coalesce_op_set.intersection(modify_op_set)
assert len(confilct) == 0
confilct = coalesce_op_set.intersection(remove_op_set)
assert len(confilct) == 0
confilct = modify_op_set.intersection(remove_op_set)
assert len(confilct) == 0
# update block
for idx, op in reversed(list(enumerate(block.ops))):
if idx in modify_reduce_op_map:
group = modify_reduce_op_map[idx]
grad_name = op.output_arg_names[0]
assert (
grad_name == group.vars[-1].name
), "Unexpected: it is supposed to sync [{}] but got [{}]".format(
group.vars[-1].name, grad_name
)
op._rename_input(grad_name, group.coalesce_var.name)
op._rename_output(grad_name, group.coalesce_var.name)
if idx in remove_reduce_op_indices:
block._remove_op(idx, sync=False)
if idx in coalesce_op_map:
group = coalesce_op_map[idx]
first_grad_name = group.vars[0].name
assert (
first_grad_name in op.output_arg_names
), "Unexpected: op is supposed to generate grad [{}] but got [{}]".format(
first_grad_name, str(op)
)
grad_names = [grad.name for grad in group.vars]
concated_shapes = []
concated_ranks = []
for grad_ in group.vars:
shape = grad_.shape
concated_shapes.extend(shape)
concated_ranks.append(len(shape))
coalesce_op = block._insert_op_without_sync(
idx,
type="coalesce_tensor",
inputs={"Input": grad_names},
outputs={
"Output": grad_names,
"FusedOutput": group.coalesce_var,
},
attrs={
"copy_data": False,
"use_align": True,
"dtype": group.dtype,
"concated_shapes": concated_shapes,
"concated_ranks": concated_ranks,
OP_ROLE_KEY: OpRole.Backward,
},
)
depend_op = insert_dependencies_for_vars(
block,
idx,
block.var(group.coalesce_dep_varname),
group.coalesce_var,
self._dist_context,
OpRole.Backward,
process_mesh=[
-1
], # hack to avoid initialize the dist attr for coalesce var
is_recompute=False,
sync=False,
op_namescope="sharding_grad_coalesce_dep",
)
block._sync_with_cpp()
return coalesce_to_group_map, grad_name_to_group_map
def _overlap_grad_comm(
self,
block,
sharding_info,
coalesce_to_group_map,
grad_name_to_group_map,
):
"""
overlap gradient communication with backward & optimizer computation.
1. assign gradient communications to grad comm stream
2. for coalesce gradient communication:
2.1 insert before communication dependencies
2.2 insert after communication dependencies only when need
3. there is not need to add explicit dependencies for non-coalesce gradient communication
P.S. this overlap pass is ONLY adapted for standalone executor (graph based) and stream awared allocator.
"""
if not use_standalone_executor() or (not self.enable_overlap):
return
self.grad_comm_group_stream_pairs = []
ranks = sharding_info.group.ranks
# NOTE since the gradient synchronization has calculation, there would be computation
# competition between backward calculation. therefore should limit the number of stream used.
for i in range(self.grad_comm_stream_num):
if i == 0:
group = sharding_info.group
else:
group = new_process_group(ranks, force_new_group=True)
# NOTE here stream is just a presentation with different name,
# it is up to executor to create the exact streams given the name.
stream = "sharding_grad_comm_stream{}".format(i)
self.grad_comm_group_stream_pairs.append(
{
"comm_group": group,
"comm_stream": stream,
}
)
ops = block.ops
# analyze dependencies
dep_map = {}
reduce_op_count = 0
grad_comm_op_to_stream_idx = {}
for idx, op in enumerate(ops):
if is_data_parallel_reduce_op(op):
if op.type == "c_allreduce_sum":
continue
stream_idx = reduce_op_count % self.grad_comm_stream_num
grad_comm_op_to_stream_idx[op] = stream_idx
comm_group = self.grad_comm_group_stream_pairs[stream_idx][
"comm_group"
]
comm_stream = self.grad_comm_group_stream_pairs[stream_idx][
"comm_stream"
]
reduce_varname = op.output("Out")[0]
grad_group = coalesce_to_group_map[reduce_varname]
assert grad_group.coalesce_var.name == reduce_varname
# coalesce deps
if len(grad_group.vars) > 1:
# NOTE should prior vars to be all grads ?
# when the grad_ops' order is random
# prior dep
dep_map[idx] = [
(
idx,
grad_group.vars[-1],
grad_group.coalesce_var,
comm_stream,
)
]
# post dep
post_idx = idx + 1
if self.sharding_hybrid_dp and grad_group.is_in_local_shard:
post_idx += 1
dep_map[idx].append(
(
post_idx,
grad_group.coalesce_var,
grad_group.vars,
comm_stream,
)
)
# assign stream
op.dist_attr.execution_stream = comm_stream
op.dist_attr.scheduling_priority = (
self.comm_op_scheduling_priority
)
op._set_attr("ring_id", comm_group.id)
if self.sharding_hybrid_dp and grad_group.is_in_local_shard:
next_op = ops[idx + 1]
assert next_op.type == "c_allreduce_sum"
assert next_op.output("Out")[0] == reduce_varname
# FIXME hybrid sharding-dp support multi comm & stream in feature
# next_op._set_attr("ring_id", comm_group.id)
next_op.dist_attr.execution_stream = comm_stream
next_op.dist_attr.scheduling_priority = (
self.comm_op_scheduling_priority
)
idx += 1
reduce_op_count += 1
idx += 1
# insert deps
indice = sorted(list(dep_map.keys()), reverse=True)
for i in indice:
for idx, prior_vars, post_vars, comm_stream in dep_map[i][::-1]:
depend_op = insert_dependencies_for_vars(
block,
idx,
prior_vars,
post_vars,
self._dist_context,
OpRole.Backward,
process_mesh=[
-1
], # hack to avoid initialize the dist attr for coalesce var
is_recompute=False,
sync=False,
op_namescope="sharding_grad_comm_dep",
)
depend_op.dist_attr.execution_stream = comm_stream
depend_op.dist_attr.scheduling_priority = (
self.comm_op_scheduling_priority
)
# hierarchical grad comm
if self.enable_hierarchical_comm:
# NOTE so far we only support Isomorphic cluster with 8 ranks per node
# TODO unifiy here create communicators
# create communicators
nranks_per_node = 8
assert self.sharding_world_size % nranks_per_node == 0
global_group = sharding_info.group
global_ranks = global_group.ranks
relative_idx_in_node = self.global_rank % nranks_per_node
node_idx = self.global_rank // nranks_per_node
inter_node_ranks = [
rank
for rank in global_ranks
if rank % nranks_per_node == relative_idx_in_node
]
_logger.info(
"Sharding Gradient Hierarchical Communication Optimization."
)
_logger.info(
"current global rank idx: {}.".format(self.global_rank)
)
_logger.info(
"local inter node ranks idx: {}.".format(inter_node_ranks)
)
assert (
len(inter_node_ranks)
== self.sharding_world_size // nranks_per_node
)
intra_node_ranks = [
rank
for rank in global_ranks
if rank // nranks_per_node == node_idx
]
assert len(intra_node_ranks) == nranks_per_node
_logger.info(
"local intra node ranks idx: {}.".format(intra_node_ranks)
)
inter_node_groups = []
intra_node_groups = []
for _ in range(self.grad_comm_stream_num):
# TODO re-use one origin communicator
inter_node_groups.append(
new_process_group(inter_node_ranks, force_new_group=True)
)
intra_node_groups.append(
new_process_group(intra_node_ranks, force_new_group=True)
)
# update program
for idx, op in reversed(list(enumerate(block.ops))):
if is_data_parallel_reduce_op(op):
assert op.type == "c_reduce_sum"
grad_comm_stream_idx = grad_comm_op_to_stream_idx[op]
inter_node_group = inter_node_groups[grad_comm_stream_idx]
intra_node_group = intra_node_groups[grad_comm_stream_idx]
reduce_varname = op.output("Out")[0]
if self.enable_overlap:
comm_stream = op.dist_attr.execution_stream
dst_rank = int(op.attr("root_id"))
in_peer = False
if dst_rank % nranks_per_node == relative_idx_in_node:
in_peer = True
intra_node_dst = dst_rank % nranks_per_node
op._set_attr('ring_id', intra_node_group.id)
op._set_attr('root_id', intra_node_dst)
if in_peer:
inter_node_dst = dst_rank // nranks_per_node
new_op = block._insert_op_without_sync(
idx + 1,
type='c_reduce_sum',
inputs={"X": reduce_varname},
outputs={
"Out": reduce_varname,
},
attrs={
'ring_id': inter_node_group.id,
'root_id': inter_node_dst,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Backward,
},
)
new_op._set_attr(
'op_namescope', str('/') + ParallelMode.DataParallel
)
if self.enable_overlap:
new_op.dist_attr.execution_stream = comm_stream
new_op.dist_attr.scheduling_priority = (
self.comm_op_scheduling_priority
)
block._sync_with_cpp()
def _get_broadcast_first_depend_op(block):
for op in block.ops:
if op.type in _supported_optimizer_type:
return op
raise Exception("Could not find optimizer op.")
def _insert_init_and_broadcast_op(
......@@ -690,6 +1424,7 @@ def _insert_init_and_broadcast_op(
OP_ROLE_KEY: op_role,
},
)
new_op._set_attr('op_namescope', str('/') + ParallelMode.DataParallel)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
new_op,
broadcast_var_dist_attr.process_mesh,
......@@ -749,6 +1484,8 @@ def _insert_reduce_op(
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
new_op, dist_attr.process_mesh, dist_attr.dims_mapping, dist_context
)
new_op._set_attr('op_namescope', str('/') + ParallelMode.DataParallel)
return new_op
def _get_dp_and_sharding_groups(origin_group, sharding_group_size, rank):
......@@ -790,7 +1527,7 @@ def _is_param_grad_fp32_cast_op(block, op):
def _is_param_fp16_cast_op(block, op, params):
if is_optimizer_op(op):
if is_optimize_op(op):
return False
if not _is_desired_cast_op(block, op):
return False
......@@ -862,6 +1599,14 @@ def _is_forward_op(op):
return op.attr("op_role") == 0
def is_sharding_param_broadcast_op(op):
return (
op.type == "c_broadcast"
and op.desc.has_attr("op_namescope")
and ParallelMode.DataParallel in op.desc.attr("op_namescope")
)
def _inference_data_parallel_group_for_operator(rank_id, op, dist_context):
dp_group = None
......@@ -975,7 +1720,7 @@ def re_order_program(block, param_grads, dist_context):
num_ops = len(block.ops)
remove_op_indices = []
# TODO support case when optimizer is not the last op
if is_optimizer_op(last_op) and last_op.type in _supported_optimizer_type:
if is_optimize_op(last_op) and last_op.type in _supported_optimizer_type:
# record optimizer
for idx, op in reversed(list(enumerate(block.ops))):
if op.type not in _supported_optimizer_type:
......@@ -1018,16 +1763,20 @@ def group_param(sharding_info, fuse_size):
group_to_param_map = {}
param_to_group_map = {}
bucket = []
cur_group = ParameterGroup(fuse_size)
cur_group = VarGroup(fuse_size)
for param in sharding_info.params:
rank = sharding_info.get_var_rank(param.name)
if cur_group.acceptable(param, rank):
cur_group.collect(param, rank)
else:
cur_group = ParameterGroup(fuse_size)
cur_group = VarGroup(fuse_size)
cur_group.collect(param, rank)
cur_group.is_in_local_shard = sharding_info.is_in_local_shard(
param.name
)
if cur_group in group_to_param_map:
group_to_param_map[cur_group].append(param.name)
else:
......@@ -1038,106 +1787,6 @@ def group_param(sharding_info, fuse_size):
return group_to_param_map, param_to_group_map
def _fuse_overlap_gradient_comm():
pass
def _fuse_overlap_parameter_comm_stage_two(
sharding_infos, dist_context, fuse_size
):
assert (
len(sharding_infos) == 1
), "fuse overlap optimization only support one sharding group right now, but got [{}].".format(
len(sharding_infos)
)
sharding_info = sharding_infos[0]
main_block = default_main_program().global_block()
startup_block = default_startup_program().global_block()
group_to_param_map, param_to_group_map = group_param(
sharding_info, fuse_size
)
_logger.info("Sharding Stage2 Optimization:")
_logger.info(
"Bucket size is [{}], [{}] Parameters are fused into [{}] Buckets".format(
fuse_size,
len(param_to_group_map.keys()),
len(group_to_param_map.keys()),
)
)
for i, group in enumerate(group_to_param_map.keys()):
assert len(group) >= 1
if len(group) > 1:
coalesce_var_name = unique_name.generate(
'coalecse_param_{}'.format(i)
)
startup_block.create_var(
name=coalesce_var_name,
dtype=group.dtype,
persistable=True,
stop_gradient=True,
)
group.coalesce_var = main_block.create_var(
name=coalesce_var_name,
dtype=group.dtype,
persistable=True,
stop_gradient=True,
)
startup_block.append_op(
type="coalesce_tensor",
inputs={"Input": group.params},
outputs={
"Output": group.params,
"FusedOutput": group.coalesce_var,
},
attrs={
"copy_data": True,
"use_align": True,
"dtype": group.dtype,
OP_ROLE_KEY: OpRole.Forward,
},
)
else:
group.coalesce_var = group.params[0]
_logger.info(
"Bucket[{}] size [{}]MB : {}".format(
i,
sum([get_var_size(p) for p in group.params]),
[p.name for p in group.params],
)
)
# TODO Overlap broadcast with opt and next forward
new_op = main_block.append_op(
type='c_broadcast',
inputs={'X': group.coalesce_var},
outputs={'Out': group.coalesce_var},
attrs={
'ring_id': sharding_info.group.id,
'root': group.rank,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize,
},
)
# NOTE the current dist context lack the presentation for bucket tensor which
# composes many tensor with different dims_mapping. we assign a fake dist attr
# for it currently.
def _fuse_overlap_parameter_comm_stage_three(sharding_infos, fuse_size):
assert (
len(sharding_infos) == 1
), "fuse overlap optimization only support one sharding group right now, but got [{}].".format(
len(sharding_infos)
)
sharding_info = sharding_infos[0]
class ShardingInfo(object):
def __init__(self, group, rank, params_grads, partition_algor):
self.group = group
......@@ -1188,7 +1837,7 @@ class ShardingInfo(object):
param_usage = {x: 0 for x in self.param_names}
for op in block.ops:
if is_optimizer_op(op):
if is_optimize_op(op):
continue
for input_name in op.input_arg_names:
if input_name in self.param_names:
......@@ -1220,14 +1869,19 @@ class ShardingInfo(object):
return self.params_grads.get(param_name, None)
class ParameterGroup(object):
class VarGroup(object):
def __init__(self, max_size):
self.max_siez = max_size
self.dtype = None
self.rank = -1
self.numel = 0
self.params = []
self.vars = []
self.coalesce_var = None
self.coalesce_dep_varname = None
self.coalesce_op_idx = None
self.reduce_op_indices = []
self.allreduce_op_indices = []
self.is_in_local_shard = False
def acceptable(self, param, rank):
if self.numel == 0:
......@@ -1245,7 +1899,7 @@ class ParameterGroup(object):
self.dtype = param.dtype
self.rank = rank
self.numel += get_var_numel(param)
self.params.append(param)
self.vars.append(param)
def __len__(self):
return len(self.params)
return len(self.vars)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.distributed.auto_parallel.operators.common import (
is_amp_flag_sync_op,
is_data_parallel_reduce_op,
is_global_norm_sync_op,
)
from paddle.distributed.auto_parallel.utils import (
OpRole,
insert_dependencies_for_vars,
use_standalone_executor,
)
from .auto_parallel_sharding import ShardingPass, _supported_optimizer_type
from .pass_base import PassBase, register_pass
def _sharding_pass_applied(pass_ctx):
for applied_pass in pass_ctx.passes:
if isinstance(applied_pass, ShardingPass):
return True
return False
# NOTE we add the "auto_parallel" prefix to the pass in order to
# indicate that this pass should obey some constrains by auto_parallel
# for example all ops and vars should has dist attr before and after pass
# should use dist op instead of custom comm op
@register_pass("auto_parallel_supplement_explicit_dependencies")
class AutoParalSupplementDepPass(PassBase):
"""
Functional Concern.
for strategies like amp & global norm, there is a collective communication to sync gradient inforation in every rank.
after partition the gradients to each rank, the order of that collective communication is different in each rank
and might cause hang problem in graph based random order executor. here supplement explicit dependencies for those cases.
TODO Performance Concern.
global collective will introduce global synchronization which forces the fast workers to wait for slow ones.
therefore we should conduct this collective when all the ranks reach a same stage.
BUT the depend API offered by executor could only ensure "conduct-not-before" but not "conduct-right-after".
Some ranks might call the colletives first than other ranks while they still some local could be performed to wait for slow peers.
IR Pass currently could not have the fully control of time the to perform these global collectives.
"""
def __init__(self):
super().__init__()
self.set_attr("dist_context", None)
def _check_self(self):
if self.get_attr("dist_context") is None:
return False
return True
def _check_conflict(self, other_pass):
return True
def _apply_single_impl(self, main_program, startup_program, context):
# TODO general this pass for all case.
if not use_standalone_executor or not _sharding_pass_applied(context):
return
self._dist_context = self.get_attr("dist_context", None)
self.flags_sync_stream = "flags_sync_stream"
main_block = main_program.global_block()
startup_block = startup_program.global_block()
# last dp grad communication
last_dp_reduce_op_idx = -1
last_dp_reduce_varname = None
for idx, op in reversed(list(enumerate(main_block.ops))):
if is_data_parallel_reduce_op(op):
last_dp_reduce_op_idx = idx
last_dp_reduce_varname = op.output_arg_names[0]
break
assert last_dp_reduce_op_idx > 0
assert last_dp_reduce_varname is not None
# analyze deps for amp & global norm
deps_map = {}
prior_varname = last_dp_reduce_varname
for idx, op in enumerate(main_block.ops):
if is_amp_flag_sync_op(op) or is_global_norm_sync_op(op):
op_namescope = None
if is_amp_flag_sync_op(op):
op_namescope = "amp_flag_sync_dep"
op.dist_attr.execution_stream = self.flags_sync_stream
elif is_global_norm_sync_op(op):
op_namescope = "global_norm_sync_dep"
deps_map[idx] = (prior_varname, op.input("X")[0], op_namescope)
prior_varname = op.output("Out")[0]
# analyze deps for check_finite_and_unscale
# ensure it is performed after last backward computation, therefore reduce the
# straggling of the amp-flag-sync
first_check_op = True
for idx, op in enumerate(main_block.ops):
if op.type == "check_finite_and_unscale":
if first_check_op:
last_backward_op = main_block.ops[idx - 1]
prior_varname = last_backward_op.output_arg_names[0]
first_check_op = False
deps_map[idx] = (
prior_varname,
op.input("Scale")[0],
"check_finite_dep",
)
# analyze deps for optimizer
# optimizers order should be fixed to allow broadcast to overlap with optimizer
first_optimizer_op = True
for idx, op in enumerate(main_block.ops):
if op.type in _supported_optimizer_type:
if first_optimizer_op:
first_optimizer_op = False
else:
deps_map[idx] = (
prior_varname,
op.input("Param")[0],
"optimizer_order_dep",
)
prior_varname = op.output("ParamOut")[0]
# insert deps
indice = sorted(list(deps_map.keys()), reverse=True)
for idx in indice:
prior_var = main_block.var(deps_map[idx][0])
post_var = main_block.var(deps_map[idx][1])
op_namescope = deps_map[idx][2]
depend_op = insert_dependencies_for_vars(
main_block,
idx,
prior_var,
post_var,
self._dist_context,
OpRole.Optimize,
process_mesh=[
-1
], # hack to avoid initialize the dist attr for coalesc var
is_recompute=False,
sync=False,
op_namescope=op_namescope,
)
main_block._sync_with_cpp()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import unittest
import numpy as np
from get_gpt_model import FakeDataset, generate_model
import paddle
from paddle.distributed.fleet import auto
from paddle.fluid.dygraph.parallel import ParallelEnv
paddle.enable_static()
def apply_pass(use_sharding=False, use_amp=False, use_recompute=False):
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
if use_sharding:
sharding = strategy.sharding
sharding.enable = True
sharding.degree = 2
sharding.stage = 2
sharding.enable_overlap = True
sharding.param_comm_stream_num = 2
sharding.grad_comm_stream_num = 2
sharding.param_bucket_size_numel = 512 * 512
sharding.grad_bucket_size_numel = 128 * 128
sharding.partition_algor = 'use_order'
if use_recompute:
recompute = strategy.recompute
recompute.enable = True
if use_amp:
amp = strategy.amp
amp.enable = True
amp.custom_white_list = [
'lookup_table_v2',
'lookup_table',
'softmax',
'layer_norm',
'gelu',
]
amp.custom_black_list = [
'c_softmax_with_cross_entropy',
'elementwise_div',
'reduce_sum',
]
amp.init_loss_scaling = 32768
amp.use_fp16_guard = False
amp.use_pure_fp16 = True
amp.use_optimizer_fp16 = False
return strategy
def reset_prog():
paddle.fluid.framework.switch_main_program(paddle.static.Program())
paddle.fluid.framework.switch_startup_program(paddle.static.Program())
class TestShardingStage2WithNewEXE(unittest.TestCase):
def setUp(self):
self.batch_size = 2
self.batch_num = 10
self.clip_norm = 0.2
self.dataset = FakeDataset(self.batch_size * self.batch_num)
def init(self, engine):
paddle.seed(2022)
np.random.seed(2022)
random.seed(2022)
place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)
def get_engine(
self, use_sharding=False, use_amp=False, use_recompute=False
):
reset_prog()
strategy = apply_pass(use_sharding, use_amp, use_recompute)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("dp")
engine = auto.Engine(model, loss, opt, strategy=strategy)
self.init(engine)
return engine
def check_param_grad_fuse_overlap(self, program):
num_op = 0
num_coalesce = 0
num_reduce = 0
num_broadcast = 0
for op in program.global_block().ops:
if op.type == "nop" or op.type == "depend":
num_op += 1
elif op.type == "coalesce_tensor":
num_coalesce += 1
elif op.type == "c_reduce_sum":
num_reduce += 1
elif op.type == "c_broadcast":
num_broadcast += 1
if paddle.distributed.get_rank() == 0:
self.assertEqual(num_op, 22)
else:
self.assertEqual(num_op, 54)
self.assertEqual(num_coalesce, 5)
self.assertEqual(num_reduce, 14)
self.assertEqual(num_broadcast, 2)
def test_param_grad_fuse_overlap(self):
# dp2
dp_engine = self.get_engine()
dp_history = dp_engine.fit(
self.dataset,
3,
epochs=1,
steps_per_epoch=self.batch_num,
log_freq=1,
batch_size=self.batch_size,
)
dp_loss = dp_history.history['loss'][0]
# sharding2
sharding_engine = self.get_engine(use_sharding=True)
sharding_history = sharding_engine.fit(
self.dataset,
3,
epochs=1,
steps_per_epoch=self.batch_num,
log_freq=1,
batch_size=self.batch_size,
)
sharding_loss = sharding_history.history['loss'][0]
# amp, recompute
amp_recompute_engine = self.get_engine(
use_sharding=False, use_amp=True, use_recompute=True
)
amp_recompute_history = amp_recompute_engine.fit(
self.dataset,
3,
epochs=1,
steps_per_epoch=self.batch_num,
log_freq=1,
batch_size=self.batch_size,
)
amp_recompute_loss = amp_recompute_history.history['loss'][0]
# sharding2, amp, recompute
all_engine = self.get_engine(
use_sharding=True, use_amp=True, use_recompute=True
)
all_history = all_engine.fit(
self.dataset,
3,
epochs=1,
steps_per_epoch=self.batch_num,
log_freq=1,
batch_size=self.batch_size,
)
all_loss = all_history.history['loss'][0]
self.check_param_grad_fuse_overlap(sharding_engine.main_program)
np.testing.assert_allclose(
dp_loss, sharding_loss, rtol=1e-05, atol=1e-08
)
np.testing.assert_allclose(
amp_recompute_loss, all_loss, rtol=1e-05, atol=1e-08
)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
import sys
import tempfile
import unittest
os.environ["FLAGS_CONVERT_GRAPH_TO_PROGRAM"] = str(1)
os.environ["FLAGS_add_dependency_for_communication_op"] = 'false'
class TestShardingWithNewEXE(unittest.TestCase):
def test_stage2(self):
file_dir = os.path.dirname(os.path.abspath(__file__))
launch_model_path = os.path.join(file_dir, "sharding_newexe.py")
if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
else:
coverage_args = []
tmp_dir = tempfile.TemporaryDirectory()
cmd = (
[sys.executable, "-u"]
+ coverage_args
+ [
"-m",
"paddle.distributed.launch",
"--devices",
"0,1",
"--log_dir",
tmp_dir.name,
launch_model_path,
]
)
process = subprocess.Popen(cmd)
process.wait()
self.assertEqual(process.returncode, 0)
tmp_dir.cleanup()
if __name__ == "__main__":
unittest.main()
......@@ -52,9 +52,13 @@ class TestStrategy(unittest.TestCase):
self.assertEqual(sharding.enable, False)
self.assertEqual(sharding.stage, 1)
self.assertEqual(sharding.degree, 8)
self.assertAlmostEqual(sharding.overlap_grad_comm, False)
self.assertAlmostEqual(sharding.bucket_size_numel, -1)
self.assertAlmostEqual(sharding.enable_overlap, False)
self.assertAlmostEqual(sharding.param_comm_stream_num, 1)
self.assertAlmostEqual(sharding.grad_comm_stream_num, 1)
self.assertAlmostEqual(sharding.partition_algor, "greedy_even")
self.assertAlmostEqual(sharding.param_bucket_size_numel, 1)
self.assertAlmostEqual(sharding.grad_bucket_size_numel, 1)
self.assertAlmostEqual(sharding.enable_hierarchical_comm, False)
self.assertEqual(sharding.enable_tuning, False)
self.assertEqual(sharding.tuning_range, [])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册