未验证 提交 e3faf345 编写于 作者: J JZ-LIANG 提交者: GitHub

[Auto Parallel] Sharding Pass (#38502)

* auto parallel sharding base

* chmod

* add unitest

* set unitest cmake dist label

* revise code according to rewiew

* chmod
上级 9456170f
...@@ -45,6 +45,7 @@ message ShardingConfig { ...@@ -45,6 +45,7 @@ message ShardingConfig {
optional bool optimize_cast = 12 [ default = false ]; optional bool optimize_cast = 12 [ default = false ];
// Optimizer sharding. Temporary plans and may be deprecated // Optimizer sharding. Temporary plans and may be deprecated
optional bool _dp_as_optimizer_sharding = 13 [ default = false ]; optional bool _dp_as_optimizer_sharding = 13 [ default = false ];
optional int32 stage = 14 [ default = 1 ];
} }
message HybridConfig { message HybridConfig {
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from ..dist_attribute import OperatorDistributedAttribute from ..dist_attribute import OperatorDistributedAttribute
_g_distributed_operator_impl_registries = {} _g_distributed_operator_impl_registries = {}
BACKWARD_ONLY_DIST_OPS = {'check_finite_and_unscale'}
class DistributedOperatorImplContainer: class DistributedOperatorImplContainer:
...@@ -116,6 +117,14 @@ def find_best_compatible_distributed_operator_impl(name, dist_op, fwd=True): ...@@ -116,6 +117,14 @@ def find_best_compatible_distributed_operator_impl(name, dist_op, fwd=True):
return best_compatible_impl, idx return best_compatible_impl, idx
def is_parameter_related(varname, block):
if ".cast_fp" in varname:
varname = varname[:varname.index(".cast_fp")]
assert block.has_var(varname)
var = block.var(varname)
return var.is_parameter
def infer_shape(block, src_var, src_var_dist_attr, op_input_dist_attr): def infer_shape(block, src_var, src_var_dist_attr, op_input_dist_attr):
var_shape = block.var(src_var.name).shape var_shape = block.var(src_var.name).shape
var_topoloy = src_var_dist_attr.process_mesh.topology var_topoloy = src_var_dist_attr.process_mesh.topology
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
from .common import DistributedOperatorImplContainer from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl from .common import register_distributed_operator_impl, is_parameter_related
from ..utils import is_dim_shard from ..utils import is_dim_shard
from ..utils import is_dim_replicate from ..utils import is_dim_replicate
from ..utils import is_valid_list_index from ..utils import is_valid_list_index
...@@ -183,8 +183,8 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -183,8 +183,8 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
need_gradient_allreduce = False need_gradient_allreduce = False
for input_name in backward_op.desc.input_names(): for input_name in backward_op.desc.input_names():
for varname in backward_op.desc.input(input_name): for varname in backward_op.desc.input(input_name):
if "@GRAD" not in varname and not main_block.var( if "@GRAD" not in varname and not is_parameter_related(
varname).is_parameter: varname, main_block):
# NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op # NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op
process_mesh = dist_attr.process_mesh process_mesh = dist_attr.process_mesh
...@@ -210,8 +210,8 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -210,8 +210,8 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
allreduce_vars = [] allreduce_vars = []
for input_name in backward_op.desc.input_names(): for input_name in backward_op.desc.input_names():
for varname in backward_op.desc.input(input_name): for varname in backward_op.desc.input(input_name):
if "@GRAD" not in varname and main_block.var( if "@GRAD" not in varname and is_parameter_related(
varname).is_parameter: varname, main_block):
assert len( assert len(
backward_op.desc.input(input_name) backward_op.desc.input(input_name)
) == 1, "parameter input to grad op should be length 1, but got [{}]".format( ) == 1, "parameter input to grad op should be length 1, but got [{}]".format(
......
...@@ -16,7 +16,7 @@ from .common import infer_shape ...@@ -16,7 +16,7 @@ from .common import infer_shape
from .common import DistributedOperatorImplContainer from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl, set_comm_op_dist_attr_for_program, naive_copy_op_dist_attr_for_program from .common import register_distributed_operator_impl, set_comm_op_dist_attr_for_program, naive_copy_op_dist_attr_for_program, is_parameter_related
from ..utils import is_dim_shard from ..utils import is_dim_shard
from ..utils import is_dim_replicate from ..utils import is_dim_replicate
from ..utils import is_valid_list_index from ..utils import is_valid_list_index
...@@ -26,7 +26,7 @@ from ..utils import compute_compatible_and_update_dim_mapping ...@@ -26,7 +26,7 @@ from ..utils import compute_compatible_and_update_dim_mapping
from ..dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute from ..dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute
from paddle.fluid import core, unique_name from paddle.fluid import core, unique_name
from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard from paddle.fluid.framework import Program, Parameter, Variable
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from ..process_group import new_process_group from ..process_group import new_process_group
...@@ -283,34 +283,35 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -283,34 +283,35 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
allreduce_op_dist_attr) allreduce_op_dist_attr)
# param initialization sync # param initialization sync
assert Weight_var.name not in dist_op_context.already_init_sync_vars if Weight_var.is_parameter:
dist_op_context.already_init_sync_vars.add(Weight_var.name) assert Weight_var.name not in dist_op_context.already_init_sync_vars
param = startup_block.var(Weight_var.name) dist_op_context.already_init_sync_vars.add(Weight_var.name)
param_dist_attr = ctx.get_tensor_dist_attr_for_program(param) param = startup_block.var(Weight_var.name)
process_mesh = param_dist_attr.process_mesh param_dist_attr = ctx.get_tensor_dist_attr_for_program(param)
dim_mapping = param_dist_attr.dims_mapping process_mesh = param_dist_attr.process_mesh
dim_mapping = param_dist_attr.dims_mapping
# NOTE all not splited axis should be presented in mesh
for axis, size in enumerate(process_mesh.topology): # NOTE all not splited axis should be presented in mesh
if size <= 1 or axis in dim_mapping: for axis, size in enumerate(process_mesh.topology):
pass if size <= 1 or axis in dim_mapping:
else: pass
group_ranks = _get_comm_group(process_mesh.processes, else:
process_mesh.topology, axis, group_ranks = _get_comm_group(process_mesh.processes,
rank_id) process_mesh.topology, axis,
sync_group = new_process_group(group_ranks) rank_id)
sync_group = new_process_group(group_ranks)
startup_block.append_op(
type='c_broadcast', startup_block.append_op(
inputs={'X': param}, type='c_broadcast',
outputs={'Out': param}, inputs={'X': param},
attrs={ outputs={'Out': param},
'ring_id': sync_group.id, attrs={
'root': 0, 'ring_id': sync_group.id,
'use_calc_stream': True, 'root': 0,
OP_ROLE_KEY: OpRole.Forward 'use_calc_stream': True,
}) OP_ROLE_KEY: OpRole.Forward
startup_block._sync_with_cpp() })
startup_block._sync_with_cpp()
@staticmethod @staticmethod
def backward(ctx, *args, **kwargs): def backward(ctx, *args, **kwargs):
......
...@@ -18,7 +18,7 @@ from .common import DistributedOperatorImplContainer ...@@ -18,7 +18,7 @@ from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl from .common import register_distributed_operator_impl
from .common import set_comm_op_dist_attr_for_program, naive_copy_op_dist_attr_for_program from .common import set_comm_op_dist_attr_for_program, naive_copy_op_dist_attr_for_program, is_parameter_related
from ..utils import is_dim_shard from ..utils import is_dim_shard
from ..utils import is_dim_replicate from ..utils import is_dim_replicate
from ..utils import is_valid_list_index from ..utils import is_valid_list_index
...@@ -184,7 +184,9 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): ...@@ -184,7 +184,9 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
Out_grad = main_block.var(kwargs['Out@GRAD'][0]) Out_grad = main_block.var(kwargs['Out@GRAD'][0])
Y_grad = main_block.var(kwargs['Y@GRAD'][0]) Y_grad = main_block.var(kwargs['Y@GRAD'][0])
assert not X_var.is_parameter, "left operand(X) [{}] of dist matmul should not be parameter".format( assert not is_parameter_related(
X_var.name, main_block
), "left operand(X) [{}] of dist matmul should not be parameter".format(
X_var.name) X_var.name)
Y_var_dim_mapping = dist_attr.get_input_dims_mapping(Y_var.name) Y_var_dim_mapping = dist_attr.get_input_dims_mapping(Y_var.name)
...@@ -200,7 +202,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): ...@@ -200,7 +202,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
Y_var_partitioned = True Y_var_partitioned = True
break break
if Y_var.is_parameter and Y_var_partitioned: if is_parameter_related(Y_var.name, main_block) and Y_var_partitioned:
if Y_var_dim_mapping[0] >= 0: if Y_var_dim_mapping[0] >= 0:
# row parallel: c_identity + matmul # row parallel: c_identity + matmul
...@@ -322,7 +324,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): ...@@ -322,7 +324,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
dp_degree = len(group_ranks) dp_degree = len(group_ranks)
dp_group = new_process_group(group_ranks) dp_group = new_process_group(group_ranks)
if need_gradient_allreduce and Y_var.is_parameter: if need_gradient_allreduce and is_parameter_related(Y_var.name, main_block):
Y_Grad_var = main_block.var(kwargs['Y@GRAD'][0]) Y_Grad_var = main_block.var(kwargs['Y@GRAD'][0])
allreduce_op = main_block.append_op( allreduce_op = main_block.append_op(
type='c_allreduce_sum', type='c_allreduce_sum',
...@@ -444,6 +446,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -444,6 +446,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
y_dims_mapping), "now just support x dims > y dims" y_dims_mapping), "now just support x dims > y dims"
if len(y_dims_mapping) != 2: if len(y_dims_mapping) != 2:
return False return False
if len(x_dims_mapping) == len(y_dims_mapping) and len( if len(x_dims_mapping) == len(y_dims_mapping) and len(
x_dims_mapping) == 4: x_dims_mapping) == 4:
if x_dims_mapping[:2] != y_dims_mapping[:2]: if x_dims_mapping[:2] != y_dims_mapping[:2]:
......
...@@ -27,6 +27,7 @@ from paddle.distributed.utils import get_logger ...@@ -27,6 +27,7 @@ from paddle.distributed.utils import get_logger
from paddle.distributed.fleet import cloud_utils from paddle.distributed.fleet import cloud_utils
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid import program_guard from paddle.fluid import program_guard
from paddle.distributed.passes import new_pass, PassContext
from .dist_context import DistributedContext from .dist_context import DistributedContext
from .dist_context import get_default_distributed_context from .dist_context import get_default_distributed_context
from .dist_context import set_default_distributed_context from .dist_context import set_default_distributed_context
...@@ -139,23 +140,9 @@ class AutoParallelizer: ...@@ -139,23 +140,9 @@ class AutoParallelizer:
def _apply_optimize(self, main_program, startup_program, params_grads): def _apply_optimize(self, main_program, startup_program, params_grads):
if self._dist_strategy.sharding: with program_guard(main_program, startup_program):
auto_parallel_sharding_pass = new_pass( optimize_ops = copy.deepcopy(self._optimizer).apply_gradients(
"auto_parallel_sharding_pass", self._dist_strategy) params_grads)
params_grads = auto_parallel_sharding_pass.apply(
main_program, startup_program, params_grads, self._pass_context)
if self._dist_strategy.gradient_merge:
auto_parallel_gradient_merge_pass = new_pass(
"auto_parallel_gradient_merge_pass",
self._dist_strategy.gradient_merge_configs)
auto_parallel_gradient_merge_pass.apply(
main_program, startup_program, params_grads, self._pass_context)
else:
with program_guard(main_program, startup_program):
optimizer = copy.deepcopy(self._optimizer)
optimize_ops = optimizer.apply_gradients(params_grads)
# update completion # update completion
complete_update_annotation( complete_update_annotation(
...@@ -163,6 +150,19 @@ class AutoParallelizer: ...@@ -163,6 +150,19 @@ class AutoParallelizer:
return optimize_ops return optimize_ops
def _apply_post_optimization_passed(self, main_program, startup_program,
rank, params_grads):
if self._dist_strategy.sharding:
config = copy.deepcopy(self._dist_strategy.sharding_configs)
config["dist_context"] = self._dist_context
config["params_grads"] = params_grads
config["global_rank"] = rank
auto_parallel_sharding_pass = new_pass("auto_parallel_sharding",
config)
auto_parallel_sharding_pass.apply(
[main_program], [startup_program], self._pass_context)
def _get_dist_program(self, rank, dist_context=None, relaunch_phase=False): def _get_dist_program(self, rank, dist_context=None, relaunch_phase=False):
completed_main_program = None completed_main_program = None
serial_main_program = self._main_program.clone() serial_main_program = self._main_program.clone()
...@@ -203,7 +203,8 @@ class AutoParallelizer: ...@@ -203,7 +203,8 @@ class AutoParallelizer:
make_data_unshard(dist_main_prog, dist_startup_prog, self._dist_context) make_data_unshard(dist_main_prog, dist_startup_prog, self._dist_context)
reshard(dist_main_prog, dist_startup_prog, rank, self._dist_context) reshard(dist_main_prog, dist_startup_prog, rank, self._dist_context)
self._apply_post_optimization_passed(dist_main_prog, dist_startup_prog,
rank, dist_params_grads)
g_process_group_map = None g_process_group_map = None
if not relaunch_phase: if not relaunch_phase:
g_process_group_map = copy.deepcopy(_g_process_group_map) g_process_group_map = copy.deepcopy(_g_process_group_map)
......
...@@ -24,7 +24,8 @@ from paddle.distributed.auto_parallel.operators.common import get_distributed_op ...@@ -24,7 +24,8 @@ from paddle.distributed.auto_parallel.operators.common import get_distributed_op
from paddle.distributed.auto_parallel.dist_context import DistributedContext, DistributedOperatorContext from paddle.distributed.auto_parallel.dist_context import DistributedContext, DistributedOperatorContext
from .dist_attribute import OperatorDistributedAttribute from .dist_attribute import OperatorDistributedAttribute
from .process_group import new_process_group from .process_group import new_process_group
from .utils import print_program_with_dist_attr, is_forward_op, is_backward_op from .utils import print_program_with_dist_attr, is_forward_op, is_backward_op, is_recompute_op
from .operators.common import BACKWARD_ONLY_DIST_OPS
__varname_not_in_block__ = ["lod_tensor_blocking_queue_0"] __varname_not_in_block__ = ["lod_tensor_blocking_queue_0"]
...@@ -102,22 +103,17 @@ class Partitioner(object): ...@@ -102,22 +103,17 @@ class Partitioner(object):
partitioned_startup_prog = fluid.Program() partitioned_startup_prog = fluid.Program()
ref_block = serial_main_program.global_block() ref_block = serial_main_program.global_block()
target_block = partitioned_startup_prog.global_block() target_block = partitioned_startup_prog.global_block()
param2shape = {} var2shape = {}
temp_varname_map = {} temp_varname_map = {}
# tensors # tensors
for var in serial_startup_program.list_vars(): for var in serial_startup_program.list_vars():
if isinstance(var, Parameter): assert var.persistable
# TODO if var not belong to this rank, should be filtered new_name = var.name + self._dist_varname_suffix
serial_main_var = ref_block.var(var.name) temp_varname_map[var.name] = new_name
dist_attr = self._dist_context.get_tensor_dist_attr_for_program( target_shape = _partition_var(self._dist_context, ref_block,
serial_main_var) target_block, var.name, new_name)
target_shape = _get_dist_shape(serial_main_var, dist_attr) var2shape[new_name] = target_shape
new_name = var.name + self._dist_varname_suffix
temp_varname_map[var.name] = new_name
_partition_parameter(self._dist_context, serial_main_var,
target_block, new_name, target_shape)
param2shape[new_name] = target_shape
# ops # ops
for op in serial_startup_program.global_block().ops: for op in serial_startup_program.global_block().ops:
...@@ -128,14 +124,14 @@ class Partitioner(object): ...@@ -128,14 +124,14 @@ class Partitioner(object):
) == 1, "initializer should output only ONE variable, but got [{}]".format( ) == 1, "initializer should output only ONE variable, but got [{}]".format(
str(op.desc)) str(op.desc))
assert temp_varname_map[output_vars[ assert temp_varname_map[output_vars[
0]] in param2shape, "try to initialize [{}] which is not a Parameter".format( 0]] in var2shape, "try to initialize [{}] which is not a persistable var".format(
output_vars[0]) output_vars[0])
new_op_desc = target_block.desc.append_op() new_op_desc = target_block.desc.append_op()
new_op_desc.copy_from(op.desc) new_op_desc.copy_from(op.desc)
new_op_desc._rename_output(output_vars[0], new_op_desc._rename_output(output_vars[0],
temp_varname_map[output_vars[0]]) temp_varname_map[output_vars[0]])
new_op_desc._set_attr("shape", new_op_desc._set_attr("shape",
param2shape[temp_varname_map[output_vars[0]]]) var2shape[temp_varname_map[output_vars[0]]])
target_block._sync_with_cpp() target_block._sync_with_cpp()
# set distribute atrribute # set distribute atrribute
...@@ -211,7 +207,6 @@ class Partitioner(object): ...@@ -211,7 +207,6 @@ class Partitioner(object):
**koutputs) **koutputs)
elif is_backward_op(op): elif is_backward_op(op):
print(str(op))
kinputs, koutputs = dist_op_context.prepare_context(op) kinputs, koutputs = dist_op_context.prepare_context(op)
dist_op_backward_impl = _get_dist_op_backward_implement( dist_op_backward_impl = _get_dist_op_backward_implement(
op, self._dist_context, forward_op_id2forward_op) op, self._dist_context, forward_op_id2forward_op)
...@@ -351,6 +346,7 @@ def _partition_var(dist_context, src_block, dst_block, src_varname, ...@@ -351,6 +346,7 @@ def _partition_var(dist_context, src_block, dst_block, src_varname,
name=dst_varname, name=dst_varname,
persistable=True, persistable=True,
stop_gradient=True) stop_gradient=True)
target_shape = None
else: else:
dist_attr = dist_context.get_tensor_dist_attr_for_program(src_var) dist_attr = dist_context.get_tensor_dist_attr_for_program(src_var)
target_shape = _get_dist_shape(src_var, dist_attr) target_shape = _get_dist_shape(src_var, dist_attr)
...@@ -361,6 +357,7 @@ def _partition_var(dist_context, src_block, dst_block, src_varname, ...@@ -361,6 +357,7 @@ def _partition_var(dist_context, src_block, dst_block, src_varname,
else: else:
_partition_intermediate_var(dist_context, src_var, dst_block, _partition_intermediate_var(dist_context, src_var, dst_block,
dst_varname, target_shape) dst_varname, target_shape)
return target_shape
def _get_dist_op_backward_implement(backward_op, dist_context, def _get_dist_op_backward_implement(backward_op, dist_context,
...@@ -371,25 +368,32 @@ def _get_dist_op_backward_implement(backward_op, dist_context, ...@@ -371,25 +368,32 @@ def _get_dist_op_backward_implement(backward_op, dist_context,
forward_op = forward_op_id2forward_op[forward_op_id] forward_op = forward_op_id2forward_op[forward_op_id]
forward_op_dist_attr = dist_context.get_op_dist_attr_for_program( forward_op_dist_attr = dist_context.get_op_dist_attr_for_program(
forward_op) forward_op)
dist_ops = get_distributed_operator_impl_container(forward_op.type) dist_op = get_distributed_operator_impl_container(forward_op.type)
# TODO backward should have its own impl_idx # TODO backward should have its own impl_idx
if dist_ops and forward_op_dist_attr.impl_idx >= 0 and dist_ops.get_impl( \ if dist_op and forward_op_dist_attr.impl_idx >= 0 and dist_op.get_impl( \
forward_op_dist_attr.impl_idx)._backward_implemented: forward_op_dist_attr.impl_idx)._backward_implemented:
return dist_ops.get_impl(forward_op_dist_attr.impl_idx) return dist_op.get_impl(forward_op_dist_attr.impl_idx)
dist_ops = get_distributed_operator_impl_container("default") # NOTE trick for dist ops that only have backward implement
return dist_ops.get_impl(0) if backward_op.type in BACKWARD_ONLY_DIST_OPS:
op_dist_attr = dist_context.get_op_dist_attr_for_program(backward_op)
assert op_dist_attr.impl_idx >= 0
return get_distributed_operator_impl_container(
backward_op.type).get_impl(op_dist_attr.impl_idx)
dist_op = get_distributed_operator_impl_container("default")
return dist_op.get_impl(0)
def _get_dist_op_forward_implement(forward_op, dist_context): def _get_dist_op_forward_implement(forward_op, dist_context):
dist_attr = dist_context.get_op_dist_attr_for_program(forward_op) dist_attr = dist_context.get_op_dist_attr_for_program(forward_op)
dist_ops = get_distributed_operator_impl_container(forward_op.type) dist_op = get_distributed_operator_impl_container(forward_op.type)
if dist_ops and dist_attr.impl_idx >= 0 and dist_ops.get_impl( if dist_op and dist_attr.impl_idx >= 0 and dist_op.get_impl(
dist_attr.impl_idx)._forward_implemented: dist_attr.impl_idx)._forward_implemented:
return dist_ops.get_impl(dist_attr.impl_idx) return dist_op.get_impl(dist_attr.impl_idx)
else: else:
dist_ops = get_distributed_operator_impl_container("default") dist_op = get_distributed_operator_impl_container("default")
return dist_ops.get_impl(0) return dist_op.get_impl(0)
...@@ -25,6 +25,7 @@ import paddle.fluid.core as core ...@@ -25,6 +25,7 @@ import paddle.fluid.core as core
from paddle.framework.io import _to_LodTensor from paddle.framework.io import _to_LodTensor
from paddle.distributed.fleet.meta_optimizers.common import OpRole from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.fluid.io import is_parameter, is_belong_to_optimizer from paddle.fluid.io import is_parameter, is_belong_to_optimizer
from paddle.distributed.auto_parallel.dist_attribute import TensorDistributedAttribute, OperatorDistributedAttribute
def is_valid_list_index(list, index): def is_valid_list_index(list, index):
...@@ -993,18 +994,23 @@ def set_grad_var_shape(program, dist_context): ...@@ -993,18 +994,23 @@ def set_grad_var_shape(program, dist_context):
block = program.global_block() block = program.global_block()
vars = block.vars vars = block.vars
for op in block.ops: for op in block.ops:
if op.type in [
"sum", "check_finite_and_unscale", "update_loss_scaling" if op.type in ["check_finite_and_unscale", "update_loss_scaling"]:
]: break
if op.type in ["sum"]:
continue continue
if int(op.attr('op_role')) == int(OpRole.Backward): if int(op.attr('op_role')) == int(OpRole.Backward):
op_dist_attr = dist_context.get_op_dist_attr_for_program(op) op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
assert op_dist_attr is not None assert op_dist_attr is not None
for var_name in op.output_arg_names: for var_name in op.output_arg_names:
assert "@GRAD" in var_name assert "@GRAD" in var_name
forward_var_name = var_name[:var_name.find("@GRAD")] forward_var_name = var_name[:var_name.find("@GRAD")]
if op.type == "c_allreduce_sum" or op.type == "c_identity" or op.type == "scale": if op.type in [
"c_allreduce_sum", "c_identity", "scale", "cast"
]:
forward_var_name = op.input_arg_names[0] forward_var_name = op.input_arg_names[0]
elif op.type == "matmul_v2_grad": elif op.type == "matmul_v2_grad":
forward_var_name = None forward_var_name = None
...@@ -1038,6 +1044,7 @@ def set_grad_var_shape(program, dist_context): ...@@ -1038,6 +1044,7 @@ def set_grad_var_shape(program, dist_context):
forward_input_dist_attr = op_dist_attr.get_input_dist_attr( forward_input_dist_attr = op_dist_attr.get_input_dist_attr(
forward_var_name) forward_var_name)
assert forward_input_dist_attr is not None, f"{forward_var_name}" assert forward_input_dist_attr is not None, f"{forward_var_name}"
forward_var = vars[forward_var_name] forward_var = vars[forward_var_name]
forward_var_dist_attr = dist_context.get_tensor_dist_attr_for_program( forward_var_dist_attr = dist_context.get_tensor_dist_attr_for_program(
...@@ -1069,6 +1076,53 @@ def is_backward_op(op): ...@@ -1069,6 +1076,53 @@ def is_backward_op(op):
int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Backward) int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Backward)
def is_recompute_op(op):
return OP_ROLE_KEY in op.attr_names and \
int(op.all_attrs()[OP_ROLE_KEY]) == 9
def is_loss_op(op):
return OP_ROLE_KEY in op.attr_names and \
int(op.all_attrs()[OP_ROLE_KEY]) == (int(core.op_proto_and_checker_maker.OpRole.Forward) | int(core.op_proto_and_checker_maker.OpRole.Loss))
def get_loss_op(block):
loss_ops = []
for op in block.ops:
if is_loss_op(op):
assert len(op.desc.output_arg_names(
)) == 1, "loss op should only output loss var"
loss_ops.append(op)
assert len(loss_ops) == 1, "num of loss op is not equal to one"
return loss_ops[0]
def set_var_dist_attr(dist_context, var, dims_mapping, process_mesh, **kwargs):
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.dims_mapping = dims_mapping
# TODO get global mesh group
tensor_dist_attr.process_mesh = process_mesh
dist_context.set_tensor_dist_attr_for_program(var, tensor_dist_attr)
return tensor_dist_attr
def naive_set_dist_op_attr_for_program_by_mesh_and_mapping(new_op, process_mesh,
ref_mapping, ctx):
assert process_mesh is not None
assert ref_mapping is not None
new_op_dist_attr = OperatorDistributedAttribute()
for input_varname in new_op.desc.input_arg_names():
new_op_dist_attr.set_input_dims_mapping(input_varname, ref_mapping)
for output_varname in new_op.desc.output_arg_names():
new_op_dist_attr.set_output_dims_mapping(output_varname, ref_mapping)
new_op_dist_attr.process_mesh = process_mesh
ctx.set_op_dist_attr_for_program(new_op, new_op_dist_attr)
def update_op_dims_mapping_by_default_dist_impl(dist_op): def update_op_dims_mapping_by_default_dist_impl(dist_op):
changed = False changed = False
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
from .pass_base import new_pass, PassManager, PassContext from .pass_base import new_pass, PassManager, PassContext
from .fuse_all_reduce import * from .fuse_all_reduce import *
from .auto_parallel_sharding import *
from .cpp_pass import * from .cpp_pass import *
__all__ = [ __all__ = [
......
...@@ -5,4 +5,5 @@ foreach(TEST_OP ${TEST_OPS}) ...@@ -5,4 +5,5 @@ foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP}) py_test_modules(${TEST_OP} MODULES ${TEST_OP})
list(APPEND DIST_TEST_OPS ${TEST_OP}) list(APPEND DIST_TEST_OPS ${TEST_OP})
set_tests_properties(${TEST_OP} PROPERTIES TIMEOUT 90) set_tests_properties(${TEST_OP} PROPERTIES TIMEOUT 90)
set_tests_properties(${TEST_OP} PROPERTIES LABELS "RUN_TYPE=DIST")
endforeach(TEST_OP) endforeach(TEST_OP)
...@@ -63,12 +63,12 @@ class AutoPallelPassTestBase(DistPassTestBase): ...@@ -63,12 +63,12 @@ class AutoPallelPassTestBase(DistPassTestBase):
def check_main(self, gpus=None, **kwargs): def check_main(self, gpus=None, **kwargs):
no_pass_rets = self._distributed_launch( no_pass_rets = self._distributed_launch(
apply_pass=False, gpus=gpus, **kwargs) model=None, apply_pass=False, gpus=gpus, **kwargs)
pass_rets = self._distributed_launch( pass_rets = self._distributed_launch(
apply_pass=True, gpus=gpus, **kwargs) model=None, apply_pass=True, gpus=gpus, **kwargs)
self.check_results(no_pass_rets, pass_rets) self.check_results(no_pass_rets, pass_rets)
def _run_gpu_main(self, apply_pass, dump_file, **kwargs): def _run_gpu_main(self, model, apply_pass, dump_file, **kwargs):
gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0)) gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
place = paddle.CUDAPlace(gpu_id) place = paddle.CUDAPlace(gpu_id)
scope = paddle.static.Scope() scope = paddle.static.Scope()
...@@ -82,8 +82,8 @@ class AutoPallelPassTestBase(DistPassTestBase): ...@@ -82,8 +82,8 @@ class AutoPallelPassTestBase(DistPassTestBase):
with paddle.fluid.unique_name.guard(): with paddle.fluid.unique_name.guard():
main_prog, startup_prog, inputs, outputs, reader = self.get_model( main_prog, startup_prog, inputs, outputs, reader = self.get_model(
place, **kwargs) place, **kwargs)
inputs = self._to_var_names(main_prog, inputs) inputs = self._to_var_names(inputs)
outputs = self._to_var_names(main_prog, outputs) outputs = self._to_var_names(outputs)
all_fetch_values = [] all_fetch_values = []
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import random
import numpy as np
import unittest
import paddle
import paddle.nn as nn
import paddle.distributed.fleet as fleet
import paddle.distributed.auto_parallel as auto
from paddle.distributed.passes import new_pass, PassManager
from auto_parallel_pass_test_base import AutoPallelPassTestBase
sys.path.append("..")
import auto_parallel_gpt_model as modeling
from auto_parallel_gpt_model import GPTModel, GPTForPretraining, GPTPretrainingCriterion
class TestShardingPass(AutoPallelPassTestBase):
def init(self):
if paddle.is_compiled_with_cuda():
paddle.set_flags({'FLAGS_cudnn_deterministic': 1})
self.rtol = 1e-5
self.atol = 1e-8
rank = paddle.distributed.get_rank()
paddle.seed(rank + 2021)
random.seed(rank + 2021)
np.random.seed(rank + 2021)
def apply_passes(self):
dist_strategy = fleet.DistributedStrategy()
dist_strategy.semi_auto = True
dist_strategy.sharding = True
dist_strategy.sharding_configs = {
"sharding_degree": 2,
"stage": 3,
}
fleet.init(is_collective=True, strategy=dist_strategy)
def apply_no_passes(self):
dist_strategy = fleet.DistributedStrategy()
dist_strategy.pipeline = False
dist_strategy.recompute = False
dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy)
def test_bs_8(self):
self.check_main(
gpus=[0, 1], batch_size=8, sequence_len=512, vocab_size=1000)
def get_model(self, place, batch_size, sequence_len, vocab_size):
return self.get_gpt_model('dp', place, batch_size, sequence_len,
vocab_size)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册