未验证 提交 e9364a38 编写于 作者: C Chen Weihang 提交者: GitHub

[AutoParallel] Generate spmd rule and reshard impl in phi api (#56831)

* add spmd and reshard code gen

* add backward reshard code gen

* test matmul forward success

* polish test impl

* add unsafe mutable value

* polish details and add test

* fix unittest time out

* fix typo

* refactor reshard input generate impl

* resolve conflict with develop

* fix compile error
上级 3b5bb4ba
......@@ -366,7 +366,7 @@ def check_op_config(op_entry, op_name):
'composite',
'support_dygraph_mode',
)
infer_meta_key_set = ('func', 'param')
infer_meta_key_set = ('func', 'param', 'spmd_rule')
kernel_key_set = (
'func',
'param',
......
......@@ -22,6 +22,7 @@ PHI_DECLARE_bool(use_stride_kernel);
#include "glog/logging.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
namespace paddle {
......@@ -530,13 +531,18 @@ void TransStride(phi::DeviceContext* dev_ctx,
/* ------------------ for auto parallel ----------------------- */
phi::distributed::DistTensor* SetKernelDistOutput(Tensor* out) {
phi::distributed::DistMetaTensor MakeDistMetaTensor(
const phi::TensorBase& tensor) {
return phi::distributed::DistMetaTensor(tensor);
}
phi::distributed::DistTensor* SetKernelDistOutput(
Tensor* out, const phi::distributed::TensorDistAttr& dist_attr) {
if (out) {
// TODO(chenweihang): now all dist case are nullptr
if (out->impl() == nullptr) {
// TODO(chenweihang): polish code, dist_attr is null now
auto dist_t = std::make_shared<phi::distributed::DistTensor>(
phi::DDim(), phi::distributed::TensorDistAttr());
auto dist_t = std::make_shared<phi::distributed::DistTensor>(phi::DDim(),
dist_attr);
out->set_impl(dist_t);
}
return static_cast<phi::distributed::DistTensor*>(out->impl().get());
......
......@@ -18,18 +18,15 @@ limitations under the License. */
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/core/meta_tensor.h"
#include "paddle/phi/core/selected_rows.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
#include "paddle/phi/core/string_tensor.h"
namespace phi {
namespace distributed {
class DistTensor;
} // namespace distributed
} // namespace phi
namespace paddle {
namespace experimental {
......@@ -139,9 +136,17 @@ void TransStrideLegacy(phi::DeviceContext* dev_ctx,
/* ------------------ for auto parallel ----------------------- */
phi::distributed::DistTensor* SetKernelDistOutput(Tensor* out);
phi::distributed::DistMetaTensor MakeDistMetaTensor(
const phi::TensorBase& tensor);
phi::distributed::DistTensor* SetKernelDistOutput(
Tensor* out,
const phi::distributed::TensorDistAttr& dist_attr =
phi::distributed::TensorDistAttr());
std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
std::vector<Tensor*> out);
std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
size_t out_size, std::vector<Tensor>* out);
......
......@@ -21,6 +21,8 @@ limitations under the License. */
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h"
#include "paddle/phi/core/flags.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
......@@ -597,15 +599,45 @@ void TransDataBackend(const phi::SelectedRows* tensor,
/* ------------------ for auto parallel ----------------------- */
std::shared_ptr<phi::distributed::DistTensor> ReshardDistTensor(
phi::DeviceContext* dev_ctx,
const Tensor& tensor,
const phi::distributed::TensorDistAttr& dist_attr) {
auto tensor_in = tensor.impl();
if (tensor_in) {
phi::distributed::DistTensor* dist_tensor =
static_cast<phi::distributed::DistTensor*>(tensor_in.get());
if (dist_tensor->dist_attr() != dist_attr) {
VLOG(6) << "Reshard tensor from " << dist_tensor->dist_attr() << " to "
<< dist_attr;
auto* func = phi::distributed::ChooseProperReshardFunction(*dist_tensor,
dist_attr);
return func->Eval(dev_ctx, *dist_tensor, dist_attr);
}
return std::static_pointer_cast<phi::distributed::DistTensor>(tensor_in);
}
return nullptr;
}
std::shared_ptr<phi::distributed::DistTensor> PrepareDataForDistTensor(
const Tensor& input,
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag,
bool is_stride_kernel) {
const auto& tensor_in = input.impl();
if (tensor_in) {
phi::distributed::DistTensor* dist_tensor =
static_cast<phi::distributed::DistTensor*>(tensor_in.get());
return PrepareDataForDistTensor(
std::static_pointer_cast<phi::distributed::DistTensor>(input.impl()),
target_args_def,
transform_flag,
is_stride_kernel);
}
std::shared_ptr<phi::distributed::DistTensor> PrepareDataForDistTensor(
const std::shared_ptr<phi::distributed::DistTensor>& input,
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag,
bool is_stride_kernel) {
if (input) {
phi::distributed::DistTensor* dist_tensor = input.get();
const phi::DenseTensor& dense_tensor = dist_tensor->value();
if (!transform_flag.NeedTransform() || !dense_tensor.initialized() ||
(!NeedTransformPlace(
......@@ -618,16 +650,18 @@ std::shared_ptr<phi::distributed::DistTensor> PrepareDataForDistTensor(
transform_flag) &&
!NeedTransform2Contiguous(is_stride_kernel,
dense_tensor.meta().is_contiguous()))) {
return std::static_pointer_cast<phi::distributed::DistTensor>(tensor_in);
return input;
}
phi::DenseTensor out = TransformData(
dense_tensor, target_args_def, transform_flag, is_stride_kernel);
// TODO(chenweihang): The global meta in DistTensor is not changed,
// but the local meta in DenseTensor maybe changed, such as layout
// change(NCHW->NHWC), so the new DistTensor's meta maybe not unified.
VLOG(6) << "PrepareDataForDistTensor return transformed dist tensor";
return std::make_shared<phi::distributed::DistTensor>(
out, dist_tensor->dist_attr());
auto dist_out = std::make_shared<phi::distributed::DistTensor>(
dist_tensor->dims(), dist_tensor->dist_attr());
auto* out = dist_out->unsafe_mutable_value();
*out = TransformData(
dense_tensor, target_args_def, transform_flag, is_stride_kernel);
return dist_out;
}
return nullptr;
}
......
......@@ -21,8 +21,10 @@ limitations under the License. */
#include "paddle/phi/core/sparse_csr_tensor.h"
namespace phi {
class DeviceContext;
namespace distributed {
class DistTensor;
class TensorDistAttr;
} // namespace distributed
} // namespace phi
......@@ -173,13 +175,23 @@ inline bool NeedTransformPlace(const phi::Place& src_place,
/* ------------------ for auto parallel ----------------------- */
// TODO(chenweihang): impl Reshard input and output function
std::shared_ptr<phi::distributed::DistTensor> ReshardDistTensor(
phi::DeviceContext* dev_ctx,
const Tensor& tensor,
const phi::distributed::TensorDistAttr& dist_attr);
std::shared_ptr<phi::distributed::DistTensor> PrepareDataForDistTensor(
const Tensor& input,
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag,
bool is_stride_kernel);
std::shared_ptr<phi::distributed::DistTensor> PrepareDataForDistTensor(
const std::shared_ptr<phi::distributed::DistTensor>& input,
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag,
bool is_stride_kernel);
std::vector<std::shared_ptr<phi::distributed::DistTensor>>
PrepareDataForDistTensor(const std::vector<Tensor>& input,
const phi::TensorArgDef& target_args_def,
......
......@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/phi/api/include/context_pool.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/core/string_tensor_utils.h"
#include "paddle/phi/core/tensor_utils.h"
#ifdef PADDLE_WITH_CUSTOM_DEVICE
......@@ -50,6 +51,8 @@ bool HasAllocation(const phi::TensorBase& t) {
} else if (phi::StringTensor::classof(&t)) {
return phi::StringTensorUtils::GetHolder(
static_cast<const phi::StringTensor&>(t)) != nullptr;
} else if (phi::distributed::DistTensor::classof(&t)) {
return static_cast<const phi::distributed::DistTensor&>(t).defined();
} else {
return false;
}
......
......@@ -96,8 +96,6 @@ struct KernelKeyParser : ArgsIterator<KernelKeyParser> {
// data_promote
DataTypeSet dtype_set{DataType::UNDEFINED};
// TODO(chenweihang): deal with multiple diff input Tensors
// TODO(chenweihang): add global device guard method to set backend
inline void AssignKernelKeySet(const phi::TensorBase& tensor) {
// assign Backend
BackendSet tensor_backend_set = detail::GetTensorBackendSet(tensor);
......
......@@ -379,6 +379,10 @@ def source_include(header_file_path):
#include "paddle/phi/api/profiler/event_tracing.h"
#include "paddle/phi/api/profiler/supplement_tracing.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/phi/infermeta/spmd_rules/rules.h"
#endif
PD_DECLARE_bool(conv2d_disable_cudnn);
PD_DECLARE_int32(low_precision_op_list);
"""
......
......@@ -44,14 +44,14 @@ PADDLE_THROW(phi::errors::Unimplemented(
MAIN_DIST_BRANCH_TEMPLATE = """
// Auto Parallel condition
if ({}) {{
// 1. Create API Output & Prepare Dist and Dense Output{}
// 2. InferSPMD (Infer Global Shape and DistAttr of Inputs&Outputs){}
// 3. Select Kernel{}
// 4. Reshard Input{}
// 5. PrepareData (DataTransform & Prepare Dist and Dense Input){}
// 6. Infer Local DenseTensor Meta{}
// 7. DenseTensor Kernel Call{}
// 8. Reshard Output{}
// 1. InferSpmd (Infer DistAttr of Inputs&Outputs){}
// 2. Create API Output & Prepare Dist and Dense Output{}
// 3. Infer DistTensor's Global Shape{}
// 4. Select Kernel{}
// 5. Reshard Input{}\n
// 6. PrepareData (DataTransform & Prepare Dense Input){}
// 7. Infer Local DenseTensor Meta{}
// 8. DenseTensor Kernel Call{}
// 9. Return
{}
}}
......@@ -60,20 +60,35 @@ MAIN_DIST_BRANCH_TEMPLATE = """
# Auto Parallel condition
AUTO_PARALLEL_COND_TEMPLATE = """AllInputsAreDistTensor({})"""
# 1. Create API Outputs
# 1. InferSPMD
SINGLE_DIST_META_IN_TEMPLATE = """
auto meta_dist_{} = MakeDistMetaTensor(*{}.impl());"""
INFER_SPMD_TEMPLATE = """
auto spmd_info = phi::distributed::{}({});
"""
# 2. Create API Outputs
API_OUT_CREATION_TEMPLATE = """
{} api_output{};
"""
INPLACE_API_OUT_CREATION_TEMPLATE = """
{} api_output{{{}}};
"""
SINGLE_OUT_CREATION_TEMPLATE = """
SINGLE_OUT_CREATION_TEMPLATE_NO_SPMD = """
auto dist_out = SetKernelDistOutput(&api_output);
auto dense_out = const_cast<phi::DenseTensor*>(&dist_out->value());
auto dense_out = dist_out->unsafe_mutable_value();
"""
MULTI_SINGLE_OUT_CREATION_TEMPLATE_NO_SPMD = """
auto dist_out_{idx} = SetKernelDistOutput({out});
auto dense_out_{idx} = dist_out_{idx}->unsafe_mutable_value();
"""
SINGLE_OUT_CREATION_TEMPLATE = """
auto dist_out = SetKernelDistOutput(&api_output, spmd_info.second[0]);
auto dense_out = dist_out->unsafe_mutable_value();
"""
MULTI_SINGLE_OUT_CREATION_TEMPLATE = """
auto dist_out_{} = SetKernelDistOutput({});
auto dense_out_{} = const_cast<phi::DenseTensor*>(&dist_out_{}->value());
auto dist_out_{idx} = SetKernelDistOutput({out}, spmd_info.second[{idx}]);
auto dense_out_{idx} = dist_out_{idx}->unsafe_mutable_value();
"""
VECTOR_OUT_CREATION_TEMPLATE = """
auto dist_out = SetKernelDistOutput({}, &api_output);
......@@ -93,12 +108,12 @@ MULTI_VECTOR_OUT_CREATION_TEMPLATE = """
TUPLE_OUT_CREATION_TEMPLATE = """
"""
# 2. InferSPMD
# Call InferMeta now, replace by InferSPMD function later
# TODO(chenweihang): InferSPMD function design
SINGLE_DIST_META_IN_TEMPLATE = """MakeMetaTensor(*{}.impl()), """
VECTOR_DIST_META_IN_TEMPLATE = """{}_meta_ptr_vec, """
VECTOR_DIST_META_IN_DECL_TEMPLATE = """
# 3. Infer Global Shape
# TODO(chenweihang): the input MetaTensor created by Inferspmd can be reused
# for InferGlobalShape to avoid creating repeated inputs.
SINGLE_GLOBAL_META_IN_TEMPLATE = """MakeMetaTensor(*{}.impl()), """
VECTOR_GLOBAL_META_IN_TEMPLATE = """{}_meta_ptr_vec, """
VECTOR_GLOBAL_META_IN_DECL_TEMPLATE = """
std::vector<phi::MetaTensor> {name}_meta_vec;
for (auto tmp : {name}) {{
{name}_meta_vec.emplace_back(MakeMetaTensor(*tmp.impl()));
......@@ -109,11 +124,11 @@ VECTOR_DIST_META_IN_DECL_TEMPLATE = """
}}
"""
# TODO(GhostScreaming): support optional args later
OPTIONAL_DIST_VECTOR_META_IN_TEMPLATE = """
OPTIONAL_GLOBAL_VECTOR_META_IN_TEMPLATE = """
"""
SINGLE_DIST_META_OUT_DECL_TEMPLATE = """
SINGLE_GLOBAL_META_OUT_DECL_TEMPLATE = """
phi::MetaTensor meta_{}({});"""
VECTOR_DIST_META_OUT_DECL_TEMPLATE = """
VECTOR_GLOBAL_META_OUT_DECL_TEMPLATE = """
std::vector<phi::MetaTensor> {name}_meta_vec;
for (auto tmp : {name}) {{
{name}_meta_vec.emplace_back(phi::MetaTensor(tmp));
......@@ -123,11 +138,11 @@ VECTOR_DIST_META_OUT_DECL_TEMPLATE = """
{name}_meta_ptr_vec[i] = &{name}_meta_vec[i];
}}
"""
INFER_SPMD_TEMPLATE = """
INFER_GLOBAL_SHAPE_TEMPLATE = """
phi::{}({}{});
"""
# 3. Select Kernel
# 4. Select Kernel
KERNEL_SELECTION_TEMPLATE = """
VLOG(6) << "{} API dist branch: kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]";
auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
......@@ -137,14 +152,18 @@ KERNEL_SELECTION_TEMPLATE = """
auto* dev_ctx = GetDeviceContextByBackend(kernel_result.has_fallback_cpu ? Backend::CPU : kernel_backend);
"""
# 4. Reshard Input
INPUT_RESHARD_TEMPLATE = """
"""
# 5. Reshard Input
SINGLE_INPUT_RESHARD_TEMPLATE = """
auto dist_input_{arg} = ReshardDistTensor(dev_ctx, {arg}, spmd_info.first[{idx}]);"""
# 5. PrepareData
# 6. PrepareData
SINGLE_PREPARE_DATA_TEMPLATE = """
auto dist_input_{} = PrepareDataForDistTensor({}, GetKernelInputArgDef(kernel.InputAt({}), kernel_backend), {}, kernel_result.is_stride_kernel);
auto input_{} = &dist_input_{}->value();
dist_input_{arg} = PrepareDataForDistTensor(dist_input_{arg}, GetKernelInputArgDef(kernel.InputAt({idx}), kernel_backend), {flag}, kernel_result.is_stride_kernel);
auto input_{arg} = &dist_input_{arg}->value();
"""
SINGLE_PREPARE_DATA_TEMPLATE_NO_RESHARD = """
auto dist_input_{arg} = PrepareDataForDistTensor({arg}, GetKernelInputArgDef(kernel.InputAt({idx}), kernel_backend), {flag}, kernel_result.is_stride_kernel);
auto input_{arg} = &dist_input_{arg}->value();
"""
VECTOR_PREPARE_DATA_TEMPLATE = """
auto dist_input_{name}_vec = PrepareDataForDistTensor({name}, GetKernelInputArgDef(kernel.InputAt({index}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel);
......@@ -170,7 +189,7 @@ INFER_META_VECTOR_INPUT_TEMPLATE = """
const auto& input_{} = *input_{}_uq_ptr;
"""
# 6. Infer Local DenseTensor Meta
# 7. Infer Local DenseTensor Meta
SINGLE_META_IN_TEMPLATE = """MakeMetaTensor(*input_{}), """
# TODO(GhostScreaming): support optional args later
VECTOR_META_IN_TEMPLATE = """dense_input_{}_meta_ptr_vec, """
......@@ -189,7 +208,7 @@ INFER_META_TEMPLATE = """
phi::{}({}{});
"""
# 7. DenseTensor Kernel Call
# 8. DenseTensor Kernel Call
# TODO(chenweihang): support kernel fallback later
SINGLE_OUTPUT_NAME = """dense_out"""
# TODO(chenweihang): support vector and tuple output later
......@@ -205,10 +224,6 @@ KERNEL_CALL_TEMPLATE = """
PREFIX_VECTOR_TENSOR_NAME = "dense_input_"
SUFFIX_VECTOR_TENSOR_NAME = "_vec"
# 8. Reshard Output
OUTPUT_RESHARD_TEMPLATE = """
"""
# BaseAPI members:
# inputs:
# names : [], list of input names
......@@ -252,6 +267,17 @@ class DistForwardAPI(ForwardAPI):
self.inplace_flag = False
self.dist_output_args = []
self.dense_output_args = []
self.input_args_code = ""
# override BaseAPI's method
def parse_infer_meta(self, infer_meta_config):
infer_meta = infer_meta_config
if 'param' not in infer_meta_config:
infer_meta['param'] = None
if 'spmd_rule' not in infer_meta_config:
infer_meta['spmd_rule'] = None
return infer_meta
def need_to_generate_code_for_inplace_impl(self, i):
return (
......@@ -289,6 +315,55 @@ class DistForwardAPI(ForwardAPI):
input_args = input_args[:-2]
return AUTO_PARALLEL_COND_TEMPLATE.format(input_args)
def generate_infer_spmd_code(self) -> str:
if self.infer_meta['spmd_rule'] is not None:
input_names = self.inputs['names']
attr_names = self.attrs['names']
infer_meta_params = (
self.infer_meta['param']
if self.infer_meta['param'] is not None
else input_names + attr_names
)
input_decl_code = ""
self.input_args_code = ""
for param in infer_meta_params:
if param in input_names:
if self.inputs['input_info'][param] == "const Tensor&":
input_decl_code += SINGLE_DIST_META_IN_TEMPLATE.format(
param, param
)
self.input_args_code += "meta_dist_" + param + ", "
else:
raise ValueError(
f"{self.api} : Param of infer_spmd error : {self.inputs['input_info'][param]} type is not supported."
)
elif param in attr_names:
self.input_args_code = self.input_args_code + param + ", "
elif isinstance(param, str):
self.input_args_code = (
self.input_args_code + "\"" + param + "\", "
)
elif isinstance(param, bool):
self.input_args_code = (
self.input_args_code + str(param).lower() + ", "
)
else:
self.input_args_code = (
self.input_args_code + str(param) + ", "
)
# TODO(chenweihang): add general spmd rule later
infer_spmd_code = ""
infer_spmd_func_code = self.infer_meta['spmd_rule']
infer_spmd_code = INFER_SPMD_TEMPLATE.format(
infer_spmd_func_code, self.input_args_code[:-2]
)
return input_decl_code + infer_spmd_code
else:
return ""
def generate_output_creation_code(self) -> str:
# forward api need to generate api and kernel outputs
output_num = len(self.outputs['types'])
......@@ -311,7 +386,10 @@ class DistForwardAPI(ForwardAPI):
self.dist_output_args.append('dist_out')
self.dense_output_args.append('dense_out')
if self.outputs['types'][0] == 'Tensor':
if self.infer_meta['spmd_rule'] is not None:
output_creation_code += SINGLE_OUT_CREATION_TEMPLATE
else:
output_creation_code += SINGLE_OUT_CREATION_TEMPLATE_NO_SPMD
elif self.outputs['types'][0] == 'std::vector<Tensor>':
output_creation_code += VECTOR_OUT_CREATION_TEMPLATE.format(
self.outputs['out_size_expr'][0]
......@@ -365,9 +443,16 @@ class DistForwardAPI(ForwardAPI):
)
)
else:
if self.infer_meta['spmd_rule'] is not None:
output_creation_code += (
MULTI_SINGLE_OUT_CREATION_TEMPLATE.format(
i, get_out_code, i, i
idx=i, out=get_out_code
)
)
else:
output_creation_code += (
MULTI_SINGLE_OUT_CREATION_TEMPLATE_NO_SPMD.format(
idx=i, out=get_out_code
)
)
else:
......@@ -379,10 +464,9 @@ class DistForwardAPI(ForwardAPI):
return output_creation_code
def generate_infer_spmd_code(self) -> str:
def generate_infer_global_shape_code(self) -> str:
input_names = self.inputs['names']
attr_names = self.attrs['names']
output_names = self.outputs['names']
# 1. get infer meta func name
infer_meta = self.infer_meta
......@@ -399,18 +483,18 @@ class DistForwardAPI(ForwardAPI):
for param in infer_meta_params:
if param in input_names:
if self.inputs['input_info'][param] == "const Tensor&":
input_args_code += SINGLE_DIST_META_IN_TEMPLATE.format(
input_args_code += SINGLE_GLOBAL_META_IN_TEMPLATE.format(
param
)
elif (
self.inputs['input_info'][param]
== "const std::vector<Tensor>&"
):
input_args_code += VECTOR_DIST_META_IN_TEMPLATE.format(
input_args_code += VECTOR_GLOBAL_META_IN_TEMPLATE.format(
param
)
input_meta_code += VECTOR_DIST_META_IN_DECL_TEMPLATE.format(
name=param
input_meta_code += (
VECTOR_GLOBAL_META_IN_DECL_TEMPLATE.format(name=param)
)
else:
raise ValueError(
......@@ -430,7 +514,7 @@ class DistForwardAPI(ForwardAPI):
output_args_code = ""
for i, out_name in enumerate(self.dist_output_args):
if self.outputs['types'][i] == 'std::vector<Tensor>':
output_decl_code += VECTOR_DIST_META_OUT_DECL_TEMPLATE.format(
output_decl_code += VECTOR_GLOBAL_META_OUT_DECL_TEMPLATE.format(
name=out_name
)
if len(self.dense_output_args) == 1:
......@@ -440,7 +524,7 @@ class DistForwardAPI(ForwardAPI):
f"{out_name} ? {out_name}_meta_ptr_vec : nullptr, "
)
else:
output_decl_code += SINGLE_DIST_META_OUT_DECL_TEMPLATE.format(
output_decl_code += SINGLE_GLOBAL_META_OUT_DECL_TEMPLATE.format(
out_name, out_name
)
if len(self.dense_output_args) == 1:
......@@ -451,10 +535,12 @@ class DistForwardAPI(ForwardAPI):
)
output_args_code = output_args_code[:-2]
if self.input_args_code != "":
input_args_code = self.input_args_code
return (
output_decl_code
+ input_meta_code
+ INFER_SPMD_TEMPLATE.format(
+ INFER_GLOBAL_SHAPE_TEMPLATE.format(
infer_meta_func_code, input_args_code, output_args_code
)
)
......@@ -465,7 +551,35 @@ class DistForwardAPI(ForwardAPI):
)
def generate_reshard_input_code(self) -> str:
return INPUT_RESHARD_TEMPLATE.format()
input_reshard_code = ""
if self.infer_meta['spmd_rule'] is not None:
input_names = self.inputs['names']
infer_meta = self.infer_meta
infer_meta_params = (
infer_meta['param']
if infer_meta['param'] is not None
else input_names
)
for i, param in enumerate(infer_meta_params):
if param in input_names:
if self.inputs['input_info'][param] == "const Tensor&":
input_reshard_code += (
SINGLE_INPUT_RESHARD_TEMPLATE.format(
arg=param, idx=i
)
)
else:
raise ValueError(
f"{self.api} : Param of reshard input error : {self.inputs['input_info'][param]} type is not supported."
)
else:
# do nothing
pass
else:
# do nothingd
pass
return input_reshard_code
def generate_single_dense_input(
self,
......@@ -479,13 +593,17 @@ class DistForwardAPI(ForwardAPI):
if kernel_param is None:
kernel_param = input_names + attr_names
if self.infer_meta['spmd_rule'] is not None:
input_tensor_code += SINGLE_PREPARE_DATA_TEMPLATE.format(
input_name,
input_name,
kernel_param.index(input_name),
trans_flag,
input_name,
input_name,
arg=input_name,
idx=kernel_param.index(input_name),
flag=trans_flag,
)
else:
input_tensor_code += SINGLE_PREPARE_DATA_TEMPLATE_NO_RESHARD.format(
arg=input_name,
idx=kernel_param.index(input_name),
flag=trans_flag,
)
return input_tensor_code
......@@ -706,9 +824,6 @@ class DistForwardAPI(ForwardAPI):
", ".join(self.dense_output_args),
)
def generate_reshard_output_code(self) -> str:
return OUTPUT_RESHARD_TEMPLATE.format()
def generate_return_code(self) -> str:
return self.gene_return_code()
......@@ -718,14 +833,14 @@ class DistForwardAPI(ForwardAPI):
return ""
return MAIN_DIST_BRANCH_TEMPLATE.format(
self.generate_if_condition_code(),
self.generate_output_creation_code(),
self.generate_infer_spmd_code(),
self.generate_output_creation_code(),
self.generate_infer_global_shape_code(),
self.generate_kernel_selection_code(),
self.generate_reshard_input_code(),
self.generate_prepare_data_code(),
self.generate_infer_meta_code(),
self.generate_kernel_call_code(),
self.generate_reshard_output_code(),
self.generate_return_code(),
)
......@@ -777,11 +892,14 @@ class DistForwardAPI(ForwardAPI):
# 3. doesn't support view api
# 4. only for general forward and backward
# 5. only support single tensor input and output
# 6. doesn't support double grad and triple grad
dist_branch_code = ""
if (
len(self.inputs['names']) > 0
and len(self.view_map) == 0
and self.check_argument_whether_support_auto_parallel()
and not self.api.endswith("_double_grad")
and not self.api.endswith("_triple_grad")
):
dist_branch_code = self.generate_auto_paralel_branch()
return API_IMPL_TEMPLATE.format(
......
......@@ -22,10 +22,24 @@ from dist_api_gen import DistForwardAPI
# Code Gen Templates #
######################
MAIN_DIST_BRANCH_TEMPLATE = """
// Auto Parallel condition
if ({}) {{
// 1. Create API Output & Prepare Dist and Dense Output{}
// 2. Infer DistTensor's Global Shape{}
// 3. Select Kernel{}
// 4. PrepareData (DataTransform & Prepare Dense Input){}
// 5. Infer Local DenseTensor Meta{}
// 6. DenseTensor Kernel Call{}
// 7. Return
{}
}}
"""
# 1. Create API Outputs
SINGLE_OUT_CREATION_TEMPLATE = """
auto dist_out = SetKernelDistOutput({});
auto dense_out = const_cast<phi::DenseTensor*>(&dist_out->value());
auto dense_out = dist_out->unsafe_mutable_value();
"""
VECTOR_OUT_CREATION_TEMPLATE = """
auto dist_out = SetKernelDistOutput({name});
......@@ -39,7 +53,21 @@ INPLACE_OUT_CREATION_TEMPLATE = """
"""
MULTI_SINGLE_OUT_CREATION_TEMPLATE = """
auto dist_out_{} = SetKernelDistOutput({});
auto dense_out_{} = const_cast<phi::DenseTensor*>(&dist_out_{}->value());
auto dense_out_{} = dist_out_{}->unsafe_mutable_value();
"""
# 2. Infer Global Shape
SINGLE_DIST_META_IN_TEMPLATE = """MakeDistMetaTensor(*{}.impl()), """
SINGLE_DIST_META_OUT_DECL_TEMPLATE = """
phi::distributed::DistMetaTensor meta_{}({});"""
INFER_GLOBAL_SHAPE_TEMPLATE = """
phi::{}({}{});
"""
# 4. PrepareData (DataTransform & Prepare Dist and Dense Input)
SINGLE_PREPARE_DATA_TEMPLATE = """
auto dist_input_{arg} = PrepareDataForDistTensor({arg}, GetKernelInputArgDef(kernel.InputAt({idx}), kernel_backend), {flag}, kernel_result.is_stride_kernel);
auto input_{arg} = &dist_input_{}->value();
"""
......@@ -131,6 +159,21 @@ class DistBackwardAPI(DistForwardAPI, BackwardAPI):
def gene_api_declaration(self) -> str:
return BackwardAPI.gene_api_declaration(self)
def generate_auto_paralel_branch(self) -> str:
# if no tensor input, do not genetate auto parallel branch
if len(self.inputs['names']) == 0:
return ""
return MAIN_DIST_BRANCH_TEMPLATE.format(
self.generate_if_condition_code(),
self.generate_output_creation_code(),
self.generate_infer_global_shape_code(),
self.generate_kernel_selection_code(),
self.generate_prepare_data_code(),
self.generate_infer_meta_code(),
self.generate_kernel_call_code(),
self.generate_return_code(),
)
def header_include():
return """
......
......@@ -651,6 +651,7 @@
output : Tensor
infer_meta :
func : MatmulInferMeta
spmd_rule : MatmulSpmdInferForward
kernel :
func : matmul
backward : matmul_grad
......
......@@ -66,6 +66,15 @@ class DistTensor final
/// \return The DenseTensor value's const reference
const DenseTensor& value() const { return value_; }
/// \brief Returns the mutable dense tensor value in dist tensor.
/// \note If DenseTensor value is modified externally, the corresponding
/// relationship between it and the current tensor's global dims and
/// dist attr may be destroyed, which may introduce some subtle bugs,
/// so you need to make sure to consider it thoroughly when using
/// this method.
/// \return The mutable pointer of DenseTensor value
DenseTensor* unsafe_mutable_value() { return &value_; }
/// \brief Returns the global dims of the dist tensor.
/// \return The global dims of the dist tensor.
const DDim& local_dims() const;
......
......@@ -73,11 +73,6 @@ bool SpmdRuleFactory::ContainsSpmdRule(const std::string& kernel_name) const {
}
int SpmdRuleFactory::InsertSpmdRule(std::string kernel_name, SpmdRule rule) {
PADDLE_ENFORCE_NE(
ContainsSpmdRule(kernel_name),
true,
phi::errors::AlreadyExists(
"`%s` Kernel's Spmd rules has been registered.", kernel_name));
spmd_rule_map_.insert({std::move(kernel_name), std::move(rule)});
return 0;
}
......
......@@ -27,12 +27,12 @@ limitations under the License. */
* 2. Since the infer functions of Spmd forward and backward are closely related
* and need to be registered together, we manage them together in one file.
*
* 3. SPMD rules are much smaller than infermeta function, and we manage files
* in operator units.
* 3. SPMD rules are less than infermeta function, and we manage files by
* operator.
*
* 4. The previous registration used some compile-time regular matching methods,
* which was less flexible, and the registration of SPMD rules here is declare
* directly in the header file
* directly in the header file.
*/
namespace phi {
......
......@@ -88,6 +88,10 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_reshard_r_to_p MODULES test_reshard_r_to_p)
set_tests_properties(test_reshard_r_to_p
PROPERTIES LABELS "RUN_TYPE=EXECLUSIVE" TIMEOUT 100)
py_test_modules(test_semi_auto_parallel_basic MODULES
test_semi_auto_parallel_basic)
set_tests_properties(test_semi_auto_parallel_basic
PROPERTIES LABELS "RUN_TYPE=EXECLUSIVE" TIMEOUT 100)
# End of unittests WITH multi cards and timeout
# NOTE(zyl): unittests WITH multi cards and WITHOUT timeout
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
import paddle
import paddle.distributed as dist
class TestMatmulApiForSemiAutoParallel:
def __init__(self):
self._dtype = os.getenv("dtype")
self._seeds = eval(os.getenv("seeds"))
self._backend = os.getenv("backend")
self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
def test_body(self, x_specs, y_specs):
x_shape = [64, 32]
y_shape = [32, 48]
x = paddle.randn(x_shape, self._dtype)
y = paddle.randn(y_shape, self._dtype)
x_dist_attr = dist.DistAttr(mesh=self._mesh, sharding_specs=x_specs)
y_dist_attr = dist.DistAttr(mesh=self._mesh, sharding_specs=y_specs)
dist_x = dist.shard_tensor(x, dist_attr=x_dist_attr)
dist_y = dist.shard_tensor(y, dist_attr=y_dist_attr)
dist_out = paddle.matmul(dist_x, dist_y)
# verify global shape
out_shape = [64, 48]
np.testing.assert_equal(dist_out.shape, out_shape, verbose=True)
return dist_out
def test_case1(self):
# case1: mk[0,-1],kn[-1,-1] -> mk[0,-1],kn[-1,-1] = mn[0,-1] partial[]
dist_out = self.test_body(x_specs=['x', None], y_specs=[None, None])
# verify local shape and dist attr
np.testing.assert_equal(dist_out._local_shape, [32, 48], verbose=True)
np.testing.assert_equal(
dist_out.dist_attr.dims_mapping, [0, -1], verbose=True
)
assert dist_out.dist_attr._is_partial() is False
def test_case2(self):
# case2: mk[-1, 0],kn[-1,-1] --> mk[-1, 0],kn[0, -1] = nm[-1, -1] partial[0]
dist_out = self.test_body(x_specs=[None, 'x'], y_specs=[None, None])
# verify local shape
np.testing.assert_equal(dist_out._local_shape, [64, 48], verbose=True)
np.testing.assert_equal(
dist_out.dist_attr.dims_mapping, [-1, -1], verbose=True
)
assert dist_out.dist_attr._is_partial() is True
assert dist_out.dist_attr._partial_dims() == {0}
def run_test_case(self):
if self._backend == "cpu":
paddle.set_device("cpu")
elif self._backend == "gpu":
paddle.set_device("gpu:" + str(dist.get_rank()))
else:
raise ValueError("Only support cpu or gpu backend.")
self.test_case1()
self.test_case2()
if __name__ == '__main__':
TestMatmulApiForSemiAutoParallel().run_test_case()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import collective.test_communication_api_base as test_base
class TestSemiAutoParallelMatmul(test_base.CommunicationTestDistBase):
def setUp(self):
super().setUp(num_of_devices=2, timeout=120)
self._default_envs = {
"dtype": "float32",
"seeds": str(self._seeds),
}
self._changeable_envs = {"backend": ["cpu", "gpu"]}
def test_matmul_api(self):
envs_list = test_base.gen_product_envs_list(
self._default_envs, self._changeable_envs
)
for envs in envs_list:
self.run_test_case(
"semi_auto_parallel_for_matmul.py",
user_defined_envs=envs,
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册