From e9364a3897fab2e7d8051177ca7b8dde8d76a8ae Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Wed, 6 Sep 2023 20:48:32 +0800 Subject: [PATCH] [AutoParallel] Generate spmd rule and reshard impl in phi api (#56831) * add spmd and reshard code gen * add backward reshard code gen * test matmul forward success * polish test impl * add unsafe mutable value * polish details and add test * fix unittest time out * fix typo * refactor reshard input generate impl * resolve conflict with develop * fix compile error --- .../fluid/operators/generator/parse_utils.py | 2 +- paddle/phi/api/lib/api_gen_utils.cc | 14 +- paddle/phi/api/lib/api_gen_utils.h | 19 +- paddle/phi/api/lib/data_transform.cc | 52 +++- paddle/phi/api/lib/data_transform.h | 14 +- paddle/phi/api/lib/kernel_dispatch.cc | 3 + paddle/phi/api/lib/kernel_dispatch.h | 2 - paddle/phi/api/yaml/generator/api_gen.py | 4 + paddle/phi/api/yaml/generator/dist_api_gen.py | 246 +++++++++++++----- .../phi/api/yaml/generator/dist_bw_api_gen.py | 47 +++- paddle/phi/api/yaml/legacy_ops.yaml | 1 + .../distributed/auto_parallel/dist_tensor.h | 9 + .../auto_parallel/inferspmd_utils.cc | 5 - paddle/phi/infermeta/spmd_rules/rules.h | 6 +- test/auto_parallel/CMakeLists.txt | 4 + .../semi_auto_parallel_for_matmul.py | 84 ++++++ .../test_semi_auto_parallel_basic.py | 41 +++ 17 files changed, 455 insertions(+), 98 deletions(-) create mode 100644 test/auto_parallel/semi_auto_parallel_for_matmul.py create mode 100644 test/auto_parallel/test_semi_auto_parallel_basic.py diff --git a/paddle/fluid/operators/generator/parse_utils.py b/paddle/fluid/operators/generator/parse_utils.py index 92834be0f01..3a2429f5345 100644 --- a/paddle/fluid/operators/generator/parse_utils.py +++ b/paddle/fluid/operators/generator/parse_utils.py @@ -366,7 +366,7 @@ def check_op_config(op_entry, op_name): 'composite', 'support_dygraph_mode', ) - infer_meta_key_set = ('func', 'param') + infer_meta_key_set = ('func', 'param', 'spmd_rule') kernel_key_set = ( 'func', 'param', diff --git a/paddle/phi/api/lib/api_gen_utils.cc b/paddle/phi/api/lib/api_gen_utils.cc index c6da10d12de..c7501494b1e 100644 --- a/paddle/phi/api/lib/api_gen_utils.cc +++ b/paddle/phi/api/lib/api_gen_utils.cc @@ -22,6 +22,7 @@ PHI_DECLARE_bool(use_stride_kernel); #include "glog/logging.h" #include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" namespace paddle { @@ -530,13 +531,18 @@ void TransStride(phi::DeviceContext* dev_ctx, /* ------------------ for auto parallel ----------------------- */ -phi::distributed::DistTensor* SetKernelDistOutput(Tensor* out) { +phi::distributed::DistMetaTensor MakeDistMetaTensor( + const phi::TensorBase& tensor) { + return phi::distributed::DistMetaTensor(tensor); +} + +phi::distributed::DistTensor* SetKernelDistOutput( + Tensor* out, const phi::distributed::TensorDistAttr& dist_attr) { if (out) { // TODO(chenweihang): now all dist case are nullptr if (out->impl() == nullptr) { - // TODO(chenweihang): polish code, dist_attr is null now - auto dist_t = std::make_shared( - phi::DDim(), phi::distributed::TensorDistAttr()); + auto dist_t = std::make_shared(phi::DDim(), + dist_attr); out->set_impl(dist_t); } return static_cast(out->impl().get()); diff --git a/paddle/phi/api/lib/api_gen_utils.h b/paddle/phi/api/lib/api_gen_utils.h index 997bb6f8dc8..d0281dfc681 100644 --- a/paddle/phi/api/lib/api_gen_utils.h +++ b/paddle/phi/api/lib/api_gen_utils.h @@ -18,18 +18,15 @@ limitations under the License. */ #include "paddle/phi/backends/all_context.h" #include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" #include "paddle/phi/core/meta_tensor.h" #include "paddle/phi/core/selected_rows.h" #include "paddle/phi/core/sparse_coo_tensor.h" #include "paddle/phi/core/sparse_csr_tensor.h" #include "paddle/phi/core/string_tensor.h" -namespace phi { -namespace distributed { -class DistTensor; -} // namespace distributed -} // namespace phi - namespace paddle { namespace experimental { @@ -139,9 +136,17 @@ void TransStrideLegacy(phi::DeviceContext* dev_ctx, /* ------------------ for auto parallel ----------------------- */ -phi::distributed::DistTensor* SetKernelDistOutput(Tensor* out); +phi::distributed::DistMetaTensor MakeDistMetaTensor( + const phi::TensorBase& tensor); + +phi::distributed::DistTensor* SetKernelDistOutput( + Tensor* out, + const phi::distributed::TensorDistAttr& dist_attr = + phi::distributed::TensorDistAttr()); + std::vector SetKernelDistOutput( std::vector out); + std::vector SetKernelDistOutput( size_t out_size, std::vector* out); diff --git a/paddle/phi/api/lib/data_transform.cc b/paddle/phi/api/lib/data_transform.cc index 0e86b84e074..7515ff917f1 100644 --- a/paddle/phi/api/lib/data_transform.cc +++ b/paddle/phi/api/lib/data_transform.cc @@ -21,6 +21,8 @@ limitations under the License. */ #include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/phi/backends/context_pool.h" #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h" #include "paddle/phi/core/flags.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" @@ -597,15 +599,45 @@ void TransDataBackend(const phi::SelectedRows* tensor, /* ------------------ for auto parallel ----------------------- */ +std::shared_ptr ReshardDistTensor( + phi::DeviceContext* dev_ctx, + const Tensor& tensor, + const phi::distributed::TensorDistAttr& dist_attr) { + auto tensor_in = tensor.impl(); + if (tensor_in) { + phi::distributed::DistTensor* dist_tensor = + static_cast(tensor_in.get()); + if (dist_tensor->dist_attr() != dist_attr) { + VLOG(6) << "Reshard tensor from " << dist_tensor->dist_attr() << " to " + << dist_attr; + auto* func = phi::distributed::ChooseProperReshardFunction(*dist_tensor, + dist_attr); + return func->Eval(dev_ctx, *dist_tensor, dist_attr); + } + return std::static_pointer_cast(tensor_in); + } + return nullptr; +} + std::shared_ptr PrepareDataForDistTensor( const Tensor& input, const phi::TensorArgDef& target_args_def, const TransformFlag& transform_flag, bool is_stride_kernel) { - const auto& tensor_in = input.impl(); - if (tensor_in) { - phi::distributed::DistTensor* dist_tensor = - static_cast(tensor_in.get()); + return PrepareDataForDistTensor( + std::static_pointer_cast(input.impl()), + target_args_def, + transform_flag, + is_stride_kernel); +} + +std::shared_ptr PrepareDataForDistTensor( + const std::shared_ptr& input, + const phi::TensorArgDef& target_args_def, + const TransformFlag& transform_flag, + bool is_stride_kernel) { + if (input) { + phi::distributed::DistTensor* dist_tensor = input.get(); const phi::DenseTensor& dense_tensor = dist_tensor->value(); if (!transform_flag.NeedTransform() || !dense_tensor.initialized() || (!NeedTransformPlace( @@ -618,16 +650,18 @@ std::shared_ptr PrepareDataForDistTensor( transform_flag) && !NeedTransform2Contiguous(is_stride_kernel, dense_tensor.meta().is_contiguous()))) { - return std::static_pointer_cast(tensor_in); + return input; } - phi::DenseTensor out = TransformData( - dense_tensor, target_args_def, transform_flag, is_stride_kernel); // TODO(chenweihang): The global meta in DistTensor is not changed, // but the local meta in DenseTensor maybe changed, such as layout // change(NCHW->NHWC), so the new DistTensor's meta maybe not unified. VLOG(6) << "PrepareDataForDistTensor return transformed dist tensor"; - return std::make_shared( - out, dist_tensor->dist_attr()); + auto dist_out = std::make_shared( + dist_tensor->dims(), dist_tensor->dist_attr()); + auto* out = dist_out->unsafe_mutable_value(); + *out = TransformData( + dense_tensor, target_args_def, transform_flag, is_stride_kernel); + return dist_out; } return nullptr; } diff --git a/paddle/phi/api/lib/data_transform.h b/paddle/phi/api/lib/data_transform.h index 4247317857c..3ac1b94f144 100644 --- a/paddle/phi/api/lib/data_transform.h +++ b/paddle/phi/api/lib/data_transform.h @@ -21,8 +21,10 @@ limitations under the License. */ #include "paddle/phi/core/sparse_csr_tensor.h" namespace phi { +class DeviceContext; namespace distributed { class DistTensor; +class TensorDistAttr; } // namespace distributed } // namespace phi @@ -173,13 +175,23 @@ inline bool NeedTransformPlace(const phi::Place& src_place, /* ------------------ for auto parallel ----------------------- */ -// TODO(chenweihang): impl Reshard input and output function +std::shared_ptr ReshardDistTensor( + phi::DeviceContext* dev_ctx, + const Tensor& tensor, + const phi::distributed::TensorDistAttr& dist_attr); + std::shared_ptr PrepareDataForDistTensor( const Tensor& input, const phi::TensorArgDef& target_args_def, const TransformFlag& transform_flag, bool is_stride_kernel); +std::shared_ptr PrepareDataForDistTensor( + const std::shared_ptr& input, + const phi::TensorArgDef& target_args_def, + const TransformFlag& transform_flag, + bool is_stride_kernel); + std::vector> PrepareDataForDistTensor(const std::vector& input, const phi::TensorArgDef& target_args_def, diff --git a/paddle/phi/api/lib/kernel_dispatch.cc b/paddle/phi/api/lib/kernel_dispatch.cc index 81b90769e81..2ebd3c46d5f 100644 --- a/paddle/phi/api/lib/kernel_dispatch.cc +++ b/paddle/phi/api/lib/kernel_dispatch.cc @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/phi/api/include/context_pool.h" #include "paddle/phi/core/compat/convert_utils.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" #include "paddle/phi/core/string_tensor_utils.h" #include "paddle/phi/core/tensor_utils.h" #ifdef PADDLE_WITH_CUSTOM_DEVICE @@ -50,6 +51,8 @@ bool HasAllocation(const phi::TensorBase& t) { } else if (phi::StringTensor::classof(&t)) { return phi::StringTensorUtils::GetHolder( static_cast(t)) != nullptr; + } else if (phi::distributed::DistTensor::classof(&t)) { + return static_cast(t).defined(); } else { return false; } diff --git a/paddle/phi/api/lib/kernel_dispatch.h b/paddle/phi/api/lib/kernel_dispatch.h index 7ff9ab3b33f..6acc23b2db7 100644 --- a/paddle/phi/api/lib/kernel_dispatch.h +++ b/paddle/phi/api/lib/kernel_dispatch.h @@ -96,8 +96,6 @@ struct KernelKeyParser : ArgsIterator { // data_promote DataTypeSet dtype_set{DataType::UNDEFINED}; - // TODO(chenweihang): deal with multiple diff input Tensors - // TODO(chenweihang): add global device guard method to set backend inline void AssignKernelKeySet(const phi::TensorBase& tensor) { // assign Backend BackendSet tensor_backend_set = detail::GetTensorBackendSet(tensor); diff --git a/paddle/phi/api/yaml/generator/api_gen.py b/paddle/phi/api/yaml/generator/api_gen.py index 5164ebda840..0c47c232768 100644 --- a/paddle/phi/api/yaml/generator/api_gen.py +++ b/paddle/phi/api/yaml/generator/api_gen.py @@ -379,6 +379,10 @@ def source_include(header_file_path): #include "paddle/phi/api/profiler/event_tracing.h" #include "paddle/phi/api/profiler/supplement_tracing.h" +#ifdef PADDLE_WITH_DISTRIBUTE +#include "paddle/phi/infermeta/spmd_rules/rules.h" +#endif + PD_DECLARE_bool(conv2d_disable_cudnn); PD_DECLARE_int32(low_precision_op_list); """ diff --git a/paddle/phi/api/yaml/generator/dist_api_gen.py b/paddle/phi/api/yaml/generator/dist_api_gen.py index 5c0b642228b..ed671ecdfeb 100644 --- a/paddle/phi/api/yaml/generator/dist_api_gen.py +++ b/paddle/phi/api/yaml/generator/dist_api_gen.py @@ -44,14 +44,14 @@ PADDLE_THROW(phi::errors::Unimplemented( MAIN_DIST_BRANCH_TEMPLATE = """ // Auto Parallel condition if ({}) {{ - // 1. Create API Output & Prepare Dist and Dense Output{} - // 2. InferSPMD (Infer Global Shape and DistAttr of Inputs&Outputs){} - // 3. Select Kernel{} - // 4. Reshard Input{} - // 5. PrepareData (DataTransform & Prepare Dist and Dense Input){} - // 6. Infer Local DenseTensor Meta{} - // 7. DenseTensor Kernel Call{} - // 8. Reshard Output{} + // 1. InferSpmd (Infer DistAttr of Inputs&Outputs){} + // 2. Create API Output & Prepare Dist and Dense Output{} + // 3. Infer DistTensor's Global Shape{} + // 4. Select Kernel{} + // 5. Reshard Input{}\n + // 6. PrepareData (DataTransform & Prepare Dense Input){} + // 7. Infer Local DenseTensor Meta{} + // 8. DenseTensor Kernel Call{} // 9. Return {} }} @@ -60,20 +60,35 @@ MAIN_DIST_BRANCH_TEMPLATE = """ # Auto Parallel condition AUTO_PARALLEL_COND_TEMPLATE = """AllInputsAreDistTensor({})""" -# 1. Create API Outputs +# 1. InferSPMD +SINGLE_DIST_META_IN_TEMPLATE = """ + auto meta_dist_{} = MakeDistMetaTensor(*{}.impl());""" +INFER_SPMD_TEMPLATE = """ + auto spmd_info = phi::distributed::{}({}); +""" + +# 2. Create API Outputs API_OUT_CREATION_TEMPLATE = """ {} api_output{}; """ INPLACE_API_OUT_CREATION_TEMPLATE = """ {} api_output{{{}}}; """ -SINGLE_OUT_CREATION_TEMPLATE = """ +SINGLE_OUT_CREATION_TEMPLATE_NO_SPMD = """ auto dist_out = SetKernelDistOutput(&api_output); - auto dense_out = const_cast(&dist_out->value()); + auto dense_out = dist_out->unsafe_mutable_value(); +""" +MULTI_SINGLE_OUT_CREATION_TEMPLATE_NO_SPMD = """ + auto dist_out_{idx} = SetKernelDistOutput({out}); + auto dense_out_{idx} = dist_out_{idx}->unsafe_mutable_value(); +""" +SINGLE_OUT_CREATION_TEMPLATE = """ + auto dist_out = SetKernelDistOutput(&api_output, spmd_info.second[0]); + auto dense_out = dist_out->unsafe_mutable_value(); """ MULTI_SINGLE_OUT_CREATION_TEMPLATE = """ - auto dist_out_{} = SetKernelDistOutput({}); - auto dense_out_{} = const_cast(&dist_out_{}->value()); + auto dist_out_{idx} = SetKernelDistOutput({out}, spmd_info.second[{idx}]); + auto dense_out_{idx} = dist_out_{idx}->unsafe_mutable_value(); """ VECTOR_OUT_CREATION_TEMPLATE = """ auto dist_out = SetKernelDistOutput({}, &api_output); @@ -93,12 +108,12 @@ MULTI_VECTOR_OUT_CREATION_TEMPLATE = """ TUPLE_OUT_CREATION_TEMPLATE = """ """ -# 2. InferSPMD -# Call InferMeta now, replace by InferSPMD function later -# TODO(chenweihang): InferSPMD function design -SINGLE_DIST_META_IN_TEMPLATE = """MakeMetaTensor(*{}.impl()), """ -VECTOR_DIST_META_IN_TEMPLATE = """{}_meta_ptr_vec, """ -VECTOR_DIST_META_IN_DECL_TEMPLATE = """ +# 3. Infer Global Shape +# TODO(chenweihang): the input MetaTensor created by Inferspmd can be reused +# for InferGlobalShape to avoid creating repeated inputs. +SINGLE_GLOBAL_META_IN_TEMPLATE = """MakeMetaTensor(*{}.impl()), """ +VECTOR_GLOBAL_META_IN_TEMPLATE = """{}_meta_ptr_vec, """ +VECTOR_GLOBAL_META_IN_DECL_TEMPLATE = """ std::vector {name}_meta_vec; for (auto tmp : {name}) {{ {name}_meta_vec.emplace_back(MakeMetaTensor(*tmp.impl())); @@ -109,11 +124,11 @@ VECTOR_DIST_META_IN_DECL_TEMPLATE = """ }} """ # TODO(GhostScreaming): support optional args later -OPTIONAL_DIST_VECTOR_META_IN_TEMPLATE = """ +OPTIONAL_GLOBAL_VECTOR_META_IN_TEMPLATE = """ """ -SINGLE_DIST_META_OUT_DECL_TEMPLATE = """ +SINGLE_GLOBAL_META_OUT_DECL_TEMPLATE = """ phi::MetaTensor meta_{}({});""" -VECTOR_DIST_META_OUT_DECL_TEMPLATE = """ +VECTOR_GLOBAL_META_OUT_DECL_TEMPLATE = """ std::vector {name}_meta_vec; for (auto tmp : {name}) {{ {name}_meta_vec.emplace_back(phi::MetaTensor(tmp)); @@ -123,11 +138,11 @@ VECTOR_DIST_META_OUT_DECL_TEMPLATE = """ {name}_meta_ptr_vec[i] = &{name}_meta_vec[i]; }} """ -INFER_SPMD_TEMPLATE = """ +INFER_GLOBAL_SHAPE_TEMPLATE = """ phi::{}({}{}); """ -# 3. Select Kernel +# 4. Select Kernel KERNEL_SELECTION_TEMPLATE = """ VLOG(6) << "{} API dist branch: kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]"; auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError( @@ -137,14 +152,18 @@ KERNEL_SELECTION_TEMPLATE = """ auto* dev_ctx = GetDeviceContextByBackend(kernel_result.has_fallback_cpu ? Backend::CPU : kernel_backend); """ -# 4. Reshard Input -INPUT_RESHARD_TEMPLATE = """ -""" +# 5. Reshard Input +SINGLE_INPUT_RESHARD_TEMPLATE = """ + auto dist_input_{arg} = ReshardDistTensor(dev_ctx, {arg}, spmd_info.first[{idx}]);""" -# 5. PrepareData +# 6. PrepareData SINGLE_PREPARE_DATA_TEMPLATE = """ - auto dist_input_{} = PrepareDataForDistTensor({}, GetKernelInputArgDef(kernel.InputAt({}), kernel_backend), {}, kernel_result.is_stride_kernel); - auto input_{} = &dist_input_{}->value(); + dist_input_{arg} = PrepareDataForDistTensor(dist_input_{arg}, GetKernelInputArgDef(kernel.InputAt({idx}), kernel_backend), {flag}, kernel_result.is_stride_kernel); + auto input_{arg} = &dist_input_{arg}->value(); +""" +SINGLE_PREPARE_DATA_TEMPLATE_NO_RESHARD = """ + auto dist_input_{arg} = PrepareDataForDistTensor({arg}, GetKernelInputArgDef(kernel.InputAt({idx}), kernel_backend), {flag}, kernel_result.is_stride_kernel); + auto input_{arg} = &dist_input_{arg}->value(); """ VECTOR_PREPARE_DATA_TEMPLATE = """ auto dist_input_{name}_vec = PrepareDataForDistTensor({name}, GetKernelInputArgDef(kernel.InputAt({index}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); @@ -170,7 +189,7 @@ INFER_META_VECTOR_INPUT_TEMPLATE = """ const auto& input_{} = *input_{}_uq_ptr; """ -# 6. Infer Local DenseTensor Meta +# 7. Infer Local DenseTensor Meta SINGLE_META_IN_TEMPLATE = """MakeMetaTensor(*input_{}), """ # TODO(GhostScreaming): support optional args later VECTOR_META_IN_TEMPLATE = """dense_input_{}_meta_ptr_vec, """ @@ -189,7 +208,7 @@ INFER_META_TEMPLATE = """ phi::{}({}{}); """ -# 7. DenseTensor Kernel Call +# 8. DenseTensor Kernel Call # TODO(chenweihang): support kernel fallback later SINGLE_OUTPUT_NAME = """dense_out""" # TODO(chenweihang): support vector and tuple output later @@ -205,10 +224,6 @@ KERNEL_CALL_TEMPLATE = """ PREFIX_VECTOR_TENSOR_NAME = "dense_input_" SUFFIX_VECTOR_TENSOR_NAME = "_vec" -# 8. Reshard Output -OUTPUT_RESHARD_TEMPLATE = """ -""" - # BaseAPI members: # inputs: # names : [], list of input names @@ -252,6 +267,17 @@ class DistForwardAPI(ForwardAPI): self.inplace_flag = False self.dist_output_args = [] self.dense_output_args = [] + self.input_args_code = "" + + # override BaseAPI's method + def parse_infer_meta(self, infer_meta_config): + infer_meta = infer_meta_config + if 'param' not in infer_meta_config: + infer_meta['param'] = None + if 'spmd_rule' not in infer_meta_config: + infer_meta['spmd_rule'] = None + + return infer_meta def need_to_generate_code_for_inplace_impl(self, i): return ( @@ -289,6 +315,55 @@ class DistForwardAPI(ForwardAPI): input_args = input_args[:-2] return AUTO_PARALLEL_COND_TEMPLATE.format(input_args) + def generate_infer_spmd_code(self) -> str: + if self.infer_meta['spmd_rule'] is not None: + input_names = self.inputs['names'] + attr_names = self.attrs['names'] + + infer_meta_params = ( + self.infer_meta['param'] + if self.infer_meta['param'] is not None + else input_names + attr_names + ) + input_decl_code = "" + self.input_args_code = "" + for param in infer_meta_params: + if param in input_names: + if self.inputs['input_info'][param] == "const Tensor&": + input_decl_code += SINGLE_DIST_META_IN_TEMPLATE.format( + param, param + ) + self.input_args_code += "meta_dist_" + param + ", " + else: + raise ValueError( + f"{self.api} : Param of infer_spmd error : {self.inputs['input_info'][param]} type is not supported." + ) + elif param in attr_names: + self.input_args_code = self.input_args_code + param + ", " + elif isinstance(param, str): + self.input_args_code = ( + self.input_args_code + "\"" + param + "\", " + ) + elif isinstance(param, bool): + self.input_args_code = ( + self.input_args_code + str(param).lower() + ", " + ) + else: + self.input_args_code = ( + self.input_args_code + str(param) + ", " + ) + + # TODO(chenweihang): add general spmd rule later + infer_spmd_code = "" + infer_spmd_func_code = self.infer_meta['spmd_rule'] + infer_spmd_code = INFER_SPMD_TEMPLATE.format( + infer_spmd_func_code, self.input_args_code[:-2] + ) + + return input_decl_code + infer_spmd_code + else: + return "" + def generate_output_creation_code(self) -> str: # forward api need to generate api and kernel outputs output_num = len(self.outputs['types']) @@ -311,7 +386,10 @@ class DistForwardAPI(ForwardAPI): self.dist_output_args.append('dist_out') self.dense_output_args.append('dense_out') if self.outputs['types'][0] == 'Tensor': - output_creation_code += SINGLE_OUT_CREATION_TEMPLATE + if self.infer_meta['spmd_rule'] is not None: + output_creation_code += SINGLE_OUT_CREATION_TEMPLATE + else: + output_creation_code += SINGLE_OUT_CREATION_TEMPLATE_NO_SPMD elif self.outputs['types'][0] == 'std::vector': output_creation_code += VECTOR_OUT_CREATION_TEMPLATE.format( self.outputs['out_size_expr'][0] @@ -365,11 +443,18 @@ class DistForwardAPI(ForwardAPI): ) ) else: - output_creation_code += ( - MULTI_SINGLE_OUT_CREATION_TEMPLATE.format( - i, get_out_code, i, i + if self.infer_meta['spmd_rule'] is not None: + output_creation_code += ( + MULTI_SINGLE_OUT_CREATION_TEMPLATE.format( + idx=i, out=get_out_code + ) + ) + else: + output_creation_code += ( + MULTI_SINGLE_OUT_CREATION_TEMPLATE_NO_SPMD.format( + idx=i, out=get_out_code + ) ) - ) else: raise ValueError( "{} : Output error: the output should not be empty.".format( @@ -379,10 +464,9 @@ class DistForwardAPI(ForwardAPI): return output_creation_code - def generate_infer_spmd_code(self) -> str: + def generate_infer_global_shape_code(self) -> str: input_names = self.inputs['names'] attr_names = self.attrs['names'] - output_names = self.outputs['names'] # 1. get infer meta func name infer_meta = self.infer_meta @@ -399,18 +483,18 @@ class DistForwardAPI(ForwardAPI): for param in infer_meta_params: if param in input_names: if self.inputs['input_info'][param] == "const Tensor&": - input_args_code += SINGLE_DIST_META_IN_TEMPLATE.format( + input_args_code += SINGLE_GLOBAL_META_IN_TEMPLATE.format( param ) elif ( self.inputs['input_info'][param] == "const std::vector&" ): - input_args_code += VECTOR_DIST_META_IN_TEMPLATE.format( + input_args_code += VECTOR_GLOBAL_META_IN_TEMPLATE.format( param ) - input_meta_code += VECTOR_DIST_META_IN_DECL_TEMPLATE.format( - name=param + input_meta_code += ( + VECTOR_GLOBAL_META_IN_DECL_TEMPLATE.format(name=param) ) else: raise ValueError( @@ -430,7 +514,7 @@ class DistForwardAPI(ForwardAPI): output_args_code = "" for i, out_name in enumerate(self.dist_output_args): if self.outputs['types'][i] == 'std::vector': - output_decl_code += VECTOR_DIST_META_OUT_DECL_TEMPLATE.format( + output_decl_code += VECTOR_GLOBAL_META_OUT_DECL_TEMPLATE.format( name=out_name ) if len(self.dense_output_args) == 1: @@ -440,7 +524,7 @@ class DistForwardAPI(ForwardAPI): f"{out_name} ? {out_name}_meta_ptr_vec : nullptr, " ) else: - output_decl_code += SINGLE_DIST_META_OUT_DECL_TEMPLATE.format( + output_decl_code += SINGLE_GLOBAL_META_OUT_DECL_TEMPLATE.format( out_name, out_name ) if len(self.dense_output_args) == 1: @@ -451,10 +535,12 @@ class DistForwardAPI(ForwardAPI): ) output_args_code = output_args_code[:-2] + if self.input_args_code != "": + input_args_code = self.input_args_code return ( output_decl_code + input_meta_code - + INFER_SPMD_TEMPLATE.format( + + INFER_GLOBAL_SHAPE_TEMPLATE.format( infer_meta_func_code, input_args_code, output_args_code ) ) @@ -465,7 +551,35 @@ class DistForwardAPI(ForwardAPI): ) def generate_reshard_input_code(self) -> str: - return INPUT_RESHARD_TEMPLATE.format() + input_reshard_code = "" + if self.infer_meta['spmd_rule'] is not None: + input_names = self.inputs['names'] + + infer_meta = self.infer_meta + infer_meta_params = ( + infer_meta['param'] + if infer_meta['param'] is not None + else input_names + ) + for i, param in enumerate(infer_meta_params): + if param in input_names: + if self.inputs['input_info'][param] == "const Tensor&": + input_reshard_code += ( + SINGLE_INPUT_RESHARD_TEMPLATE.format( + arg=param, idx=i + ) + ) + else: + raise ValueError( + f"{self.api} : Param of reshard input error : {self.inputs['input_info'][param]} type is not supported." + ) + else: + # do nothing + pass + else: + # do nothingd + pass + return input_reshard_code def generate_single_dense_input( self, @@ -479,14 +593,18 @@ class DistForwardAPI(ForwardAPI): if kernel_param is None: kernel_param = input_names + attr_names - input_tensor_code += SINGLE_PREPARE_DATA_TEMPLATE.format( - input_name, - input_name, - kernel_param.index(input_name), - trans_flag, - input_name, - input_name, - ) + if self.infer_meta['spmd_rule'] is not None: + input_tensor_code += SINGLE_PREPARE_DATA_TEMPLATE.format( + arg=input_name, + idx=kernel_param.index(input_name), + flag=trans_flag, + ) + else: + input_tensor_code += SINGLE_PREPARE_DATA_TEMPLATE_NO_RESHARD.format( + arg=input_name, + idx=kernel_param.index(input_name), + flag=trans_flag, + ) return input_tensor_code @@ -706,9 +824,6 @@ class DistForwardAPI(ForwardAPI): ", ".join(self.dense_output_args), ) - def generate_reshard_output_code(self) -> str: - return OUTPUT_RESHARD_TEMPLATE.format() - def generate_return_code(self) -> str: return self.gene_return_code() @@ -718,14 +833,14 @@ class DistForwardAPI(ForwardAPI): return "" return MAIN_DIST_BRANCH_TEMPLATE.format( self.generate_if_condition_code(), - self.generate_output_creation_code(), self.generate_infer_spmd_code(), + self.generate_output_creation_code(), + self.generate_infer_global_shape_code(), self.generate_kernel_selection_code(), self.generate_reshard_input_code(), self.generate_prepare_data_code(), self.generate_infer_meta_code(), self.generate_kernel_call_code(), - self.generate_reshard_output_code(), self.generate_return_code(), ) @@ -777,11 +892,14 @@ class DistForwardAPI(ForwardAPI): # 3. doesn't support view api # 4. only for general forward and backward # 5. only support single tensor input and output + # 6. doesn't support double grad and triple grad dist_branch_code = "" if ( len(self.inputs['names']) > 0 and len(self.view_map) == 0 and self.check_argument_whether_support_auto_parallel() + and not self.api.endswith("_double_grad") + and not self.api.endswith("_triple_grad") ): dist_branch_code = self.generate_auto_paralel_branch() return API_IMPL_TEMPLATE.format( diff --git a/paddle/phi/api/yaml/generator/dist_bw_api_gen.py b/paddle/phi/api/yaml/generator/dist_bw_api_gen.py index 487d6e3a257..25944e33569 100644 --- a/paddle/phi/api/yaml/generator/dist_bw_api_gen.py +++ b/paddle/phi/api/yaml/generator/dist_bw_api_gen.py @@ -22,10 +22,24 @@ from dist_api_gen import DistForwardAPI # Code Gen Templates # ###################### +MAIN_DIST_BRANCH_TEMPLATE = """ + // Auto Parallel condition + if ({}) {{ + // 1. Create API Output & Prepare Dist and Dense Output{} + // 2. Infer DistTensor's Global Shape{} + // 3. Select Kernel{} + // 4. PrepareData (DataTransform & Prepare Dense Input){} + // 5. Infer Local DenseTensor Meta{} + // 6. DenseTensor Kernel Call{} + // 7. Return + {} + }} +""" + # 1. Create API Outputs SINGLE_OUT_CREATION_TEMPLATE = """ auto dist_out = SetKernelDistOutput({}); - auto dense_out = const_cast(&dist_out->value()); + auto dense_out = dist_out->unsafe_mutable_value(); """ VECTOR_OUT_CREATION_TEMPLATE = """ auto dist_out = SetKernelDistOutput({name}); @@ -39,7 +53,21 @@ INPLACE_OUT_CREATION_TEMPLATE = """ """ MULTI_SINGLE_OUT_CREATION_TEMPLATE = """ auto dist_out_{} = SetKernelDistOutput({}); - auto dense_out_{} = const_cast(&dist_out_{}->value()); + auto dense_out_{} = dist_out_{}->unsafe_mutable_value(); +""" + +# 2. Infer Global Shape +SINGLE_DIST_META_IN_TEMPLATE = """MakeDistMetaTensor(*{}.impl()), """ +SINGLE_DIST_META_OUT_DECL_TEMPLATE = """ + phi::distributed::DistMetaTensor meta_{}({});""" +INFER_GLOBAL_SHAPE_TEMPLATE = """ + phi::{}({}{}); +""" + +# 4. PrepareData (DataTransform & Prepare Dist and Dense Input) +SINGLE_PREPARE_DATA_TEMPLATE = """ + auto dist_input_{arg} = PrepareDataForDistTensor({arg}, GetKernelInputArgDef(kernel.InputAt({idx}), kernel_backend), {flag}, kernel_result.is_stride_kernel); + auto input_{arg} = &dist_input_{}->value(); """ @@ -131,6 +159,21 @@ class DistBackwardAPI(DistForwardAPI, BackwardAPI): def gene_api_declaration(self) -> str: return BackwardAPI.gene_api_declaration(self) + def generate_auto_paralel_branch(self) -> str: + # if no tensor input, do not genetate auto parallel branch + if len(self.inputs['names']) == 0: + return "" + return MAIN_DIST_BRANCH_TEMPLATE.format( + self.generate_if_condition_code(), + self.generate_output_creation_code(), + self.generate_infer_global_shape_code(), + self.generate_kernel_selection_code(), + self.generate_prepare_data_code(), + self.generate_infer_meta_code(), + self.generate_kernel_call_code(), + self.generate_return_code(), + ) + def header_include(): return """ diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index fcf4b2f28bb..4c151374c68 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -651,6 +651,7 @@ output : Tensor infer_meta : func : MatmulInferMeta + spmd_rule : MatmulSpmdInferForward kernel : func : matmul backward : matmul_grad diff --git a/paddle/phi/core/distributed/auto_parallel/dist_tensor.h b/paddle/phi/core/distributed/auto_parallel/dist_tensor.h index 7af036a9268..bc8b98d81a3 100644 --- a/paddle/phi/core/distributed/auto_parallel/dist_tensor.h +++ b/paddle/phi/core/distributed/auto_parallel/dist_tensor.h @@ -66,6 +66,15 @@ class DistTensor final /// \return The DenseTensor value's const reference const DenseTensor& value() const { return value_; } + /// \brief Returns the mutable dense tensor value in dist tensor. + /// \note If DenseTensor value is modified externally, the corresponding + /// relationship between it and the current tensor's global dims and + /// dist attr may be destroyed, which may introduce some subtle bugs, + /// so you need to make sure to consider it thoroughly when using + /// this method. + /// \return The mutable pointer of DenseTensor value + DenseTensor* unsafe_mutable_value() { return &value_; } + /// \brief Returns the global dims of the dist tensor. /// \return The global dims of the dist tensor. const DDim& local_dims() const; diff --git a/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.cc b/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.cc index 3b94dc017e5..531727b3ee8 100644 --- a/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.cc +++ b/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.cc @@ -73,11 +73,6 @@ bool SpmdRuleFactory::ContainsSpmdRule(const std::string& kernel_name) const { } int SpmdRuleFactory::InsertSpmdRule(std::string kernel_name, SpmdRule rule) { - PADDLE_ENFORCE_NE( - ContainsSpmdRule(kernel_name), - true, - phi::errors::AlreadyExists( - "`%s` Kernel's Spmd rules has been registered.", kernel_name)); spmd_rule_map_.insert({std::move(kernel_name), std::move(rule)}); return 0; } diff --git a/paddle/phi/infermeta/spmd_rules/rules.h b/paddle/phi/infermeta/spmd_rules/rules.h index ad519ff287a..5ec2f212ec6 100644 --- a/paddle/phi/infermeta/spmd_rules/rules.h +++ b/paddle/phi/infermeta/spmd_rules/rules.h @@ -27,12 +27,12 @@ limitations under the License. */ * 2. Since the infer functions of Spmd forward and backward are closely related * and need to be registered together, we manage them together in one file. * - * 3. SPMD rules are much smaller than infermeta function, and we manage files - * in operator units. + * 3. SPMD rules are less than infermeta function, and we manage files by + * operator. * * 4. The previous registration used some compile-time regular matching methods, * which was less flexible, and the registration of SPMD rules here is declare - * directly in the header file + * directly in the header file. */ namespace phi { diff --git a/test/auto_parallel/CMakeLists.txt b/test/auto_parallel/CMakeLists.txt index aeb00a0fc72..eae16e02454 100644 --- a/test/auto_parallel/CMakeLists.txt +++ b/test/auto_parallel/CMakeLists.txt @@ -88,6 +88,10 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_reshard_r_to_p MODULES test_reshard_r_to_p) set_tests_properties(test_reshard_r_to_p PROPERTIES LABELS "RUN_TYPE=EXECLUSIVE" TIMEOUT 100) + py_test_modules(test_semi_auto_parallel_basic MODULES + test_semi_auto_parallel_basic) + set_tests_properties(test_semi_auto_parallel_basic + PROPERTIES LABELS "RUN_TYPE=EXECLUSIVE" TIMEOUT 100) # End of unittests WITH multi cards and timeout # NOTE(zyl): unittests WITH multi cards and WITHOUT timeout diff --git a/test/auto_parallel/semi_auto_parallel_for_matmul.py b/test/auto_parallel/semi_auto_parallel_for_matmul.py new file mode 100644 index 00000000000..953c734e6dc --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_for_matmul.py @@ -0,0 +1,84 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np + +import paddle +import paddle.distributed as dist + + +class TestMatmulApiForSemiAutoParallel: + def __init__(self): + self._dtype = os.getenv("dtype") + self._seeds = eval(os.getenv("seeds")) + self._backend = os.getenv("backend") + self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + + def test_body(self, x_specs, y_specs): + x_shape = [64, 32] + y_shape = [32, 48] + x = paddle.randn(x_shape, self._dtype) + y = paddle.randn(y_shape, self._dtype) + + x_dist_attr = dist.DistAttr(mesh=self._mesh, sharding_specs=x_specs) + y_dist_attr = dist.DistAttr(mesh=self._mesh, sharding_specs=y_specs) + + dist_x = dist.shard_tensor(x, dist_attr=x_dist_attr) + dist_y = dist.shard_tensor(y, dist_attr=y_dist_attr) + + dist_out = paddle.matmul(dist_x, dist_y) + + # verify global shape + out_shape = [64, 48] + np.testing.assert_equal(dist_out.shape, out_shape, verbose=True) + + return dist_out + + def test_case1(self): + # case1: mk[0,-1],kn[-1,-1] -> mk[0,-1],kn[-1,-1] = mn[0,-1] partial[] + dist_out = self.test_body(x_specs=['x', None], y_specs=[None, None]) + # verify local shape and dist attr + np.testing.assert_equal(dist_out._local_shape, [32, 48], verbose=True) + np.testing.assert_equal( + dist_out.dist_attr.dims_mapping, [0, -1], verbose=True + ) + assert dist_out.dist_attr._is_partial() is False + + def test_case2(self): + # case2: mk[-1, 0],kn[-1,-1] --> mk[-1, 0],kn[0, -1] = nm[-1, -1] partial[0] + dist_out = self.test_body(x_specs=[None, 'x'], y_specs=[None, None]) + # verify local shape + np.testing.assert_equal(dist_out._local_shape, [64, 48], verbose=True) + np.testing.assert_equal( + dist_out.dist_attr.dims_mapping, [-1, -1], verbose=True + ) + assert dist_out.dist_attr._is_partial() is True + assert dist_out.dist_attr._partial_dims() == {0} + + def run_test_case(self): + if self._backend == "cpu": + paddle.set_device("cpu") + elif self._backend == "gpu": + paddle.set_device("gpu:" + str(dist.get_rank())) + else: + raise ValueError("Only support cpu or gpu backend.") + + self.test_case1() + self.test_case2() + + +if __name__ == '__main__': + TestMatmulApiForSemiAutoParallel().run_test_case() diff --git a/test/auto_parallel/test_semi_auto_parallel_basic.py b/test/auto_parallel/test_semi_auto_parallel_basic.py new file mode 100644 index 00000000000..a1ec1b18e9b --- /dev/null +++ b/test/auto_parallel/test_semi_auto_parallel_basic.py @@ -0,0 +1,41 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import collective.test_communication_api_base as test_base + + +class TestSemiAutoParallelMatmul(test_base.CommunicationTestDistBase): + def setUp(self): + super().setUp(num_of_devices=2, timeout=120) + self._default_envs = { + "dtype": "float32", + "seeds": str(self._seeds), + } + self._changeable_envs = {"backend": ["cpu", "gpu"]} + + def test_matmul_api(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_parallel_for_matmul.py", + user_defined_envs=envs, + ) + + +if __name__ == "__main__": + unittest.main() -- GitLab