未验证 提交 7039bef3 编写于 作者: C Chen Weihang 提交者: GitHub

[AutoParallel] Dygraph basic impl for semi auto parallel (#55698)

* add phi forward api gen impl

* add phi backward gen code

* polish api code gen impl

* polish code gen impl

* remove auto_paralel namespace

* add dygraph forward impl

* add for_auto_parallel cond

* fix code gen errors

* add dygraph backward impl

* resolve conflict with develop

* refactor dist api gen impl

* revert origin api gen impl

* replace template for override func

* fix dnnl marco error

* revert third_party change

* add with distributed marco

* Update grad_tensor_holder.cc details

* merge dist tensor constructor

* change test tensor to replicate

* fx typo

* resolve conflict with develop

* fix out dim error
上级 fcde3991
......@@ -26,8 +26,12 @@
#include "paddle/fluid/platform/errors.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#endif
/**
* Implementation of GradNodeBase, Edge and GradTensorHolder.
......@@ -121,6 +125,14 @@ void GradNodeBase::SetGradInMeta(const paddle::Tensor& fwd_out,
phi::SparseCsrTensor* csr_tensor =
static_cast<phi::SparseCsrTensor*>(fwd_out.impl().get());
dense_tensor = csr_tensor->mutable_non_zero_elements();
#ifdef PADDLE_WITH_DISTRIBUTE
} else if (phi::distributed::DistTensor::classof(fwd_out.impl().get())) {
// TODO(chenweihang): DistTensor contains global and local meta, here
// only set the local meta now, we should set global meta later
dense_tensor =
static_cast<phi::distributed::DistTensor*>(fwd_out.impl().get())
->mutable_value();
#endif
} else {
VLOG(7) << "Unable to initialize the DenseTensorMeta of GradSlotMeta with "
"non-DenseTensor argument.";
......@@ -256,10 +268,28 @@ void GradNodeBase::SetGradOutMeta(const paddle::Tensor& fwd_in,
"which is illegal."));
meta.SetTensorMeta(dense_tensor->meta());
meta.SetPlace(fwd_in.place());
#ifdef PADDLE_WITH_DISTRIBUTE
} else if (phi::distributed::DistTensor::classof(fwd_in.impl().get())) {
phi::DenseTensor* dense_tensor =
static_cast<phi::distributed::DistTensor*>(fwd_in.impl().get())
->mutable_value();
PADDLE_ENFORCE_NE(
dense_tensor->meta().dtype,
phi::DataType::UNDEFINED,
paddle::platform::errors::Fatal("Attempting to copy DenseTensorMeta "
"with phi::DataType::UNDEFINED,"
"which is illegal."));
meta.SetTensorMeta(dense_tensor->meta());
meta.SetPlace(fwd_in.place());
#endif
} else {
VLOG(7)
<< "Unable to initialize the DenseTensorMeta of GradSlotMeta with "
"non-DenseTensor argument.";
}
} else {
VLOG(7) << "Unable to initialize the DenseTensorMeta of GradSlotMeta with "
"non-DenseTensor argument.";
VLOG(7) << "Unable to initialize the DenseTensorMeta because the Tensor "
"is not initialized.";
}
}
......@@ -367,7 +397,8 @@ void GradNodeBase::SetGradOutMeta(const std::vector<paddle::Tensor>& fwd_in,
// Record TensorMeta
if (fwd_in_tensor.impl() && fwd_in_tensor.impl().get()) {
if (phi::DenseTensor::classof(fwd_in_tensor.impl().get())) {
// Only Copy Meta
// TODO(chenweihang): DistTensor contains global and local meta, here
// only set the local meta now, we should set global meta later
phi::DenseTensor* dense_tensor =
static_cast<phi::DenseTensor*>(fwd_in_tensor.impl().get());
PADDLE_ENFORCE_NE(dense_tensor->dtype(),
......
......@@ -20,6 +20,10 @@
#include "paddle/fluid/imperative/gradient_accumulator.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#endif
namespace egr {
......@@ -83,6 +87,23 @@ void GradTensorHolder::CopyValueFromTensor(size_t slot_id,
} else if (t.is_sparse_csr_tensor() || t.is_sparse_coo_tensor()) {
buffer_[slot_id][rank] =
paddle::experimental::sparse::full_like(t, 1, t.dtype());
#ifdef PADDLE_WITH_DISTRIBUTE
} else if (t.is_dist_tensor()) {
VLOG(6) << "Create a new dist tensor.";
// TODO(chenweihang): we need a shard_tensor API in C++
// TODO(chenweihang): replace by valid dist_attr later
auto temp =
paddle::experimental::full(t.shape(), 1, t.dtype(), t.place());
auto dense_temp =
std::dynamic_pointer_cast<phi::DenseTensor>(temp.impl());
auto dist_tensor = std::make_shared<phi::distributed::DistTensor>(
dense_temp,
dense_temp->meta(),
std::make_shared<
phi::distributed::auto_parallel::TensorDistAttr>());
temp.set_impl(dist_tensor);
buffer_[slot_id][rank] = temp;
#endif
} else {
PADDLE_THROW(paddle::platform::errors::Fatal(
"Only Support DENSE_TENSOR, SPARSE_COO_TENSOR, SPARSE_CSR_TENSOR "
......@@ -178,6 +199,10 @@ void GradTensorHolder::add(size_t slot_id,
&buffer_values);
}
}
#ifdef PADDLE_WITH_DISTRIBUTE
} else if (t.is_dist_tensor()) {
buffer_tensor = add_ad_func(t, buffer_tensor);
#endif
} else {
// TODO(jiabin): Support Other TensorBase later
// TODO(zhanlve): Replace SelectedRowsAddTensor with add_dygraph_function
......
......@@ -297,13 +297,15 @@ void InitDistTensorWithTensor(
if (place == src.place()) {
std::shared_ptr<phi::DenseTensor> tensor =
std::static_pointer_cast<phi::DenseTensor>(src.impl());
self->tensor.set_impl(std::make_shared<DistTensor>(tensor, dist_attr));
self->tensor.set_impl(
std::make_shared<DistTensor>(tensor, tensor->meta(), dist_attr));
VLOG(4) << "Same place, do ShareDataWith for DistTensor.";
} else {
std::shared_ptr<phi::DenseTensor> tensor =
std::static_pointer_cast<phi::DenseTensor>(
src.copy_to(place, true).impl());
self->tensor.set_impl(std::make_shared<DistTensor>(tensor, dist_attr));
self->tensor.set_impl(
std::make_shared<DistTensor>(tensor, tensor->meta(), dist_attr));
VLOG(4) << "Different place, do TensorCopy for DistTensor.";
}
if (src.get_autograd_meta()) {
......
......@@ -249,6 +249,27 @@ static PyObject* tensor_method_numpy(TensorObject* self,
place,
dense_tensor->Holder()->ptr(),
dense_tensor->Holder()->size());
#ifdef PADDLE_WITH_DISTRIBUTE
} else if (self->tensor.is_dist_tensor()) {
// TODO(chenweihang): deal with DistTensor as local DenseTensor now,
// if the local DenseTensor is shard or partial, do gather or reduce?
VLOG(6) << "Getting DistTensor's numpy value";
auto* dist_tensor =
static_cast<phi::distributed::DistTensor*>(self->tensor.impl().get());
auto& dense_tensor = dist_tensor->value();
cpu_tensor.set_meta(dense_tensor.meta());
// deep copy
auto tmp_allocation_ptr =
memory::Alloc(cpu_place, dense_tensor.Holder()->size());
cpu_tensor.ResetHolder(std::shared_ptr<phi::Allocation>(
tmp_allocation_ptr.release(), tmp_allocation_ptr.get_deleter()));
// deep copy
paddle::memory::Copy(place,
cpu_tensor.Holder()->ptr(),
place,
dense_tensor.Holder()->ptr(),
dense_tensor.Holder()->size());
#endif
} else {
VLOG(6) << "Getting DenseTensor's numpy value";
auto dense_tensor =
......@@ -290,6 +311,22 @@ static PyObject* tensor_method_numpy(TensorObject* self,
dense_tensor->Holder()->ptr(),
dense_tensor->Holder()->size(),
kind);
#ifdef PADDLE_WITH_DISTRIBUTE
} else if (self->tensor.is_dist_tensor()) {
VLOG(6) << "Getting DistTensor's numpy value";
auto* dist_tensor =
static_cast<phi::distributed::DistTensor*>(self->tensor.impl().get());
auto& dense_tensor = dist_tensor->value();
cpu_tensor.set_meta(dense_tensor.meta());
auto tmp_allocation_ptr =
memory::Alloc(cpu_place, dense_tensor.Holder()->size());
cpu_tensor.ResetHolder(std::shared_ptr<phi::Allocation>(
tmp_allocation_ptr.release(), tmp_allocation_ptr.get_deleter()));
paddle::platform::GpuMemcpySync(cpu_tensor.Holder()->ptr(),
dense_tensor.Holder()->ptr(),
dense_tensor.Holder()->size(),
kind);
#endif
} else {
VLOG(6) << "Getting DenseTensor's numpy value";
auto dense_tensor =
......
......@@ -9,6 +9,9 @@ set(api_header_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/include/api.h)
set(api_source_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/lib/api.cc)
set(api_header_file_tmp ${api_header_file}.tmp)
set(api_source_file_tmp ${api_source_file}.tmp)
# dist forward api file
set(dist_api_gen_file
${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/generator/dist_api_gen.py)
# backward api file
set(bw_api_gen_file
......@@ -21,6 +24,9 @@ set(bw_api_header_file
set(bw_api_source_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/lib/backward_api.cc)
set(bw_api_header_file_tmp ${bw_api_header_file}.tmp)
set(bw_api_source_file_tmp ${bw_api_source_file}.tmp)
# dist backward api file
set(dist_bw_api_gen_file
${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/generator/dist_bw_api_gen.py)
# dygraph(intermediate) api file
set(im_api_gen_file
......@@ -124,19 +130,37 @@ endif()
execute_process(COMMAND ${PYTHON_EXECUTABLE} -m pip install pyyaml)
# generate forward api
execute_process(
COMMAND
${PYTHON_EXECUTABLE} ${api_gen_file} --api_yaml_path ${api_yaml_file}
${legacy_api_yaml_file} --api_header_path ${api_header_file_tmp}
--api_source_path ${api_source_file_tmp})
if(WITH_DISTRIBUTE)
# generate dist forward api
execute_process(
COMMAND
${PYTHON_EXECUTABLE} ${dist_api_gen_file} --api_yaml_path
${api_yaml_file} ${legacy_api_yaml_file} --api_header_path
${api_header_file_tmp} --api_source_path ${api_source_file_tmp})
# generate backward api
execute_process(
COMMAND
${PYTHON_EXECUTABLE} ${bw_api_gen_file} --backward_yaml_path
${bw_api_yaml_file} ${legacy_bw_api_yaml_file} --backward_header_path
${bw_api_header_file_tmp} --backward_source_path ${bw_api_source_file_tmp})
# generate dist backward api
execute_process(
COMMAND
${PYTHON_EXECUTABLE} ${dist_bw_api_gen_file} --backward_yaml_path
${bw_api_yaml_file} ${legacy_bw_api_yaml_file} --backward_header_path
${bw_api_header_file_tmp} --backward_source_path
${bw_api_source_file_tmp})
else()
# generate forward api
execute_process(
COMMAND
${PYTHON_EXECUTABLE} ${api_gen_file} --api_yaml_path ${api_yaml_file}
${legacy_api_yaml_file} --api_header_path ${api_header_file_tmp}
--api_source_path ${api_source_file_tmp})
# generate backward api
execute_process(
COMMAND
${PYTHON_EXECUTABLE} ${bw_api_gen_file} --backward_yaml_path
${bw_api_yaml_file} ${legacy_bw_api_yaml_file} --backward_header_path
${bw_api_header_file_tmp} --backward_source_path
${bw_api_source_file_tmp})
endif()
# generate fused_op api
execute_process(
......
......@@ -19,6 +19,13 @@ limitations under the License. */
DECLARE_bool(use_stride_kernel);
#include "glog/logging.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#endif
namespace paddle {
namespace experimental {
......@@ -475,5 +482,26 @@ void TransStride(phi::DeviceContext* dev_ctx,
phi::SelectedRows* from,
phi::SelectedRows* to) {}
#ifdef PADDLE_WITH_DISTRIBUTE
/* ------------------ for auto parallel ----------------------- */
phi::distributed::DistTensor* SetKernelDistOutput(Tensor* out) {
if (out) {
// TODO(chenweihang): now all dist case are nullptr
if (out->impl() == nullptr) {
auto dense_t = std::make_shared<phi::DenseTensor>();
// TODO(chenweihang): polish code, dist_attr is null now
auto dist_attr =
std::make_shared<phi::distributed::auto_parallel::TensorDistAttr>();
auto dist_t = std::make_shared<phi::distributed::DistTensor>(
dense_t, phi::DenseTensorMeta(), dist_attr);
out->set_impl(dist_t);
}
return static_cast<phi::distributed::DistTensor*>(out->impl().get());
}
return nullptr;
}
#endif
} // namespace experimental
} // namespace paddle
......@@ -24,6 +24,12 @@ limitations under the License. */
#include "paddle/phi/core/sparse_csr_tensor.h"
#include "paddle/phi/core/string_tensor.h"
namespace phi {
namespace distributed {
class DistTensor;
} // namespace distributed
} // namespace phi
namespace paddle {
namespace experimental {
......@@ -127,5 +133,11 @@ void TransStride(phi::DeviceContext* dev_ctx,
phi::SelectedRows* from,
phi::SelectedRows* to);
#ifdef PADDLE_WITH_DISTRIBUTE
/* ------------------ for auto parallel ----------------------- */
phi::distributed::DistTensor* SetKernelDistOutput(Tensor* out);
#endif
} // namespace experimental
} // namespace paddle
......@@ -27,6 +27,10 @@ limitations under the License. */
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/contiguous_kernel.h"
#include "paddle/phi/kernels/transfer_layout_kernel.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#endif
DECLARE_bool(use_stride_kernel);
namespace paddle {
......@@ -567,5 +571,46 @@ void TransDataBackend(const phi::SelectedRows* tensor,
}
}
#ifdef PADDLE_WITH_DISTRIBUTE
/* ------------------ for auto parallel ----------------------- */
std::shared_ptr<phi::distributed::DistTensor> PrepareDataForDistTensor(
const Tensor& input,
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag,
bool is_stride_kernel) {
const auto& tensor_in = input.impl();
if (tensor_in) {
phi::distributed::DistTensor* dist_tensor =
static_cast<phi::distributed::DistTensor*>(tensor_in.get());
phi::DenseTensor& dense_tensor = *(dist_tensor->mutable_value());
if (!transform_flag.NeedTransform() || !dense_tensor.initialized() ||
(!NeedTransformPlace(
dense_tensor.place(), target_args_def.backend, transform_flag) &&
!NeedTransformDataType(
dense_tensor.dtype(), target_args_def.dtype, transform_flag) &&
!NeedTransformLayout(dense_tensor.layout(),
target_args_def.layout,
dense_tensor.place(),
transform_flag) &&
!NeedTransform2Contiguous(is_stride_kernel,
dense_tensor.meta().is_contiguous()))) {
return std::static_pointer_cast<phi::distributed::DistTensor>(tensor_in);
}
phi::DenseTensor out = TransformData(
&dense_tensor, target_args_def, transform_flag, is_stride_kernel);
// TODO(chenweihang): The global meta in DistTensor is not changed,
// but the local meta in DenseTensor maybe changed, such as layout
// change(NCHW->NHWC), so the new DistTensor's meta maybe not unified.
VLOG(6) << "PrepareDataForDistTensor return transformed dist tensor";
return std::make_shared<phi::distributed::DistTensor>(
std::make_shared<phi::DenseTensor>(std::move(out)),
dist_tensor->meta(),
dist_tensor->dist_attr());
}
return nullptr;
}
#endif
} // namespace experimental
} // namespace paddle
......@@ -20,6 +20,12 @@ limitations under the License. */
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
namespace phi {
namespace distributed {
class DistTensor;
} // namespace distributed
} // namespace phi
namespace paddle {
namespace experimental {
......@@ -165,5 +171,16 @@ inline bool NeedTransformPlace(const phi::Place& src_place,
return ret;
}
#ifdef PADDLE_WITH_DISTRIBUTE
/* ------------------ for auto parallel ----------------------- */
// TODO(chenweihang): impl Reshard input and output function
std::shared_ptr<phi::distributed::DistTensor> PrepareDataForDistTensor(
const Tensor& input,
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag,
bool is_stride_kernel);
#endif
} // namespace experimental
} // namespace paddle
......@@ -173,6 +173,36 @@ struct KernelTypeParser : ArgsIterator<KernelTypeParser> {
}
};
#ifdef PADDLE_WITH_DISTRIBUTE
/* ------------------ for auto parallel ----------------------- */
struct DistTensorTypeParser : ArgsIterator<DistTensorTypeParser> {
bool result = true;
void operator()(const Tensor& x) { result &= x.is_dist_tensor(); }
void operator()(const paddle::optional<Tensor>& x) {
if (x) {
result &= x.get_ptr()->is_dist_tensor();
}
}
void operator()(const std::vector<Tensor>& x) {
if (!x.empty()) {
for (auto& t : x) {
result &= t.is_dist_tensor();
}
}
}
// skip other type args, these args don't used in kernel selection
template <typename T>
void operator()(const T& x) {
// do nothing
}
};
#endif
} // namespace detail
template <typename... Args>
......@@ -205,5 +235,12 @@ DataLayout ParseLayout(DataLayout layout);
DataLayout ParseLayout(const Tensor& tensor);
DataLayout ParseLayoutWithInputOrder(DataLayout layout, const Tensor& tensor);
#ifdef PADDLE_WITH_DISTRIBUTE
template <typename... Args>
bool AllInputsAreDistTensor(const Args&... args) {
return detail::DistTensorTypeParser().apply(args...).result;
}
#endif
} // namespace experimental
} // namespace paddle
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import yaml
from api_base import PREFIX_TENSOR_NAME
from api_gen import (
ForwardAPI,
api_namespace,
declare_extension_api,
header_include,
source_include,
)
######################
# Code Gen Templates #
######################
API_IMPL_TEMPLATE = """
PADDLE_API {} {}({}) {{
// Kernel Key Construction{}
// Kernel Dispatch Body{}
}}
"""
DIPATCH_END_GUARD_TEMPLATE = """
PADDLE_THROW(phi::errors::Unimplemented(
"The kernel of ({}) for input tensors is unimplemented, please check the type of input tensors."));
"""
# TODO(chenweihang): add profile function code later
# TODO(chenweihang): add view support later
MAIN_DIST_BRANCH_TEMPLATE = """
// Auto Parallel condition
if ({}) {{
// 1. Create API Output & Prepare Dist and Dense Output{}
// 2. InferSPMD (Infer Global Shape and DistAttr of Inputs&Outputs){}
// 3. Select Kernel{}
// 4. Reshard Input{}
// 5. PrepareData (DataTransform & Prepare Dist and Dense Input){}
// 6. Infer Local DenseTensor Meta{}
// 7. DenseTensor Kernel Call{}
// 8. Reshard Output{}
// 9. Return
{}
}}
"""
# Auto Parallel condition
AUTO_PARALLEL_COND_TEMPLATE = """AllInputsAreDistTensor({})"""
# 1. Create API Outputs
API_OUT_CREATION_TEMPLATE = """
{} api_output{};
"""
INPLACE_API_OUT_CREATION_TEMPLATE = """
{} api_output{{{}}};
"""
SINGLE_OUT_CREATION_TEMPLATE = """
auto dist_out = SetKernelDistOutput(&api_output);
auto dense_out = dist_out->mutable_value();
"""
MULTI_SINGLE_OUT_CREATION_TEMPLATE = """
auto dist_out_{} = SetKernelDistOutput({});
auto dense_out_{} = dist_out_{}->mutable_value();
"""
# TODO(chenweihang): support vector and tuple output later
VECTOR_OUT_CREATION_TEMPLATE = """
"""
MULTI_VECTOR_OUT_CREATION_TEMPLATE = """
auto dist_out_{} = {}({}, {});
auto dense_out_{} = dist_out_{}->mutable_value();
"""
TUPLE_OUT_CREATION_TEMPLATE = """
"""
# 2. InferSPMD
# Call InferMeta now, replace by InferSPMD function later
# TODO(chenweihang): InferSPMD function design
SINGLE_DIST_META_IN_TEMPLATE = """MakeMetaTensor(*{}.impl()), """
# TODO(chenweihang): support vector and optional args later
VECTOR_DIST_META_IN_TEMPLATE = """
"""
OPTIONAL_DIST_VECTOR_META_IN_TEMPLATE = """
"""
SINGLE_DIST_META_OUT_DECL_TEMPLATE = """
phi::MetaTensor meta_{}({});"""
INFER_SPMD_TEMPLATE = """
phi::{}({}{});
"""
# 3. Select Kernel
KERNEL_SELECTION_TEMPLATE = """
VLOG(6) << "{} API dist branch: kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]";
auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"{}", {{kernel_backend, kernel_layout, kernel_data_type}});
const auto& kernel = kernel_result.kernel;
VLOG(6) << "{} kernel: " << kernel;
auto* dev_ctx = GetDeviceContextByBackend(kernel_result.has_fallback_cpu ? Backend::CPU : kernel_backend);
"""
# 4. Reshard Input
INPUT_RESHARD_TEMPLATE = """
"""
# 5. PrepareData
SINGLE_PREPARE_DATA_TEMPLATE = """
auto dist_input_{} = PrepareDataForDistTensor({}, GetKernelInputArgDef(kernel.InputAt({}), kernel_backend), {}, kernel_result.is_stride_kernel);
auto input_{} = dist_input_{}->mutable_value();
"""
INFER_META_SINGLE_INPUT_TEMPLATE = """
auto dist_input_{} = {}.impl();
auto input_{} = static_cast<phi::distributed::DistTensor*>(dist_input_{}.get())->mutable_value();
"""
INFER_META_OPTIONAL_INPUT_TEMPLATE = """
paddle::optional<phi::TensorBase> input_{} = {} ? paddle::optional<phi::TensorBase>(*{}->impl()) : paddle::none;
"""
INFER_META_VECTOR_INPUT_TEMPLATE = """
auto input_{}_uq_ptr = TensorToDenseTensor({});
const auto& input_{} = *input_{}_uq_ptr;
"""
# 6. Infer Local DenseTensor Meta
SINGLE_META_IN_TEMPLATE = """MakeMetaTensor(*input_{}), """
# TODO(chenweihang): support vector and optional args later
VECTOR_META_IN_TEMPLATE = """
"""
OPTIONAL_VECTOR_META_IN_TEMPLATE = """
"""
SINGLE_META_OUT_DECL_TEMPLATE = """
phi::MetaTensor meta_{}({});"""
INFER_META_TEMPLATE = """
phi::{}({}{});
"""
# 7. DenseTensor Kernel Call
# TODO(chenweihang): support kernel fallback later
SINGLE_OUTPUT_NAME = """dense_out"""
# TODO(chenweihang): support vector and tuple output later
VECTOR_OUTPUT_NAME_TEMPLATE = """
"""
TUPLE_OUTPUT_NAME_TEMPLATE = """
"""
KERNEL_CALL_TEMPLATE = """
using kernel_signature = {};
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)({}, {});
"""
# 8. Reshard Output
OUTPUT_RESHARD_TEMPLATE = """
"""
# BaseAPI members:
# inputs:
# names : [], list of input names
# input_info : {input_name : type}
# attrs:
# names : [], list of attribute names
# attr_info : { attr_name : (type, default_values)}
# outputs:
# names : [], list of output names
# types : [], list of output types
# out_size_expr : [], expression for getting size of vector<Tensor>
class DistForwardAPI(ForwardAPI):
def __init__(self, api_item_yaml):
super().__init__(api_item_yaml)
self.init_dist_api_members()
def init_dist_api_members(self):
self.gene_dist_input_func = {
"const Tensor&": {
"dense": self.generate_dense_input,
},
"const paddle::optional<Tensor>&": {
"dense": self.generate_dense_input,
},
}
self.inplace_flag = False
self.dist_output_args = []
self.dense_output_args = []
def need_to_generate_code_for_inplace_impl(self, i):
return (
self.inplace_flag
and self.inplace_map is not None
and self.outputs['names'][i] in self.inplace_map
)
def need_to_generate_code_for_view_impl(self, i):
return (
not self.inplace_flag
and self.view_map is not None
and self.outputs['names'][i] in self.view_map
)
def is_inplace_output(self, i):
return self.outputs['names'][i] in self.inplace_map
def is_inplace_and_optional_output(self, i):
return (
self.outputs['names'][i] in self.inplace_map
and self.inplace_map[self.outputs['names'][i]] in self.optional_vars
)
def vector_output_size_assertion_check(self):
assert (
self.outputs['out_size_expr'] is not None
), f"{self.api}: The out size expr : '{{expr}}' should be set when output has Tensor[]. You can refer 'split' api."
def generate_if_condition_code(self) -> str:
input_args = ""
for input_name in self.inputs['names']:
input_args = input_args + input_name + ", "
if len(input_args) > 2:
input_args = input_args[:-2]
return AUTO_PARALLEL_COND_TEMPLATE.format(input_args)
def generate_output_creation_code(self) -> str:
# forward api need to generate api and kernel outputs
output_num = len(self.outputs['types'])
return_type = self.get_return_type_with_intermediate(self.inplace_flag)
output_creation_code = ""
if output_num == 1:
# api output generate
if self.need_to_generate_code_for_inplace_impl(0):
inplace_assign_code = (
" = " + self.inplace_map[self.outputs['names'][0]]
)
output_creation_code += API_OUT_CREATION_TEMPLATE.format(
return_type, inplace_assign_code
)
else:
output_creation_code += API_OUT_CREATION_TEMPLATE.format(
return_type, ""
)
# kernel output generate
self.dist_output_args.append('dist_out')
self.dense_output_args.append('dense_out')
if self.outputs['types'][0] == 'Tensor':
output_creation_code += SINGLE_OUT_CREATION_TEMPLATE
else:
self.vector_output_size_assertion_check()
elif output_num > 1:
# api output generate
if self.inplace_flag:
inplace_assign_code = ""
for i, out_name in enumerate(self.outputs['names']):
if self.need_to_generate_code_for_inplace_impl(i):
inplace_assign_code += self.inplace_map[out_name] + ', '
else:
inplace_assign_code += 'Tensor(), '
inplace_assign_code = inplace_assign_code[:-2]
output_creation_code += (
INPLACE_API_OUT_CREATION_TEMPLATE.format(
return_type, inplace_assign_code
)
)
else:
output_creation_code += API_OUT_CREATION_TEMPLATE.format(
return_type, ""
)
# kernel output generate
for i, out_type in enumerate(self.outputs['types']):
self.dist_output_args.append(f'dist_out_{i}')
self.dense_output_args.append(f'dense_out_{i}')
set_out_func = "SetKernelDistOutput"
get_out_code = f"&std::get<{i}>(api_output)"
if self.is_inplace_and_optional_output(i):
get_out_code = f"std::get<{i}>(api_output).get_ptr()"
if out_type == 'std::vector<Tensor>':
self.vector_output_size_assertion_check(i)
# Special case for inplace vector and inplace optional<vector>
# TODO(chenweihang): support this branch later
if self.is_inplace_output():
set_out_func = "SetInplaceVectorKernelOutput"
if self.is_inplace_and_optional_output(i):
set_out_func = (
"SetInplaceOptionalVectorKernelOutput"
)
get_out_code = f"std::get<{i}>(api_output)"
output_creation_code += (
MULTI_VECTOR_OUT_CREATION_TEMPLATE.format(
i,
set_out_func,
self.outputs['out_size_expr'][i],
get_out_code,
i,
i,
)
)
else:
output_creation_code += (
MULTI_SINGLE_OUT_CREATION_TEMPLATE.format(
i, get_out_code, i, i
)
)
else:
raise ValueError(
"{} : Output error: the output should not be empty.".format(
self.api
)
)
return output_creation_code
def generate_infer_spmd_code(self) -> str:
input_names = self.inputs['names']
attr_names = self.attrs['names']
output_names = self.outputs['names']
# 1. get infer meta func name
infer_meta = self.infer_meta
infer_meta_func_code = infer_meta['func']
# 2. get meta tensor input args
infer_meta_params = (
infer_meta['param']
if infer_meta['param'] is not None
else input_names + attr_names
)
input_args_code = ""
for param in infer_meta_params:
if param in input_names:
if self.inputs['input_info'][param] == "const Tensor&":
input_args_code += SINGLE_DIST_META_IN_TEMPLATE.format(
param
)
else:
raise ValueError(
f"{self.api} : Param of infer_spmd error : {self.inputs['input_info'][param]} type is not supported."
)
elif param in attr_names:
input_args_code = input_args_code + param + ", "
elif isinstance(param, str):
input_args_code = input_args_code + "\"" + param + "\", "
elif isinstance(param, bool):
input_args_code = input_args_code + str(param).lower() + ", "
else:
input_args_code = input_args_code + str(param) + ", "
# 3. get meta tensor output args
output_decl_code = ""
output_args_code = ""
for i, out_name in enumerate(self.dist_output_args):
if self.outputs['types'][i] == 'std::vector<Tensor>':
# TODO(chenweihang): support vector output later
pass
else:
output_decl_code += SINGLE_DIST_META_OUT_DECL_TEMPLATE.format(
out_name, out_name
)
if len(self.dense_output_args) == 1:
output_args_code += f"&meta_{out_name}, "
else:
output_args_code += (
f"{out_name} ? &meta_{out_name} : nullptr, "
)
output_args_code = output_args_code[:-2]
return output_decl_code + INFER_SPMD_TEMPLATE.format(
infer_meta_func_code, input_args_code, output_args_code
)
def generate_kernel_selection_code(self) -> str:
return KERNEL_SELECTION_TEMPLATE.format(
self.api, self.kernel['func'][0], self.kernel['func'][0]
)
def generate_reshard_input_code(self) -> str:
return INPUT_RESHARD_TEMPLATE.format()
# override BaseAPI's method
def generate_dense_input(
self,
input_name,
):
input_tensor_code = ""
trans_flag = self.gene_trans_flag(input_name)
input_names = self.inputs['names']
attr_names = self.attrs['names']
kernel_param = self.kernel['param']
if kernel_param is None:
kernel_param = input_names + attr_names
input_tensor_code += SINGLE_PREPARE_DATA_TEMPLATE.format(
input_name,
input_name,
kernel_param.index(input_name),
trans_flag,
input_name,
input_name,
)
return input_tensor_code
def generate_prepare_data_code(self) -> str:
input_names = self.inputs['names']
attr_names = self.attrs['names']
kernel_param = self.kernel['param']
if kernel_param is None:
kernel_param = input_names + attr_names
input_tensor_code = ""
for i, input_name in enumerate(input_names):
# set input code
if input_name in kernel_param:
# onlu support dense tensor
api_tensor_type = self.inputs['input_info'][input_name]
phi_tensor_type = 'dense'
if api_tensor_type in self.gene_dist_input_func.keys():
input_tensor_code += self.gene_dist_input_func[
api_tensor_type
][phi_tensor_type](input_name)
else:
# do nothing
pass
else:
if input_name in self.infer_meta['param']:
if input_name in self.optional_vars:
input_tensor_code += (
INFER_META_OPTIONAL_INPUT_TEMPLATE.format(
input_name, input_name, input_name, input_name
)
)
else:
if (
self.inputs['input_info'][input_name]
== "const std::vector<Tensor>&"
):
input_tensor_code += (
INFER_META_VECTOR_INPUT_TEMPLATE.format(
input_name, input_name, input_name
)
)
else:
input_tensor_code += (
INFER_META_SINGLE_INPUT_TEMPLATE.format(
input_name,
input_name,
input_name,
input_name,
)
)
return input_tensor_code
def generate_infer_meta_code(self) -> str:
input_names = self.inputs['names']
attr_names = self.attrs['names']
output_names = self.outputs['names']
# 1. get infer meta func name
infer_meta = self.infer_meta
infer_meta_func_code = infer_meta['func']
# 2. get meta tensor input args
infer_meta_params = (
infer_meta['param']
if infer_meta['param'] is not None
else input_names + attr_names
)
input_args_code = ""
for param in infer_meta_params:
if param in input_names:
if self.inputs['input_info'][param] == "const Tensor&":
input_args_code += SINGLE_META_IN_TEMPLATE.format(param)
else:
raise ValueError(
f"{self.api} : Param of infer_meta error : {self.inputs['input_info'][param]} type is not supported."
)
elif param in attr_names:
input_args_code = input_args_code + param + ", "
elif isinstance(param, str):
input_args_code = input_args_code + "\"" + param + "\", "
elif isinstance(param, bool):
input_args_code = input_args_code + str(param).lower() + ", "
else:
input_args_code = input_args_code + str(param) + ", "
# 3. get meta tensor output args
output_decl_code = ""
output_args_code = ""
for i, out_name in enumerate(self.dense_output_args):
if self.outputs['types'][i] == 'std::vector<Tensor>':
# TODO(chenweihang): support vector output later
pass
else:
output_decl_code += SINGLE_META_OUT_DECL_TEMPLATE.format(
out_name, out_name
)
if len(self.dense_output_args) == 1:
output_args_code += f"&meta_{out_name}, "
else:
output_args_code += (
f"{out_name} ? &meta_{out_name} : nullptr, "
)
output_args_code = output_args_code[:-2]
return output_decl_code + INFER_META_TEMPLATE.format(
infer_meta_func_code, input_args_code, output_args_code
)
def generate_kernel_call_code(self) -> str:
dense_input_trans_map = {
'const Tensor&': 'const phi::DenseTensor&',
'const std::vector<Tensor>&': 'const std::vector<const phi::DenseTensor*>&',
'const paddle::optional<Tensor&>': 'paddle::optional<const phi::DenseTensor&>',
'const paddle::optional<Tensor>&': 'const paddle::optional<phi::DenseTensor>&',
'const paddle::optional<std::vector<Tensor>>&': 'const paddle::optional<std::vector<const phi::DenseTensor*>>&',
}
dense_output_trans_map = {
'Tensor': 'phi::DenseTensor*',
'std::vector<Tensor>': 'std::vector<phi::DenseTensor*>',
}
input_names = self.inputs['names']
input_infos = self.inputs['input_info']
kernel_args_type_list = ['const phi::DeviceContext&']
attr_names = self.attrs['names']
kernel_args = self.kernel['param']
if kernel_args is None:
kernel_args = input_names + attr_names
# 1. generate input args list
input_args = ["*dev_ctx"]
for arg in kernel_args:
if arg in input_names:
if arg in self.optional_vars:
input_args.append(PREFIX_TENSOR_NAME + arg)
else:
if input_infos[arg] == "const Tensor&":
input_args.append("*" + PREFIX_TENSOR_NAME + arg)
elif input_infos[arg] == "const std::vector<Tensor>&":
input_args.append(PREFIX_TENSOR_NAME + arg)
else:
# do nothing
pass
kernel_args_type_list.append(
dense_input_trans_map[input_infos[arg]]
)
elif arg in attr_names:
if 'IntArray' in self.attrs['attr_info'][arg][0]:
kernel_args_type_list.append('const phi::IntArray&')
arg = 'phi::IntArray(' + arg + ')'
elif 'vector<phi::Scalar>' in self.attrs['attr_info'][arg][0]:
kernel_args_type_list.append(
'const std::vector<phi::Scalar>&'
)
elif 'Scalar' in self.attrs['attr_info'][arg][0]:
kernel_args_type_list.append('const phi::Scalar&')
arg = 'phi::Scalar(' + arg + ')'
else:
kernel_args_type_list.append(
self.attrs['attr_info'][arg][0]
)
input_args.append(arg)
elif isinstance(arg, bool):
input_args.append(str(arg).lower())
else:
input_args.append(str(arg))
# 2. generate output args list
# record into `self.dense_output_args` in `generate_output_creation_code` function
# 3. generate kernel signature
for i, out_type in enumerate(self.outputs['types']):
kernel_args_type_list.append(dense_output_trans_map[out_type])
kernel_signature = "void(*)(" + ", ".join(kernel_args_type_list) + ")"
return KERNEL_CALL_TEMPLATE.format(
kernel_signature,
", ".join(input_args),
", ".join(self.dense_output_args),
)
def generate_reshard_output_code(self) -> str:
return OUTPUT_RESHARD_TEMPLATE.format()
def generate_return_code(self) -> str:
return self.gene_return_code()
def generate_auto_paralel_branch(self) -> str:
# if no tensor input, do not genetate auto parallel branch
if len(self.inputs['names']) == 0:
return ""
return MAIN_DIST_BRANCH_TEMPLATE.format(
self.generate_if_condition_code(),
self.generate_output_creation_code(),
self.generate_infer_spmd_code(),
self.generate_kernel_selection_code(),
self.generate_reshard_input_code(),
self.generate_prepare_data_code(),
self.generate_infer_meta_code(),
self.generate_kernel_call_code(),
self.generate_reshard_output_code(),
self.generate_return_code(),
)
def check_argument_whether_support_auto_parallel(self):
for name in self.inputs['names']:
if self.inputs['input_info'][name] != "const Tensor&":
return False
for out_type in self.outputs['types']:
if out_type != "Tensor":
return False
return True
# override BaseAPI's method
def gene_base_api_code(self, inplace_flag=False):
# init status
self.inplace_flag = inplace_flag
self.dist_output_args = []
self.dense_output_args = []
# generate api body
api_func_name = self.get_api_func_name()
if inplace_flag and api_func_name[-1] != '_':
api_func_name += '_'
if len(self.kernel['func']) > 1:
kernel_dispatch_code = ''
for kernel_name in self.kernel['func']:
kernel_dispatch_code += self.gene_dispatch_code(
kernel_name, inplace_flag
)
return API_IMPL_TEMPLATE.format(
self.get_return_type(inplace_flag),
api_func_name,
self.get_define_args(inplace_flag),
self.gene_kernel_select(),
kernel_dispatch_code
+ DIPATCH_END_GUARD_TEMPLATE.format(self.api),
)
else:
# auto parallel branch, all apis contains this branch default
# 1. only works for the ops contains single kernel
# 2. doesn't support initialize ops now
# 3. doesn't support view api
# 4. only for general forward and backward
# 5. only support single tensor input and output
dist_branch_code = ""
if (
len(self.inputs['names']) > 0
and len(self.view_map) == 0
and self.check_argument_whether_support_auto_parallel()
):
dist_branch_code = self.generate_auto_paralel_branch()
return API_IMPL_TEMPLATE.format(
self.get_return_type(inplace_flag),
api_func_name,
self.get_define_args(inplace_flag),
self.gene_kernel_select(),
dist_branch_code
+ self.gen_kernel_code(
self.kernel['func'][0], '', inplace_flag
),
)
def generate_api(
api_yaml_path, is_fused_ops_yaml, header_file_path, source_file_path
):
apis = []
for each_api_yaml in api_yaml_path:
with open(each_api_yaml, 'r') as f:
api_list = yaml.load(f, Loader=yaml.FullLoader)
if api_list:
apis.extend(api_list)
header_file = open(header_file_path, 'w')
source_file = open(source_file_path, 'w')
namespace = api_namespace()
header_file.write("#pragma once\n")
header_file.write(header_include())
header_file.write(namespace[0])
include_header_file = (
"paddle/phi/api/include/fused_api.h"
if is_fused_ops_yaml is True
else "paddle/phi/api/include/api.h"
)
# not all fused ops supoort dygraph
if is_fused_ops_yaml is True:
new_apis = [
api
for api in apis
if "support_dygraph_mode" in api
and api["support_dygraph_mode"] is True
]
apis = new_apis
source_file.write(source_include(include_header_file))
source_file.write(namespace[0])
for api in apis:
dist_foward_api = DistForwardAPI(api)
if dist_foward_api.is_dygraph_api:
dist_foward_api.is_dygraph_api = False
header_file.write(dist_foward_api.gene_api_declaration())
if is_fused_ops_yaml is True:
source_file.write(dist_foward_api.gene_api_code())
else:
source_file.write(dist_foward_api.gene_api_code())
header_file.write(namespace[1])
source_file.write(namespace[1])
source_file.write(declare_extension_api())
header_file.close()
source_file.close()
def main():
parser = argparse.ArgumentParser(
description='Generate PaddlePaddle C++ API files'
)
parser.add_argument(
'--api_yaml_path',
help='path to api yaml file',
nargs='+',
default=['paddle/phi/api/yaml/ops.yaml'],
)
parser.add_argument(
'--is_fused_ops_yaml',
help='flag of fused ops yaml',
action='store_true',
)
parser.add_argument(
'--api_header_path',
help='output of generated api header code file',
default='paddle/phi/api/include/api.h',
)
parser.add_argument(
'--api_source_path',
help='output of generated api source code file',
default='paddle/phi/api/lib/api.cc',
)
options = parser.parse_args()
api_yaml_path = options.api_yaml_path
is_fused_ops_yaml = options.is_fused_ops_yaml
header_file_path = options.api_header_path
source_file_path = options.api_source_path
generate_api(
api_yaml_path, is_fused_ops_yaml, header_file_path, source_file_path
)
if __name__ == '__main__':
main()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import yaml
from backward_api_gen import BackwardAPI
from dist_api_gen import DistForwardAPI
######################
# Code Gen Templates #
######################
# 1. Create API Outputs
SINGLE_OUT_CREATION_TEMPLATE = """
auto dist_out = SetKernelDistOutput({});
auto dense_out = dist_out->mutable_value();
"""
INPLACE_OUT_CREATION_TEMPLATE = """
*{} = {};
"""
MULTI_SINGLE_OUT_CREATION_TEMPLATE = """
auto dist_out_{} = SetKernelDistOutput({});
auto dense_out_{} = dist_out_{}->mutable_value();
"""
class DistBackwardAPI(DistForwardAPI, BackwardAPI):
def __init__(self, backward_item_yaml):
BackwardAPI.__init__(self, backward_item_yaml)
self.init_dist_api_members()
# override DistForwardAPI's method
def generate_output_creation_code(self) -> str:
# backward api only need to generate kernel outputs
output_num = len(self.outputs['types'])
output_creation_code = ""
if output_num == 1:
self.dist_output_args.append('dist_out')
self.dense_output_args.append('dense_out')
if self.outputs['types'][0] == 'Tensor':
output_creation_code += SINGLE_OUT_CREATION_TEMPLATE.format(
self.outputs['names'][0]
)
else:
self.vector_output_size_assertion_check()
elif output_num > 1:
for i, out_type in enumerate(self.outputs['types']):
self.dist_output_args.append(f'dist_out_{i}')
self.dense_output_args.append(f'dense_out_{i}')
if out_type == 'Tensor':
output_creation_code += (
MULTI_SINGLE_OUT_CREATION_TEMPLATE.format(
i, self.outputs['names'][i], i, i
)
)
else:
self.vector_output_size_assertion_check()
else:
raise ValueError(
"{} : Output error: the output should not be empty.".format(
self.api
)
)
return output_creation_code
# override DistForwardAPI's method
def generate_return_code(self) -> str:
return "return;"
# override BaseAPI's method
def get_api_func_name(self):
return self.api
# override BaseAPI's method
# The method lookup order are: (DistBackwardAPI.__mro__)
# <class '__main__.DistBackwardAPI'>,
# <class 'dist_api_gen.DistForwardAPI'>,
# <class 'api_gen.ForwardAPI'>,
# <class 'backward_api_gen.BackwardAPI'>,
# <class 'api_base.BaseAPI'>,
# <class 'object'>
# if don't override it, the ForwardAPI's gene_output wiil be called
def gene_output(
self,
out_dtype_list,
out_tensor_type_list=None,
code_indent='',
inplace_flag=False,
):
return BackwardAPI.gene_output(
self,
out_dtype_list,
out_tensor_type_list,
code_indent,
inplace_flag,
)
# override BaseAPI's method
def get_return_type(self, inplace_flag=False):
return BackwardAPI.get_return_type(self)
# override BaseAPI's method
def gene_return_code(self):
return ""
# override BaseAPI's method
def gene_api_declaration(self) -> str:
return BackwardAPI.gene_api_declaration(self)
def header_include():
return """
#include <tuple>
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/utils/optional.h"
"""
def source_include(header_file_path, fw_header_file_path):
return f"""
#include "{header_file_path}"
#include <memory>
#include "glog/logging.h"
#include "gflags/gflags.h"
#include "paddle/phi/api/lib/api_custom_impl.h"
#include "paddle/phi/api/lib/api_gen_utils.h"
#include "paddle/phi/api/lib/data_transform.h"
#include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "{fw_header_file_path}"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/api/profiler/event_tracing.h"
#include "paddle/phi/api/profiler/supplement_tracing.h"
DECLARE_bool(conv2d_disable_cudnn);
DECLARE_int32(low_precision_op_list);
"""
def backward_api_namespace():
return (
"""
namespace paddle {
namespace experimental {
""",
"""
} // namespace experimental
} // namespace paddle
""",
)
def generate_backward_api(
backward_yaml_path,
is_fused_backward_yaml,
header_file_path,
source_file_path,
):
bw_apis = []
for each_api_yaml in backward_yaml_path:
with open(each_api_yaml, 'r') as f:
api_list = yaml.load(f, Loader=yaml.FullLoader)
if api_list:
bw_apis.extend(api_list)
header_file = open(header_file_path, 'w')
source_file = open(source_file_path, 'w')
namespace = backward_api_namespace()
header_file.write("#pragma once\n")
header_file.write(header_include())
header_file.write(namespace[0])
include_header_file = (
"paddle/phi/api/backward/fused_backward_api.h"
if is_fused_backward_yaml
else "paddle/phi/api/backward/backward_api.h"
)
include_fw_header_file = (
"paddle/phi/api/include/fused_api.h"
if is_fused_backward_yaml
else "paddle/phi/api/include/api.h"
)
source_file.write(
source_include(include_header_file, include_fw_header_file)
)
source_file.write(namespace[0])
# not all fused ops supoort dygraph
if is_fused_backward_yaml is True:
new_bw_apis = [
bw_api
for bw_api in bw_apis
if "support_dygraph_mode" in bw_api
and bw_api["support_dygraph_mode"] is True
]
bw_apis = new_bw_apis
for bw_api in bw_apis:
dist_bw_api = DistBackwardAPI(bw_api)
header_file.write(dist_bw_api.gene_api_declaration())
if is_fused_backward_yaml is True:
source_file.write(dist_bw_api.gene_api_code())
else:
source_file.write(dist_bw_api.gene_api_code())
header_file.write(namespace[1])
source_file.write(namespace[1])
header_file.close()
source_file.close()
def main():
parser = argparse.ArgumentParser(
description='Generate PaddlePaddle C++ backward API files'
)
parser.add_argument(
'--backward_yaml_path',
help='path to backward yaml file',
nargs='+',
default=['paddle/phi/api/yaml/backward.yaml'],
)
parser.add_argument(
'--is_fused_backward_yaml',
help='flag of fused backward yaml',
action='store_true',
)
parser.add_argument(
'--backward_header_path',
help='output of generated backward header code file',
default='paddle/phi/api/backward/backward_api.h',
)
parser.add_argument(
'--backward_source_path',
help='output of generated backward source code file',
default='paddle/phi/api/lib/backward_api.cc',
)
options = parser.parse_args()
backward_yaml_path = options.backward_yaml_path
is_fused_backward_yaml = options.is_fused_backward_yaml
header_file_path = options.backward_header_path
source_file_path = options.backward_source_path
generate_backward_api(
backward_yaml_path,
is_fused_backward_yaml,
header_file_path,
source_file_path,
)
if __name__ == '__main__':
main()
......@@ -12,6 +12,7 @@ collect_srcs(
flags.cc
errors.cc
enforce.cc
storage_properties.cc
os_info.cc
kernel_context.cc
ddim.cc
......
......@@ -18,6 +18,9 @@
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
class DenseTensorUtils;
namespace distributed {
namespace auto_parallel {
......@@ -54,10 +57,10 @@ class DistTensor final
}
DistTensor(const std::shared_ptr<phi::DenseTensor>& dense_tensor,
const DenseTensorMeta& meta,
const std::shared_ptr<TensorDistAttr>& dist_attr)
: dist_attr_(dist_attr) {
: meta_(meta), dist_attr_(dist_attr) {
value_ = std::make_unique<DenseTensor>(*dense_tensor);
set_meta(dense_tensor->meta());
}
~DistTensor() = default;
......@@ -121,6 +124,8 @@ class DistTensor final
void set_meta(const DenseTensorMeta& meta);
private:
friend class phi::DenseTensorUtils;
DenseTensorMeta meta_;
std::shared_ptr<TensorDistAttr> dist_attr_{nullptr};
std::unique_ptr<DenseTensor> value_{nullptr};
......
......@@ -93,6 +93,7 @@ std::shared_ptr<DistTensor> RToSReshardFunction::Eval(
return std::make_shared<DistTensor>(
std::make_shared<DenseTensor>(out_physical_tensor_cur_rank),
out_physical_tensor_cur_rank.meta(),
out_dist_attr);
}
......
......@@ -66,7 +66,9 @@ std::shared_ptr<DistTensor> SToRReshardFunction::Eval(
dev_ctx, in_physical_tensor_cur_rank, in_process_ids);
return std::make_shared<DistTensor>(
std::make_shared<DenseTensor>(out_all_gather), out_dist_attr);
std::make_shared<DenseTensor>(out_all_gather),
out_all_gather.meta(),
out_dist_attr);
}
} // namespace distributed
......
......@@ -17,11 +17,15 @@ limitations under the License. */
#include "glog/logging.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/selected_rows.h"
#include "paddle/phi/core/string_tensor.h"
#include "paddle/phi/core/string_tensor_utils.h"
#include "paddle/phi/core/tensor_utils.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#endif
namespace phi {
......@@ -84,6 +88,12 @@ void MetaTensor::set_dims(const DDim& dims) {
} else if (phi::SparseCsrTensor::classof(tensor_)) {
DenseTensorUtils::GetMutableMeta(static_cast<SparseCsrTensor*>(tensor_))
->dims = dims;
#ifdef PADDLE_WITH_DISTRIBUTE
} else if (phi::distributed::DistTensor::classof(tensor_)) {
DenseTensorUtils::GetMutableMeta(
static_cast<distributed::DistTensor*>(tensor_))
->dims = dims;
#endif
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Unsupported setting dims for `%s`.", tensor_->type_info().name()));
......@@ -115,7 +125,12 @@ void MetaTensor::set_dtype(DataType dtype) {
} else if (phi::SparseCsrTensor::classof(tensor_)) {
DenseTensorUtils::GetMutableMeta(static_cast<SparseCsrTensor*>(tensor_))
->dtype = dtype;
// No need to set dtype
#ifdef PADDLE_WITH_DISTRIBUTE
} else if (phi::distributed::DistTensor::classof(tensor_)) {
DenseTensorUtils::GetMutableMeta(
static_cast<distributed::DistTensor*>(tensor_))
->dtype = dtype;
#endif
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Unsupported settting dtype for `%s`.", tensor_->type_info().name()));
......@@ -146,6 +161,12 @@ void MetaTensor::set_layout(DataLayout layout) {
} else if (phi::SparseCsrTensor::classof(tensor_)) {
DenseTensorUtils::GetMutableMeta(static_cast<SparseCsrTensor*>(tensor_))
->layout = layout;
#ifdef PADDLE_WITH_DISTRIBUTE
} else if (phi::distributed::DistTensor::classof(tensor_)) {
DenseTensorUtils::GetMutableMeta(
static_cast<distributed::DistTensor*>(tensor_))
->layout = layout;
#endif
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Unsupported settting layout for `%s`.", tensor_->type_info().name()));
......@@ -156,7 +177,11 @@ void MetaTensor::share_lod(const MetaTensor& meta_tensor) {
ValidCheck(*this);
ValidCheck(meta_tensor);
if (phi::SparseCooTensor::classof(tensor_) ||
phi::SparseCsrTensor::classof(tensor_)) {
phi::SparseCsrTensor::classof(tensor_)
#ifdef PADDLE_WITH_DISTRIBUTE
|| phi::distributed::DistTensor::classof(tensor_)
#endif
) {
return;
}
if (meta_tensor.lod().empty()) {
......@@ -182,7 +207,11 @@ void MetaTensor::share_meta(const MetaTensor& meta_tensor) {
if (phi::DenseTensor::classof(tensor_) ||
phi::SelectedRows::classof(tensor_) ||
phi::SparseCooTensor::classof(tensor_) ||
phi::SparseCsrTensor::classof(tensor_)) {
phi::SparseCsrTensor::classof(tensor_)
#ifdef PADDLE_WITH_DISTRIBUTE
|| phi::distributed::DistTensor::classof(tensor_)
#endif
) {
share_dims(meta_tensor);
set_dtype(meta_tensor.dtype());
set_layout(meta_tensor.layout());
......@@ -207,7 +236,12 @@ void MetaTensor::share_dims(const MetaTensor& meta_tensor) {
bool is_selected_rows = phi::SelectedRows::classof(tensor_);
bool is_sparse_coo = phi::SparseCooTensor::classof(tensor_);
bool is_sparse_csr = phi::SparseCsrTensor::classof(tensor_);
if (is_dense_tensor || is_selected_rows || is_sparse_coo || is_sparse_csr) {
bool is_dist_tensor = false;
#ifdef PADDLE_WITH_DISTRIBUTE
is_dist_tensor = phi::distributed::DistTensor::classof(tensor_);
#endif
if (is_dense_tensor || is_selected_rows || is_sparse_coo || is_sparse_csr ||
is_dist_tensor) {
if (is_selected_rows) {
const auto in_tensor_base = meta_tensor.tensor();
PADDLE_ENFORCE_EQ(
......
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/storage_properties.h"
namespace phi {
std::unique_ptr<StorageProperties> CopyStorageProperties(
const std::unique_ptr<StorageProperties>& sp) {
if (sp) {
if (NPUStorageProperties::classof(sp.get())) {
auto result = std::make_unique<NPUStorageProperties>();
result->storage_format =
static_cast<NPUStorageProperties*>(sp.get())->storage_format;
result->storage_dims =
static_cast<NPUStorageProperties*>(sp.get())->storage_dims;
return result;
#ifdef PADDLE_WITH_DNNL
} else if (OneDNNStorageProperties::classof(sp.get())) {
auto result = std::make_unique<OneDNNStorageProperties>();
result->format = static_cast<OneDNNStorageProperties*>(sp.get())->format;
result->mem_desc =
static_cast<OneDNNStorageProperties*>(sp.get())->mem_desc;
return result;
#endif
} else {
return nullptr;
}
}
return nullptr;
}
} // namespace phi
......@@ -28,11 +28,13 @@ namespace phi {
struct StorageProperties {
public:
virtual ~StorageProperties() = default;
TypeInfo<StorageProperties> type_info() const { return type_info_; }
private:
template <typename T, typename U>
friend class TypeInfoTraits;
TypeInfo<StorageProperties> type_info_{
TypeInfo<StorageProperties>::kUnknownType};
};
......@@ -70,29 +72,7 @@ struct OneDNNStorageProperties
};
#endif
static std::unique_ptr<StorageProperties> CopyStorageProperties(
const std::unique_ptr<StorageProperties>& sp) {
if (sp) {
if (NPUStorageProperties::classof(sp.get())) {
auto result = std::make_unique<NPUStorageProperties>();
result->storage_format =
static_cast<NPUStorageProperties*>(sp.get())->storage_format;
result->storage_dims =
static_cast<NPUStorageProperties*>(sp.get())->storage_dims;
return result;
#ifdef PADDLE_WITH_DNNL
} else if (OneDNNStorageProperties::classof(sp.get())) {
auto result = std::make_unique<OneDNNStorageProperties>();
result->format = static_cast<OneDNNStorageProperties*>(sp.get())->format;
result->mem_desc =
static_cast<OneDNNStorageProperties*>(sp.get())->mem_desc;
return result;
#endif
} else {
return nullptr;
}
}
return nullptr;
}
std::unique_ptr<StorageProperties> CopyStorageProperties(
const std::unique_ptr<StorageProperties>& sp);
} // namespace phi
......@@ -21,8 +21,16 @@ limitations under the License. */
#include "paddle/phi/core/sparse_csr_tensor.h"
#include "paddle/phi/core/tensor_array.h"
#include "paddle/phi/core/tensor_meta.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#endif
namespace phi {
// TODO(chenweihang): DenseTensorUtils has been abused during the development
// process, and now its semantics are incorrect. It can not only operate
// DenseTensors, but also other types of Tensors, requiring renaming or
// splitting
class DenseTensorUtils {
public:
static DenseTensorMeta* GetMutableMeta(DenseTensor* tensor) {
......@@ -37,6 +45,12 @@ class DenseTensorUtils {
return &(tensor->meta_);
}
#ifdef PADDLE_WITH_DISTRIBUTE
static DenseTensorMeta* GetMutableMeta(distributed::DistTensor* tensor) {
return &(tensor->meta_);
}
#endif
static const std::shared_ptr<phi::Allocation>& GetHolder(
const DenseTensor& tensor) {
return tensor.holder_;
......
......@@ -18,6 +18,7 @@ import numpy as np
import paddle
import paddle.distributed as dist
import paddle.nn.functional as F
class TestDistTensor(unittest.TestCase):
......@@ -52,5 +53,36 @@ class TestDistTensor(unittest.TestCase):
self.assertEqual(dist_tensor_with_tensor.dist_attr, dist_attr)
class TestDistTensorForDygraphAPI(unittest.TestCase):
def check_tensor_eq(self, a, b):
np1 = a.numpy()
np2 = b.numpy()
np.testing.assert_allclose(np1, np2, rtol=1e-05)
def create_local_and_dist_tensor_pair(self, np_array):
local_t = paddle.to_tensor(np_array, dtype='float32')
mesh = dist.ProcessMesh([0], dim_names=["x"])
dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=[None, None])
dist_t = dist.shard_tensor(np_array, dist_attr=dist_attr)
local_t.stop_gradient = False
dist_t.stop_gradient = False
return local_t, dist_t
def test_relu_api_for_dist_tensor(self):
x = np.random.random(size=[4, 4]).astype("float32")
local_in, dist_in = self.create_local_and_dist_tensor_pair(x)
local_out = F.relu(local_in)
dist_out = F.relu(dist_in)
self.check_tensor_eq(local_out, dist_out)
# test backward
local_out.backward()
dist_out.backward()
self.check_tensor_eq(local_in.grad, dist_in.grad)
if __name__ == "__main__":
unittest.main()
......@@ -47,7 +47,7 @@ TEST(dist_tensor, constructor) {
EXPECT_TRUE(x3.initialized());
auto a = std::make_shared<DenseTensor>(alloc, DenseTensorMeta(dtype, dims));
DistTensor x4(a, dist_attr);
DistTensor x4(a, a->meta(), dist_attr);
EXPECT_TRUE(x4.defined());
EXPECT_TRUE(x4.initialized());
}
......
......@@ -54,7 +54,9 @@ std::shared_ptr<DistTensor> ConstructReplicatedDistCPU(
dist_attr->set_process_mesh(mesh);
return std::make_shared<DistTensor>(
std::make_shared<DenseTensor>(input_dense), dist_attr);
std::make_shared<DenseTensor>(input_dense),
input_dense.meta(),
dist_attr);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
......@@ -87,7 +89,9 @@ std::shared_ptr<DistTensor> ConstructReplicatedDistGPU(
dist_attr->set_process_mesh(mesh);
return std::make_shared<DistTensor>(
std::make_shared<DenseTensor>(input_dense_gpu), dist_attr);
std::make_shared<DenseTensor>(input_dense_gpu),
input_dense_gpu.meta(),
dist_attr);
}
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册