未验证 提交 e3faf345 编写于 作者: J JZ-LIANG 提交者: GitHub

[Auto Parallel] Sharding Pass (#38502)

* auto parallel sharding base

* chmod

* add unitest

* set unitest cmake dist label

* revise code according to rewiew

* chmod
上级 9456170f
...@@ -45,6 +45,7 @@ message ShardingConfig { ...@@ -45,6 +45,7 @@ message ShardingConfig {
optional bool optimize_cast = 12 [ default = false ]; optional bool optimize_cast = 12 [ default = false ];
// Optimizer sharding. Temporary plans and may be deprecated // Optimizer sharding. Temporary plans and may be deprecated
optional bool _dp_as_optimizer_sharding = 13 [ default = false ]; optional bool _dp_as_optimizer_sharding = 13 [ default = false ];
optional int32 stage = 14 [ default = 1 ];
} }
message HybridConfig { message HybridConfig {
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from ..dist_attribute import OperatorDistributedAttribute from ..dist_attribute import OperatorDistributedAttribute
_g_distributed_operator_impl_registries = {} _g_distributed_operator_impl_registries = {}
BACKWARD_ONLY_DIST_OPS = {'check_finite_and_unscale'}
class DistributedOperatorImplContainer: class DistributedOperatorImplContainer:
...@@ -116,6 +117,14 @@ def find_best_compatible_distributed_operator_impl(name, dist_op, fwd=True): ...@@ -116,6 +117,14 @@ def find_best_compatible_distributed_operator_impl(name, dist_op, fwd=True):
return best_compatible_impl, idx return best_compatible_impl, idx
def is_parameter_related(varname, block):
if ".cast_fp" in varname:
varname = varname[:varname.index(".cast_fp")]
assert block.has_var(varname)
var = block.var(varname)
return var.is_parameter
def infer_shape(block, src_var, src_var_dist_attr, op_input_dist_attr): def infer_shape(block, src_var, src_var_dist_attr, op_input_dist_attr):
var_shape = block.var(src_var.name).shape var_shape = block.var(src_var.name).shape
var_topoloy = src_var_dist_attr.process_mesh.topology var_topoloy = src_var_dist_attr.process_mesh.topology
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
from .common import DistributedOperatorImplContainer from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl from .common import register_distributed_operator_impl, is_parameter_related
from ..utils import is_dim_shard from ..utils import is_dim_shard
from ..utils import is_dim_replicate from ..utils import is_dim_replicate
from ..utils import is_valid_list_index from ..utils import is_valid_list_index
...@@ -183,8 +183,8 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -183,8 +183,8 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
need_gradient_allreduce = False need_gradient_allreduce = False
for input_name in backward_op.desc.input_names(): for input_name in backward_op.desc.input_names():
for varname in backward_op.desc.input(input_name): for varname in backward_op.desc.input(input_name):
if "@GRAD" not in varname and not main_block.var( if "@GRAD" not in varname and not is_parameter_related(
varname).is_parameter: varname, main_block):
# NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op # NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op
process_mesh = dist_attr.process_mesh process_mesh = dist_attr.process_mesh
...@@ -210,8 +210,8 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -210,8 +210,8 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
allreduce_vars = [] allreduce_vars = []
for input_name in backward_op.desc.input_names(): for input_name in backward_op.desc.input_names():
for varname in backward_op.desc.input(input_name): for varname in backward_op.desc.input(input_name):
if "@GRAD" not in varname and main_block.var( if "@GRAD" not in varname and is_parameter_related(
varname).is_parameter: varname, main_block):
assert len( assert len(
backward_op.desc.input(input_name) backward_op.desc.input(input_name)
) == 1, "parameter input to grad op should be length 1, but got [{}]".format( ) == 1, "parameter input to grad op should be length 1, but got [{}]".format(
......
...@@ -16,7 +16,7 @@ from .common import infer_shape ...@@ -16,7 +16,7 @@ from .common import infer_shape
from .common import DistributedOperatorImplContainer from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl, set_comm_op_dist_attr_for_program, naive_copy_op_dist_attr_for_program from .common import register_distributed_operator_impl, set_comm_op_dist_attr_for_program, naive_copy_op_dist_attr_for_program, is_parameter_related
from ..utils import is_dim_shard from ..utils import is_dim_shard
from ..utils import is_dim_replicate from ..utils import is_dim_replicate
from ..utils import is_valid_list_index from ..utils import is_valid_list_index
...@@ -26,7 +26,7 @@ from ..utils import compute_compatible_and_update_dim_mapping ...@@ -26,7 +26,7 @@ from ..utils import compute_compatible_and_update_dim_mapping
from ..dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute from ..dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute
from paddle.fluid import core, unique_name from paddle.fluid import core, unique_name
from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard from paddle.fluid.framework import Program, Parameter, Variable
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from ..process_group import new_process_group from ..process_group import new_process_group
...@@ -283,34 +283,35 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -283,34 +283,35 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
allreduce_op_dist_attr) allreduce_op_dist_attr)
# param initialization sync # param initialization sync
assert Weight_var.name not in dist_op_context.already_init_sync_vars if Weight_var.is_parameter:
dist_op_context.already_init_sync_vars.add(Weight_var.name) assert Weight_var.name not in dist_op_context.already_init_sync_vars
param = startup_block.var(Weight_var.name) dist_op_context.already_init_sync_vars.add(Weight_var.name)
param_dist_attr = ctx.get_tensor_dist_attr_for_program(param) param = startup_block.var(Weight_var.name)
process_mesh = param_dist_attr.process_mesh param_dist_attr = ctx.get_tensor_dist_attr_for_program(param)
dim_mapping = param_dist_attr.dims_mapping process_mesh = param_dist_attr.process_mesh
dim_mapping = param_dist_attr.dims_mapping
# NOTE all not splited axis should be presented in mesh
for axis, size in enumerate(process_mesh.topology): # NOTE all not splited axis should be presented in mesh
if size <= 1 or axis in dim_mapping: for axis, size in enumerate(process_mesh.topology):
pass if size <= 1 or axis in dim_mapping:
else: pass
group_ranks = _get_comm_group(process_mesh.processes, else:
process_mesh.topology, axis, group_ranks = _get_comm_group(process_mesh.processes,
rank_id) process_mesh.topology, axis,
sync_group = new_process_group(group_ranks) rank_id)
sync_group = new_process_group(group_ranks)
startup_block.append_op(
type='c_broadcast', startup_block.append_op(
inputs={'X': param}, type='c_broadcast',
outputs={'Out': param}, inputs={'X': param},
attrs={ outputs={'Out': param},
'ring_id': sync_group.id, attrs={
'root': 0, 'ring_id': sync_group.id,
'use_calc_stream': True, 'root': 0,
OP_ROLE_KEY: OpRole.Forward 'use_calc_stream': True,
}) OP_ROLE_KEY: OpRole.Forward
startup_block._sync_with_cpp() })
startup_block._sync_with_cpp()
@staticmethod @staticmethod
def backward(ctx, *args, **kwargs): def backward(ctx, *args, **kwargs):
......
...@@ -18,7 +18,7 @@ from .common import DistributedOperatorImplContainer ...@@ -18,7 +18,7 @@ from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl from .common import register_distributed_operator_impl
from .common import set_comm_op_dist_attr_for_program, naive_copy_op_dist_attr_for_program from .common import set_comm_op_dist_attr_for_program, naive_copy_op_dist_attr_for_program, is_parameter_related
from ..utils import is_dim_shard from ..utils import is_dim_shard
from ..utils import is_dim_replicate from ..utils import is_dim_replicate
from ..utils import is_valid_list_index from ..utils import is_valid_list_index
...@@ -184,7 +184,9 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): ...@@ -184,7 +184,9 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
Out_grad = main_block.var(kwargs['Out@GRAD'][0]) Out_grad = main_block.var(kwargs['Out@GRAD'][0])
Y_grad = main_block.var(kwargs['Y@GRAD'][0]) Y_grad = main_block.var(kwargs['Y@GRAD'][0])
assert not X_var.is_parameter, "left operand(X) [{}] of dist matmul should not be parameter".format( assert not is_parameter_related(
X_var.name, main_block
), "left operand(X) [{}] of dist matmul should not be parameter".format(
X_var.name) X_var.name)
Y_var_dim_mapping = dist_attr.get_input_dims_mapping(Y_var.name) Y_var_dim_mapping = dist_attr.get_input_dims_mapping(Y_var.name)
...@@ -200,7 +202,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): ...@@ -200,7 +202,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
Y_var_partitioned = True Y_var_partitioned = True
break break
if Y_var.is_parameter and Y_var_partitioned: if is_parameter_related(Y_var.name, main_block) and Y_var_partitioned:
if Y_var_dim_mapping[0] >= 0: if Y_var_dim_mapping[0] >= 0:
# row parallel: c_identity + matmul # row parallel: c_identity + matmul
...@@ -322,7 +324,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): ...@@ -322,7 +324,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
dp_degree = len(group_ranks) dp_degree = len(group_ranks)
dp_group = new_process_group(group_ranks) dp_group = new_process_group(group_ranks)
if need_gradient_allreduce and Y_var.is_parameter: if need_gradient_allreduce and is_parameter_related(Y_var.name, main_block):
Y_Grad_var = main_block.var(kwargs['Y@GRAD'][0]) Y_Grad_var = main_block.var(kwargs['Y@GRAD'][0])
allreduce_op = main_block.append_op( allreduce_op = main_block.append_op(
type='c_allreduce_sum', type='c_allreduce_sum',
...@@ -444,6 +446,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -444,6 +446,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
y_dims_mapping), "now just support x dims > y dims" y_dims_mapping), "now just support x dims > y dims"
if len(y_dims_mapping) != 2: if len(y_dims_mapping) != 2:
return False return False
if len(x_dims_mapping) == len(y_dims_mapping) and len( if len(x_dims_mapping) == len(y_dims_mapping) and len(
x_dims_mapping) == 4: x_dims_mapping) == 4:
if x_dims_mapping[:2] != y_dims_mapping[:2]: if x_dims_mapping[:2] != y_dims_mapping[:2]:
......
...@@ -27,6 +27,7 @@ from paddle.distributed.utils import get_logger ...@@ -27,6 +27,7 @@ from paddle.distributed.utils import get_logger
from paddle.distributed.fleet import cloud_utils from paddle.distributed.fleet import cloud_utils
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid import program_guard from paddle.fluid import program_guard
from paddle.distributed.passes import new_pass, PassContext
from .dist_context import DistributedContext from .dist_context import DistributedContext
from .dist_context import get_default_distributed_context from .dist_context import get_default_distributed_context
from .dist_context import set_default_distributed_context from .dist_context import set_default_distributed_context
...@@ -139,23 +140,9 @@ class AutoParallelizer: ...@@ -139,23 +140,9 @@ class AutoParallelizer:
def _apply_optimize(self, main_program, startup_program, params_grads): def _apply_optimize(self, main_program, startup_program, params_grads):
if self._dist_strategy.sharding: with program_guard(main_program, startup_program):
auto_parallel_sharding_pass = new_pass( optimize_ops = copy.deepcopy(self._optimizer).apply_gradients(
"auto_parallel_sharding_pass", self._dist_strategy) params_grads)
params_grads = auto_parallel_sharding_pass.apply(
main_program, startup_program, params_grads, self._pass_context)
if self._dist_strategy.gradient_merge:
auto_parallel_gradient_merge_pass = new_pass(
"auto_parallel_gradient_merge_pass",
self._dist_strategy.gradient_merge_configs)
auto_parallel_gradient_merge_pass.apply(
main_program, startup_program, params_grads, self._pass_context)
else:
with program_guard(main_program, startup_program):
optimizer = copy.deepcopy(self._optimizer)
optimize_ops = optimizer.apply_gradients(params_grads)
# update completion # update completion
complete_update_annotation( complete_update_annotation(
...@@ -163,6 +150,19 @@ class AutoParallelizer: ...@@ -163,6 +150,19 @@ class AutoParallelizer:
return optimize_ops return optimize_ops
def _apply_post_optimization_passed(self, main_program, startup_program,
rank, params_grads):
if self._dist_strategy.sharding:
config = copy.deepcopy(self._dist_strategy.sharding_configs)
config["dist_context"] = self._dist_context
config["params_grads"] = params_grads
config["global_rank"] = rank
auto_parallel_sharding_pass = new_pass("auto_parallel_sharding",
config)
auto_parallel_sharding_pass.apply(
[main_program], [startup_program], self._pass_context)
def _get_dist_program(self, rank, dist_context=None, relaunch_phase=False): def _get_dist_program(self, rank, dist_context=None, relaunch_phase=False):
completed_main_program = None completed_main_program = None
serial_main_program = self._main_program.clone() serial_main_program = self._main_program.clone()
...@@ -203,7 +203,8 @@ class AutoParallelizer: ...@@ -203,7 +203,8 @@ class AutoParallelizer:
make_data_unshard(dist_main_prog, dist_startup_prog, self._dist_context) make_data_unshard(dist_main_prog, dist_startup_prog, self._dist_context)
reshard(dist_main_prog, dist_startup_prog, rank, self._dist_context) reshard(dist_main_prog, dist_startup_prog, rank, self._dist_context)
self._apply_post_optimization_passed(dist_main_prog, dist_startup_prog,
rank, dist_params_grads)
g_process_group_map = None g_process_group_map = None
if not relaunch_phase: if not relaunch_phase:
g_process_group_map = copy.deepcopy(_g_process_group_map) g_process_group_map = copy.deepcopy(_g_process_group_map)
......
...@@ -24,7 +24,8 @@ from paddle.distributed.auto_parallel.operators.common import get_distributed_op ...@@ -24,7 +24,8 @@ from paddle.distributed.auto_parallel.operators.common import get_distributed_op
from paddle.distributed.auto_parallel.dist_context import DistributedContext, DistributedOperatorContext from paddle.distributed.auto_parallel.dist_context import DistributedContext, DistributedOperatorContext
from .dist_attribute import OperatorDistributedAttribute from .dist_attribute import OperatorDistributedAttribute
from .process_group import new_process_group from .process_group import new_process_group
from .utils import print_program_with_dist_attr, is_forward_op, is_backward_op from .utils import print_program_with_dist_attr, is_forward_op, is_backward_op, is_recompute_op
from .operators.common import BACKWARD_ONLY_DIST_OPS
__varname_not_in_block__ = ["lod_tensor_blocking_queue_0"] __varname_not_in_block__ = ["lod_tensor_blocking_queue_0"]
...@@ -102,22 +103,17 @@ class Partitioner(object): ...@@ -102,22 +103,17 @@ class Partitioner(object):
partitioned_startup_prog = fluid.Program() partitioned_startup_prog = fluid.Program()
ref_block = serial_main_program.global_block() ref_block = serial_main_program.global_block()
target_block = partitioned_startup_prog.global_block() target_block = partitioned_startup_prog.global_block()
param2shape = {} var2shape = {}
temp_varname_map = {} temp_varname_map = {}
# tensors # tensors
for var in serial_startup_program.list_vars(): for var in serial_startup_program.list_vars():
if isinstance(var, Parameter): assert var.persistable
# TODO if var not belong to this rank, should be filtered new_name = var.name + self._dist_varname_suffix
serial_main_var = ref_block.var(var.name) temp_varname_map[var.name] = new_name
dist_attr = self._dist_context.get_tensor_dist_attr_for_program( target_shape = _partition_var(self._dist_context, ref_block,
serial_main_var) target_block, var.name, new_name)
target_shape = _get_dist_shape(serial_main_var, dist_attr) var2shape[new_name] = target_shape
new_name = var.name + self._dist_varname_suffix
temp_varname_map[var.name] = new_name
_partition_parameter(self._dist_context, serial_main_var,
target_block, new_name, target_shape)
param2shape[new_name] = target_shape
# ops # ops
for op in serial_startup_program.global_block().ops: for op in serial_startup_program.global_block().ops:
...@@ -128,14 +124,14 @@ class Partitioner(object): ...@@ -128,14 +124,14 @@ class Partitioner(object):
) == 1, "initializer should output only ONE variable, but got [{}]".format( ) == 1, "initializer should output only ONE variable, but got [{}]".format(
str(op.desc)) str(op.desc))
assert temp_varname_map[output_vars[ assert temp_varname_map[output_vars[
0]] in param2shape, "try to initialize [{}] which is not a Parameter".format( 0]] in var2shape, "try to initialize [{}] which is not a persistable var".format(
output_vars[0]) output_vars[0])
new_op_desc = target_block.desc.append_op() new_op_desc = target_block.desc.append_op()
new_op_desc.copy_from(op.desc) new_op_desc.copy_from(op.desc)
new_op_desc._rename_output(output_vars[0], new_op_desc._rename_output(output_vars[0],
temp_varname_map[output_vars[0]]) temp_varname_map[output_vars[0]])
new_op_desc._set_attr("shape", new_op_desc._set_attr("shape",
param2shape[temp_varname_map[output_vars[0]]]) var2shape[temp_varname_map[output_vars[0]]])
target_block._sync_with_cpp() target_block._sync_with_cpp()
# set distribute atrribute # set distribute atrribute
...@@ -211,7 +207,6 @@ class Partitioner(object): ...@@ -211,7 +207,6 @@ class Partitioner(object):
**koutputs) **koutputs)
elif is_backward_op(op): elif is_backward_op(op):
print(str(op))
kinputs, koutputs = dist_op_context.prepare_context(op) kinputs, koutputs = dist_op_context.prepare_context(op)
dist_op_backward_impl = _get_dist_op_backward_implement( dist_op_backward_impl = _get_dist_op_backward_implement(
op, self._dist_context, forward_op_id2forward_op) op, self._dist_context, forward_op_id2forward_op)
...@@ -351,6 +346,7 @@ def _partition_var(dist_context, src_block, dst_block, src_varname, ...@@ -351,6 +346,7 @@ def _partition_var(dist_context, src_block, dst_block, src_varname,
name=dst_varname, name=dst_varname,
persistable=True, persistable=True,
stop_gradient=True) stop_gradient=True)
target_shape = None
else: else:
dist_attr = dist_context.get_tensor_dist_attr_for_program(src_var) dist_attr = dist_context.get_tensor_dist_attr_for_program(src_var)
target_shape = _get_dist_shape(src_var, dist_attr) target_shape = _get_dist_shape(src_var, dist_attr)
...@@ -361,6 +357,7 @@ def _partition_var(dist_context, src_block, dst_block, src_varname, ...@@ -361,6 +357,7 @@ def _partition_var(dist_context, src_block, dst_block, src_varname,
else: else:
_partition_intermediate_var(dist_context, src_var, dst_block, _partition_intermediate_var(dist_context, src_var, dst_block,
dst_varname, target_shape) dst_varname, target_shape)
return target_shape
def _get_dist_op_backward_implement(backward_op, dist_context, def _get_dist_op_backward_implement(backward_op, dist_context,
...@@ -371,25 +368,32 @@ def _get_dist_op_backward_implement(backward_op, dist_context, ...@@ -371,25 +368,32 @@ def _get_dist_op_backward_implement(backward_op, dist_context,
forward_op = forward_op_id2forward_op[forward_op_id] forward_op = forward_op_id2forward_op[forward_op_id]
forward_op_dist_attr = dist_context.get_op_dist_attr_for_program( forward_op_dist_attr = dist_context.get_op_dist_attr_for_program(
forward_op) forward_op)
dist_ops = get_distributed_operator_impl_container(forward_op.type) dist_op = get_distributed_operator_impl_container(forward_op.type)
# TODO backward should have its own impl_idx # TODO backward should have its own impl_idx
if dist_ops and forward_op_dist_attr.impl_idx >= 0 and dist_ops.get_impl( \ if dist_op and forward_op_dist_attr.impl_idx >= 0 and dist_op.get_impl( \
forward_op_dist_attr.impl_idx)._backward_implemented: forward_op_dist_attr.impl_idx)._backward_implemented:
return dist_ops.get_impl(forward_op_dist_attr.impl_idx) return dist_op.get_impl(forward_op_dist_attr.impl_idx)
dist_ops = get_distributed_operator_impl_container("default") # NOTE trick for dist ops that only have backward implement
return dist_ops.get_impl(0) if backward_op.type in BACKWARD_ONLY_DIST_OPS:
op_dist_attr = dist_context.get_op_dist_attr_for_program(backward_op)
assert op_dist_attr.impl_idx >= 0
return get_distributed_operator_impl_container(
backward_op.type).get_impl(op_dist_attr.impl_idx)
dist_op = get_distributed_operator_impl_container("default")
return dist_op.get_impl(0)
def _get_dist_op_forward_implement(forward_op, dist_context): def _get_dist_op_forward_implement(forward_op, dist_context):
dist_attr = dist_context.get_op_dist_attr_for_program(forward_op) dist_attr = dist_context.get_op_dist_attr_for_program(forward_op)
dist_ops = get_distributed_operator_impl_container(forward_op.type) dist_op = get_distributed_operator_impl_container(forward_op.type)
if dist_ops and dist_attr.impl_idx >= 0 and dist_ops.get_impl( if dist_op and dist_attr.impl_idx >= 0 and dist_op.get_impl(
dist_attr.impl_idx)._forward_implemented: dist_attr.impl_idx)._forward_implemented:
return dist_ops.get_impl(dist_attr.impl_idx) return dist_op.get_impl(dist_attr.impl_idx)
else: else:
dist_ops = get_distributed_operator_impl_container("default") dist_op = get_distributed_operator_impl_container("default")
return dist_ops.get_impl(0) return dist_op.get_impl(0)
...@@ -25,6 +25,7 @@ import paddle.fluid.core as core ...@@ -25,6 +25,7 @@ import paddle.fluid.core as core
from paddle.framework.io import _to_LodTensor from paddle.framework.io import _to_LodTensor
from paddle.distributed.fleet.meta_optimizers.common import OpRole from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.fluid.io import is_parameter, is_belong_to_optimizer from paddle.fluid.io import is_parameter, is_belong_to_optimizer
from paddle.distributed.auto_parallel.dist_attribute import TensorDistributedAttribute, OperatorDistributedAttribute
def is_valid_list_index(list, index): def is_valid_list_index(list, index):
...@@ -993,18 +994,23 @@ def set_grad_var_shape(program, dist_context): ...@@ -993,18 +994,23 @@ def set_grad_var_shape(program, dist_context):
block = program.global_block() block = program.global_block()
vars = block.vars vars = block.vars
for op in block.ops: for op in block.ops:
if op.type in [
"sum", "check_finite_and_unscale", "update_loss_scaling" if op.type in ["check_finite_and_unscale", "update_loss_scaling"]:
]: break
if op.type in ["sum"]:
continue continue
if int(op.attr('op_role')) == int(OpRole.Backward): if int(op.attr('op_role')) == int(OpRole.Backward):
op_dist_attr = dist_context.get_op_dist_attr_for_program(op) op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
assert op_dist_attr is not None assert op_dist_attr is not None
for var_name in op.output_arg_names: for var_name in op.output_arg_names:
assert "@GRAD" in var_name assert "@GRAD" in var_name
forward_var_name = var_name[:var_name.find("@GRAD")] forward_var_name = var_name[:var_name.find("@GRAD")]
if op.type == "c_allreduce_sum" or op.type == "c_identity" or op.type == "scale": if op.type in [
"c_allreduce_sum", "c_identity", "scale", "cast"
]:
forward_var_name = op.input_arg_names[0] forward_var_name = op.input_arg_names[0]
elif op.type == "matmul_v2_grad": elif op.type == "matmul_v2_grad":
forward_var_name = None forward_var_name = None
...@@ -1038,6 +1044,7 @@ def set_grad_var_shape(program, dist_context): ...@@ -1038,6 +1044,7 @@ def set_grad_var_shape(program, dist_context):
forward_input_dist_attr = op_dist_attr.get_input_dist_attr( forward_input_dist_attr = op_dist_attr.get_input_dist_attr(
forward_var_name) forward_var_name)
assert forward_input_dist_attr is not None, f"{forward_var_name}" assert forward_input_dist_attr is not None, f"{forward_var_name}"
forward_var = vars[forward_var_name] forward_var = vars[forward_var_name]
forward_var_dist_attr = dist_context.get_tensor_dist_attr_for_program( forward_var_dist_attr = dist_context.get_tensor_dist_attr_for_program(
...@@ -1069,6 +1076,53 @@ def is_backward_op(op): ...@@ -1069,6 +1076,53 @@ def is_backward_op(op):
int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Backward) int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Backward)
def is_recompute_op(op):
return OP_ROLE_KEY in op.attr_names and \
int(op.all_attrs()[OP_ROLE_KEY]) == 9
def is_loss_op(op):
return OP_ROLE_KEY in op.attr_names and \
int(op.all_attrs()[OP_ROLE_KEY]) == (int(core.op_proto_and_checker_maker.OpRole.Forward) | int(core.op_proto_and_checker_maker.OpRole.Loss))
def get_loss_op(block):
loss_ops = []
for op in block.ops:
if is_loss_op(op):
assert len(op.desc.output_arg_names(
)) == 1, "loss op should only output loss var"
loss_ops.append(op)
assert len(loss_ops) == 1, "num of loss op is not equal to one"
return loss_ops[0]
def set_var_dist_attr(dist_context, var, dims_mapping, process_mesh, **kwargs):
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.dims_mapping = dims_mapping
# TODO get global mesh group
tensor_dist_attr.process_mesh = process_mesh
dist_context.set_tensor_dist_attr_for_program(var, tensor_dist_attr)
return tensor_dist_attr
def naive_set_dist_op_attr_for_program_by_mesh_and_mapping(new_op, process_mesh,
ref_mapping, ctx):
assert process_mesh is not None
assert ref_mapping is not None
new_op_dist_attr = OperatorDistributedAttribute()
for input_varname in new_op.desc.input_arg_names():
new_op_dist_attr.set_input_dims_mapping(input_varname, ref_mapping)
for output_varname in new_op.desc.output_arg_names():
new_op_dist_attr.set_output_dims_mapping(output_varname, ref_mapping)
new_op_dist_attr.process_mesh = process_mesh
ctx.set_op_dist_attr_for_program(new_op, new_op_dist_attr)
def update_op_dims_mapping_by_default_dist_impl(dist_op): def update_op_dims_mapping_by_default_dist_impl(dist_op):
changed = False changed = False
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
from .pass_base import new_pass, PassManager, PassContext from .pass_base import new_pass, PassManager, PassContext
from .fuse_all_reduce import * from .fuse_all_reduce import *
from .auto_parallel_sharding import *
from .cpp_pass import * from .cpp_pass import *
__all__ = [ __all__ = [
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import reduce
from collections import OrderedDict
import numpy as np
import paddle
from paddle.framework import core
from paddle.fluid import unique_name
from .pass_base import PassBase, register_pass
from paddle.distributed.fleet.meta_optimizers.common import is_backward_op, is_optimizer_op
from paddle.distributed.auto_parallel.process_group import get_world_process_groups, new_process_group
from paddle.distributed.auto_parallel.operators.common import is_parameter_related
from paddle.distributed.auto_parallel.utils import _get_comm_group, naive_set_dist_op_attr_for_program_by_mesh_and_mapping, set_var_dist_attr
OpRole = core.op_proto_and_checker_maker.OpRole
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
_skip_ops = ['create_py_reader', 'create_double_buffer_reader', 'read', 'slice']
# update here to support new optimizers
_supported_optimizer_type = [
"adam", "adamax", "adamw", "decayed_adagrad", "momentum", "dgc_momentum",
"lars_momentum", "merged_momentum", "lamb", "sgd"
]
# NOTE we add the "auto_parallel" prefix to the pass in order to
# indicate that this pass should obey some constrains by auto_parallel
# for example all ops and vars should has dist attr before and after pass
# should use dist op instead of custom comm op
@register_pass("auto_parallel_sharding")
class ShardingPass(PassBase):
def __init__(self):
super(ShardingPass, self).__init__()
self.set_attr("dist_context", None)
self.set_attr("stage", None)
self.set_attr("sharding_degree", None)
self.set_attr("params_grads", [])
self.set_attr("global_rank", -1)
self.dp_groups = set()
self.sharding_infos = []
self.varname_to_sharding_info = {}
self.partial_sharding = False
self.outer_dp_group = None
def _check_self(self):
if self.get_attr("dist_context") is None:
return False
if self.get_attr("stage") not in [1, 2, 3]:
return False
if (not isinstance(self.get_attr("sharding_degree"),
int)) or self.get_attr("sharding_degree") <= 1:
return False
if len(self.get_attr("params_grads")) <= 0:
return False
if (not isinstance(self.get_attr("global_rank"),
int)) or self.get_attr("global_rank") < 0:
return False
return True
def _check_conflict(self, other_pass):
return True
def _apply_single_impl(self, main_program, startup_program, context):
self._dist_context = self.get_attr("dist_context")
self.sharding_world_size = int(self.get_attr("sharding_degree"))
self.stage = int(self.get_attr("stage"))
self.global_rank = int(self.get_attr("global_rank"))
params_grads = self.get_attr("params_grads")
main_block, startup_block = main_program.global_block(
), startup_program.global_block()
self._build_sharding_groups(main_block, params_grads)
self._shard_optimizer(main_block, startup_block, params_grads, context)
self._shard_gradient_synchronization(main_block)
self._shard_parameter(main_block, startup_block)
def _build_sharding_groups(self, main_block, params_grads):
self._collective_data_parallel_groups(main_block)
self._build_sharding_infos(params_grads)
def _collective_data_parallel_groups(self, main_block):
for op in main_block.ops:
if op.type in _skip_ops:
continue
group = _inference_data_parallel_group_for_operator(
self.global_rank, op, self._dist_context)
if group is not None:
self.dp_groups.add(group)
# TODO(JZ-LIANG) allow more than one dp groups in network, support more general distribution
# genetated by auto search
if len(self.dp_groups) != 1:
raise NotImplementedError(
"So far Only and Exactly one data parallel group in network are supported, but got [{}] different data parallel groups".
format(len(groups)))
def _build_sharding_infos(self, params_grads):
for dp_group in self.dp_groups:
assert dp_group.nranks >= self.sharding_world_size, "sharding world size [{}] should not larger than dp world size [{}]".format(
self.sharding_world_size, dp_group.nranks)
assert dp_group.nranks % self.sharding_world_size == 0, "sharding world size [{}] should be divisible by dp world size [{}]".format(
self.sharding_world_size, dp_group.nranks)
assert self.global_rank in dp_group.ranks, "current ranks [{}] does NOT belong to the data parallel group [{}]".format(
self.global_rank, dp_group.ranks)
assert len(
params_grads
) >= self.sharding_world_size, "number of parameters [{}] is not enough to be shard among [{}] ranks".format(
len(params_grads), self.sharding_world_size)
# sharding hybrid data parallel: partial sharding param within
if dp_group.nranks > self.sharding_world_size:
self.partial_sharding = True
assert len(
self.dp_groups
) == 1, "hybrid sharding and data parallelism are supported only when there is excatly one data parallel group in the network"
outer_dp_group, sharding_group = _get_dp_and_sharding_groups(
dp_group.ranks, self.sharding_world_size, self.global_rank)
sharding_group = new_process_group(sharding_group)
self.outer_dp_group = new_process_group(outer_dp_group)
else:
sharding_group = dp_group
# TODO(JZ-LIANG) when support multiple dp groups in future, should group param and bind them to corresponding dp group
params_in_group = [p for p, g in params_grads]
assert len(params_in_group) == len(set(
params_in_group)), "found duplicated param in params_grads"
sharding_info = ShardingInfo(sharding_group, self.global_rank,
params_in_group)
self.sharding_infos.append(sharding_info)
for param in params_in_group:
self.varname_to_sharding_info[param.name] = sharding_info
def _shard_optimizer(self, main_block, startup_block, params_grads,
pass_context):
"""
sharding all optimizer related ops and vars, include:
gradient clip ops & vars
weight decay ops & vars
optimizer ops and states
"""
self._shard_amp_related_op_and_vars(main_block, pass_context)
self._shard_weight_decay(main_block)
self._shard_gradient_clip(main_block)
self._shard_optimizer_ops_and_states(main_block, startup_block)
self._insert_optimizer_broadcasts(main_block, startup_block)
def _shard_amp_related_op_and_vars(self, main_block, pass_context):
if self.stage < 2:
return
for idx, op in reversed(list(enumerate(main_block.ops))):
# shard amp related param_grad cast
if _is_param_grad_fp32_cast_op(main_block, op):
output_name = op.output_arg_names[0]
param_name = output_name[:output_name.find("@")]
if not self._is_parameter_in_local_shard(param_name):
main_block._remove_op(idx, sync=False)
main_block._remove_var(output_name, sync=False)
# shard check nan inf
elif op.type in ["check_finite_and_unscale", "update_loss_scaling"]:
reversed_x = []
for input_name in op.desc.input('X'):
param_name = input_name[:input_name.find("@")]
if self._is_parameter_in_local_shard(param_name):
reversed_x.append(input_name)
op.desc.set_input('X', reversed_x)
op.desc.set_output('Out', reversed_x)
main_block._sync_with_cpp()
def _shard_gradient_clip(self, main_block):
if self.stage < 2:
return
# TODO (JZ-LIANG) support calculate global norm with tensor parallelism
is_clip_grad_by_global_norm = False
for idx, op in list(enumerate(main_block.ops)):
if not _is_gradient_clip_op(op):
continue
if op.type == 'sum':
is_clip_grad_by_global_norm = True
break
if not is_clip_grad_by_global_norm:
return
removed_op_idx = set()
removed_tmp_var = set()
for idx, op in list(enumerate(main_block.ops)):
if not _is_gradient_clip_op(op):
continue
if op.type == 'sum':
reserved_vars = []
for input_name in op.input_arg_names:
if input_name not in removed_tmp_var:
reserved_vars.append(input_name)
op.desc.set_input("X", reserved_vars)
sum_op_output = op.desc.output_arg_names()[0]
for i, sharding_info in enumerate(self.sharding_infos):
new_op = main_block._insert_op(
idx + i,
type='c_allreduce_sum',
inputs={'X': [sum_op_output]},
outputs={'Out': [sum_op_output]},
attrs={
'ring_id': sharding_info.group.id,
'op_namescope': "/gradient_clip_model_parallelism",
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize,
})
dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
main_block.var(sum_op_output))
assert dist_attr is not None
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
new_op, dist_attr.process_mesh, dist_attr.dims_mapping,
self._dist_context)
break
for input_name in op.input_arg_names:
param_name = input_name[:input_name.find("@GRAD")]
if not self._is_parameter_in_local_shard(param_name):
removed_op_idx.add(idx)
for output_name in op.output_arg_names:
removed_tmp_var.add(output_name)
for idx, op in reversed(list(enumerate(main_block.ops))):
if not _is_gradient_clip_op(op):
continue
if idx in removed_op_idx:
main_block._remove_op(idx, sync=False)
for varname in removed_tmp_var:
main_block._remove_var(varname, sync=False)
main_block._sync_with_cpp()
def _shard_weight_decay(self, main_block):
if self.stage < 2:
return
for idx, op in reversed(list(enumerate(main_block.ops))):
if not _is_weight_decay_op(op):
continue
else:
raise NotImplementedError(
"weight decay is NOT supported by now")
main_block._sync_with_cpp()
def _shard_optimizer_ops_and_states(self, main_block, startup_block):
should_removed_optimizer_states = []
for idx, op in reversed(list(enumerate(main_block.ops))):
if not is_optimizer_op(op):
break
if op.type in _supported_optimizer_type:
assert "Param" in op.input_names
assert len(op.input("Param")) == 1
param_name = op.input("Param")[0]
if not self._is_parameter_in_local_shard(param_name):
should_removed_optimizer_states.extend([
varname for varname in op.output_arg_names
if varname != param_name
])
main_block._remove_op(idx, sync=False)
for idx, op in reversed(list(enumerate(startup_block.ops))):
if len(op.output_arg_names) == 1 and op.output_arg_names[
0] in should_removed_optimizer_states:
startup_block._remove_op(idx, sync=False)
for varname in should_removed_optimizer_states:
if main_block.has_var(varname):
main_block._remove_var(varname, sync=False)
if startup_block.has_var(varname):
startup_block._remove_var(varname, sync=False)
main_block._sync_with_cpp()
startup_block._sync_with_cpp()
def _insert_optimizer_broadcasts(self, main_block, startup_block):
if self.stage > 2:
return
for sharding_info in self.sharding_infos:
for param in sharding_info.params:
assert main_block.has_var(param.name)
assert startup_block.has_var(param.name)
new_op = main_block.append_op(
type='c_broadcast',
inputs={'X': param},
outputs={'Out': param},
attrs={
'ring_id': sharding_info.group.id,
'root': sharding_info.get_var_rank(param.name),
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize
})
param_dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
param)
assert param_dist_attr is not None
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
new_op, param_dist_attr.process_mesh,
param_dist_attr.dims_mapping, self._dist_context)
main_block._sync_with_cpp()
def _is_parameter_in_local_shard(self, param_name):
assert param_name in self.varname_to_sharding_info
sharding_info = self.varname_to_sharding_info[param_name]
return sharding_info.is_in_local_shard(param_name)
def _shard_gradient_synchronization(self, main_block):
if self.stage < 2:
return
dp_ring_ids = [group.id for group in self.dp_groups]
for idx, op in reversed(list(enumerate(main_block.ops))):
if _is_param_grad_allreduce_op(op, main_block, dp_ring_ids):
input_name = op.input_arg_names[0]
base_name = _get_base_name_from_grad_name(input_name)
sharding_info = self.varname_to_sharding_info[base_name]
_insert_reduce_op(
main_block, idx, input_name, sharding_info.group.id,
sharding_info.get_var_rank(base_name), self._dist_context)
if not self.partial_sharding:
main_block._remove_op(idx + 1, sync=False)
else:
op._set_attr("ring_id", self.outer_dp_group.id)
main_block._sync_with_cpp()
def _shard_parameter(self, main_block, startup_block):
if self.stage < 3:
return
dp_ring_ids = [group.id for group in self.dp_groups]
for sharding_info in self.sharding_infos:
need_broadcast_vars, param_usage = sharding_info.get_broadcast_vars_and_param_usage(
main_block)
not_used_param_nane = []
for param_name in param_usage:
if param_usage[param_name] == 0 and sharding_info.get_var_rank(
param_name) != sharding_info.local_rank:
not_used_param_nane.append(param_name)
for idx, op in reversed(list(enumerate(main_block.ops))):
if is_optimizer_op(op):
continue
for input_name in op.desc.input_arg_names():
if op.type == "cast":
continue
if input_name not in need_broadcast_vars:
continue
root_rank = sharding_info.get_var_rank(input_name)
if root_rank == sharding_info.local_rank:
broadcast_varname = input_name
else:
broadcast_varname = unique_name.generate(input_name +
"@BroadCast")
input_var = main_block.var(input_name)
new_var = main_block.create_var(
name=broadcast_varname,
shape=input_var.shape,
dtype=input_var.dtype,
persistable=False)
ref_dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
input_var)
out_var_dist_attr = set_var_dist_attr(
self._dist_context, new_var,
ref_dist_attr.dims_mapping,
ref_dist_attr.process_mesh)
op._rename_input(input_name, broadcast_varname)
_insert_init_and_broadcast_op(
main_block, idx, broadcast_varname,
sharding_info.local_rank, root_rank,
sharding_info.group.id,
op.attr('op_role'), self._dist_context)
for idx, op in reversed(list(enumerate(main_block.ops))):
if op.type != "cast":
continue
input_name = op.input_arg_names[0]
output_name = op.output_arg_names[0]
if input_name in not_used_param_nane:
main_block._remove_op(idx, sync=False)
main_block._remove_var(output_name, sync=False)
for idx, op in reversed(list(enumerate(startup_block.ops))):
assert len(op.output_arg_names) == 1
output_name = op.output_arg_names[0]
if op.type == "c_broadcast" and op.attr(
"ring_id") in dp_ring_ids:
if self.outer_dp_group and sharding_info.get_var_rank(
output_name) == sharding_info.local_rank:
op._set_attr("ring_id", self.outer_dp_group.id)
else:
startup_block._remove_op(idx, sync=False)
continue
if op.type != "c_broadcast" and output_name in not_used_param_nane:
startup_block._remove_op(idx, sync=False)
for varname in not_used_param_nane:
main_block._remove_var(varname, sync=False)
startup_block._remove_var(varname, sync=False)
main_block._sync_with_cpp()
startup_block._sync_with_cpp()
def _insert_init_and_broadcast_op(block, insert_idx, varname, local_rank,
root_rank, ring_id, op_role, dist_context):
"""
empty op for initialization
"""
broadcast_var = block.var(varname)
broadcast_var_dist_attr = dist_context.get_tensor_dist_attr_for_program(
broadcast_var)
new_op = block._insert_op_without_sync(
insert_idx,
type='c_broadcast',
inputs={'X': varname},
outputs={'Out': varname},
attrs={
'ring_id': ring_id,
'root': root_rank,
'use_calc_stream': True,
OP_ROLE_KEY: op_role
})
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
new_op, broadcast_var_dist_attr.process_mesh,
broadcast_var_dist_attr.dims_mapping, dist_context)
if local_rank != root_rank:
new_op = block._insert_op_without_sync(
insert_idx,
type="empty",
outputs={"Out": broadcast_var.name},
attrs={
"shape": broadcast_var.shape,
"dtype": broadcast_var.dtype,
OP_ROLE_KEY: op_role
})
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
new_op, broadcast_var_dist_attr.process_mesh,
broadcast_var_dist_attr.dims_mapping, dist_context)
return
def _insert_reduce_op(block,
insert_idx,
reduce_var,
ring_id,
root_id,
dist_context,
op_role=OpRole.Backward,
use_calc_stream=True):
assert root_id >= 0, "root id should be a positive int, but now root id is {}".format(
root_id)
new_op = block._insert_op_without_sync(
insert_idx,
type='c_reduce_sum',
inputs={'X': [reduce_var]},
outputs={'Out': [reduce_var]},
attrs={
'ring_id': ring_id,
'root_id': root_id,
'use_calc_stream': use_calc_stream,
OP_ROLE_KEY: op_role
})
dist_attr = dist_context.get_tensor_dist_attr_for_program(
block.var(reduce_var))
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
new_op, dist_attr.process_mesh, dist_attr.dims_mapping, dist_context)
def _get_dp_and_sharding_groups(origin_group, sharding_group_size, rank):
dp_axis = 0
sharding_axis = 1
shape = [len(origin_group) // sharding_group_size, sharding_group_size]
dp_group = _get_comm_group(origin_group, shape, dp_axis, rank)
sharding_group = _get_comm_group(origin_group, shape, sharding_axis, rank)
return dp_group, sharding_group
def _is_gradient_clip_op(op):
return op.desc.has_attr("op_namescope") \
and op.desc.attr("op_namescope").startswith("/gradient_clip")
def _is_weight_decay_op(op):
return op.desc.has_attr("op_namescope") \
and op.desc.attr("op_namescope").startswith("/regularization")
def _is_param_grad_fp32_cast_op(block, op):
if not is_backward_op(op):
return False
if not _is_desired_cast_op(block, op, core.VarDesc.VarType.FP16,
core.VarDesc.VarType.FP32):
return False
output_name = op.desc.output_arg_names()[0]
base_name = output_name[:output_name.find("@")]
if not block.has_var(base_name):
return False
return block.var(base_name).is_parameter
def _is_param_fp16_cast_op(block, op, params):
if is_optimizer_op(op):
return False
if not _is_desired_cast_op(block, op):
return False
input_name = op.desc.input_arg_names()[0]
if input_name not in params:
return False
return True
def _is_desired_cast_op(block,
op,
src_var_type=core.VarDesc.VarType.FP32,
dst_var_type=core.VarDesc.VarType.FP16):
if op.type != "cast":
return False
assert (len(op.desc.input_arg_names()) == 1)
assert (len(op.desc.output_arg_names()) == 1)
input_var = block.var(op.desc.input_arg_names()[0])
output_var = block.var(op.desc.output_arg_names()[0])
if input_var.dtype != src_var_type or \
output_var.dtype != dst_var_type:
return False
return True
def _get_base_name_from_grad_name(grad_name):
base_name = None
if ".cast_fp16@GRAD" in grad_name:
base_name = grad_name[:grad_name.find(".cast_fp16@GRAD")]
elif "@GRAD" in grad_name:
base_name = grad_name[:grad_name.find("@GRAD")]
return base_name
def _is_param_grad_allreduce_op(op, block, dp_ring_ids):
if not is_backward_op(op):
return False
if op.type != "c_allreduce_sum":
return False
if op.attr('ring_id') not in dp_ring_ids:
return False
output_name = op.output_arg_names[0]
base_name = _get_base_name_from_grad_name(output_name)
if not block.has_var(base_name):
return False
return block.var(base_name).is_parameter
def _inference_data_parallel_group_for_operator(rank_id, op, dist_context):
dp_group = None
for input_name in op.input_arg_names:
if not is_parameter_related(input_name, op.block):
dist_attr = dist_context.get_op_dist_attr_for_program(op)
process_mesh = dist_attr.process_mesh
input_dim_mapping = dist_attr.get_input_dims_mapping(input_name)
mesh_shape = process_mesh.topology
# TODO(JZ-LIANG) replace with specific batch size dimension
batch_size_axis = input_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
group_ranks = _get_comm_group(process_mesh.processes,
process_mesh.topology,
batch_size_axis, rank_id)
dp_group = new_process_group(group_ranks)
break
return dp_group
def shard_parameters(params, group_size):
# TODO(JZ-LIANG) support multiple partition methods
# method1: greedy even but unorder
# method2: roughly even with oreder
mapping = {}
for rank_ in range(group_size):
mapping[rank_] = []
sizes = [0] * group_size
for param in params:
rank = sizes.index(min(sizes))
mapping[rank].append(param)
numel = reduce(lambda x, y: x * y, param.shape)
assert numel > 0, "param [{}] should larger than 0, but it is [{}]".format(
param.name, numel)
sizes[rank] += numel
return mapping
class ShardingInfo(object):
def __init__(self, group, rank, params):
self.group = group
self.params = params
self.param_names = [p.name for p in self.params]
self.group_size = group.nranks
self.global_rank = rank
self.local_rank = group.ranks.index(self.global_rank)
# rank in below mapping are local rank in this sharding group
self.rank_to_params = shard_parameters(self.params, self.group_size)
# include fp32 and fp16 param
self.param_to_rank = dict()
self._map_param_to_rank()
def _map_param_to_rank(self):
"""
mapping parameters to the rank which holds it.
"""
for rank, params in self.rank_to_params.items():
for param in params:
self.param_to_rank[param.name] = rank
def get_var_rank(self, varname):
if varname in self.param_to_rank:
return self.param_to_rank[varname]
return -1
def is_in_local_shard(self, param_name):
return self.get_var_rank(param_name) == self.local_rank
def get_broadcast_vars_and_param_usage(self, block):
broadcast_vars = set([])
fp16_params = set([])
fp16_to_fp32 = {}
param_usage = {x: 0 for x in self.param_names}
for op in block.ops:
if is_optimizer_op(op):
continue
for input_name in op.desc.input_arg_names():
if input_name in self.param_names:
param_usage[input_name] += 1
for op in block.ops:
if not _is_param_fp16_cast_op(block, op, self.param_names):
continue
input_name = op.input_arg_names[0]
output_name = op.output_arg_names[0]
broadcast_vars.add(output_name)
fp16_params.add(output_name)
fp16_to_fp32[output_name] = input_name
param_usage[input_name] -= 1
self.param_to_rank[output_name] = self.param_to_rank[input_name]
for param, usage in param_usage.items():
if usage > 0:
broadcast_vars.add(param)
return broadcast_vars, param_usage
...@@ -5,4 +5,5 @@ foreach(TEST_OP ${TEST_OPS}) ...@@ -5,4 +5,5 @@ foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP}) py_test_modules(${TEST_OP} MODULES ${TEST_OP})
list(APPEND DIST_TEST_OPS ${TEST_OP}) list(APPEND DIST_TEST_OPS ${TEST_OP})
set_tests_properties(${TEST_OP} PROPERTIES TIMEOUT 90) set_tests_properties(${TEST_OP} PROPERTIES TIMEOUT 90)
set_tests_properties(${TEST_OP} PROPERTIES LABELS "RUN_TYPE=DIST")
endforeach(TEST_OP) endforeach(TEST_OP)
...@@ -63,12 +63,12 @@ class AutoPallelPassTestBase(DistPassTestBase): ...@@ -63,12 +63,12 @@ class AutoPallelPassTestBase(DistPassTestBase):
def check_main(self, gpus=None, **kwargs): def check_main(self, gpus=None, **kwargs):
no_pass_rets = self._distributed_launch( no_pass_rets = self._distributed_launch(
apply_pass=False, gpus=gpus, **kwargs) model=None, apply_pass=False, gpus=gpus, **kwargs)
pass_rets = self._distributed_launch( pass_rets = self._distributed_launch(
apply_pass=True, gpus=gpus, **kwargs) model=None, apply_pass=True, gpus=gpus, **kwargs)
self.check_results(no_pass_rets, pass_rets) self.check_results(no_pass_rets, pass_rets)
def _run_gpu_main(self, apply_pass, dump_file, **kwargs): def _run_gpu_main(self, model, apply_pass, dump_file, **kwargs):
gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0)) gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
place = paddle.CUDAPlace(gpu_id) place = paddle.CUDAPlace(gpu_id)
scope = paddle.static.Scope() scope = paddle.static.Scope()
...@@ -82,8 +82,8 @@ class AutoPallelPassTestBase(DistPassTestBase): ...@@ -82,8 +82,8 @@ class AutoPallelPassTestBase(DistPassTestBase):
with paddle.fluid.unique_name.guard(): with paddle.fluid.unique_name.guard():
main_prog, startup_prog, inputs, outputs, reader = self.get_model( main_prog, startup_prog, inputs, outputs, reader = self.get_model(
place, **kwargs) place, **kwargs)
inputs = self._to_var_names(main_prog, inputs) inputs = self._to_var_names(inputs)
outputs = self._to_var_names(main_prog, outputs) outputs = self._to_var_names(outputs)
all_fetch_values = [] all_fetch_values = []
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import random
import numpy as np
import unittest
import paddle
import paddle.nn as nn
import paddle.distributed.fleet as fleet
import paddle.distributed.auto_parallel as auto
from paddle.distributed.passes import new_pass, PassManager
from auto_parallel_pass_test_base import AutoPallelPassTestBase
sys.path.append("..")
import auto_parallel_gpt_model as modeling
from auto_parallel_gpt_model import GPTModel, GPTForPretraining, GPTPretrainingCriterion
class TestShardingPass(AutoPallelPassTestBase):
def init(self):
if paddle.is_compiled_with_cuda():
paddle.set_flags({'FLAGS_cudnn_deterministic': 1})
self.rtol = 1e-5
self.atol = 1e-8
rank = paddle.distributed.get_rank()
paddle.seed(rank + 2021)
random.seed(rank + 2021)
np.random.seed(rank + 2021)
def apply_passes(self):
dist_strategy = fleet.DistributedStrategy()
dist_strategy.semi_auto = True
dist_strategy.sharding = True
dist_strategy.sharding_configs = {
"sharding_degree": 2,
"stage": 3,
}
fleet.init(is_collective=True, strategy=dist_strategy)
def apply_no_passes(self):
dist_strategy = fleet.DistributedStrategy()
dist_strategy.pipeline = False
dist_strategy.recompute = False
dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy)
def test_bs_8(self):
self.check_main(
gpus=[0, 1], batch_size=8, sequence_len=512, vocab_size=1000)
def get_model(self, place, batch_size, sequence_len, vocab_size):
return self.get_gpt_model('dp', place, batch_size, sequence_len,
vocab_size)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册