未验证 提交 e9364a38 编写于 作者: C Chen Weihang 提交者: GitHub

[AutoParallel] Generate spmd rule and reshard impl in phi api (#56831)

* add spmd and reshard code gen

* add backward reshard code gen

* test matmul forward success

* polish test impl

* add unsafe mutable value

* polish details and add test

* fix unittest time out

* fix typo

* refactor reshard input generate impl

* resolve conflict with develop

* fix compile error
上级 3b5bb4ba
...@@ -366,7 +366,7 @@ def check_op_config(op_entry, op_name): ...@@ -366,7 +366,7 @@ def check_op_config(op_entry, op_name):
'composite', 'composite',
'support_dygraph_mode', 'support_dygraph_mode',
) )
infer_meta_key_set = ('func', 'param') infer_meta_key_set = ('func', 'param', 'spmd_rule')
kernel_key_set = ( kernel_key_set = (
'func', 'func',
'param', 'param',
......
...@@ -22,6 +22,7 @@ PHI_DECLARE_bool(use_stride_kernel); ...@@ -22,6 +22,7 @@ PHI_DECLARE_bool(use_stride_kernel);
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" #include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
namespace paddle { namespace paddle {
...@@ -530,13 +531,18 @@ void TransStride(phi::DeviceContext* dev_ctx, ...@@ -530,13 +531,18 @@ void TransStride(phi::DeviceContext* dev_ctx,
/* ------------------ for auto parallel ----------------------- */ /* ------------------ for auto parallel ----------------------- */
phi::distributed::DistTensor* SetKernelDistOutput(Tensor* out) { phi::distributed::DistMetaTensor MakeDistMetaTensor(
const phi::TensorBase& tensor) {
return phi::distributed::DistMetaTensor(tensor);
}
phi::distributed::DistTensor* SetKernelDistOutput(
Tensor* out, const phi::distributed::TensorDistAttr& dist_attr) {
if (out) { if (out) {
// TODO(chenweihang): now all dist case are nullptr // TODO(chenweihang): now all dist case are nullptr
if (out->impl() == nullptr) { if (out->impl() == nullptr) {
// TODO(chenweihang): polish code, dist_attr is null now auto dist_t = std::make_shared<phi::distributed::DistTensor>(phi::DDim(),
auto dist_t = std::make_shared<phi::distributed::DistTensor>( dist_attr);
phi::DDim(), phi::distributed::TensorDistAttr());
out->set_impl(dist_t); out->set_impl(dist_t);
} }
return static_cast<phi::distributed::DistTensor*>(out->impl().get()); return static_cast<phi::distributed::DistTensor*>(out->impl().get());
......
...@@ -18,18 +18,15 @@ limitations under the License. */ ...@@ -18,18 +18,15 @@ limitations under the License. */
#include "paddle/phi/backends/all_context.h" #include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/core/meta_tensor.h" #include "paddle/phi/core/meta_tensor.h"
#include "paddle/phi/core/selected_rows.h" #include "paddle/phi/core/selected_rows.h"
#include "paddle/phi/core/sparse_coo_tensor.h" #include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h" #include "paddle/phi/core/sparse_csr_tensor.h"
#include "paddle/phi/core/string_tensor.h" #include "paddle/phi/core/string_tensor.h"
namespace phi {
namespace distributed {
class DistTensor;
} // namespace distributed
} // namespace phi
namespace paddle { namespace paddle {
namespace experimental { namespace experimental {
...@@ -139,9 +136,17 @@ void TransStrideLegacy(phi::DeviceContext* dev_ctx, ...@@ -139,9 +136,17 @@ void TransStrideLegacy(phi::DeviceContext* dev_ctx,
/* ------------------ for auto parallel ----------------------- */ /* ------------------ for auto parallel ----------------------- */
phi::distributed::DistTensor* SetKernelDistOutput(Tensor* out); phi::distributed::DistMetaTensor MakeDistMetaTensor(
const phi::TensorBase& tensor);
phi::distributed::DistTensor* SetKernelDistOutput(
Tensor* out,
const phi::distributed::TensorDistAttr& dist_attr =
phi::distributed::TensorDistAttr());
std::vector<phi::distributed::DistTensor*> SetKernelDistOutput( std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
std::vector<Tensor*> out); std::vector<Tensor*> out);
std::vector<phi::distributed::DistTensor*> SetKernelDistOutput( std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
size_t out_size, std::vector<Tensor>* out); size_t out_size, std::vector<Tensor>* out);
......
...@@ -21,6 +21,8 @@ limitations under the License. */ ...@@ -21,6 +21,8 @@ limitations under the License. */
#include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/backends/context_pool.h" #include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h"
#include "paddle/phi/core/flags.h" #include "paddle/phi/core/flags.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
...@@ -597,15 +599,45 @@ void TransDataBackend(const phi::SelectedRows* tensor, ...@@ -597,15 +599,45 @@ void TransDataBackend(const phi::SelectedRows* tensor,
/* ------------------ for auto parallel ----------------------- */ /* ------------------ for auto parallel ----------------------- */
std::shared_ptr<phi::distributed::DistTensor> ReshardDistTensor(
phi::DeviceContext* dev_ctx,
const Tensor& tensor,
const phi::distributed::TensorDistAttr& dist_attr) {
auto tensor_in = tensor.impl();
if (tensor_in) {
phi::distributed::DistTensor* dist_tensor =
static_cast<phi::distributed::DistTensor*>(tensor_in.get());
if (dist_tensor->dist_attr() != dist_attr) {
VLOG(6) << "Reshard tensor from " << dist_tensor->dist_attr() << " to "
<< dist_attr;
auto* func = phi::distributed::ChooseProperReshardFunction(*dist_tensor,
dist_attr);
return func->Eval(dev_ctx, *dist_tensor, dist_attr);
}
return std::static_pointer_cast<phi::distributed::DistTensor>(tensor_in);
}
return nullptr;
}
std::shared_ptr<phi::distributed::DistTensor> PrepareDataForDistTensor( std::shared_ptr<phi::distributed::DistTensor> PrepareDataForDistTensor(
const Tensor& input, const Tensor& input,
const phi::TensorArgDef& target_args_def, const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag, const TransformFlag& transform_flag,
bool is_stride_kernel) { bool is_stride_kernel) {
const auto& tensor_in = input.impl(); return PrepareDataForDistTensor(
if (tensor_in) { std::static_pointer_cast<phi::distributed::DistTensor>(input.impl()),
phi::distributed::DistTensor* dist_tensor = target_args_def,
static_cast<phi::distributed::DistTensor*>(tensor_in.get()); transform_flag,
is_stride_kernel);
}
std::shared_ptr<phi::distributed::DistTensor> PrepareDataForDistTensor(
const std::shared_ptr<phi::distributed::DistTensor>& input,
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag,
bool is_stride_kernel) {
if (input) {
phi::distributed::DistTensor* dist_tensor = input.get();
const phi::DenseTensor& dense_tensor = dist_tensor->value(); const phi::DenseTensor& dense_tensor = dist_tensor->value();
if (!transform_flag.NeedTransform() || !dense_tensor.initialized() || if (!transform_flag.NeedTransform() || !dense_tensor.initialized() ||
(!NeedTransformPlace( (!NeedTransformPlace(
...@@ -618,16 +650,18 @@ std::shared_ptr<phi::distributed::DistTensor> PrepareDataForDistTensor( ...@@ -618,16 +650,18 @@ std::shared_ptr<phi::distributed::DistTensor> PrepareDataForDistTensor(
transform_flag) && transform_flag) &&
!NeedTransform2Contiguous(is_stride_kernel, !NeedTransform2Contiguous(is_stride_kernel,
dense_tensor.meta().is_contiguous()))) { dense_tensor.meta().is_contiguous()))) {
return std::static_pointer_cast<phi::distributed::DistTensor>(tensor_in); return input;
} }
phi::DenseTensor out = TransformData(
dense_tensor, target_args_def, transform_flag, is_stride_kernel);
// TODO(chenweihang): The global meta in DistTensor is not changed, // TODO(chenweihang): The global meta in DistTensor is not changed,
// but the local meta in DenseTensor maybe changed, such as layout // but the local meta in DenseTensor maybe changed, such as layout
// change(NCHW->NHWC), so the new DistTensor's meta maybe not unified. // change(NCHW->NHWC), so the new DistTensor's meta maybe not unified.
VLOG(6) << "PrepareDataForDistTensor return transformed dist tensor"; VLOG(6) << "PrepareDataForDistTensor return transformed dist tensor";
return std::make_shared<phi::distributed::DistTensor>( auto dist_out = std::make_shared<phi::distributed::DistTensor>(
out, dist_tensor->dist_attr()); dist_tensor->dims(), dist_tensor->dist_attr());
auto* out = dist_out->unsafe_mutable_value();
*out = TransformData(
dense_tensor, target_args_def, transform_flag, is_stride_kernel);
return dist_out;
} }
return nullptr; return nullptr;
} }
......
...@@ -21,8 +21,10 @@ limitations under the License. */ ...@@ -21,8 +21,10 @@ limitations under the License. */
#include "paddle/phi/core/sparse_csr_tensor.h" #include "paddle/phi/core/sparse_csr_tensor.h"
namespace phi { namespace phi {
class DeviceContext;
namespace distributed { namespace distributed {
class DistTensor; class DistTensor;
class TensorDistAttr;
} // namespace distributed } // namespace distributed
} // namespace phi } // namespace phi
...@@ -173,13 +175,23 @@ inline bool NeedTransformPlace(const phi::Place& src_place, ...@@ -173,13 +175,23 @@ inline bool NeedTransformPlace(const phi::Place& src_place,
/* ------------------ for auto parallel ----------------------- */ /* ------------------ for auto parallel ----------------------- */
// TODO(chenweihang): impl Reshard input and output function std::shared_ptr<phi::distributed::DistTensor> ReshardDistTensor(
phi::DeviceContext* dev_ctx,
const Tensor& tensor,
const phi::distributed::TensorDistAttr& dist_attr);
std::shared_ptr<phi::distributed::DistTensor> PrepareDataForDistTensor( std::shared_ptr<phi::distributed::DistTensor> PrepareDataForDistTensor(
const Tensor& input, const Tensor& input,
const phi::TensorArgDef& target_args_def, const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag, const TransformFlag& transform_flag,
bool is_stride_kernel); bool is_stride_kernel);
std::shared_ptr<phi::distributed::DistTensor> PrepareDataForDistTensor(
const std::shared_ptr<phi::distributed::DistTensor>& input,
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag,
bool is_stride_kernel);
std::vector<std::shared_ptr<phi::distributed::DistTensor>> std::vector<std::shared_ptr<phi::distributed::DistTensor>>
PrepareDataForDistTensor(const std::vector<Tensor>& input, PrepareDataForDistTensor(const std::vector<Tensor>& input,
const phi::TensorArgDef& target_args_def, const phi::TensorArgDef& target_args_def,
......
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/phi/api/include/context_pool.h" #include "paddle/phi/api/include/context_pool.h"
#include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/core/string_tensor_utils.h" #include "paddle/phi/core/string_tensor_utils.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
#ifdef PADDLE_WITH_CUSTOM_DEVICE #ifdef PADDLE_WITH_CUSTOM_DEVICE
...@@ -50,6 +51,8 @@ bool HasAllocation(const phi::TensorBase& t) { ...@@ -50,6 +51,8 @@ bool HasAllocation(const phi::TensorBase& t) {
} else if (phi::StringTensor::classof(&t)) { } else if (phi::StringTensor::classof(&t)) {
return phi::StringTensorUtils::GetHolder( return phi::StringTensorUtils::GetHolder(
static_cast<const phi::StringTensor&>(t)) != nullptr; static_cast<const phi::StringTensor&>(t)) != nullptr;
} else if (phi::distributed::DistTensor::classof(&t)) {
return static_cast<const phi::distributed::DistTensor&>(t).defined();
} else { } else {
return false; return false;
} }
......
...@@ -96,8 +96,6 @@ struct KernelKeyParser : ArgsIterator<KernelKeyParser> { ...@@ -96,8 +96,6 @@ struct KernelKeyParser : ArgsIterator<KernelKeyParser> {
// data_promote // data_promote
DataTypeSet dtype_set{DataType::UNDEFINED}; DataTypeSet dtype_set{DataType::UNDEFINED};
// TODO(chenweihang): deal with multiple diff input Tensors
// TODO(chenweihang): add global device guard method to set backend
inline void AssignKernelKeySet(const phi::TensorBase& tensor) { inline void AssignKernelKeySet(const phi::TensorBase& tensor) {
// assign Backend // assign Backend
BackendSet tensor_backend_set = detail::GetTensorBackendSet(tensor); BackendSet tensor_backend_set = detail::GetTensorBackendSet(tensor);
......
...@@ -379,6 +379,10 @@ def source_include(header_file_path): ...@@ -379,6 +379,10 @@ def source_include(header_file_path):
#include "paddle/phi/api/profiler/event_tracing.h" #include "paddle/phi/api/profiler/event_tracing.h"
#include "paddle/phi/api/profiler/supplement_tracing.h" #include "paddle/phi/api/profiler/supplement_tracing.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/phi/infermeta/spmd_rules/rules.h"
#endif
PD_DECLARE_bool(conv2d_disable_cudnn); PD_DECLARE_bool(conv2d_disable_cudnn);
PD_DECLARE_int32(low_precision_op_list); PD_DECLARE_int32(low_precision_op_list);
""" """
......
...@@ -44,14 +44,14 @@ PADDLE_THROW(phi::errors::Unimplemented( ...@@ -44,14 +44,14 @@ PADDLE_THROW(phi::errors::Unimplemented(
MAIN_DIST_BRANCH_TEMPLATE = """ MAIN_DIST_BRANCH_TEMPLATE = """
// Auto Parallel condition // Auto Parallel condition
if ({}) {{ if ({}) {{
// 1. Create API Output & Prepare Dist and Dense Output{} // 1. InferSpmd (Infer DistAttr of Inputs&Outputs){}
// 2. InferSPMD (Infer Global Shape and DistAttr of Inputs&Outputs){} // 2. Create API Output & Prepare Dist and Dense Output{}
// 3. Select Kernel{} // 3. Infer DistTensor's Global Shape{}
// 4. Reshard Input{} // 4. Select Kernel{}
// 5. PrepareData (DataTransform & Prepare Dist and Dense Input){} // 5. Reshard Input{}\n
// 6. Infer Local DenseTensor Meta{} // 6. PrepareData (DataTransform & Prepare Dense Input){}
// 7. DenseTensor Kernel Call{} // 7. Infer Local DenseTensor Meta{}
// 8. Reshard Output{} // 8. DenseTensor Kernel Call{}
// 9. Return // 9. Return
{} {}
}} }}
...@@ -60,20 +60,35 @@ MAIN_DIST_BRANCH_TEMPLATE = """ ...@@ -60,20 +60,35 @@ MAIN_DIST_BRANCH_TEMPLATE = """
# Auto Parallel condition # Auto Parallel condition
AUTO_PARALLEL_COND_TEMPLATE = """AllInputsAreDistTensor({})""" AUTO_PARALLEL_COND_TEMPLATE = """AllInputsAreDistTensor({})"""
# 1. Create API Outputs # 1. InferSPMD
SINGLE_DIST_META_IN_TEMPLATE = """
auto meta_dist_{} = MakeDistMetaTensor(*{}.impl());"""
INFER_SPMD_TEMPLATE = """
auto spmd_info = phi::distributed::{}({});
"""
# 2. Create API Outputs
API_OUT_CREATION_TEMPLATE = """ API_OUT_CREATION_TEMPLATE = """
{} api_output{}; {} api_output{};
""" """
INPLACE_API_OUT_CREATION_TEMPLATE = """ INPLACE_API_OUT_CREATION_TEMPLATE = """
{} api_output{{{}}}; {} api_output{{{}}};
""" """
SINGLE_OUT_CREATION_TEMPLATE = """ SINGLE_OUT_CREATION_TEMPLATE_NO_SPMD = """
auto dist_out = SetKernelDistOutput(&api_output); auto dist_out = SetKernelDistOutput(&api_output);
auto dense_out = const_cast<phi::DenseTensor*>(&dist_out->value()); auto dense_out = dist_out->unsafe_mutable_value();
"""
MULTI_SINGLE_OUT_CREATION_TEMPLATE_NO_SPMD = """
auto dist_out_{idx} = SetKernelDistOutput({out});
auto dense_out_{idx} = dist_out_{idx}->unsafe_mutable_value();
"""
SINGLE_OUT_CREATION_TEMPLATE = """
auto dist_out = SetKernelDistOutput(&api_output, spmd_info.second[0]);
auto dense_out = dist_out->unsafe_mutable_value();
""" """
MULTI_SINGLE_OUT_CREATION_TEMPLATE = """ MULTI_SINGLE_OUT_CREATION_TEMPLATE = """
auto dist_out_{} = SetKernelDistOutput({}); auto dist_out_{idx} = SetKernelDistOutput({out}, spmd_info.second[{idx}]);
auto dense_out_{} = const_cast<phi::DenseTensor*>(&dist_out_{}->value()); auto dense_out_{idx} = dist_out_{idx}->unsafe_mutable_value();
""" """
VECTOR_OUT_CREATION_TEMPLATE = """ VECTOR_OUT_CREATION_TEMPLATE = """
auto dist_out = SetKernelDistOutput({}, &api_output); auto dist_out = SetKernelDistOutput({}, &api_output);
...@@ -93,12 +108,12 @@ MULTI_VECTOR_OUT_CREATION_TEMPLATE = """ ...@@ -93,12 +108,12 @@ MULTI_VECTOR_OUT_CREATION_TEMPLATE = """
TUPLE_OUT_CREATION_TEMPLATE = """ TUPLE_OUT_CREATION_TEMPLATE = """
""" """
# 2. InferSPMD # 3. Infer Global Shape
# Call InferMeta now, replace by InferSPMD function later # TODO(chenweihang): the input MetaTensor created by Inferspmd can be reused
# TODO(chenweihang): InferSPMD function design # for InferGlobalShape to avoid creating repeated inputs.
SINGLE_DIST_META_IN_TEMPLATE = """MakeMetaTensor(*{}.impl()), """ SINGLE_GLOBAL_META_IN_TEMPLATE = """MakeMetaTensor(*{}.impl()), """
VECTOR_DIST_META_IN_TEMPLATE = """{}_meta_ptr_vec, """ VECTOR_GLOBAL_META_IN_TEMPLATE = """{}_meta_ptr_vec, """
VECTOR_DIST_META_IN_DECL_TEMPLATE = """ VECTOR_GLOBAL_META_IN_DECL_TEMPLATE = """
std::vector<phi::MetaTensor> {name}_meta_vec; std::vector<phi::MetaTensor> {name}_meta_vec;
for (auto tmp : {name}) {{ for (auto tmp : {name}) {{
{name}_meta_vec.emplace_back(MakeMetaTensor(*tmp.impl())); {name}_meta_vec.emplace_back(MakeMetaTensor(*tmp.impl()));
...@@ -109,11 +124,11 @@ VECTOR_DIST_META_IN_DECL_TEMPLATE = """ ...@@ -109,11 +124,11 @@ VECTOR_DIST_META_IN_DECL_TEMPLATE = """
}} }}
""" """
# TODO(GhostScreaming): support optional args later # TODO(GhostScreaming): support optional args later
OPTIONAL_DIST_VECTOR_META_IN_TEMPLATE = """ OPTIONAL_GLOBAL_VECTOR_META_IN_TEMPLATE = """
""" """
SINGLE_DIST_META_OUT_DECL_TEMPLATE = """ SINGLE_GLOBAL_META_OUT_DECL_TEMPLATE = """
phi::MetaTensor meta_{}({});""" phi::MetaTensor meta_{}({});"""
VECTOR_DIST_META_OUT_DECL_TEMPLATE = """ VECTOR_GLOBAL_META_OUT_DECL_TEMPLATE = """
std::vector<phi::MetaTensor> {name}_meta_vec; std::vector<phi::MetaTensor> {name}_meta_vec;
for (auto tmp : {name}) {{ for (auto tmp : {name}) {{
{name}_meta_vec.emplace_back(phi::MetaTensor(tmp)); {name}_meta_vec.emplace_back(phi::MetaTensor(tmp));
...@@ -123,11 +138,11 @@ VECTOR_DIST_META_OUT_DECL_TEMPLATE = """ ...@@ -123,11 +138,11 @@ VECTOR_DIST_META_OUT_DECL_TEMPLATE = """
{name}_meta_ptr_vec[i] = &{name}_meta_vec[i]; {name}_meta_ptr_vec[i] = &{name}_meta_vec[i];
}} }}
""" """
INFER_SPMD_TEMPLATE = """ INFER_GLOBAL_SHAPE_TEMPLATE = """
phi::{}({}{}); phi::{}({}{});
""" """
# 3. Select Kernel # 4. Select Kernel
KERNEL_SELECTION_TEMPLATE = """ KERNEL_SELECTION_TEMPLATE = """
VLOG(6) << "{} API dist branch: kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]"; VLOG(6) << "{} API dist branch: kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]";
auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError( auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
...@@ -137,14 +152,18 @@ KERNEL_SELECTION_TEMPLATE = """ ...@@ -137,14 +152,18 @@ KERNEL_SELECTION_TEMPLATE = """
auto* dev_ctx = GetDeviceContextByBackend(kernel_result.has_fallback_cpu ? Backend::CPU : kernel_backend); auto* dev_ctx = GetDeviceContextByBackend(kernel_result.has_fallback_cpu ? Backend::CPU : kernel_backend);
""" """
# 4. Reshard Input # 5. Reshard Input
INPUT_RESHARD_TEMPLATE = """ SINGLE_INPUT_RESHARD_TEMPLATE = """
""" auto dist_input_{arg} = ReshardDistTensor(dev_ctx, {arg}, spmd_info.first[{idx}]);"""
# 5. PrepareData # 6. PrepareData
SINGLE_PREPARE_DATA_TEMPLATE = """ SINGLE_PREPARE_DATA_TEMPLATE = """
auto dist_input_{} = PrepareDataForDistTensor({}, GetKernelInputArgDef(kernel.InputAt({}), kernel_backend), {}, kernel_result.is_stride_kernel); dist_input_{arg} = PrepareDataForDistTensor(dist_input_{arg}, GetKernelInputArgDef(kernel.InputAt({idx}), kernel_backend), {flag}, kernel_result.is_stride_kernel);
auto input_{} = &dist_input_{}->value(); auto input_{arg} = &dist_input_{arg}->value();
"""
SINGLE_PREPARE_DATA_TEMPLATE_NO_RESHARD = """
auto dist_input_{arg} = PrepareDataForDistTensor({arg}, GetKernelInputArgDef(kernel.InputAt({idx}), kernel_backend), {flag}, kernel_result.is_stride_kernel);
auto input_{arg} = &dist_input_{arg}->value();
""" """
VECTOR_PREPARE_DATA_TEMPLATE = """ VECTOR_PREPARE_DATA_TEMPLATE = """
auto dist_input_{name}_vec = PrepareDataForDistTensor({name}, GetKernelInputArgDef(kernel.InputAt({index}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); auto dist_input_{name}_vec = PrepareDataForDistTensor({name}, GetKernelInputArgDef(kernel.InputAt({index}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel);
...@@ -170,7 +189,7 @@ INFER_META_VECTOR_INPUT_TEMPLATE = """ ...@@ -170,7 +189,7 @@ INFER_META_VECTOR_INPUT_TEMPLATE = """
const auto& input_{} = *input_{}_uq_ptr; const auto& input_{} = *input_{}_uq_ptr;
""" """
# 6. Infer Local DenseTensor Meta # 7. Infer Local DenseTensor Meta
SINGLE_META_IN_TEMPLATE = """MakeMetaTensor(*input_{}), """ SINGLE_META_IN_TEMPLATE = """MakeMetaTensor(*input_{}), """
# TODO(GhostScreaming): support optional args later # TODO(GhostScreaming): support optional args later
VECTOR_META_IN_TEMPLATE = """dense_input_{}_meta_ptr_vec, """ VECTOR_META_IN_TEMPLATE = """dense_input_{}_meta_ptr_vec, """
...@@ -189,7 +208,7 @@ INFER_META_TEMPLATE = """ ...@@ -189,7 +208,7 @@ INFER_META_TEMPLATE = """
phi::{}({}{}); phi::{}({}{});
""" """
# 7. DenseTensor Kernel Call # 8. DenseTensor Kernel Call
# TODO(chenweihang): support kernel fallback later # TODO(chenweihang): support kernel fallback later
SINGLE_OUTPUT_NAME = """dense_out""" SINGLE_OUTPUT_NAME = """dense_out"""
# TODO(chenweihang): support vector and tuple output later # TODO(chenweihang): support vector and tuple output later
...@@ -205,10 +224,6 @@ KERNEL_CALL_TEMPLATE = """ ...@@ -205,10 +224,6 @@ KERNEL_CALL_TEMPLATE = """
PREFIX_VECTOR_TENSOR_NAME = "dense_input_" PREFIX_VECTOR_TENSOR_NAME = "dense_input_"
SUFFIX_VECTOR_TENSOR_NAME = "_vec" SUFFIX_VECTOR_TENSOR_NAME = "_vec"
# 8. Reshard Output
OUTPUT_RESHARD_TEMPLATE = """
"""
# BaseAPI members: # BaseAPI members:
# inputs: # inputs:
# names : [], list of input names # names : [], list of input names
...@@ -252,6 +267,17 @@ class DistForwardAPI(ForwardAPI): ...@@ -252,6 +267,17 @@ class DistForwardAPI(ForwardAPI):
self.inplace_flag = False self.inplace_flag = False
self.dist_output_args = [] self.dist_output_args = []
self.dense_output_args = [] self.dense_output_args = []
self.input_args_code = ""
# override BaseAPI's method
def parse_infer_meta(self, infer_meta_config):
infer_meta = infer_meta_config
if 'param' not in infer_meta_config:
infer_meta['param'] = None
if 'spmd_rule' not in infer_meta_config:
infer_meta['spmd_rule'] = None
return infer_meta
def need_to_generate_code_for_inplace_impl(self, i): def need_to_generate_code_for_inplace_impl(self, i):
return ( return (
...@@ -289,6 +315,55 @@ class DistForwardAPI(ForwardAPI): ...@@ -289,6 +315,55 @@ class DistForwardAPI(ForwardAPI):
input_args = input_args[:-2] input_args = input_args[:-2]
return AUTO_PARALLEL_COND_TEMPLATE.format(input_args) return AUTO_PARALLEL_COND_TEMPLATE.format(input_args)
def generate_infer_spmd_code(self) -> str:
if self.infer_meta['spmd_rule'] is not None:
input_names = self.inputs['names']
attr_names = self.attrs['names']
infer_meta_params = (
self.infer_meta['param']
if self.infer_meta['param'] is not None
else input_names + attr_names
)
input_decl_code = ""
self.input_args_code = ""
for param in infer_meta_params:
if param in input_names:
if self.inputs['input_info'][param] == "const Tensor&":
input_decl_code += SINGLE_DIST_META_IN_TEMPLATE.format(
param, param
)
self.input_args_code += "meta_dist_" + param + ", "
else:
raise ValueError(
f"{self.api} : Param of infer_spmd error : {self.inputs['input_info'][param]} type is not supported."
)
elif param in attr_names:
self.input_args_code = self.input_args_code + param + ", "
elif isinstance(param, str):
self.input_args_code = (
self.input_args_code + "\"" + param + "\", "
)
elif isinstance(param, bool):
self.input_args_code = (
self.input_args_code + str(param).lower() + ", "
)
else:
self.input_args_code = (
self.input_args_code + str(param) + ", "
)
# TODO(chenweihang): add general spmd rule later
infer_spmd_code = ""
infer_spmd_func_code = self.infer_meta['spmd_rule']
infer_spmd_code = INFER_SPMD_TEMPLATE.format(
infer_spmd_func_code, self.input_args_code[:-2]
)
return input_decl_code + infer_spmd_code
else:
return ""
def generate_output_creation_code(self) -> str: def generate_output_creation_code(self) -> str:
# forward api need to generate api and kernel outputs # forward api need to generate api and kernel outputs
output_num = len(self.outputs['types']) output_num = len(self.outputs['types'])
...@@ -311,7 +386,10 @@ class DistForwardAPI(ForwardAPI): ...@@ -311,7 +386,10 @@ class DistForwardAPI(ForwardAPI):
self.dist_output_args.append('dist_out') self.dist_output_args.append('dist_out')
self.dense_output_args.append('dense_out') self.dense_output_args.append('dense_out')
if self.outputs['types'][0] == 'Tensor': if self.outputs['types'][0] == 'Tensor':
if self.infer_meta['spmd_rule'] is not None:
output_creation_code += SINGLE_OUT_CREATION_TEMPLATE output_creation_code += SINGLE_OUT_CREATION_TEMPLATE
else:
output_creation_code += SINGLE_OUT_CREATION_TEMPLATE_NO_SPMD
elif self.outputs['types'][0] == 'std::vector<Tensor>': elif self.outputs['types'][0] == 'std::vector<Tensor>':
output_creation_code += VECTOR_OUT_CREATION_TEMPLATE.format( output_creation_code += VECTOR_OUT_CREATION_TEMPLATE.format(
self.outputs['out_size_expr'][0] self.outputs['out_size_expr'][0]
...@@ -365,9 +443,16 @@ class DistForwardAPI(ForwardAPI): ...@@ -365,9 +443,16 @@ class DistForwardAPI(ForwardAPI):
) )
) )
else: else:
if self.infer_meta['spmd_rule'] is not None:
output_creation_code += ( output_creation_code += (
MULTI_SINGLE_OUT_CREATION_TEMPLATE.format( MULTI_SINGLE_OUT_CREATION_TEMPLATE.format(
i, get_out_code, i, i idx=i, out=get_out_code
)
)
else:
output_creation_code += (
MULTI_SINGLE_OUT_CREATION_TEMPLATE_NO_SPMD.format(
idx=i, out=get_out_code
) )
) )
else: else:
...@@ -379,10 +464,9 @@ class DistForwardAPI(ForwardAPI): ...@@ -379,10 +464,9 @@ class DistForwardAPI(ForwardAPI):
return output_creation_code return output_creation_code
def generate_infer_spmd_code(self) -> str: def generate_infer_global_shape_code(self) -> str:
input_names = self.inputs['names'] input_names = self.inputs['names']
attr_names = self.attrs['names'] attr_names = self.attrs['names']
output_names = self.outputs['names']
# 1. get infer meta func name # 1. get infer meta func name
infer_meta = self.infer_meta infer_meta = self.infer_meta
...@@ -399,18 +483,18 @@ class DistForwardAPI(ForwardAPI): ...@@ -399,18 +483,18 @@ class DistForwardAPI(ForwardAPI):
for param in infer_meta_params: for param in infer_meta_params:
if param in input_names: if param in input_names:
if self.inputs['input_info'][param] == "const Tensor&": if self.inputs['input_info'][param] == "const Tensor&":
input_args_code += SINGLE_DIST_META_IN_TEMPLATE.format( input_args_code += SINGLE_GLOBAL_META_IN_TEMPLATE.format(
param param
) )
elif ( elif (
self.inputs['input_info'][param] self.inputs['input_info'][param]
== "const std::vector<Tensor>&" == "const std::vector<Tensor>&"
): ):
input_args_code += VECTOR_DIST_META_IN_TEMPLATE.format( input_args_code += VECTOR_GLOBAL_META_IN_TEMPLATE.format(
param param
) )
input_meta_code += VECTOR_DIST_META_IN_DECL_TEMPLATE.format( input_meta_code += (
name=param VECTOR_GLOBAL_META_IN_DECL_TEMPLATE.format(name=param)
) )
else: else:
raise ValueError( raise ValueError(
...@@ -430,7 +514,7 @@ class DistForwardAPI(ForwardAPI): ...@@ -430,7 +514,7 @@ class DistForwardAPI(ForwardAPI):
output_args_code = "" output_args_code = ""
for i, out_name in enumerate(self.dist_output_args): for i, out_name in enumerate(self.dist_output_args):
if self.outputs['types'][i] == 'std::vector<Tensor>': if self.outputs['types'][i] == 'std::vector<Tensor>':
output_decl_code += VECTOR_DIST_META_OUT_DECL_TEMPLATE.format( output_decl_code += VECTOR_GLOBAL_META_OUT_DECL_TEMPLATE.format(
name=out_name name=out_name
) )
if len(self.dense_output_args) == 1: if len(self.dense_output_args) == 1:
...@@ -440,7 +524,7 @@ class DistForwardAPI(ForwardAPI): ...@@ -440,7 +524,7 @@ class DistForwardAPI(ForwardAPI):
f"{out_name} ? {out_name}_meta_ptr_vec : nullptr, " f"{out_name} ? {out_name}_meta_ptr_vec : nullptr, "
) )
else: else:
output_decl_code += SINGLE_DIST_META_OUT_DECL_TEMPLATE.format( output_decl_code += SINGLE_GLOBAL_META_OUT_DECL_TEMPLATE.format(
out_name, out_name out_name, out_name
) )
if len(self.dense_output_args) == 1: if len(self.dense_output_args) == 1:
...@@ -451,10 +535,12 @@ class DistForwardAPI(ForwardAPI): ...@@ -451,10 +535,12 @@ class DistForwardAPI(ForwardAPI):
) )
output_args_code = output_args_code[:-2] output_args_code = output_args_code[:-2]
if self.input_args_code != "":
input_args_code = self.input_args_code
return ( return (
output_decl_code output_decl_code
+ input_meta_code + input_meta_code
+ INFER_SPMD_TEMPLATE.format( + INFER_GLOBAL_SHAPE_TEMPLATE.format(
infer_meta_func_code, input_args_code, output_args_code infer_meta_func_code, input_args_code, output_args_code
) )
) )
...@@ -465,7 +551,35 @@ class DistForwardAPI(ForwardAPI): ...@@ -465,7 +551,35 @@ class DistForwardAPI(ForwardAPI):
) )
def generate_reshard_input_code(self) -> str: def generate_reshard_input_code(self) -> str:
return INPUT_RESHARD_TEMPLATE.format() input_reshard_code = ""
if self.infer_meta['spmd_rule'] is not None:
input_names = self.inputs['names']
infer_meta = self.infer_meta
infer_meta_params = (
infer_meta['param']
if infer_meta['param'] is not None
else input_names
)
for i, param in enumerate(infer_meta_params):
if param in input_names:
if self.inputs['input_info'][param] == "const Tensor&":
input_reshard_code += (
SINGLE_INPUT_RESHARD_TEMPLATE.format(
arg=param, idx=i
)
)
else:
raise ValueError(
f"{self.api} : Param of reshard input error : {self.inputs['input_info'][param]} type is not supported."
)
else:
# do nothing
pass
else:
# do nothingd
pass
return input_reshard_code
def generate_single_dense_input( def generate_single_dense_input(
self, self,
...@@ -479,13 +593,17 @@ class DistForwardAPI(ForwardAPI): ...@@ -479,13 +593,17 @@ class DistForwardAPI(ForwardAPI):
if kernel_param is None: if kernel_param is None:
kernel_param = input_names + attr_names kernel_param = input_names + attr_names
if self.infer_meta['spmd_rule'] is not None:
input_tensor_code += SINGLE_PREPARE_DATA_TEMPLATE.format( input_tensor_code += SINGLE_PREPARE_DATA_TEMPLATE.format(
input_name, arg=input_name,
input_name, idx=kernel_param.index(input_name),
kernel_param.index(input_name), flag=trans_flag,
trans_flag, )
input_name, else:
input_name, input_tensor_code += SINGLE_PREPARE_DATA_TEMPLATE_NO_RESHARD.format(
arg=input_name,
idx=kernel_param.index(input_name),
flag=trans_flag,
) )
return input_tensor_code return input_tensor_code
...@@ -706,9 +824,6 @@ class DistForwardAPI(ForwardAPI): ...@@ -706,9 +824,6 @@ class DistForwardAPI(ForwardAPI):
", ".join(self.dense_output_args), ", ".join(self.dense_output_args),
) )
def generate_reshard_output_code(self) -> str:
return OUTPUT_RESHARD_TEMPLATE.format()
def generate_return_code(self) -> str: def generate_return_code(self) -> str:
return self.gene_return_code() return self.gene_return_code()
...@@ -718,14 +833,14 @@ class DistForwardAPI(ForwardAPI): ...@@ -718,14 +833,14 @@ class DistForwardAPI(ForwardAPI):
return "" return ""
return MAIN_DIST_BRANCH_TEMPLATE.format( return MAIN_DIST_BRANCH_TEMPLATE.format(
self.generate_if_condition_code(), self.generate_if_condition_code(),
self.generate_output_creation_code(),
self.generate_infer_spmd_code(), self.generate_infer_spmd_code(),
self.generate_output_creation_code(),
self.generate_infer_global_shape_code(),
self.generate_kernel_selection_code(), self.generate_kernel_selection_code(),
self.generate_reshard_input_code(), self.generate_reshard_input_code(),
self.generate_prepare_data_code(), self.generate_prepare_data_code(),
self.generate_infer_meta_code(), self.generate_infer_meta_code(),
self.generate_kernel_call_code(), self.generate_kernel_call_code(),
self.generate_reshard_output_code(),
self.generate_return_code(), self.generate_return_code(),
) )
...@@ -777,11 +892,14 @@ class DistForwardAPI(ForwardAPI): ...@@ -777,11 +892,14 @@ class DistForwardAPI(ForwardAPI):
# 3. doesn't support view api # 3. doesn't support view api
# 4. only for general forward and backward # 4. only for general forward and backward
# 5. only support single tensor input and output # 5. only support single tensor input and output
# 6. doesn't support double grad and triple grad
dist_branch_code = "" dist_branch_code = ""
if ( if (
len(self.inputs['names']) > 0 len(self.inputs['names']) > 0
and len(self.view_map) == 0 and len(self.view_map) == 0
and self.check_argument_whether_support_auto_parallel() and self.check_argument_whether_support_auto_parallel()
and not self.api.endswith("_double_grad")
and not self.api.endswith("_triple_grad")
): ):
dist_branch_code = self.generate_auto_paralel_branch() dist_branch_code = self.generate_auto_paralel_branch()
return API_IMPL_TEMPLATE.format( return API_IMPL_TEMPLATE.format(
......
...@@ -22,10 +22,24 @@ from dist_api_gen import DistForwardAPI ...@@ -22,10 +22,24 @@ from dist_api_gen import DistForwardAPI
# Code Gen Templates # # Code Gen Templates #
###################### ######################
MAIN_DIST_BRANCH_TEMPLATE = """
// Auto Parallel condition
if ({}) {{
// 1. Create API Output & Prepare Dist and Dense Output{}
// 2. Infer DistTensor's Global Shape{}
// 3. Select Kernel{}
// 4. PrepareData (DataTransform & Prepare Dense Input){}
// 5. Infer Local DenseTensor Meta{}
// 6. DenseTensor Kernel Call{}
// 7. Return
{}
}}
"""
# 1. Create API Outputs # 1. Create API Outputs
SINGLE_OUT_CREATION_TEMPLATE = """ SINGLE_OUT_CREATION_TEMPLATE = """
auto dist_out = SetKernelDistOutput({}); auto dist_out = SetKernelDistOutput({});
auto dense_out = const_cast<phi::DenseTensor*>(&dist_out->value()); auto dense_out = dist_out->unsafe_mutable_value();
""" """
VECTOR_OUT_CREATION_TEMPLATE = """ VECTOR_OUT_CREATION_TEMPLATE = """
auto dist_out = SetKernelDistOutput({name}); auto dist_out = SetKernelDistOutput({name});
...@@ -39,7 +53,21 @@ INPLACE_OUT_CREATION_TEMPLATE = """ ...@@ -39,7 +53,21 @@ INPLACE_OUT_CREATION_TEMPLATE = """
""" """
MULTI_SINGLE_OUT_CREATION_TEMPLATE = """ MULTI_SINGLE_OUT_CREATION_TEMPLATE = """
auto dist_out_{} = SetKernelDistOutput({}); auto dist_out_{} = SetKernelDistOutput({});
auto dense_out_{} = const_cast<phi::DenseTensor*>(&dist_out_{}->value()); auto dense_out_{} = dist_out_{}->unsafe_mutable_value();
"""
# 2. Infer Global Shape
SINGLE_DIST_META_IN_TEMPLATE = """MakeDistMetaTensor(*{}.impl()), """
SINGLE_DIST_META_OUT_DECL_TEMPLATE = """
phi::distributed::DistMetaTensor meta_{}({});"""
INFER_GLOBAL_SHAPE_TEMPLATE = """
phi::{}({}{});
"""
# 4. PrepareData (DataTransform & Prepare Dist and Dense Input)
SINGLE_PREPARE_DATA_TEMPLATE = """
auto dist_input_{arg} = PrepareDataForDistTensor({arg}, GetKernelInputArgDef(kernel.InputAt({idx}), kernel_backend), {flag}, kernel_result.is_stride_kernel);
auto input_{arg} = &dist_input_{}->value();
""" """
...@@ -131,6 +159,21 @@ class DistBackwardAPI(DistForwardAPI, BackwardAPI): ...@@ -131,6 +159,21 @@ class DistBackwardAPI(DistForwardAPI, BackwardAPI):
def gene_api_declaration(self) -> str: def gene_api_declaration(self) -> str:
return BackwardAPI.gene_api_declaration(self) return BackwardAPI.gene_api_declaration(self)
def generate_auto_paralel_branch(self) -> str:
# if no tensor input, do not genetate auto parallel branch
if len(self.inputs['names']) == 0:
return ""
return MAIN_DIST_BRANCH_TEMPLATE.format(
self.generate_if_condition_code(),
self.generate_output_creation_code(),
self.generate_infer_global_shape_code(),
self.generate_kernel_selection_code(),
self.generate_prepare_data_code(),
self.generate_infer_meta_code(),
self.generate_kernel_call_code(),
self.generate_return_code(),
)
def header_include(): def header_include():
return """ return """
......
...@@ -651,6 +651,7 @@ ...@@ -651,6 +651,7 @@
output : Tensor output : Tensor
infer_meta : infer_meta :
func : MatmulInferMeta func : MatmulInferMeta
spmd_rule : MatmulSpmdInferForward
kernel : kernel :
func : matmul func : matmul
backward : matmul_grad backward : matmul_grad
......
...@@ -66,6 +66,15 @@ class DistTensor final ...@@ -66,6 +66,15 @@ class DistTensor final
/// \return The DenseTensor value's const reference /// \return The DenseTensor value's const reference
const DenseTensor& value() const { return value_; } const DenseTensor& value() const { return value_; }
/// \brief Returns the mutable dense tensor value in dist tensor.
/// \note If DenseTensor value is modified externally, the corresponding
/// relationship between it and the current tensor's global dims and
/// dist attr may be destroyed, which may introduce some subtle bugs,
/// so you need to make sure to consider it thoroughly when using
/// this method.
/// \return The mutable pointer of DenseTensor value
DenseTensor* unsafe_mutable_value() { return &value_; }
/// \brief Returns the global dims of the dist tensor. /// \brief Returns the global dims of the dist tensor.
/// \return The global dims of the dist tensor. /// \return The global dims of the dist tensor.
const DDim& local_dims() const; const DDim& local_dims() const;
......
...@@ -73,11 +73,6 @@ bool SpmdRuleFactory::ContainsSpmdRule(const std::string& kernel_name) const { ...@@ -73,11 +73,6 @@ bool SpmdRuleFactory::ContainsSpmdRule(const std::string& kernel_name) const {
} }
int SpmdRuleFactory::InsertSpmdRule(std::string kernel_name, SpmdRule rule) { int SpmdRuleFactory::InsertSpmdRule(std::string kernel_name, SpmdRule rule) {
PADDLE_ENFORCE_NE(
ContainsSpmdRule(kernel_name),
true,
phi::errors::AlreadyExists(
"`%s` Kernel's Spmd rules has been registered.", kernel_name));
spmd_rule_map_.insert({std::move(kernel_name), std::move(rule)}); spmd_rule_map_.insert({std::move(kernel_name), std::move(rule)});
return 0; return 0;
} }
......
...@@ -27,12 +27,12 @@ limitations under the License. */ ...@@ -27,12 +27,12 @@ limitations under the License. */
* 2. Since the infer functions of Spmd forward and backward are closely related * 2. Since the infer functions of Spmd forward and backward are closely related
* and need to be registered together, we manage them together in one file. * and need to be registered together, we manage them together in one file.
* *
* 3. SPMD rules are much smaller than infermeta function, and we manage files * 3. SPMD rules are less than infermeta function, and we manage files by
* in operator units. * operator.
* *
* 4. The previous registration used some compile-time regular matching methods, * 4. The previous registration used some compile-time regular matching methods,
* which was less flexible, and the registration of SPMD rules here is declare * which was less flexible, and the registration of SPMD rules here is declare
* directly in the header file * directly in the header file.
*/ */
namespace phi { namespace phi {
......
...@@ -88,6 +88,10 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -88,6 +88,10 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_reshard_r_to_p MODULES test_reshard_r_to_p) py_test_modules(test_reshard_r_to_p MODULES test_reshard_r_to_p)
set_tests_properties(test_reshard_r_to_p set_tests_properties(test_reshard_r_to_p
PROPERTIES LABELS "RUN_TYPE=EXECLUSIVE" TIMEOUT 100) PROPERTIES LABELS "RUN_TYPE=EXECLUSIVE" TIMEOUT 100)
py_test_modules(test_semi_auto_parallel_basic MODULES
test_semi_auto_parallel_basic)
set_tests_properties(test_semi_auto_parallel_basic
PROPERTIES LABELS "RUN_TYPE=EXECLUSIVE" TIMEOUT 100)
# End of unittests WITH multi cards and timeout # End of unittests WITH multi cards and timeout
# NOTE(zyl): unittests WITH multi cards and WITHOUT timeout # NOTE(zyl): unittests WITH multi cards and WITHOUT timeout
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
import paddle
import paddle.distributed as dist
class TestMatmulApiForSemiAutoParallel:
def __init__(self):
self._dtype = os.getenv("dtype")
self._seeds = eval(os.getenv("seeds"))
self._backend = os.getenv("backend")
self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
def test_body(self, x_specs, y_specs):
x_shape = [64, 32]
y_shape = [32, 48]
x = paddle.randn(x_shape, self._dtype)
y = paddle.randn(y_shape, self._dtype)
x_dist_attr = dist.DistAttr(mesh=self._mesh, sharding_specs=x_specs)
y_dist_attr = dist.DistAttr(mesh=self._mesh, sharding_specs=y_specs)
dist_x = dist.shard_tensor(x, dist_attr=x_dist_attr)
dist_y = dist.shard_tensor(y, dist_attr=y_dist_attr)
dist_out = paddle.matmul(dist_x, dist_y)
# verify global shape
out_shape = [64, 48]
np.testing.assert_equal(dist_out.shape, out_shape, verbose=True)
return dist_out
def test_case1(self):
# case1: mk[0,-1],kn[-1,-1] -> mk[0,-1],kn[-1,-1] = mn[0,-1] partial[]
dist_out = self.test_body(x_specs=['x', None], y_specs=[None, None])
# verify local shape and dist attr
np.testing.assert_equal(dist_out._local_shape, [32, 48], verbose=True)
np.testing.assert_equal(
dist_out.dist_attr.dims_mapping, [0, -1], verbose=True
)
assert dist_out.dist_attr._is_partial() is False
def test_case2(self):
# case2: mk[-1, 0],kn[-1,-1] --> mk[-1, 0],kn[0, -1] = nm[-1, -1] partial[0]
dist_out = self.test_body(x_specs=[None, 'x'], y_specs=[None, None])
# verify local shape
np.testing.assert_equal(dist_out._local_shape, [64, 48], verbose=True)
np.testing.assert_equal(
dist_out.dist_attr.dims_mapping, [-1, -1], verbose=True
)
assert dist_out.dist_attr._is_partial() is True
assert dist_out.dist_attr._partial_dims() == {0}
def run_test_case(self):
if self._backend == "cpu":
paddle.set_device("cpu")
elif self._backend == "gpu":
paddle.set_device("gpu:" + str(dist.get_rank()))
else:
raise ValueError("Only support cpu or gpu backend.")
self.test_case1()
self.test_case2()
if __name__ == '__main__':
TestMatmulApiForSemiAutoParallel().run_test_case()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import collective.test_communication_api_base as test_base
class TestSemiAutoParallelMatmul(test_base.CommunicationTestDistBase):
def setUp(self):
super().setUp(num_of_devices=2, timeout=120)
self._default_envs = {
"dtype": "float32",
"seeds": str(self._seeds),
}
self._changeable_envs = {"backend": ["cpu", "gpu"]}
def test_matmul_api(self):
envs_list = test_base.gen_product_envs_list(
self._default_envs, self._changeable_envs
)
for envs in envs_list:
self.run_test_case(
"semi_auto_parallel_for_matmul.py",
user_defined_envs=envs,
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册