From 025053b46f3711981181469714be143c829b6dd7 Mon Sep 17 00:00:00 2001 From: zhaoyingli <86812880+zhaoyinglia@users.noreply.github.com> Date: Wed, 24 Nov 2021 15:12:53 +0800 Subject: [PATCH] Adapt auto search (#37490) * adapt auto search * adapt auto search * fix matmulv2 compatible * del debug --- .../framework/distributed_strategy.proto | 1 + .../distributed/auto_parallel/completion.py | 55 ++-- .../distributed/auto_parallel/dist_context.py | 22 ++ .../distributed/auto_parallel/dist_op.py | 11 + .../distributed/auto_parallel/dist_tensor.py | 11 + .../auto_parallel/operators/common.py | 58 ++-- .../auto_parallel/operators/dist_embedding.py | 65 +++- .../auto_parallel/operators/dist_matmul.py | 280 +++++++++++++++--- .../distributed/auto_parallel/parallelizer.py | 39 ++- .../paddle/distributed/auto_parallel/utils.py | 55 ++++ .../fleet/base/distributed_strategy.py | 23 ++ 11 files changed, 502 insertions(+), 118 deletions(-) diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index bd84471e63..5aef432635 100644 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -273,6 +273,7 @@ message DistributedStrategy { optional bool fuse_grad_merge = 34 [ default = false ]; optional bool semi_auto = 35 [ default = false ]; optional bool adam_d2sum = 36 [ default = true ]; + optional bool auto_search = 37 [ default = false ]; optional RecomputeConfig recompute_configs = 101; optional AMPConfig amp_configs = 102; diff --git a/python/paddle/distributed/auto_parallel/completion.py b/python/paddle/distributed/auto_parallel/completion.py index 934239c0cd..745a018e8c 100644 --- a/python/paddle/distributed/auto_parallel/completion.py +++ b/python/paddle/distributed/auto_parallel/completion.py @@ -715,6 +715,27 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None): grad_op_dist_attr.process_mesh = forward_op_process_mesh # var + for input_name in grad_op.input_arg_names: + input_var = vars[input_name] + ref_dims_mapping = None + if "@GRAD" in input_name: + forward_name = _get_forward_varname_from_grad_varname( + input_name) + ref_dims_mapping = forward_op_dist_attr.get_output_dims_mapping( + forward_name) + else: + if forward_op_dist_attr.get_input_dims_mapping(input_name): + ref_dims_mapping = forward_op_dist_attr.get_input_dims_mapping( + input_name) + else: + ref_dims_mapping = forward_op_dist_attr.get_output_dims_mapping( + input_name) + + assert ref_dims_mapping is not None, "[{}] 's dims mapping is NONE".format( + input_var.name) + grad_op_dist_attr.set_input_dims_mapping(input_name, + ref_dims_mapping) + for output_name in grad_op.desc.output_names(): assert len(grad_op.desc.output(output_name)) in [0, 1] if _is_grad_var_name(output_name): @@ -726,41 +747,25 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None): ] input_name = "X" assert input_name in forward_op.desc.input_names( - ), "var [{}] in op [{}]'s output but coulf not find [{}] in its forward op".format( + ), "var [{}] in op [{}]'s output but could not find [{}] in its forward op".format( output_name, grad_op.type, input_name) if len(grad_op.desc.output(output_name)) == 1: - assert len(forward_op.desc.input(input_name)) == 1 - input_var = vars[forward_op.desc.input(input_name)[0]] - input_var_dist_attr = dist_context.get_tensor_dist_attr_for_program( - input_var) - assert input_var_dist_attr is not None, "[{}] has not dist attribute".format( - input_var.name) - ref_dims_mapping = input_var_dist_attr.dims_mapping - # tensor dist attr output_var = vars[grad_op.desc.output(output_name)[0]] + forward_name = _get_forward_varname_from_grad_varname( + output_var.name) + ref_dims_mapping = forward_op_dist_attr.get_input_dims_mapping( + forward_name) + output_var_dist_attr = TensorDistributedAttribute() output_var_dist_attr.dims_mapping = ref_dims_mapping output_var_dist_attr.process_mesh = forward_op_process_mesh dist_context.set_tensor_dist_attr_for_program( output_var, output_var_dist_attr) - # op dist attr grad_op_dist_attr.set_output_dims_mapping(output_var.name, ref_dims_mapping) - for input_name in grad_op.input_arg_names: - input_var = vars[input_name] - input_var_dist_attr = dist_context.get_tensor_dist_attr_for_program( - input_var) - assert input_var_dist_attr is not None, "[{}] has not dist attribute".format( - input_var.name) - ref_dims_mapping = input_var_dist_attr.dims_mapping - assert ref_dims_mapping is not None, "[{}] 's dims mapping is NONE".format( - input_var.name) - grad_op_dist_attr.set_input_dims_mapping(input_name, - ref_dims_mapping) - dist_context.set_op_dist_attr_for_program(grad_op, grad_op_dist_attr) @@ -828,13 +833,7 @@ def complete_update_annotation(auto_parallel_main_prog, dist_context): param_dist_attr = dist_context.get_tensor_dist_attr_for_program( param) - grad_dist_attr = dist_context.get_tensor_dist_attr_for_program( - grad_var) - assert param_dist_attr is not None - assert grad_dist_attr is not None - assert param_dist_attr.dims_mapping == grad_dist_attr.dims_mapping - ref_process_mesh = dist_context.get_tensor_dist_attr_for_program( param).process_mesh assert ref_process_mesh is not None diff --git a/python/paddle/distributed/auto_parallel/dist_context.py b/python/paddle/distributed/auto_parallel/dist_context.py index e3b3ee6a37..347d02dacf 100755 --- a/python/paddle/distributed/auto_parallel/dist_context.py +++ b/python/paddle/distributed/auto_parallel/dist_context.py @@ -335,6 +335,17 @@ class DistributedContext: dist_op.serial_op.type, dist_tensor.dist_attr) return True + def __deepcopy__(self, memo): + cls = self.__class__ + result = cls.__new__(cls) + memo[id(self)] = result + for k, v in self.__dict__.items(): + if k == "_serial_program" or k == "_serial_graph": + setattr(result, k, v) + else: + setattr(result, k, copy.deepcopy(v, memo)) + return result + class DistributedOperatorContext: """ @@ -352,6 +363,17 @@ class DistributedOperatorContext: self.gradopidx2opidx = {} self.already_init_sync_vars = set() + def __deepcopy__(self, memo): + cls = self.__class__ + result = cls.__new__(cls) + memo[id(self)] = result + for k, v in self.__dict__.items(): + if k == "_dst_main_program" or k == "_dst_startup_program" or k == "_cur_src_op": + setattr(result, k, v) + else: + setattr(result, k, copy.deepcopy(v, memo)) + return result + def set_dst_main_program(self, prog): self._dst_main_program = prog diff --git a/python/paddle/distributed/auto_parallel/dist_op.py b/python/paddle/distributed/auto_parallel/dist_op.py index aa447d7a42..ef595e2a00 100644 --- a/python/paddle/distributed/auto_parallel/dist_op.py +++ b/python/paddle/distributed/auto_parallel/dist_op.py @@ -219,6 +219,17 @@ class DistributedOperator: return str + def __deepcopy__(self, memo): + cls = self.__class__ + result = cls.__new__(cls) + memo[id(self)] = result + for k, v in self.__dict__.items(): + if k == "_serial_op" or k == "_serial_inputs" or k == "_serial_outputs": + setattr(result, k, v) + else: + setattr(result, k, copy.deepcopy(v, memo)) + return result + class DistributedModule: def __init__(self, serial_module, dist_attr=None): diff --git a/python/paddle/distributed/auto_parallel/dist_tensor.py b/python/paddle/distributed/auto_parallel/dist_tensor.py index 3b292d7f43..f46c6e86d6 100644 --- a/python/paddle/distributed/auto_parallel/dist_tensor.py +++ b/python/paddle/distributed/auto_parallel/dist_tensor.py @@ -66,6 +66,17 @@ class DistributedTensor: return False return True + def __deepcopy__(self, memo): + cls = self.__class__ + result = cls.__new__(cls) + memo[id(self)] = result + for k, v in self.__dict__.items(): + if k == "_serial_tensor": + setattr(result, k, v) + else: + setattr(result, k, copy.deepcopy(v, memo)) + return result + def __str__(self): str = "{{tensor name: {}, tensor id: {}".format( self.serial_tensor.desc.name(), self.serial_tensor.desc.id()) diff --git a/python/paddle/distributed/auto_parallel/operators/common.py b/python/paddle/distributed/auto_parallel/operators/common.py index 376e1a8ac6..678f4e7fdc 100644 --- a/python/paddle/distributed/auto_parallel/operators/common.py +++ b/python/paddle/distributed/auto_parallel/operators/common.py @@ -111,37 +111,27 @@ def find_best_compatible_distributed_operator_impl(name, dist_op, fwd=True): return best_compatible_impl, idx -def copy_distributed_attr_for_var(dist_context, dst_var, src_var): - """ - copy src var's dist_attr to dst var - """ - dist_attr = dist_context.get_tensor_dist_attr_for_program(src_var) - dist_context.set_tensor_dist_attr_for_program(dst_var, dist_attr) - - -def copy_distributed_attr_for_dist_op(dist_context, dist_op, dst_block, - src_op_dist_attr): - """ - copy src op's dist_attr to dst dist op - """ - from ..dist_attribute import OperatorDistributedAttribute - # need check dist op attr and its inputs and outputs - - op_dist_attr = OperatorDistributedAttribute() - op_dist_attr.process_mesh = src_op_dist_attr.process_mesh - op_dist_attr.impl_idx = src_op_dist_attr.impl_idx - - for input_varname in dist_op.desc.input_arg_names(): - input_var = dst_block.var(input_varname) - tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program( - input_var) - op_dist_attr.set_input_dist_attr(input_varname, tensor_dist_attr) - - for output_varname in dist_op.desc.output_arg_names(): - output_var = dst_block.var(output_varname) - tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program( - output_var) - op_dist_attr.set_output_dist_attr(output_varname, tensor_dist_attr) - - dist_context.set_op_dist_attr_for_program(dist_op, op_dist_attr) - op_dist_attr = dist_context.get_op_dist_attr_for_program(dist_op) +def infer_shape(block, src_var, src_var_dist_attr, op_input_dist_attr): + var_shape = block.var(src_var.name).shape + var_topoloy = src_var_dist_attr.process_mesh.topology + var_dims_mapping = src_var_dist_attr.dims_mapping + + complete_shape = [] + for idx, shape in enumerate(var_shape): + if var_dims_mapping[idx] == -1: + complete_shape.append(shape) + else: + new_shape = shape * var_topoloy[var_dims_mapping[idx]] + complete_shape.append(new_shape) + + exact_shape = [] + input_topology = op_input_dist_attr.process_mesh.topology + input_dims_mapping = op_input_dist_attr.dims_mapping + for idx, shape in enumerate(complete_shape): + if input_dims_mapping[idx] == -1: + exact_shape.append(shape) + else: + new_shape = shape // input_topology[input_dims_mapping[idx]] + exact_shape.append(new_shape) + + return exact_shape diff --git a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py index 0099d6a09c..3df04a70a5 100755 --- a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py @@ -12,12 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License +from .common import infer_shape from .common import DistributedOperatorImplContainer from .common import DistributedOperatorImpl from .common import register_distributed_operator_impl_container from .common import register_distributed_operator_impl -from .common import copy_distributed_attr_for_var -from .common import copy_distributed_attr_for_dist_op from ..utils import is_dim_shard from ..utils import is_dim_replicate from ..utils import is_valid_list_index @@ -172,6 +171,14 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): check_variable_and_dtype(Ids_var, 'input', ['int32', 'int64'], 'c_embedding') + # infer new var shape with op dist attr + out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var) + assert out_tensor_dist_attr is not None + out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name) + assert out_var_dist_attr is not None + ref_shape = infer_shape(main_block, Out_var, out_tensor_dist_attr, + out_var_dist_attr) + intermediate_var_0 = main_block.create_var( name=unique_name.generate_with_ignorable_key(".".join( ["c_embedding", 'tmp'])), @@ -180,9 +187,9 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): type=core.VarDesc.VarType.LOD_TENSOR, persistable=False, stop_gradient=Out_var.stop_gradient) - - # copy Out_var's dist_attr to intermediate_var_0's dist_attr - copy_distributed_attr_for_var(ctx, intermediate_var_0, Out_var) + # set intermediate_var_0's dist_attr with Out_var's dist_attr + ctx.set_tensor_dist_attr_for_program(intermediate_var_0, + out_var_dist_attr) check_variable_and_dtype( Out_var, 'tensor', @@ -195,6 +202,8 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): 'W': [Weight_var]}, outputs={'Out': [intermediate_var_0]}, attrs={"start_index": relative_idx}) + if intermediate_var_0.shape != ref_shape: + intermediate_var_0.desc.set_shape(ref_shape) # use_model_parallel c_allreduce_sum_op = main_block.append_op( @@ -206,12 +215,46 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): 'use_calc_stream': True, 'use_model_parallel': True, }) - - # copy serial op's dist_attr to dist op's dist_attr - copy_distributed_attr_for_dist_op(ctx, c_embedding_op, main_block, - op_dist_attr) - copy_distributed_attr_for_dist_op(ctx, c_allreduce_sum_op, main_block, - op_dist_attr) + if Out_var.shape != ref_shape: + Out_var.desc.set_shape(ref_shape) + + # set dist op's dist_attr with serial op's dist_attr + # matmulv2 + embedding_op_dist_attr = OperatorDistributedAttribute() + embedding_op_dist_attr.process_mesh = op_dist_attr.process_mesh + embedding_op_dist_attr.impl_idx = op_dist_attr.impl_idx + for input_varname in c_embedding_op.desc.input_arg_names(): + input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname) + assert input_dist_attr is not None, "dist_attr is {}".format( + op_dist_attr) + embedding_op_dist_attr.set_input_dist_attr(input_varname, + input_dist_attr) + output_varname = c_embedding_op.desc.output_arg_names()[0] + output_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name) + assert output_dist_attr is not None, "dist_attr is {}".format( + op_dist_attr) + embedding_op_dist_attr.set_output_dist_attr(output_varname, + output_dist_attr) + ctx.set_op_dist_attr_for_program(c_embedding_op, embedding_op_dist_attr) + + # allreduce + allreduce_op_dist_attr = OperatorDistributedAttribute() + allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh + allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx + for input_varname in c_allreduce_sum_op.desc.input_arg_names(): + input_var = main_block.var(input_varname) + tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var) + assert tensor_dist_attr is not None + allreduce_op_dist_attr.set_input_dist_attr(input_varname, + tensor_dist_attr) + for output_varname in c_allreduce_sum_op.desc.output_arg_names(): + output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname) + assert output_dist_attr is not None, "dist_attr is {}".format( + op_dist_attr) + allreduce_op_dist_attr.set_output_dist_attr(output_varname, + output_dist_attr) + ctx.set_op_dist_attr_for_program(c_allreduce_sum_op, + allreduce_op_dist_attr) # param initialization sync assert Weight_var.name not in dist_op_context.already_init_sync_vars diff --git a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py index 43816ba88a..786d24052e 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py @@ -12,12 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License +from .common import infer_shape from .common import DistributedOperatorImplContainer from .common import DistributedOperatorImpl from .common import register_distributed_operator_impl_container from .common import register_distributed_operator_impl -from .common import copy_distributed_attr_for_var -from .common import copy_distributed_attr_for_dist_op from ..utils import is_dim_shard from ..utils import is_dim_replicate from ..utils import is_valid_list_index @@ -356,6 +355,21 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): parallel_axis, rank_id) group = new_process_group(group_ranks) + # infer new var shape with op dist attr + x_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(X_var) + assert x_tensor_dist_attr is not None + identity_var_dist_attr = op_dist_attr.get_input_dist_attr(X_var.name) + assert identity_var_dist_attr is not None + ref_shape_x = infer_shape(main_block, X_var, x_tensor_dist_attr, + identity_var_dist_attr) + # infer out var shape with op dist attr + out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var) + assert out_tensor_dist_attr is not None + out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name) + assert out_var_dist_attr is not None + ref_shape_out = infer_shape(main_block, Out_var, out_tensor_dist_attr, + out_var_dist_attr) + intermediate_var_0 = main_block.create_var( name=unique_name.generate_with_ignorable_key(".".join( ["c_identity", 'tmp'])), @@ -364,8 +378,9 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): type=core.VarDesc.VarType.LOD_TENSOR, persistable=False, stop_gradient=X_var.stop_gradient) - # copy X_var's dist_attr to intermediate_var_0's dist_attr - copy_distributed_attr_for_var(ctx, intermediate_var_0, X_var) + # set intermediate_var_0's dist_attr with X_var's dist_attr + ctx.set_tensor_dist_attr_for_program(intermediate_var_0, + identity_var_dist_attr) check_variable_and_dtype( X_var, 'tensor', @@ -380,6 +395,8 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): 'use_calc_stream': True, 'use_model_parallel': True, }) + if intermediate_var_0.shape != ref_shape_x: + intermediate_var_0.desc.set_shape(ref_shape_x) check_variable_and_dtype(intermediate_var_0, 'x', ['float16', 'float32', 'float64'], 'linear') @@ -393,12 +410,56 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]} matmul_op = main_block.append_op( type='matmul', inputs=inputs, outputs={'Out': Out_var}, attrs=attrs) - - # copy serial op's dist_attr to dist op's dist_attr - copy_distributed_attr_for_dist_op(ctx, c_identity_op, main_block, - op_dist_attr) - copy_distributed_attr_for_dist_op(ctx, matmul_op, main_block, - op_dist_attr) + if Out_var.shape != ref_shape_out: + Out_var.desc.set_shape(ref_shape_out) + + # set dist op's dist_attr with serial op's dist_attr + # c_identity + identity_op_dist_attr = OperatorDistributedAttribute() + identity_op_dist_attr.process_mesh = op_dist_attr.process_mesh + identity_op_dist_attr.impl_idx = op_dist_attr.impl_idx + # input + input_varname = c_identity_op.desc.input_arg_names()[0] + input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname) + assert input_dist_attr is not None, "dist_attr is {}".format( + op_dist_attr) + identity_op_dist_attr.set_input_dist_attr(input_varname, + input_dist_attr) + # output + output_varname = c_identity_op.desc.output_arg_names()[0] + identity_op_dist_attr.set_output_dist_attr(output_varname, + input_dist_attr) + # set op dist attr + ctx.set_op_dist_attr_for_program(c_identity_op, identity_op_dist_attr) + + # matmul + matmul_op_dist_attr = OperatorDistributedAttribute() + matmul_op_dist_attr.process_mesh = op_dist_attr.process_mesh + matmul_op_dist_attr.impl_idx = op_dist_attr.impl_idx + # input + for input_varname in matmul_op.desc.input_arg_names(): + if input_varname in src_op.desc.input_arg_names(): + input_dist_attr = op_dist_attr.get_input_dist_attr( + input_varname) + assert input_dist_attr is not None, "dist_attr is {}".format( + op_dist_attr) + matmul_op_dist_attr.set_input_dist_attr(input_varname, + input_dist_attr) + else: + input_var = main_block.var(input_varname) + tensor_dist_attr = ctx.get_tensor_dist_attr_for_program( + input_var) + matmul_op_dist_attr.set_input_dist_attr(input_varname, + tensor_dist_attr) + # output + output_varname = matmul_op.desc.output_arg_names()[0] + output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname) + assert output_dist_attr is not None, "dist_attr is {}".format( + op_dist_attr) + matmul_op_dist_attr.set_output_dist_attr(output_varname, + output_dist_attr) + # set op dist attr + ctx.set_op_dist_attr_for_program(matmul_op, matmul_op_dist_attr) # init param sync if Weight_var.is_parameter: @@ -518,6 +579,15 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): 'alpha': 1, } inputs = {'X': X_var, 'Y': Weight_var} + + # infer out var shape with op dist attr + out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var) + assert out_tensor_dist_attr is not None + out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name) + assert out_var_dist_attr is not None + ref_shape = infer_shape(main_block, Out_var, out_tensor_dist_attr, + out_var_dist_attr) + intermediate_var_0 = main_block.create_var( shape=Out_var.shape, dtype=Out_var.dtype, @@ -526,14 +596,17 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): persistable=False, is_data=False, need_check_feed=Out_var.desc.need_check_feed()) - # copy Out_var's dist_attr to intermediate_var_0's dist_attr - copy_distributed_attr_for_var(ctx, intermediate_var_0, Out_var) + # set intermediate_var_0's dist_attr with Out_var's dist_attr + ctx.set_tensor_dist_attr_for_program(intermediate_var_0, + out_var_dist_attr) matmul_op = main_block.append_op( type='matmul', inputs=inputs, outputs={'Out': intermediate_var_0}, attrs=attrs) + if intermediate_var_0.shape != ref_shape: + intermediate_var_0.desc.set_shape(ref_shape) c_allreduce_sum_op = main_block.append_op( type='c_allreduce_sum', @@ -544,12 +617,46 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): 'use_calc_stream': True, 'use_model_parallel': True }) - - # copy serial op's dist_attr to dist op's dist_attr - copy_distributed_attr_for_dist_op(ctx, matmul_op, main_block, - op_dist_attr) - copy_distributed_attr_for_dist_op(ctx, c_allreduce_sum_op, main_block, - op_dist_attr) + if Out_var.shape != ref_shape: + Out_var.desc.set_shape(ref_shape) + + # set dist op's dist_attr with serial op's dist_attr + # matmul + matmul_op_dist_attr = OperatorDistributedAttribute() + matmul_op_dist_attr.process_mesh = op_dist_attr.process_mesh + matmul_op_dist_attr.impl_idx = op_dist_attr.impl_idx + for input_varname in matmul_op.desc.input_arg_names(): + input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname) + assert input_dist_attr is not None, "dist_attr is {}".format( + op_dist_attr) + matmul_op_dist_attr.set_input_dist_attr(input_varname, + input_dist_attr) + output_varname = matmul_op.desc.output_arg_names()[0] + output_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name) + assert output_dist_attr is not None, "dist_attr is {}".format( + op_dist_attr) + matmul_op_dist_attr.set_output_dist_attr(output_varname, + output_dist_attr) + ctx.set_op_dist_attr_for_program(matmul_op, matmul_op_dist_attr) + + # allreduce + allreduce_op_dist_attr = OperatorDistributedAttribute() + allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh + allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx + for input_varname in c_allreduce_sum_op.desc.input_arg_names(): + input_var = main_block.var(input_varname) + tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var) + assert tensor_dist_attr is not None + allreduce_op_dist_attr.set_input_dist_attr(input_varname, + tensor_dist_attr) + for output_varname in c_allreduce_sum_op.desc.output_arg_names(): + output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname) + assert output_dist_attr is not None, "dist_attr is {}".format( + op_dist_attr) + allreduce_op_dist_attr.set_output_dist_attr(output_varname, + output_dist_attr) + ctx.set_op_dist_attr_for_program(c_allreduce_sum_op, + allreduce_op_dist_attr) # init param sync if Weight_var.is_parameter: @@ -729,6 +836,21 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): parallel_axis, rank_id) group = new_process_group(group_ranks) + # infer new var shape with op dist attr + x_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(X_var) + assert x_tensor_dist_attr is not None + identity_var_dist_attr = op_dist_attr.get_input_dist_attr(X_var.name) + assert identity_var_dist_attr is not None + ref_shape_x = infer_shape(main_block, X_var, x_tensor_dist_attr, + identity_var_dist_attr) + # infer out var shape with op dist attr + out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var) + assert out_tensor_dist_attr is not None + out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name) + assert out_var_dist_attr is not None + ref_shape_out = infer_shape(main_block, Out_var, out_tensor_dist_attr, + out_var_dist_attr) + intermediate_var_0 = main_block.create_var( name=unique_name.generate_with_ignorable_key(".".join( ["c_identity", 'tmp'])), @@ -737,13 +859,13 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): type=core.VarDesc.VarType.LOD_TENSOR, persistable=False, stop_gradient=X_var.stop_gradient) - # copy X_var's dist_attr to intermediate_var_0's dist_attr - copy_distributed_attr_for_var(ctx, intermediate_var_0, X_var) + # set intermediate_var_0's dist_attr with X_var's dist_attr + ctx.set_tensor_dist_attr_for_program(intermediate_var_0, + identity_var_dist_attr) check_variable_and_dtype( X_var, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'], '_c_identity') - c_identity_op = main_block.append_op( type='c_identity', inputs={'X': [X_var]}, @@ -753,6 +875,8 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): 'use_calc_stream': True, 'use_model_parallel': True, }) + if intermediate_var_0.shape != ref_shape_x: + intermediate_var_0.desc.set_shape(ref_shape_x) check_variable_and_dtype(intermediate_var_0, 'x', ['float16', 'float32', 'float64'], 'linear') @@ -765,12 +889,52 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): inputs=inputs, outputs={'Out': Out_var}, attrs=attrs) - - # copy serial op's dist_attr to dist op's dist_attr - copy_distributed_attr_for_dist_op(ctx, c_identity_op, main_block, - op_dist_attr) - copy_distributed_attr_for_dist_op(ctx, matmul_v2_op, main_block, - op_dist_attr) + if Out_var.shape != ref_shape_out: + Out_var.desc.set_shape(ref_shape_out) + + # set dist op's dist_attr with serial op's dist_attr + # c_identity + identity_op_dist_attr = OperatorDistributedAttribute() + identity_op_dist_attr.process_mesh = op_dist_attr.process_mesh + identity_op_dist_attr.impl_idx = op_dist_attr.impl_idx + # input + input_varname = c_identity_op.desc.input_arg_names()[0] + input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname) + assert input_dist_attr is not None, "dist_attr is {}".format( + op_dist_attr) + identity_op_dist_attr.set_input_dist_attr(input_varname, + input_dist_attr) + # output + output_varname = c_identity_op.desc.output_arg_names()[0] + identity_op_dist_attr.set_output_dist_attr(output_varname, + input_dist_attr) + ctx.set_op_dist_attr_for_program(c_identity_op, identity_op_dist_attr) + + # matmulv2 + matmulv2_op_dist_attr = OperatorDistributedAttribute() + matmulv2_op_dist_attr.process_mesh = op_dist_attr.process_mesh + matmulv2_op_dist_attr.impl_idx = op_dist_attr.impl_idx + for input_varname in matmul_v2_op.desc.input_arg_names(): + if input_varname in src_op.desc.input_arg_names(): + input_dist_attr = op_dist_attr.get_input_dist_attr( + input_varname) + assert input_dist_attr is not None, "dist_attr is {}".format( + op_dist_attr) + matmulv2_op_dist_attr.set_input_dist_attr(input_varname, + input_dist_attr) + else: + input_var = main_block.var(input_varname) + tensor_dist_attr = ctx.get_tensor_dist_attr_for_program( + input_var) + matmulv2_op_dist_attr.set_input_dist_attr(input_varname, + tensor_dist_attr) + for output_varname in matmul_v2_op.desc.output_arg_names(): + output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname) + assert output_dist_attr is not None, "dist_attr is {}".format( + op_dist_attr) + matmulv2_op_dist_attr.set_output_dist_attr(output_varname, + output_dist_attr) + ctx.set_op_dist_attr_for_program(matmul_v2_op, matmulv2_op_dist_attr) # init param sync if Weight_var.is_parameter: @@ -886,6 +1050,15 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): 'linear') attrs = {'trans_x': False, 'trans_y': False} inputs = {'X': X_var, 'Y': Weight_var} + + # infer out var shape with op dist attr + out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var) + assert out_tensor_dist_attr is not None + out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name) + assert out_var_dist_attr is not None + ref_shape = infer_shape(main_block, Out_var, out_tensor_dist_attr, + out_var_dist_attr) + intermediate_var_0 = main_block.create_var( shape=Out_var.shape, dtype=Out_var.dtype, @@ -894,14 +1067,17 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): persistable=False, is_data=False, need_check_feed=Out_var.desc.need_check_feed()) - # copy Out_var's dist_attr to intermediate_var_0's dist_attr - copy_distributed_attr_for_var(ctx, intermediate_var_0, Out_var) + # set intermediate_var_0's dist_attr with Out_var's dist_attr + ctx.set_tensor_dist_attr_for_program(intermediate_var_0, + out_var_dist_attr) matmul_v2_op = main_block.append_op( type='matmul_v2', inputs=inputs, outputs={'Out': intermediate_var_0}, attrs=attrs) + if intermediate_var_0.shape != ref_shape: + intermediate_var_0.desc.set_shape(ref_shape) c_allreduce_sum_op = main_block.append_op( type='c_allreduce_sum', @@ -912,12 +1088,46 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): 'use_calc_stream': True, 'use_model_parallel': True }) - - # copy serial op's dist_attr to dist op's dist_attr - copy_distributed_attr_for_dist_op(ctx, matmul_v2_op, main_block, - op_dist_attr) - copy_distributed_attr_for_dist_op(ctx, c_allreduce_sum_op, main_block, - op_dist_attr) + if Out_var.shape != ref_shape: + Out_var.desc.set_shape(ref_shape) + + # set dist op's dist_attr with serial op's dist_attr + # matmulv2 + matmulv2_op_dist_attr = OperatorDistributedAttribute() + matmulv2_op_dist_attr.process_mesh = op_dist_attr.process_mesh + matmulv2_op_dist_attr.impl_idx = op_dist_attr.impl_idx + for input_varname in matmul_v2_op.desc.input_arg_names(): + input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname) + assert input_dist_attr is not None, "dist_attr is {}".format( + op_dist_attr) + matmulv2_op_dist_attr.set_input_dist_attr(input_varname, + input_dist_attr) + output_varname = matmul_v2_op.desc.output_arg_names()[0] + output_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name) + assert output_dist_attr is not None, "dist_attr is {}".format( + op_dist_attr) + matmulv2_op_dist_attr.set_output_dist_attr(output_varname, + output_dist_attr) + ctx.set_op_dist_attr_for_program(matmul_v2_op, matmulv2_op_dist_attr) + + # allreduce + allreduce_op_dist_attr = OperatorDistributedAttribute() + allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh + allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx + for input_varname in c_allreduce_sum_op.desc.input_arg_names(): + input_var = main_block.var(input_varname) + tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var) + assert tensor_dist_attr is not None + allreduce_op_dist_attr.set_input_dist_attr(input_varname, + tensor_dist_attr) + for output_varname in c_allreduce_sum_op.desc.output_arg_names(): + output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname) + assert output_dist_attr is not None, "dist_attr is {}".format( + op_dist_attr) + allreduce_op_dist_attr.set_output_dist_attr(output_varname, + output_dist_attr) + ctx.set_op_dist_attr_for_program(c_allreduce_sum_op, + allreduce_op_dist_attr) # init param sync if Weight_var.is_parameter: diff --git a/python/paddle/distributed/auto_parallel/parallelizer.py b/python/paddle/distributed/auto_parallel/parallelizer.py index 7a0cbd7da3..4e2f83bd34 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer.py +++ b/python/paddle/distributed/auto_parallel/parallelizer.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import paddle +from paddle.distributed.utils import get_logger from paddle.distributed.fleet import cloud_utils import paddle.fluid.core as core from .dist_context import DistributedContext @@ -22,7 +24,11 @@ from .completion import complete_annotation, complete_backward_annotation from .partitioner import Partitioner from .process_group import get_all_process_groups from .utils import make_data_unshard +from .utils import set_grad_var_shape from .reshard import reshard +# from .auto_search import auto_search + +_logger = get_logger(logging.INFO) class AutoParallelizer: @@ -59,9 +65,19 @@ class AutoParallelizer: assert startup_program is not None main_program = loss.block.program - # Annotation completion - completed_main_program = complete_annotation(main_program, - self._dist_context) + if self._dist_strategy.auto_search: + # auto search + _logger.info("Start search dist attr.") + # self._dist_context, _ = auto_search(main_program, startup_program, + # loss, self._optimizer) + # completed_main_program = main_program + raise NotImplementedError("Auto search has not implemented") + else: + # Annotation completion + _logger.info("Start annotation dist attr.") + completed_main_program = complete_annotation(main_program, + self._dist_context) + # Logical partition rank = paddle.distributed.get_rank() partitioner = Partitioner(self._dist_strategy, self._dist_context, rank) @@ -74,13 +90,8 @@ class AutoParallelizer: self._optimizer, dist_params_grads, partitioned_main_prog, partitioned_startup_prog) - # Traverse different rank programs and traverse each op of them, - # instantiate communication by process_mapping. - all_process_groups = get_all_process_groups() - for process_group in all_process_groups: - if rank not in process_group._ranks: - continue - process_group.instantiate() + # set the grad var shape + set_grad_var_shape(partitioned_main_prog, self._dist_context) # The last step: remove all distributed attributes to be compatiable # with inference. @@ -91,6 +102,14 @@ class AutoParallelizer: reshard(partitioned_main_prog, partitioned_startup_prog, rank, self._dist_context) + # Traverse different rank programs and traverse each op of them, + # instantiate communication by process_mapping. + all_process_groups = get_all_process_groups() + for process_group in all_process_groups: + if rank not in process_group._ranks: + continue + process_group.instantiate() + # Copy distributed info to the default context set_default_distributed_context(self._dist_context) diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index 9c1f9a8c7c..a3505eae87 100755 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -981,3 +981,58 @@ def _get_split_indices(complete_shape, dims_mapping, process_shape, complete_shape)) split_indices_list = [sorted(x) for x in split_indices_list] return split_indices_list + + +def set_grad_var_shape(program, dist_context): + from .operators.common import infer_shape + from paddle.distributed.fleet.meta_optimizers.common import OpRole + + block = program.global_block() + vars = block.vars + for op in block.ops: + if op.type == "sum": + continue + if int(op.attr('op_role')) == int(OpRole.Backward): + op_dist_attr = dist_context.get_op_dist_attr_for_program(op) + assert op_dist_attr is not None + + for var_name in op.output_arg_names: + assert "@GRAD" in var_name + forward_var_name = var_name[:var_name.find("@GRAD")] + if op.type == "c_allreduce_sum" or op.type == "c_identity" or op.type == "scale": + forward_var_name = op.input_arg_names[0] + + need_set_shape_list = [ + "reshape2_grad", "softmax_with_cross_entropy_grad", + "transpose2_grad", "softmax_grad", "cross_entropy_grad2", + "dropout_grad" + ] + forward_list = [ + "reshape2", "softmax_with_cross_entropy", "transpose2", + "softmax", "cross_entropy2", "dropout" + ] + if op.type in need_set_shape_list: + for forward_op in block.ops: + assert int(forward_op.attr('op_role')) != int( + OpRole.Backward) + idx = need_set_shape_list.index(op.type) + forward_op_name = forward_list[idx] + if forward_op.type == forward_op_name and forward_var_name in forward_op.input_arg_names: + op_dist_attr = dist_context.get_op_dist_attr_for_program( + forward_op) + break + + forward_input_dist_attr = op_dist_attr.get_input_dist_attr( + forward_var_name) + assert forward_input_dist_attr is not None, f"{forward_var_name}" + forward_var = vars[forward_var_name] + forward_var_dist_attr = dist_context.get_tensor_dist_attr_for_program( + forward_var) + assert forward_var_dist_attr is not None + grad_var = vars[var_name] + ref_shape = infer_shape(block, forward_var, + forward_var_dist_attr, + forward_input_dist_attr) + + if list(grad_var.shape) != ref_shape: + grad_var.desc.set_shape(ref_shape) diff --git a/python/paddle/distributed/fleet/base/distributed_strategy.py b/python/paddle/distributed/fleet/base/distributed_strategy.py index 378c2ff8d5..975c7b3f74 100644 --- a/python/paddle/distributed/fleet/base/distributed_strategy.py +++ b/python/paddle/distributed/fleet/base/distributed_strategy.py @@ -1631,6 +1631,29 @@ class DistributedStrategy(object): else: print("WARNING: semi-auto should have value of bool type") + @property + def auto_search(self): + """ + Indicating whether we are using auto-search parallel function + For details, please reference the following code example + Default Value: False + Examples: + .. code-block:: python + import paddle + paddle.enable_static() + import paddle.distributed.fleet as fleet + strategy = fleet.DistributedStrategy() + strategy.auto_search = True + """ + return self.strategy.auto_search + + @auto_search.setter + def auto_search(self, flag): + if isinstance(flag, bool): + self.strategy.auto_search = flag + else: + print("WARNING: auto-search should have value of bool type") + @property def cudnn_exhaustive_search(self): """ -- GitLab