未验证 提交 14006e96 编写于 作者: R ronnywang 提交者: GitHub

[PHI CAPI] Add support for registering a new operator, PART2 (#55533)

上级 de3e9c30
...@@ -38,7 +38,11 @@ limitations under the License. */ ...@@ -38,7 +38,11 @@ limitations under the License. */
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
#include "paddle/utils/any.h" #include "paddle/utils/any.h"
#ifdef PADDLE_WITH_CUSTOM_DEVICE #ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/backends/device_manager.h" #include "paddle/phi/backends/device_manager.h"
#include "paddle/phi/capi/include/c_infer_meta_context.h"
#include "paddle/phi/capi/include/c_kernel_registry.h"
#include "paddle/phi/capi/include/c_meta_tensor.h"
#endif #endif
#include "paddle/phi/api/include/operants_manager.h" #include "paddle/phi/api/include/operants_manager.h"
...@@ -1226,3 +1230,112 @@ LoadOpMetaInfoAndRegisterOp(const std::string& dso_name) { ...@@ -1226,3 +1230,112 @@ LoadOpMetaInfoAndRegisterOp(const std::string& dso_name) {
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
#ifdef PADDLE_WITH_CUSTOM_DEVICE
void PD_RegisterOperator(const char* kernel_name_cstr,
size_t in_nargs,
PD_KernelArgumentType* in_args_type,
size_t attr_nargs,
PD_KernelArgumentType* attr_args_type,
size_t out_nargs,
PD_KernelArgumentType* out_args_type,
void (*infer_shape_fn)(PD_InferMetaContext*)) {
std::string kernel_name(kernel_name_cstr);
if (infer_shape_fn &&
!paddle::framework::OpInfoMap::Instance().Has(kernel_name)) {
VLOG(8) << "Registering a new operator: " << kernel_name;
std::vector<std::string> op_inputs, op_outputs, op_attrs;
for (size_t i = 0; i < in_nargs; ++i) {
if (in_args_type[i] == PD_KernelArgumentType::PD_ARG_TYPE_TENSOR) {
op_inputs.push_back("Input_" + std::to_string(i));
} else if (in_args_type[i] ==
PD_KernelArgumentType::PD_ARG_TYPE_LIST_TENSOR) {
op_inputs.push_back("Input_" + std::to_string(i) +
paddle::kTensorVectorSuffix);
} else if (in_args_type[i] ==
PD_KernelArgumentType::PD_ARG_TYPE_OPTIONAL_TENSOR) {
op_inputs.push_back("Input_" + std::to_string(i) +
paddle::kOptionalSuffix);
} else {
op_inputs.push_back("Input_unknown");
}
}
for (size_t i = 0; i < out_nargs; ++i) {
if (out_args_type[i] == PD_KernelArgumentType::PD_ARG_TYPE_TENSOR) {
op_outputs.push_back("Output_" + std::to_string(i));
} else if (out_args_type[i] ==
PD_KernelArgumentType::PD_ARG_TYPE_LIST_TENSOR) {
op_outputs.push_back("Output_" + std::to_string(i) +
paddle::kTensorVectorSuffix);
} else {
op_outputs.push_back("Output_unknown");
}
}
for (size_t i = 0; i < attr_nargs; ++i) {
auto attr_type = attr_args_type[i];
if (attr_type == PD_KernelArgumentType::PD_ARG_TYPE_BOOL) {
op_attrs.push_back("Attr_" + std::to_string(i) + ":bool");
} else if (attr_type == PD_KernelArgumentType::PD_ARG_TYPE_INT32) {
op_attrs.push_back("Attr_" + std::to_string(i) + ":int");
} else if (attr_type == PD_KernelArgumentType::PD_ARG_TYPE_FLOAT32) {
op_attrs.push_back("Attr_" + std::to_string(i) + ":float");
} else if (attr_type == PD_KernelArgumentType::PD_ARG_TYPE_INT64) {
op_attrs.push_back("Attr_" + std::to_string(i) + ":int64_t");
} else if (attr_type == PD_KernelArgumentType::PD_ARG_TYPE_STRING) {
op_attrs.push_back("Attr_" + std::to_string(i) + ":std::string");
} else if (attr_type == PD_KernelArgumentType::PD_ARG_TYPE_LIST_INT32) {
op_attrs.push_back("Attr_" + std::to_string(i) + ":std::vector<int>");
} else if (attr_type == PD_KernelArgumentType::PD_ARG_TYPE_LIST_FLOAT32) {
op_attrs.push_back("Attr_" + std::to_string(i) + ":std::vector<float>");
} else if (attr_type == PD_KernelArgumentType::PD_ARG_TYPE_LIST_INT64) {
op_attrs.push_back("Attr_" + std::to_string(i) +
":std::vector<int64_t>");
} else if (attr_type == PD_KernelArgumentType::PD_ARG_TYPE_LIST_STRING) {
op_attrs.push_back("Attr_" + std::to_string(i) +
":std::vector<std::string>");
} else {
op_attrs.push_back("Attr_unknown");
}
}
paddle::framework::OpInfo info;
// Op
info.creator_ = [](const std::string& op_name,
const paddle::framework::VariableNameMap& inputs,
const paddle::framework::VariableNameMap& outputs,
const paddle::framework::AttributeMap& attrs) {
return new paddle::framework::OperatorWithKernel(
op_name, inputs, outputs, attrs);
};
// OpMaker
info.proto_ = new paddle::framework::proto::OpProto;
info.proto_->set_type(kernel_name);
info.checker_ = new paddle::framework::OpAttrChecker();
paddle::framework::CustomOpMaker custom_maker(
op_inputs, op_outputs, op_attrs);
custom_maker(info.proto_, info.checker_);
PADDLE_ENFORCE_EQ(
info.proto_->IsInitialized(),
true,
phi::errors::PreconditionNotMet(
"Fail to initialize %s's OpProto, because %s is not initialized.",
kernel_name,
info.proto_->InitializationErrorString()));
info.infer_shape_ = [infer_shape_fn, kernel_name](
paddle::framework::InferShapeContext* ctx) {
auto infer_meta_context =
paddle::framework::BuildInferMetaContext(ctx, kernel_name);
infer_shape_fn(
reinterpret_cast<PD_InferMetaContext*>(&infer_meta_context));
};
paddle::framework::OpInfoMap::Instance().Insert(kernel_name, info);
}
}
#endif
...@@ -17,10 +17,12 @@ ...@@ -17,10 +17,12 @@
#include "paddle/phi/capi/include/c_data_type.h" #include "paddle/phi/capi/include/c_data_type.h"
#include "paddle/phi/capi/include/c_device_context.h" #include "paddle/phi/capi/include/c_device_context.h"
#include "paddle/phi/capi/include/c_infer_meta_context.h"
#include "paddle/phi/capi/include/c_int_array.h" #include "paddle/phi/capi/include/c_int_array.h"
#include "paddle/phi/capi/include/c_kernel_context.h" #include "paddle/phi/capi/include/c_kernel_context.h"
#include "paddle/phi/capi/include/c_kernel_factory.h" #include "paddle/phi/capi/include/c_kernel_factory.h"
#include "paddle/phi/capi/include/c_kernel_registry.h" #include "paddle/phi/capi/include/c_kernel_registry.h"
#include "paddle/phi/capi/include/c_meta_tensor.h"
#include "paddle/phi/capi/include/c_place.h" #include "paddle/phi/capi/include/c_place.h"
#include "paddle/phi/capi/include/c_scalar.h" #include "paddle/phi/capi/include/c_scalar.h"
#include "paddle/phi/capi/include/c_tensor.h" #include "paddle/phi/capi/include/c_tensor.h"
......
...@@ -23,6 +23,8 @@ PD_DECLARE_CAPI(int_array); ...@@ -23,6 +23,8 @@ PD_DECLARE_CAPI(int_array);
PD_DECLARE_CAPI(kernel_context); PD_DECLARE_CAPI(kernel_context);
PD_DECLARE_CAPI(kernel_factory); PD_DECLARE_CAPI(kernel_factory);
PD_DECLARE_CAPI(kernel_registry); PD_DECLARE_CAPI(kernel_registry);
PD_DECLARE_CAPI(infer_meta_context);
PD_DECLARE_CAPI(meta_tensor);
PD_DECLARE_CAPI(place); PD_DECLARE_CAPI(place);
PD_DECLARE_CAPI(scalar); PD_DECLARE_CAPI(scalar);
PD_DECLARE_CAPI(tensor); PD_DECLARE_CAPI(tensor);
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <vector> #include <vector>
#include "paddle/phi/capi/include/c_data_type.h" #include "paddle/phi/capi/include/c_data_type.h"
#include "paddle/phi/capi/include/c_infer_meta_context.h"
#include "paddle/phi/capi/include/c_kernel_context.h" #include "paddle/phi/capi/include/c_kernel_context.h"
#include "paddle/phi/capi/include/c_kernel_factory.h" #include "paddle/phi/capi/include/c_kernel_factory.h"
...@@ -71,6 +72,15 @@ void PD_RegisterPhiKernel(const char *kernel_name_cstr, ...@@ -71,6 +72,15 @@ void PD_RegisterPhiKernel(const char *kernel_name_cstr,
void (*fn)(PD_KernelContext *), void (*fn)(PD_KernelContext *),
void *variadic_kernel_fn); void *variadic_kernel_fn);
void PD_RegisterOperator(const char *kernel_name_cstr,
size_t in_nargs,
PD_KernelArgumentType *in_args_type,
size_t attr_nargs,
PD_KernelArgumentType *attr_args_type,
size_t out_nargs,
PD_KernelArgumentType *out_args_type,
void (*infer_shape_fn)(PD_InferMetaContext *));
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
......
...@@ -187,6 +187,30 @@ inline std::vector<phi::capi::DenseTensor> PD_MultiOutputAt( ...@@ -187,6 +187,30 @@ inline std::vector<phi::capi::DenseTensor> PD_MultiOutputAt(
return ret; return ret;
} }
inline std::vector<phi::capi::MetaTensor> PD_InferMetaMultiInputAt(
PD_InferMetaContext *ctx, size_t index) {
std::vector<phi::capi::MetaTensor> ret;
auto list = PD_InferMetaContextMultiInputAt(ctx, index);
auto data = reinterpret_cast<PD_MetaTensor **>(list.data);
for (size_t i = 0; i < list.size; ++i) {
ret.emplace_back(data[i]);
}
PD_DeletePointerList(list);
return ret;
}
inline std::vector<phi::capi::MetaTensor> PD_InferMetaMultiOutputAt(
PD_InferMetaContext *ctx, size_t index) {
std::vector<phi::capi::MetaTensor> ret;
auto list = PD_InferMetaContextMultiOutputAt(ctx, index);
auto data = reinterpret_cast<PD_MetaTensor **>(list.data);
for (size_t i = 0; i < list.size; ++i) {
ret.emplace_back(data[i]);
}
PD_DeletePointerList(list);
return ret;
}
template <typename T> template <typename T>
inline std::vector<T *> PD_GetPointerVector(std::vector<T> *vec) { inline std::vector<T *> PD_GetPointerVector(std::vector<T> *vec) {
std::vector<T *> ret; std::vector<T *> ret;
...@@ -336,6 +360,152 @@ inline std::vector<bool> PD_AttrAt<std::vector<bool>>(PD_KernelContext *ctx, ...@@ -336,6 +360,152 @@ inline std::vector<bool> PD_AttrAt<std::vector<bool>>(PD_KernelContext *ctx,
return list; return list;
} }
template <typename T>
inline T PD_InferMetaAttrAt(PD_InferMetaContext *ctx, size_t index);
template <>
inline bool PD_InferMetaAttrAt<bool>(PD_InferMetaContext *ctx, size_t index) {
return PD_InferMetaContextBoolAttrAt(ctx, index);
}
template <>
inline int32_t PD_InferMetaAttrAt<int32_t>(PD_InferMetaContext *ctx,
size_t index) {
return PD_InferMetaContextInt32AttrAt(ctx, index);
}
template <>
inline int64_t PD_InferMetaAttrAt<int64_t>(PD_InferMetaContext *ctx,
size_t index) {
return PD_InferMetaContextInt64AttrAt(ctx, index);
}
template <>
inline float PD_InferMetaAttrAt<float>(PD_InferMetaContext *ctx, size_t index) {
return PD_InferMetaContextFloatAttrAt(ctx, index);
}
template <>
inline double PD_InferMetaAttrAt<double>(PD_InferMetaContext *ctx,
size_t index) {
return PD_InferMetaContextDoubleAttrAt(ctx, index);
}
template <>
inline std::string PD_InferMetaAttrAt<std::string>(PD_InferMetaContext *ctx,
size_t index) {
return PD_InferMetaContextStringAttrAt(ctx, index);
}
template <>
inline PD_DataType PD_InferMetaAttrAt<PD_DataType>(PD_InferMetaContext *ctx,
size_t index) {
return PD_InferMetaContextDataTypeAttrAt(ctx, index);
}
template <>
inline PD_DataLayout PD_InferMetaAttrAt<PD_DataLayout>(PD_InferMetaContext *ctx,
size_t index) {
return PD_InferMetaContextDataLayoutAttrAt(ctx, index);
}
template <>
inline std::vector<int32_t> PD_InferMetaAttrAt<std::vector<int32_t>>(
PD_InferMetaContext *ctx, size_t index) {
auto list = PD_InferMetaContextListInt32AttrAt(ctx, index);
auto data = reinterpret_cast<int32_t *>(list.data);
std::vector<int32_t> cc_list(data, data + list.size);
return cc_list;
}
template <>
inline std::vector<int64_t> PD_InferMetaAttrAt<std::vector<int64_t>>(
PD_InferMetaContext *ctx, size_t index) {
auto list = PD_InferMetaContextListInt64AttrAt(ctx, index);
auto data = reinterpret_cast<int64_t *>(list.data);
std::vector<int64_t> cc_list(data, data + list.size);
return cc_list;
}
template <>
inline std::vector<float> PD_InferMetaAttrAt<std::vector<float>>(
PD_InferMetaContext *ctx, size_t index) {
auto list = PD_InferMetaContextListFloatAttrAt(ctx, index);
auto data = reinterpret_cast<float *>(list.data);
std::vector<float> cc_list(data, data + list.size);
return cc_list;
}
template <>
inline std::vector<double> PD_InferMetaAttrAt<std::vector<double>>(
PD_InferMetaContext *ctx, size_t index) {
auto list = PD_InferMetaContextListDoubleAttrAt(ctx, index);
auto data = reinterpret_cast<double *>(list.data);
std::vector<double> cc_list(data, data + list.size);
return cc_list;
}
template <>
inline phi::capi::Scalar PD_InferMetaAttrAt<phi::capi::Scalar>(
PD_InferMetaContext *ctx, size_t index) {
auto scalar = PD_InferMetaContextScalarAttrAt(ctx, index);
return phi::capi::Scalar(scalar);
}
template <>
inline phi::capi::IntArray PD_InferMetaAttrAt<phi::capi::IntArray>(
PD_InferMetaContext *ctx, size_t index) {
auto int_array = PD_InferMetaContextIntArrayAttrAt(ctx, index);
return phi::capi::IntArray(int_array);
}
template <>
inline phi::capi::Place PD_InferMetaAttrAt<phi::capi::Place>(
PD_InferMetaContext *ctx, size_t index) {
auto place = PD_InferMetaContextPlaceAttrAt(ctx, index);
return phi::capi::Place(place);
}
template <>
inline std::vector<phi::capi::Scalar>
PD_InferMetaAttrAt<std::vector<phi::capi::Scalar>>(PD_InferMetaContext *ctx,
size_t index) {
auto c_list = PD_InferMetaContextListScalarAttrAt(ctx, index);
auto data = reinterpret_cast<PD_Scalar **>(c_list.data);
std::vector<phi::capi::Scalar> list;
for (size_t i = 0; i < c_list.size; ++i) {
list.emplace_back(data[i]);
}
PD_DeletePointerList(c_list);
return list;
}
template <>
inline std::vector<std::string> PD_InferMetaAttrAt<std::vector<std::string>>(
PD_InferMetaContext *ctx, size_t index) {
auto c_list = PD_InferMetaContextListStringAttrAt(ctx, index);
auto data = reinterpret_cast<char **>(c_list.data);
std::vector<std::string> list;
for (size_t i = 0; i < c_list.size; ++i) {
list.emplace_back(data[i]);
}
PD_DeletePointerList(c_list);
return list;
}
template <>
inline std::vector<bool> PD_InferMetaAttrAt<std::vector<bool>>(
PD_InferMetaContext *ctx, size_t index) {
auto c_list = PD_InferMetaContextListBoolAttrAt(ctx, index);
std::vector<bool> list;
auto data = reinterpret_cast<uint8_t *>(c_list.data);
for (size_t i = 0; i < c_list.size; ++i) {
list[i] = static_cast<bool>(data[i]);
}
PD_DeleteUInt8List(c_list);
return list;
}
#define CPP_TYPE_TO_PD_ARG_TYPE_REGISTER(_) \ #define CPP_TYPE_TO_PD_ARG_TYPE_REGISTER(_) \
_(phi::capi::DenseTensor, ::PD_KernelArgumentType::PD_ARG_TYPE_TENSOR) \ _(phi::capi::DenseTensor, ::PD_KernelArgumentType::PD_ARG_TYPE_TENSOR) \
_(phi::capi::DeviceContext, ::PD_KernelArgumentType::PD_ARG_TYPE_CONTEXT) \ _(phi::capi::DeviceContext, ::PD_KernelArgumentType::PD_ARG_TYPE_CONTEXT) \
...@@ -391,13 +561,82 @@ using IntArray = capi::IntArray; ...@@ -391,13 +561,82 @@ using IntArray = capi::IntArray;
using Place = capi::Place; using Place = capi::Place;
using DataType = ::PD_DataType; using DataType = ::PD_DataType;
using DataLayout = ::PD_DataLayout; using DataLayout = ::PD_DataLayout;
using DenseTensor = capi::DenseTensor;
using MetaTensor = capi::MetaTensor;
} // namespace phi } // namespace phi
#include "paddle/phi/capi/include/kernel_utils.h" #include "paddle/phi/capi/include/kernel_utils.h"
// clang-format off // clang-format off
#define PD_BUILD_NEW_PHI_KERNEL(kernel_name, \
backend, \
layout, \
meta_kernel_fn, \
infer_shape_fn, \
...) \
static void \
__CUSTOM_adefs_CFN_##kernel_name##_##backend##_##layout( \
const PD_KernelKey* kernel_key, PD_Kernel* kernel); \
template <typename kernel_type> \
struct __##kernel_name##_##backend##_##layout##__ { \
__##kernel_name##_##backend##_##layout##__() { \
::phi::capi::CustomKernelArgsParseFunctor<decltype( \
&meta_kernel_fn<kernel_type>)> \
parser; \
PD_RegisterOperator(#kernel_name, \
parser.in_args_type.size(), \
parser.in_args_type.data(), \
parser.attr_args_type.size(), \
parser.attr_args_type.data(), \
parser.out_args_type.size(), \
parser.out_args_type.data(), \
PHI_CAPI_INFER_META(infer_shape_fn)); \
PD_RegisterPhiKernel( \
#kernel_name, \
#backend, \
::phi::capi::CppTypeToPDType<kernel_type>::Type(), \
PD_DATALAYOUT(layout), \
parser.in_args_type.size(), \
parser.in_args_type.data(), \
parser.attr_args_type.size(), \
parser.attr_args_type.data(), \
parser.out_args_type.size(), \
parser.out_args_type.data(), \
__CUSTOM_adefs_CFN_##kernel_name##_##backend##_##layout, \
CUSTOM_PHI_KERNEL(meta_kernel_fn<kernel_type>), \
CUSTOM_PHI_VARIADIC_KERNEL( \
meta_kernel_fn<kernel_type>)); \
} \
static void Touch() {} \
}; \
PD_CUSTOM_PHI_KERNEL_STATIC_ASSERT_GLOBAL_NAMESPACE( \
CUSTOM_tp_ns_check_##kernel_name##_##backend##_##layout, \
"PD_BUILD_KERNEL must be called in global namespace."); \
static void \
__CUSTOM_adefs_FN_##kernel_name##_##backend##_##layout( \
const ::phi::capi::KernelKey &kernel_key, \
::phi::capi::Kernel* kernel); \
_PD_BUILD_PHI_KERNEL(__##kernel_name##_##backend##_##layout##__, \
kernel_name, \
backend, \
layout, \
meta_kernel_fn, \
__VA_ARGS__) \
void \
__CUSTOM_adefs_CFN_##kernel_name##_##backend##_##layout( \
const PD_KernelKey* kernel_key, PD_Kernel* kernel) { \
auto cc_kernel = ::phi::capi::Kernel(kernel); \
__CUSTOM_adefs_FN_##kernel_name##_##backend##_##layout( \
::phi::capi::KernelKey( \
const_cast<PD_KernelKey*>(kernel_key)), \
&cc_kernel); \
} \
void \
__CUSTOM_adefs_FN_##kernel_name##_##backend##_##layout( \
const ::phi::capi::KernelKey &kernel_key, \
::phi::capi::Kernel* kernel)
#define PD_BUILD_PHI_KERNEL(kernel_name, \ #define PD_BUILD_PHI_KERNEL(kernel_name, \
backend, \ backend, \
layout, \ layout, \
......
...@@ -24,6 +24,9 @@ namespace capi { ...@@ -24,6 +24,9 @@ namespace capi {
#define CUSTOM_PHI_KERNEL(...) \ #define CUSTOM_PHI_KERNEL(...) \
::phi::capi::CustomKernelImpl<decltype(&__VA_ARGS__), &__VA_ARGS__>::Compute ::phi::capi::CustomKernelImpl<decltype(&__VA_ARGS__), &__VA_ARGS__>::Compute
#define PHI_CAPI_INFER_META(...) \
::phi::capi::InferMetaFnImpl<decltype(&__VA_ARGS__), &__VA_ARGS__>::Call
#define CUSTOM_PHI_VARIADIC_KERNEL(...) \ #define CUSTOM_PHI_VARIADIC_KERNEL(...) \
reinterpret_cast<void *>( \ reinterpret_cast<void *>( \
&::phi::capi::CustomKernelImpl<decltype(&__VA_ARGS__), \ &::phi::capi::CustomKernelImpl<decltype(&__VA_ARGS__), \
...@@ -909,6 +912,151 @@ struct CustomKernelArgsParseFunctor<Return_ (*)(Args_...)> { ...@@ -909,6 +912,151 @@ struct CustomKernelArgsParseFunctor<Return_ (*)(Args_...)> {
} }
}; };
#define PD_SPECIALIZE_CAPI_InferMetaFnCallHelper_FOR_ATTRIBUTE(attr_type) \
template <typename... Tail> \
struct InferMetaFnCallHelper<attr_type, Tail...> { \
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs> \
static void Call(PD_InferMetaContext *ctx, PreviousArgs &...pargs) { \
static_assert(out_idx == 0, \
"InferMeta's Attributes should appear before Outputs."); \
attr_type arg = PD_InferMetaAttrAt<attr_type>(ctx, attr_idx); \
InferMetaFnCallHelper< \
Tail...>::template Call<in_idx, attr_idx + 1, out_idx>(ctx, \
pargs..., \
arg); \
} \
}
#define PD_SPECIALIZE_CAPI_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF( \
attr_type) \
template <typename... Tail> \
struct InferMetaFnCallHelper<const attr_type &, Tail...> { \
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs> \
static void Call(PD_InferMetaContext *ctx, PreviousArgs &...pargs) { \
static_assert(out_idx == 0, \
"InferMeta's Attributes should appear before Outputs."); \
attr_type arg = PD_InferMetaAttrAt<attr_type>(ctx, attr_idx); \
InferMetaFnCallHelper< \
Tail...>::template Call<in_idx, attr_idx + 1, out_idx>(ctx, \
pargs..., \
arg); \
} \
}
template <typename T>
struct InferMetaTypeTag {};
template <typename Fn, Fn fn>
struct InferMetaFnImpl;
template <typename Return, typename... Args, Return (*infer_meta_fn)(Args...)>
struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> {
static void Call(PD_InferMetaContext *ctx) {
InferMetaFnCallHelper<Args...,
InferMetaTypeTag<int>>::template Call<0, 0, 0>(ctx);
}
private:
template <typename... RemainingArgs>
struct InferMetaFnCallHelper;
template <typename... Tail>
struct InferMetaFnCallHelper<const MetaTensor &, Tail...> {
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
static void Call(PD_InferMetaContext *ctx, PreviousArgs &...pargs) {
static_assert(attr_idx == 0,
"InferMeta's Input should appear before Attributes.");
static_assert(out_idx == 0,
"InferMeta's Input should appear before Outputs.");
auto arg = MetaTensor(PD_InferMetaContextInputAt(ctx, in_idx));
InferMetaFnCallHelper<
Tail...>::template Call<in_idx + 1, attr_idx, out_idx>(ctx,
pargs...,
arg);
}
};
template <typename... Tail>
struct InferMetaFnCallHelper<const std::vector<const MetaTensor *> &,
Tail...> {
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
static void Call(PD_InferMetaContext *ctx, PreviousArgs &...pargs) {
static_assert(attr_idx == 0,
"InferMeta's Input should appear before Attributes.");
static_assert(out_idx == 0,
"InferMeta's Input should appear before Outputs.");
auto arg = PD_InferMetaMultiInputAt(ctx, in_idx);
std::vector<const MetaTensor *> tensor_ptr_vec;
for (auto &tensor : arg) {
tensor_ptr_vec.push_back(tensor.raw_data() ? &tensor : nullptr);
}
InferMetaFnCallHelper<Tail...>::
template Call<in_idx + 1, attr_idx, out_idx>(
ctx, pargs..., tensor_ptr_vec);
}
};
PD_SPECIALIZE_CAPI_InferMetaFnCallHelper_FOR_ATTRIBUTE(bool);
PD_SPECIALIZE_CAPI_InferMetaFnCallHelper_FOR_ATTRIBUTE(int);
PD_SPECIALIZE_CAPI_InferMetaFnCallHelper_FOR_ATTRIBUTE(int64_t);
PD_SPECIALIZE_CAPI_InferMetaFnCallHelper_FOR_ATTRIBUTE(float);
PD_SPECIALIZE_CAPI_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataType);
// PD_SPECIALIZE_CAPI_InferMetaFnCallHelper_FOR_ATTRIBUTE(Backend);
PD_SPECIALIZE_CAPI_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataLayout);
PD_SPECIALIZE_CAPI_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(std::string);
PD_SPECIALIZE_CAPI_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(Scalar);
PD_SPECIALIZE_CAPI_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(IntArray);
PD_SPECIALIZE_CAPI_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(
std::vector<bool>);
PD_SPECIALIZE_CAPI_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(
std::vector<int>);
PD_SPECIALIZE_CAPI_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(
std::vector<int64_t>);
PD_SPECIALIZE_CAPI_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(
std::vector<float>);
PD_SPECIALIZE_CAPI_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(
std::vector<double>);
PD_SPECIALIZE_CAPI_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(
std::vector<std::string>);
template <typename... Tail>
struct InferMetaFnCallHelper<MetaTensor *, Tail...> {
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
static void Call(PD_InferMetaContext *ctx, PreviousArgs &...pargs) {
auto arg = MetaTensor(PD_InferMetaContextOutputAt(ctx, out_idx));
auto *arg_ptr = &arg;
InferMetaFnCallHelper<
Tail...>::template Call<in_idx, attr_idx, out_idx + 1>(ctx,
pargs...,
arg_ptr);
}
};
template <typename... Tail>
struct InferMetaFnCallHelper<std::vector<MetaTensor *>, Tail...> {
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
static void Call(PD_InferMetaContext *ctx, PreviousArgs &...pargs) {
auto arg = PD_InferMetaMultiOutputAt(ctx, out_idx);
std::vector<MetaTensor *> tensor_ptr_vec;
for (auto &tensor : arg) {
tensor_ptr_vec.push_back(tensor.raw_data() ? &tensor : nullptr);
}
InferMetaFnCallHelper<Tail...>::
template Call<in_idx, attr_idx, out_idx + 1>(
ctx, pargs..., tensor_ptr_vec);
}
};
/* End case */
template <typename T>
struct InferMetaFnCallHelper<InferMetaTypeTag<T>> {
template <int in_idx, int attr_idx, int out_idx>
static void Call(PD_InferMetaContext *ctx, Args &...args) {
return infer_meta_fn(args...);
}
};
};
} // namespace capi } // namespace capi
} // namespace phi } // namespace phi
......
...@@ -7,6 +7,8 @@ collect_srcs( ...@@ -7,6 +7,8 @@ collect_srcs(
c_kernel_context.cc c_kernel_context.cc
c_kernel_factory.cc c_kernel_factory.cc
c_kernel_registry.cc c_kernel_registry.cc
c_infer_meta_context.cc
c_meta_tensor.cc
c_place.cc c_place.cc
c_scalar.cc c_scalar.cc
c_tensor.cc) c_tensor.cc)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/capi/include/c_infer_meta_context.h"
#include "paddle/phi/capi/include/common.h"
#include "paddle/phi/capi/include/type_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
PD_MetaTensor* PD_InferMetaContextInputAt(PD_InferMetaContext* ctx,
size_t index) {
auto* meta_ctx = reinterpret_cast<phi::InferMetaContext*>(ctx);
const std::pair<int, int> range = meta_ctx->InputRangeAt(index);
const phi::MetaTensor& arg = meta_ctx->InputAt(range.first);
return reinterpret_cast<PD_MetaTensor*>(const_cast<phi::MetaTensor*>(&arg));
}
PD_List PD_InferMetaContextMultiInputAt(PD_InferMetaContext* ctx,
size_t index) {
auto* meta_ctx = reinterpret_cast<phi::InferMetaContext*>(ctx);
const std::pair<int, int> range = meta_ctx->InputRangeAt(index);
std::vector<const phi::MetaTensor*> tensor_vec =
meta_ctx->InputsBetween(range.first, range.second);
PD_List list;
list.size = tensor_vec.size();
list.data = new void*[list.size];
for (size_t i = 0; i < list.size; ++i) {
(reinterpret_cast<void**>(list.data))[i] =
reinterpret_cast<void*>(const_cast<phi::MetaTensor*>(tensor_vec[i]));
}
return list;
}
PD_MetaTensor* PD_InferMetaContextOutputAt(PD_InferMetaContext* ctx,
size_t index) {
auto* meta_ctx = reinterpret_cast<phi::InferMetaContext*>(ctx);
const std::pair<int, int> range = meta_ctx->OutputRangeAt(index);
phi::MetaTensor* arg = meta_ctx->MutableOutputAt(range.first);
return reinterpret_cast<PD_MetaTensor*>(arg);
}
PD_List PD_InferMetaContextMultiOutputAt(PD_InferMetaContext* ctx,
size_t index) {
auto* meta_ctx = reinterpret_cast<phi::InferMetaContext*>(ctx);
const std::pair<int, int> range = meta_ctx->OutputRangeAt(index);
std::vector<phi::MetaTensor*> tensor_vec =
meta_ctx->MutableOutputBetween(range.first, range.second);
PD_List list;
list.size = tensor_vec.size();
list.data = new void*[list.size];
for (size_t i = 0; i < list.size; ++i) {
(reinterpret_cast<void**>(list.data))[i] =
reinterpret_cast<void*>(tensor_vec[i]);
}
return list;
}
bool PD_InferMetaContextBoolAttrAt(PD_InferMetaContext* ctx, size_t index) {
auto meta_ctx = reinterpret_cast<phi::InferMetaContext*>(ctx);
return meta_ctx->AttrAt<bool>(index);
}
int32_t PD_InferMetaContextInt32AttrAt(PD_InferMetaContext* ctx, size_t index) {
auto meta_ctx = reinterpret_cast<phi::InferMetaContext*>(ctx);
return meta_ctx->AttrAt<int32_t>(index);
}
int64_t PD_InferMetaContextInt64AttrAt(PD_InferMetaContext* ctx, size_t index) {
auto meta_ctx = reinterpret_cast<phi::InferMetaContext*>(ctx);
return meta_ctx->AttrAt<int64_t>(index);
}
float PD_InferMetaContextFloatAttrAt(PD_InferMetaContext* ctx, size_t index) {
auto meta_ctx = reinterpret_cast<phi::InferMetaContext*>(ctx);
return meta_ctx->AttrAt<float>(index);
}
double PD_InferMetaContextDoubleAttrAt(PD_InferMetaContext* ctx, size_t index) {
auto meta_ctx = reinterpret_cast<phi::InferMetaContext*>(ctx);
return meta_ctx->AttrAt<double>(index);
}
PD_Scalar* PD_InferMetaContextScalarAttrAt(PD_InferMetaContext* ctx,
size_t index) {
auto meta_ctx = reinterpret_cast<phi::InferMetaContext*>(ctx);
return reinterpret_cast<PD_Scalar*>(
const_cast<phi::Scalar*>(&meta_ctx->AttrAt<phi::Scalar>(index)));
}
PD_IntArray* PD_InferMetaContextIntArrayAttrAt(PD_InferMetaContext* ctx,
size_t index) {
auto meta_ctx = reinterpret_cast<phi::InferMetaContext*>(ctx);
return reinterpret_cast<PD_IntArray*>(
const_cast<phi::IntArray*>(&meta_ctx->AttrAt<phi::IntArray>(index)));
}
PD_List PD_InferMetaContextListBoolAttrAt(PD_InferMetaContext* ctx,
size_t index) {
PD_List list;
auto meta_ctx = reinterpret_cast<phi::InferMetaContext*>(ctx);
const auto& cc_list = meta_ctx->AttrAt<std::vector<bool>>(index);
list.size = cc_list.size();
auto data = reinterpret_cast<uint8_t*>(new uint8_t[cc_list.size()]);
for (size_t i = 0; i < cc_list.size(); ++i) {
data[i] = static_cast<uint8_t>(cc_list[i]);
}
list.data = data;
return list;
}
PD_List PD_InferMetaContextListInt32AttrAt(PD_InferMetaContext* ctx,
size_t index) {
PD_List list;
auto meta_ctx = reinterpret_cast<phi::InferMetaContext*>(ctx);
const auto& cc_list = meta_ctx->AttrAt<std::vector<int32_t>>(index);
list.size = cc_list.size();
list.data = const_cast<int32_t*>(cc_list.data());
return list;
}
PD_List PD_InferMetaContextListInt64AttrAt(PD_InferMetaContext* ctx,
size_t index) {
PD_List list;
auto meta_ctx = reinterpret_cast<phi::InferMetaContext*>(ctx);
const auto& cc_list = meta_ctx->AttrAt<std::vector<int64_t>>(index);
list.size = cc_list.size();
list.data = const_cast<int64_t*>(cc_list.data());
return list;
}
PD_List PD_InferMetaContextListFloatAttrAt(PD_InferMetaContext* ctx,
size_t index) {
PD_List list;
auto meta_ctx = reinterpret_cast<phi::InferMetaContext*>(ctx);
const auto& cc_list = meta_ctx->AttrAt<std::vector<float>>(index);
list.size = cc_list.size();
list.data = const_cast<float*>(cc_list.data());
return list;
}
PD_List PD_InferMetaContextListDoubleAttrAt(PD_InferMetaContext* ctx,
size_t index) {
PD_List list;
auto meta_ctx = reinterpret_cast<phi::InferMetaContext*>(ctx);
const auto& cc_list = meta_ctx->AttrAt<std::vector<double>>(index);
list.size = cc_list.size();
list.data = const_cast<double*>(cc_list.data());
return list;
}
char* PD_InferMetaContextStringAttrAt(PD_InferMetaContext* ctx, size_t index) {
auto meta_ctx = reinterpret_cast<phi::InferMetaContext*>(ctx);
return const_cast<char*>(meta_ctx->AttrAt<std::string>(index).data());
}
PD_List PD_InferMetaContextListStringAttrAt(PD_InferMetaContext* ctx,
size_t index) {
PD_List list;
auto meta_ctx = reinterpret_cast<phi::InferMetaContext*>(ctx);
const auto& cc_list = meta_ctx->AttrAt<std::vector<std::string>>(index);
list.size = cc_list.size();
auto data = new char*[list.size];
for (size_t i = 0; i < list.size; ++i) {
data[i] = const_cast<char*>(cc_list[i].data());
}
list.data = reinterpret_cast<void*>(data);
return list;
}
PD_List PD_InferMetaContextListScalarAttrAt(PD_InferMetaContext* ctx,
size_t index) {
PD_List list;
auto meta_ctx = reinterpret_cast<phi::InferMetaContext*>(ctx);
const auto& cc_list = meta_ctx->AttrAt<std::vector<phi::Scalar>>(index);
list.size = cc_list.size();
auto data = new PD_Scalar*[list.size];
for (size_t i = 0; i < list.size; ++i) {
data[i] =
const_cast<PD_Scalar*>(reinterpret_cast<const PD_Scalar*>(&cc_list[i]));
}
list.data = data;
return list;
}
PD_Place* PD_InferMetaContextPlaceAttrAt(PD_InferMetaContext* ctx,
size_t index) {
auto meta_ctx = reinterpret_cast<phi::InferMetaContext*>(ctx);
return reinterpret_cast<PD_Place*>(
const_cast<phi::Place*>(&meta_ctx->AttrAt<phi::Place>(index)));
}
PD_DataType PD_InferMetaContextDataTypeAttrAt(PD_InferMetaContext* ctx,
size_t index) {
auto meta_ctx = reinterpret_cast<phi::InferMetaContext*>(ctx);
return phi::capi::ToPDDataType(meta_ctx->AttrAt<phi::DataType>(index));
}
PD_DataLayout PD_InferMetaContextDataLayoutAttrAt(PD_InferMetaContext* ctx,
size_t index) {
auto meta_ctx = reinterpret_cast<phi::InferMetaContext*>(ctx);
return phi::capi::ToPDDataLayout(meta_ctx->AttrAt<phi::DataLayout>(index));
}
PD_REGISTER_CAPI(infer_meta_context);
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/capi/include/c_meta_tensor.h"
#include "paddle/phi/capi/include/common.h"
#include "paddle/phi/capi/include/type_utils.h"
#include "paddle/phi/core/meta_tensor.h"
PD_DataType PD_MetaTensorGetPDDataType(const PD_MetaTensor *tensor,
PD_Status *status) {
if (status) {
if (!tensor) {
*status = C_FAILED;
return PD_DataType::UNDEFINED;
}
*status = C_SUCCESS;
}
auto cc_tensor = reinterpret_cast<const phi::MetaTensor *>(tensor);
return phi::capi::ToPDDataType(cc_tensor->dtype());
}
PD_DataLayout PD_MetaTensorGetDataLayout(const PD_MetaTensor *tensor,
PD_Status *status) {
if (status) {
if (!tensor) {
*status = C_FAILED;
return PD_DataLayout::ALL_LAYOUT;
}
*status = C_SUCCESS;
}
auto cc_tensor = reinterpret_cast<const phi::MetaTensor *>(tensor);
return phi::capi::ToPDDataLayout(cc_tensor->layout());
}
int64_t PD_MetaTensorGetElementCount(const PD_MetaTensor *tensor,
PD_Status *status) {
if (status) {
if (!tensor) {
*status = C_FAILED;
return 0;
}
*status = C_SUCCESS;
}
auto cc_tensor = reinterpret_cast<const phi::MetaTensor *>(tensor);
return cc_tensor->numel();
}
int64_t PD_MetaTensorGetNumDims(const PD_MetaTensor *tensor,
PD_Status *status) {
if (status) {
if (!tensor) {
*status = C_FAILED;
return 0;
}
*status = C_SUCCESS;
}
auto cc_tensor = reinterpret_cast<const phi::MetaTensor *>(tensor);
return cc_tensor->dims().size();
}
int64_t PD_MetaTensorGetDim(const PD_MetaTensor *tensor,
size_t index,
PD_Status *status) {
auto cc_tensor = reinterpret_cast<const phi::MetaTensor *>(tensor);
if (status) {
if (!tensor || index >= static_cast<size_t>(cc_tensor->dims().size())) {
*status = C_FAILED;
return 0;
}
*status = C_SUCCESS;
}
return cc_tensor->dims()[index];
}
bool PD_MetaTensorIsValid(const PD_MetaTensor *tensor, PD_Status *status) {
if (status) {
if (!tensor) {
*status = C_FAILED;
return false;
}
*status = C_SUCCESS;
}
auto cc_tensor = reinterpret_cast<const phi::MetaTensor *>(tensor);
return cc_tensor->initialized();
}
void PD_MetaTensorSetDims(PD_MetaTensor *tensor,
int64_t ndims,
const int64_t *dims,
PD_Status *status) {
if (status) {
if (!tensor) {
*status = C_FAILED;
return;
}
*status = C_SUCCESS;
}
auto cc_tensor = reinterpret_cast<phi::MetaTensor *>(tensor);
std::vector<int> shape(dims, dims + ndims);
cc_tensor->set_dims(phi::make_ddim(shape));
}
void PD_MetaTensorSetDataType(PD_MetaTensor *tensor,
PD_DataType dtype,
PD_Status *status) {
if (status) {
if (!tensor) {
*status = C_FAILED;
return;
}
*status = C_SUCCESS;
}
auto cc_tensor = reinterpret_cast<phi::MetaTensor *>(tensor);
cc_tensor->set_dtype(phi::capi::ToPhiDataType(dtype));
}
void PD_MetaTensorSetDataLayout(PD_MetaTensor *tensor,
PD_DataLayout layout,
PD_Status *status) {
if (status) {
if (!tensor) {
*status = C_FAILED;
return;
}
*status = C_SUCCESS;
}
auto cc_tensor = reinterpret_cast<phi::MetaTensor *>(tensor);
cc_tensor->set_layout(phi::capi::ToPhiDataLayout(layout));
}
PD_REGISTER_CAPI(meta_tensor);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册