未验证 提交 5592f8ad 编写于 作者: J JZ-LIANG 提交者: GitHub

[Auto Parallel-Performance] Sharding Comm Optimization (#48604)

* remove deps and prior comm

* grad comm fuse

* add deps for amp&global norm

* stage2 broadcast prior deps

* stage2 grad overlap

* stream_analyzer bugfix

* overlap enable

* dep op namescope

* depend support multiple inputs

* check finite deps

* stage2 param comm overlap

* Set kD2HStream

* grad comm hierarchical

* grad comm hierarchical

* new unitest
Co-authored-by: Nchenruibiao <chenruibiao@baidu.com>
上级 852c8db3
......@@ -90,8 +90,12 @@ SHARDING = "sharding"
set_field_default_config(SHARDING, "enable", False)
set_field_default_config(SHARDING, "stage", 1)
set_field_default_config(SHARDING, "degree", 8)
set_field_default_config(SHARDING, "overlap_grad_comm", False)
set_field_default_config(SHARDING, "bucket_size_numel", -1)
set_field_default_config(SHARDING, "enable_overlap", False)
set_field_default_config(SHARDING, "param_comm_stream_num", 1)
set_field_default_config(SHARDING, "grad_comm_stream_num", 1)
set_field_default_config(SHARDING, "param_bucket_size_numel", 1)
set_field_default_config(SHARDING, "grad_bucket_size_numel", 1)
set_field_default_config(SHARDING, "enable_hierarchical_comm", False)
set_field_default_config(SHARDING, "partition_algor", "greedy_even")
set_field_default_config(SHARDING, "enable_tuning", False)
set_field_default_config(SHARDING, "tuning_range", [])
......
......@@ -45,6 +45,15 @@ class ParallelMode:
MoEParallel = "auto_parallel/moe_parallel"
class SyncMode:
"""
the synchorization mode for communication or auxiliary operator
"""
AmpFlagSync = "auto_parallel/amp_flag_synchorization"
GlobalNormSync = "auto_parallel/global_norm_synchorization"
def is_elementwise_op(op_type):
if op_type in _g_elementwise_ops:
return True
......@@ -441,7 +450,7 @@ def sync_and_scale_gradients(dist_ctx, op, dp_group, allreduce_var_names):
dims_mapping = op_dist_attr.get_output_dims_mapping(grad_var.name)
assert (
dims_mapping is not None
), "Unexception: dims_mapping of output [{}] of op [{}] is None".format(
), "Unexpected: dims_mapping of output [{}] of op [{}] is None".format(
grad_var.name, op_dist_attr.op_type
)
# NOTE auxiliary op's dist attr should follow dist_op not dist_tensor
......@@ -502,6 +511,22 @@ def is_data_parallel_reduce_op(op):
)
def is_amp_flag_sync_op(op):
return (
op.type == "c_allreduce_max"
and op.desc.has_attr("op_namescope")
and SyncMode.AmpFlagSync in op.desc.attr("op_namescope")
)
def is_global_norm_sync_op(op):
return (
op.type == "c_allreduce_sum"
and op.desc.has_attr("op_namescope")
and SyncMode.GlobalNormSync in op.desc.attr("op_namescope")
)
def is_in_backward_phase(dist_ctx):
# NOTE currently high-order differential in Paddle dose NOT distinguish gradient computation operators
# in Forward phase and operators in Backward phase (both with op_role=1), which will mislead
......
......@@ -24,6 +24,7 @@ from ..utils import set_dist_op_desc_original_id, set_var_dist_attr
from .common import (
DistributedOperatorImpl,
DistributedOperatorImplContainer,
SyncMode,
register_distributed_operator_impl,
register_distributed_operator_impl_container,
)
......@@ -166,6 +167,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
OP_ROLE_KEY: OpRole.Optimize,
},
)
allreduce_op._set_attr('op_namescope', str('/') + SyncMode.AmpFlagSync)
cast_op2 = main_block.append_op(
type='cast',
inputs={'X': inf_var_int32},
......
......@@ -318,6 +318,16 @@ class Parallelizer:
[main_program], [startup_program], self._pass_context
)
# deps for newexe
config = {}
config["dist_context"] = self._dist_context
APSED_pass = new_pass(
"auto_parallel_supplement_explicit_dependencies", config
)
APSED_pass.apply(
[main_program], [startup_program], self._pass_context
)
# gradient_merge is then train-only optimization
if self._mode == "train" and self._strategy.gradient_merge.enable:
config = copy.deepcopy(self._strategy.gradient_merge.to_dict())
......
......@@ -48,8 +48,10 @@ def clear_all_process_groups():
_g_process_group_map[0] = ProcessGroup(0, [])
def new_process_group(ranks, group_id=None):
def new_process_group(ranks, group_id=None, force_new_group=False):
global _g_process_group_map
if not force_new_group:
# A key constructed from ranks is used for avoiding duplication
new_key = ''.join(map(str, sorted(ranks)))
for pg_id, pg in _g_process_group_map.items():
......@@ -137,7 +139,6 @@ class ProcessGroup:
]
strategy.current_endpoint = genv.current_endpoint
strategy.nrings = 1
if core.is_compiled_with_cuda():
place = core.CUDAPlace(genv.device_id)
core.NCCLParallelContext(strategy, place).init_with_ring_id(
......
......@@ -1184,6 +1184,8 @@ def _get_split_indices(
def set_grad_var_shape(program, dist_context):
from paddle.distributed.fleet.meta_optimizers.common import OpRole
from .operators.common import infer_shape
block = program.global_block()
......@@ -1955,6 +1957,9 @@ def set_recompute_segments(model, losses, strategy, program):
and hasattr(model.gpt, "checkpoints")
):
ckpts = model.gpt.checkpoints
# last recompute segment is not need to recompute
if len(ckpts) > 2:
ckpts.pop()
else:
ckpts = recompute.checkpoints
else:
......@@ -2189,6 +2194,7 @@ def insert_dependencies_for_two_ops(
dist_context,
is_recompute=False,
sync=False,
op_namescope=None,
):
"""
dependency: prior_op should be run before posterior_op
......@@ -2233,49 +2239,74 @@ def insert_dependencies_for_two_ops(
[block.var(name) for name in posterior_op.input_arg_names]
)
return insert_dependencies_for_two_vars(
return insert_dependencies_for_vars(
block,
idx,
first_var,
second_var,
dist_context,
OpRole.Backward,
prior_op_mesh,
is_recompute,
sync,
process_mesh=prior_op_mesh,
is_recompute=is_recompute,
sync=sync,
op_namescope=op_namescope,
use_nop=False,
)
def insert_dependencies_for_two_vars(
def insert_dependencies_for_vars(
block,
idx,
prior_var,
post_var,
prior_vars,
post_vars,
dist_context,
oprole,
process_mesh=None,
is_recompute=False,
sync=False,
op_namescope=None,
use_nop=False,
):
"""
dependency: op that generates prior_var should be run before op that generates post_var
dependency: op that generates prior_vars should be run before op that generates post_vars
"""
if isinstance(prior_vars, Variable):
prior_vars = [prior_vars]
if isinstance(post_vars, Variable):
post_vars = [post_vars]
for prior_var in prior_vars:
assert block.has_var(prior_var.name)
for post_var in post_vars:
assert block.has_var(post_var.name)
if process_mesh is None:
process_mesh = dist_context.get_tensor_dist_attr_for_program(
post_var
post_vars[0]
).process_mesh
assert process_mesh is not None
use_nop = True
if use_nop:
depend_op = block._insert_op_without_sync(
idx,
type='nop',
inputs={
"X": prior_var,
"X": prior_vars,
},
outputs={"Out": post_var},
outputs={"Out": post_vars},
)
else:
depend_op = block._insert_op_without_sync(
idx,
type='depend',
inputs={
"X": post_vars,
"Dep": prior_vars,
},
outputs={"Out": post_vars},
)
# depend_op.desc.set_type("depend")
depend_op._set_attr(OP_ROLE_KEY, oprole)
# depend_op.desc.set_input("Dep", [first_var.name])
......@@ -2284,6 +2315,8 @@ def insert_dependencies_for_two_vars(
naive_set_dist_op_attr_for_program_by_mesh(
depend_op, process_mesh, dist_context, is_recompute
)
if op_namescope is not None:
depend_op._set_attr('op_namescope', "/{}".format(op_namescope))
if sync:
block._sync_with_cpp()
......@@ -2291,6 +2324,13 @@ def insert_dependencies_for_two_vars(
return depend_op
def is_dep_skip_op(op):
if "c_" in op.type:
return True
return False
def use_standalone_executor():
return os.environ.get('FLAGS_CONVERT_GRAPH_TO_PROGRAM', None) in [
1,
......
......@@ -23,6 +23,7 @@ from .auto_parallel_recompute import * # noqa: F403
from .auto_parallel_quantization import * # noqa: F403
from .auto_parallel_data_parallel_optimization import * # noqa: F403
from .auto_parallel_grad_clip import * # noqa: F403
from .auto_parallel_supplement_explicit_dependencies import * # noqa: F403
from .cpp_pass import * # noqa: F403
from .ps_trainer_pass import * # noqa: F403
from .ps_server_pass import * # noqa: F403
......
......@@ -22,7 +22,7 @@ from paddle.distributed.auto_parallel.operators.common import (
from paddle.distributed.auto_parallel.utils import (
find_higher_order_backward_op,
get_var_numel,
insert_dependencies_for_two_vars,
insert_dependencies_for_vars,
is_forward_op,
is_loss_grad_op,
is_optimize_op,
......@@ -153,12 +153,12 @@ class DataParallelOptimizationPass(PassBase):
continue
assert op.has_attr(
"ring_id"
), "Unexception: comm op [{}] has NOT ring id.".format(str(op))
), "Unexpected: comm op [{}] has NOT ring id.".format(str(op))
group = ring_id_to_process_group(op.attr("ring_id"))
assert (
group is not None
), "Unexception: data parallel group of [{}] from op [{}] is None".format(
), "Unexpected: data parallel group of [{}] from op [{}] is None".format(
grad_name, str(op)
)
......@@ -187,7 +187,7 @@ class DataParallelOptimizationPass(PassBase):
not_synchronized_grads.append(grad_name)
assert (
len(not_synchronized_grads) == 0
), "Unexception: gradients [{}] is scaled BUT NOT synchronized.".format(
), "Unexpected: gradients [{}] is scaled BUT NOT synchronized.".format(
not_synchronized_grads
)
......@@ -251,12 +251,12 @@ class DataParallelOptimizationPass(PassBase):
):
assert op.has_attr(
'rescale_grad'
), "Unexception: op [{}] is supported to have [rescale_grad] attribute.".format(
), "Unexpected: op [{}] is supported to have [rescale_grad] attribute.".format(
str(op)
)
assert (
len(op.input("Grad")) == 1
), "Unexception: op [{}] is supported to have only one input grad var.".format(
), "Unexpected: op [{}] is supported to have only one input grad var.".format(
str(op)
)
......@@ -271,7 +271,7 @@ class DataParallelOptimizationPass(PassBase):
assert scaled_grads == set(
self._grad_name_to_group_map.keys()
), "Unexception: gradients [{}] are unscaled.".format(
), "Unexpected: gradients [{}] are unscaled.".format(
set(self._grad_name_to_group_map.keys()) - scaled_grads
)
......@@ -463,7 +463,7 @@ class DataParallelOptimizationPass(PassBase):
group.coalesce_var = group.gradients[0]
continue
# create coalecse tensor
# create coalesce tensor
group.coalesce_var = block.create_var(
name=unique_name.generate(
self.coalesce_prefix + '_{}'.format(i)
......@@ -508,12 +508,10 @@ class DataParallelOptimizationPass(PassBase):
for idx in sorted(remove_op_indices, reverse=True):
assert (
block.ops[idx].type in remove_op_types
), "Unexception: try to remove op {}".format(
str(block.ops[idx])
)
), "Unexpected: try to remove op {}".format(str(block.ops[idx]))
block._remove_op(idx, False)
# insert coalecse op
# insert coalesce op
concated_shapes = []
concated_ranks = []
for grad_ in group.gradients:
......@@ -596,7 +594,7 @@ class DataParallelOptimizationPass(PassBase):
not_sync_coalesces.remove(var_name)
assert (
len(not_sync_coalesces) == 0
), "Unexception: {} has NOT been add prior Dep before allreduce.".format(
), "Unexpected: {} has NOT been add prior Dep before allreduce.".format(
not_sync_coalesces
)
......@@ -628,7 +626,7 @@ class DataParallelOptimizationPass(PassBase):
assert (
len(not_sync_coalesces) == 0
), "Unexception: {} has NOT been add post Dep after allreduce.".format(
), "Unexpected: {} has NOT been add post Dep after allreduce.".format(
not_sync_coalesces
)
......@@ -642,7 +640,7 @@ class DataParallelOptimizationPass(PassBase):
for idx, prior_name, post_name in dep_var_pairs:
prior_var = block.var(prior_name)
post_var = block.var(post_name)
depend_op = insert_dependencies_for_two_vars(
depend_op = insert_dependencies_for_vars(
block,
idx,
prior_var,
......@@ -651,9 +649,10 @@ class DataParallelOptimizationPass(PassBase):
OpRole.Backward,
process_mesh=[
-1
], # hack to avoid initialize the dist attr for coalesc var
], # hack to avoid initialize the dist attr for coalesce var
is_recompute=False,
sync=False,
op_namescope="data_parallel_overlap_dep",
)
depend_op.dist_attr.execution_stream = self.gradient_sync_stream
block._sync_with_cpp()
......@@ -694,16 +693,17 @@ class DataParallelOptimizationPass(PassBase):
self._logger.addHandler(log_handler)
if len(grad_groups) > 0:
self._logger.info("Data Parallel Optimization: ")
self._logger.info(
"origin {} allreduce ops are fused into {} coalecse allreduce ops.".format(
" {} Allreduce ops are fused into {} coalesce allreduce ops.".format(
len(self._grad_name_to_group_map.keys()), len(grad_groups)
)
)
self._logger.info("gradient fusing group are following: ")
self._logger.debug("gradient fusing group are following: ")
fused_grads = set()
for i, group in enumerate(grad_groups):
self._logger.info(
"coalecse gradient [{}] is composed by: {}".format(
self._logger.debug(
"coalesce gradient [{}] is composed by: {}".format(
i, [grad.name for grad in group.gradients]
)
)
......@@ -711,12 +711,14 @@ class DataParallelOptimizationPass(PassBase):
individual_grads = set(self._grad_name_to_group_map.keys()) - set(
fused_grads
)
self._logger.info(
self._logger.debug(
"the following [{}] gradients are not fused: ".format(
len(individual_grads)
)
)
self._logger.info("individual gradient {}".format(individual_grads))
self._logger.debug(
"individual gradient {}".format(individual_grads)
)
class GradientsGroup:
......
......@@ -23,11 +23,12 @@ from ..auto_parallel.dist_attribute import (
OperatorDistributedAttribute,
TensorDistributedAttribute,
)
from ..auto_parallel.operators.common import SyncMode
from ..auto_parallel.process_group import get_world_process_group
from ..auto_parallel.reshard import Resharder
from ..auto_parallel.utils import (
_get_comm_group,
insert_dependencies_for_two_vars,
insert_dependencies_for_vars,
is_gradient_clip_op,
is_optimize_op,
use_standalone_executor,
......@@ -372,8 +373,9 @@ class ClipGradByGloblNormPass(PassBase):
OP_ROLE_KEY: OpRole.Optimize,
},
)
# TODO better regular the usage of op namescope
allreduce_op._set_attr(
'op_namescope', "/gradient_clip_pass"
'op_namescope', str('/') + SyncMode.GlobalNormSync
)
self.clip_helper._init_dist_attr(allreduce_op)
......@@ -394,15 +396,14 @@ class ClipGradByGloblNormPass(PassBase):
prior_op = block.ops[j]
break
j -= 1
print("here: ", block.ops[j])
assert (
prior_op is not None
), "Unexception: ClipByGlobalNorm could not find priory depend op"
), "Unexpected: ClipByGlobalNorm could not find priory depend op"
prior_var = block.vars[prior_op.output_arg_names[0]]
assert (
prior_var is not None
), "Unexception: ClipByGlobalNorm could not find priory depend var"
insert_dependencies_for_two_vars(
), "Unexpected: ClipByGlobalNorm could not find priory depend var"
insert_dependencies_for_vars(
block,
idx,
prior_var,
......@@ -414,6 +415,7 @@ class ClipGradByGloblNormPass(PassBase):
], # hack to avoid initialize the dist attr for coalesc var
is_recompute=False,
sync=False,
op_namescope="grad_clip_fill_constant_dep",
)
for varname in removed_tmp_var:
......
......@@ -474,6 +474,7 @@ class RecomputePass(PassBase):
self._dist_context,
is_recompute=True,
sync=False,
op_namescope="recompute_segment_dep",
)
main_program._sync_with_cpp()
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.distributed.auto_parallel.operators.common import (
is_amp_flag_sync_op,
is_data_parallel_reduce_op,
is_global_norm_sync_op,
)
from paddle.distributed.auto_parallel.utils import (
OpRole,
insert_dependencies_for_vars,
use_standalone_executor,
)
from .auto_parallel_sharding import ShardingPass, _supported_optimizer_type
from .pass_base import PassBase, register_pass
def _sharding_pass_applied(pass_ctx):
for applied_pass in pass_ctx.passes:
if isinstance(applied_pass, ShardingPass):
return True
return False
# NOTE we add the "auto_parallel" prefix to the pass in order to
# indicate that this pass should obey some constrains by auto_parallel
# for example all ops and vars should has dist attr before and after pass
# should use dist op instead of custom comm op
@register_pass("auto_parallel_supplement_explicit_dependencies")
class AutoParalSupplementDepPass(PassBase):
"""
Functional Concern.
for strategies like amp & global norm, there is a collective communication to sync gradient inforation in every rank.
after partition the gradients to each rank, the order of that collective communication is different in each rank
and might cause hang problem in graph based random order executor. here supplement explicit dependencies for those cases.
TODO Performance Concern.
global collective will introduce global synchronization which forces the fast workers to wait for slow ones.
therefore we should conduct this collective when all the ranks reach a same stage.
BUT the depend API offered by executor could only ensure "conduct-not-before" but not "conduct-right-after".
Some ranks might call the colletives first than other ranks while they still some local could be performed to wait for slow peers.
IR Pass currently could not have the fully control of time the to perform these global collectives.
"""
def __init__(self):
super().__init__()
self.set_attr("dist_context", None)
def _check_self(self):
if self.get_attr("dist_context") is None:
return False
return True
def _check_conflict(self, other_pass):
return True
def _apply_single_impl(self, main_program, startup_program, context):
# TODO general this pass for all case.
if not use_standalone_executor or not _sharding_pass_applied(context):
return
self._dist_context = self.get_attr("dist_context", None)
self.flags_sync_stream = "flags_sync_stream"
main_block = main_program.global_block()
startup_block = startup_program.global_block()
# last dp grad communication
last_dp_reduce_op_idx = -1
last_dp_reduce_varname = None
for idx, op in reversed(list(enumerate(main_block.ops))):
if is_data_parallel_reduce_op(op):
last_dp_reduce_op_idx = idx
last_dp_reduce_varname = op.output_arg_names[0]
break
assert last_dp_reduce_op_idx > 0
assert last_dp_reduce_varname is not None
# analyze deps for amp & global norm
deps_map = {}
prior_varname = last_dp_reduce_varname
for idx, op in enumerate(main_block.ops):
if is_amp_flag_sync_op(op) or is_global_norm_sync_op(op):
op_namescope = None
if is_amp_flag_sync_op(op):
op_namescope = "amp_flag_sync_dep"
op.dist_attr.execution_stream = self.flags_sync_stream
elif is_global_norm_sync_op(op):
op_namescope = "global_norm_sync_dep"
deps_map[idx] = (prior_varname, op.input("X")[0], op_namescope)
prior_varname = op.output("Out")[0]
# analyze deps for check_finite_and_unscale
# ensure it is performed after last backward computation, therefore reduce the
# straggling of the amp-flag-sync
first_check_op = True
for idx, op in enumerate(main_block.ops):
if op.type == "check_finite_and_unscale":
if first_check_op:
last_backward_op = main_block.ops[idx - 1]
prior_varname = last_backward_op.output_arg_names[0]
first_check_op = False
deps_map[idx] = (
prior_varname,
op.input("Scale")[0],
"check_finite_dep",
)
# analyze deps for optimizer
# optimizers order should be fixed to allow broadcast to overlap with optimizer
first_optimizer_op = True
for idx, op in enumerate(main_block.ops):
if op.type in _supported_optimizer_type:
if first_optimizer_op:
first_optimizer_op = False
else:
deps_map[idx] = (
prior_varname,
op.input("Param")[0],
"optimizer_order_dep",
)
prior_varname = op.output("ParamOut")[0]
# insert deps
indice = sorted(list(deps_map.keys()), reverse=True)
for idx in indice:
prior_var = main_block.var(deps_map[idx][0])
post_var = main_block.var(deps_map[idx][1])
op_namescope = deps_map[idx][2]
depend_op = insert_dependencies_for_vars(
main_block,
idx,
prior_var,
post_var,
self._dist_context,
OpRole.Optimize,
process_mesh=[
-1
], # hack to avoid initialize the dist attr for coalesc var
is_recompute=False,
sync=False,
op_namescope=op_namescope,
)
main_block._sync_with_cpp()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import unittest
import numpy as np
from get_gpt_model import FakeDataset, generate_model
import paddle
from paddle.distributed.fleet import auto
from paddle.fluid.dygraph.parallel import ParallelEnv
paddle.enable_static()
def apply_pass(use_sharding=False, use_amp=False, use_recompute=False):
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
if use_sharding:
sharding = strategy.sharding
sharding.enable = True
sharding.degree = 2
sharding.stage = 2
sharding.enable_overlap = True
sharding.param_comm_stream_num = 2
sharding.grad_comm_stream_num = 2
sharding.param_bucket_size_numel = 512 * 512
sharding.grad_bucket_size_numel = 128 * 128
sharding.partition_algor = 'use_order'
if use_recompute:
recompute = strategy.recompute
recompute.enable = True
if use_amp:
amp = strategy.amp
amp.enable = True
amp.custom_white_list = [
'lookup_table_v2',
'lookup_table',
'softmax',
'layer_norm',
'gelu',
]
amp.custom_black_list = [
'c_softmax_with_cross_entropy',
'elementwise_div',
'reduce_sum',
]
amp.init_loss_scaling = 32768
amp.use_fp16_guard = False
amp.use_pure_fp16 = True
amp.use_optimizer_fp16 = False
return strategy
def reset_prog():
paddle.fluid.framework.switch_main_program(paddle.static.Program())
paddle.fluid.framework.switch_startup_program(paddle.static.Program())
class TestShardingStage2WithNewEXE(unittest.TestCase):
def setUp(self):
self.batch_size = 2
self.batch_num = 10
self.clip_norm = 0.2
self.dataset = FakeDataset(self.batch_size * self.batch_num)
def init(self, engine):
paddle.seed(2022)
np.random.seed(2022)
random.seed(2022)
place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)
def get_engine(
self, use_sharding=False, use_amp=False, use_recompute=False
):
reset_prog()
strategy = apply_pass(use_sharding, use_amp, use_recompute)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("dp")
engine = auto.Engine(model, loss, opt, strategy=strategy)
self.init(engine)
return engine
def check_param_grad_fuse_overlap(self, program):
num_op = 0
num_coalesce = 0
num_reduce = 0
num_broadcast = 0
for op in program.global_block().ops:
if op.type == "nop" or op.type == "depend":
num_op += 1
elif op.type == "coalesce_tensor":
num_coalesce += 1
elif op.type == "c_reduce_sum":
num_reduce += 1
elif op.type == "c_broadcast":
num_broadcast += 1
if paddle.distributed.get_rank() == 0:
self.assertEqual(num_op, 22)
else:
self.assertEqual(num_op, 54)
self.assertEqual(num_coalesce, 5)
self.assertEqual(num_reduce, 14)
self.assertEqual(num_broadcast, 2)
def test_param_grad_fuse_overlap(self):
# dp2
dp_engine = self.get_engine()
dp_history = dp_engine.fit(
self.dataset,
3,
epochs=1,
steps_per_epoch=self.batch_num,
log_freq=1,
batch_size=self.batch_size,
)
dp_loss = dp_history.history['loss'][0]
# sharding2
sharding_engine = self.get_engine(use_sharding=True)
sharding_history = sharding_engine.fit(
self.dataset,
3,
epochs=1,
steps_per_epoch=self.batch_num,
log_freq=1,
batch_size=self.batch_size,
)
sharding_loss = sharding_history.history['loss'][0]
# amp, recompute
amp_recompute_engine = self.get_engine(
use_sharding=False, use_amp=True, use_recompute=True
)
amp_recompute_history = amp_recompute_engine.fit(
self.dataset,
3,
epochs=1,
steps_per_epoch=self.batch_num,
log_freq=1,
batch_size=self.batch_size,
)
amp_recompute_loss = amp_recompute_history.history['loss'][0]
# sharding2, amp, recompute
all_engine = self.get_engine(
use_sharding=True, use_amp=True, use_recompute=True
)
all_history = all_engine.fit(
self.dataset,
3,
epochs=1,
steps_per_epoch=self.batch_num,
log_freq=1,
batch_size=self.batch_size,
)
all_loss = all_history.history['loss'][0]
self.check_param_grad_fuse_overlap(sharding_engine.main_program)
np.testing.assert_allclose(
dp_loss, sharding_loss, rtol=1e-05, atol=1e-08
)
np.testing.assert_allclose(
amp_recompute_loss, all_loss, rtol=1e-05, atol=1e-08
)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
import sys
import tempfile
import unittest
os.environ["FLAGS_CONVERT_GRAPH_TO_PROGRAM"] = str(1)
os.environ["FLAGS_add_dependency_for_communication_op"] = 'false'
class TestShardingWithNewEXE(unittest.TestCase):
def test_stage2(self):
file_dir = os.path.dirname(os.path.abspath(__file__))
launch_model_path = os.path.join(file_dir, "sharding_newexe.py")
if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
else:
coverage_args = []
tmp_dir = tempfile.TemporaryDirectory()
cmd = (
[sys.executable, "-u"]
+ coverage_args
+ [
"-m",
"paddle.distributed.launch",
"--devices",
"0,1",
"--log_dir",
tmp_dir.name,
launch_model_path,
]
)
process = subprocess.Popen(cmd)
process.wait()
self.assertEqual(process.returncode, 0)
tmp_dir.cleanup()
if __name__ == "__main__":
unittest.main()
......@@ -52,9 +52,13 @@ class TestStrategy(unittest.TestCase):
self.assertEqual(sharding.enable, False)
self.assertEqual(sharding.stage, 1)
self.assertEqual(sharding.degree, 8)
self.assertAlmostEqual(sharding.overlap_grad_comm, False)
self.assertAlmostEqual(sharding.bucket_size_numel, -1)
self.assertAlmostEqual(sharding.enable_overlap, False)
self.assertAlmostEqual(sharding.param_comm_stream_num, 1)
self.assertAlmostEqual(sharding.grad_comm_stream_num, 1)
self.assertAlmostEqual(sharding.partition_algor, "greedy_even")
self.assertAlmostEqual(sharding.param_bucket_size_numel, 1)
self.assertAlmostEqual(sharding.grad_bucket_size_numel, 1)
self.assertAlmostEqual(sharding.enable_hierarchical_comm, False)
self.assertEqual(sharding.enable_tuning, False)
self.assertEqual(sharding.tuning_range, [])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册