未验证 提交 b199ba85 编写于 作者: C Chen Weihang 提交者: GitHub

[PTen] Refine Kernel Registrar Writing (#37977)

* refine the kernel register impl

* fix cmake and symbol error

* remove overload marco

* polish details
上级 dfed4a63
...@@ -555,10 +555,10 @@ class Reshape2Op : public ReshapeOp { ...@@ -555,10 +555,10 @@ class Reshape2Op : public ReshapeOp {
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto multi_inputs = ctx.MultiInput<framework::Tensor>("ShapeTensor"); auto multi_inputs = ctx.MultiInput<framework::Tensor>("ShapeTensor");
if (multi_inputs.size() > 0) { if (multi_inputs.size() > 0) {
return framework::KernelSignature("reshape.mulhost", {"X", "ShapeTensor"}, return framework::KernelSignature("reshape_mulhost", {"X", "ShapeTensor"},
{}, {"Out"}); {}, {"Out"});
} else if (ctx.HasInput("Shape")) { } else if (ctx.HasInput("Shape")) {
return framework::KernelSignature("reshape.host", {"X", "Shape"}, {}, return framework::KernelSignature("reshape_host", {"X", "Shape"}, {},
{"Out"}); {"Out"});
} else { } else {
return framework::KernelSignature("reshape", {"X"}, {"shape"}, {"Out"}); return framework::KernelSignature("reshape", {"X"}, {"shape"}, {"Out"});
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/pten/core/kernel_registry.h"
// TODO(chenweihang) After the kernel is split into a single file,
// the kernel declare statement is automatically generated according to the
// file name of the kernel, and this header file will be removed
PT_DECLARE_KERNEL(full_like, CPU);
PT_DECLARE_KERNEL(dot, CPU);
PT_DECLARE_KERNEL(flatten, CPU);
PT_DECLARE_KERNEL(sign, CPU);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_DECLARE_KERNEL(full_like, CUDA);
PT_DECLARE_KERNEL(dot, CUDA);
PT_DECLARE_KERNEL(flatten, CUDA);
PT_DECLARE_KERNEL(sign, CUDA);
#endif
#ifdef PADDLE_WITH_XPU
PT_DECLARE_KERNEL(flatten, XPU);
#endif
...@@ -25,10 +25,14 @@ limitations under the License. */ ...@@ -25,10 +25,14 @@ limitations under the License. */
#include "paddle/pten/include/core.h" #include "paddle/pten/include/core.h"
#include "paddle/pten/include/infermeta.h" #include "paddle/pten/include/infermeta.h"
PT_DECLARE_MODULE(UtilsCPU); PT_DECLARE_KERNEL(copy, CPU);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_DECLARE_MODULE(UtilsCUDA); PT_DECLARE_KERNEL(copy, CUDA);
#endif
#ifdef PADDLE_WITH_XPU
PT_DECLARE_KERNEL(copy, XPU);
#endif #endif
namespace paddle { namespace paddle {
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <cstring> #include <cstring>
#include <string>
#include <type_traits> #include <type_traits>
#include <typeindex> #include <typeindex>
#include <typeinfo> #include <typeinfo>
...@@ -24,6 +25,8 @@ ...@@ -24,6 +25,8 @@
#include "paddle/pten/core/kernel_factory.h" #include "paddle/pten/core/kernel_factory.h"
#include "paddle/pten/core/kernel_utils.h" #include "paddle/pten/core/kernel_utils.h"
#include "paddle/fluid/platform/enforce.h"
namespace pten { namespace pten {
#define BACKEND(arg__) pten::Backend::arg__ #define BACKEND(arg__) pten::Backend::arg__
...@@ -193,52 +196,58 @@ struct KernelRegistrar { ...@@ -193,52 +196,58 @@ struct KernelRegistrar {
#define _PT_ARG_N(args) _PT_ARG_N_EXPAND args #define _PT_ARG_N(args) _PT_ARG_N_EXPAND args
#define _PT_RESQ_N() 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0 #define _PT_RESQ_N() 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0
/** PT_REGISTER_KERNEL
*
* The most frequently used kernel registration macro, used for kernel
* registration with only data type as template parameter, and the function
* pointer of the corresponding data type is automatically instantiated
* during registration.
*/
#define PT_REGISTER_KERNEL( \ #define PT_REGISTER_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \ kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
_PT_REGISTER_KERNEL(kernel_name, \
PT_ID, \
backend, \
layout, \
meta_kernel_fn, \
cpp_dtype, \
__VA_ARGS__)
#ifndef _WIN32
#define _PT_REGISTER_KERNEL( \
kernel_name, func_id, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
PT_CONCATENATE(pt_op_kernel_ns_check_, func_id), \ pt_register_kernel_ns_check_##kernel_name, \
"PT_REGISTER_KERNEL must be called in global namespace."); \ "PT_REGISTER_KERNEL must be called in global namespace."); \
_PT_REGISTER_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, __VA_ARGS__)
#ifndef _WIN32
#define _PT_REGISTER_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
PT_KERNEL_INSTANTIATION(meta_kernel_fn, cpp_dtype, __VA_ARGS__); \ PT_KERNEL_INSTANTIATION(meta_kernel_fn, cpp_dtype, __VA_ARGS__); \
static void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, \ static void __PT_KERNEL_args_def_FN_##kernel_name(::pten::Kernel*); \
func_id)(::pten::Kernel*); \
PT_KERNEL_REGISTRAR_INIT(kernel_name, \ PT_KERNEL_REGISTRAR_INIT(kernel_name, \
func_id, \
backend, \ backend, \
layout, \ layout, \
&PT_CONCATENATE(__PT_KERNEL_args_def_FN_, func_id), \ &__PT_KERNEL_args_def_FN_##kernel_name, \
meta_kernel_fn, \ meta_kernel_fn, \
cpp_dtype, \ cpp_dtype, \
__VA_ARGS__); \ __VA_ARGS__); \
void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, \ void __PT_KERNEL_args_def_FN_##kernel_name(::pten::Kernel* kernel)
func_id)(::pten::Kernel * kernel)
#else #else
/**
* `template decltype(fn) fn` can work on gcc and clang,
* but msvc will failed, error like:
*
* error C2206: typedef cannot be used for function definition
*
* reference:
*
* https://stackoverflow.com/questions/63989585/explicit-instantiation-of-function-using-decltype-work-on-g-but-not-on-visua
*
* And msvc can work without template instantiation
*/
#define _PT_REGISTER_KERNEL( \ #define _PT_REGISTER_KERNEL( \
kernel_name, func_id, backend, layout, meta_kernel_fn, cpp_dtype, ...) \ kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ static void __PT_KERNEL_args_def_FN_##kernel_name(::pten::Kernel*); \
PT_CONCATENATE(pt_op_kernel_ns_check_, func_id), \
"PT_REGISTER_KERNEL must be called in global namespace."); \
static void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, \
func_id)(::pten::Kernel*); \
PT_KERNEL_REGISTRAR_INIT(kernel_name, \ PT_KERNEL_REGISTRAR_INIT(kernel_name, \
func_id, \
backend, \ backend, \
layout, \ layout, \
&PT_CONCATENATE(__PT_KERNEL_args_def_FN_, func_id), \ &__PT_KERNEL_args_def_FN_##kernel_name, \
meta_kernel_fn, \ meta_kernel_fn, \
cpp_dtype, \ cpp_dtype, \
__VA_ARGS__); \ __VA_ARGS__); \
void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, \ void __PT_KERNEL_args_def_FN_##kernel_name(::pten::Kernel* kernel)
func_id)(::pten::Kernel * kernel)
#endif #endif
#define PT_KERNEL_INSTANTIATION(meta_kernel_fn, cpp_dtype, ...) \ #define PT_KERNEL_INSTANTIATION(meta_kernel_fn, cpp_dtype, ...) \
...@@ -251,19 +260,6 @@ struct KernelRegistrar { ...@@ -251,19 +260,6 @@ struct KernelRegistrar {
PT_CONCATENATE(_PT_KERNEL_INSTANTIATION_, N) \ PT_CONCATENATE(_PT_KERNEL_INSTANTIATION_, N) \
(meta_kernel_fn, cpp_dtype, __VA_ARGS__) (meta_kernel_fn, cpp_dtype, __VA_ARGS__)
/**
* `template decltype(fn) fn` can work on gcc and clang,
* but msvc will failed, error like:
*
* error C2206: typedef cannot be used for function definition
*
* reference:
*
* https://stackoverflow.com/questions/63989585/explicit-instantiation-of-function-using-decltype-work-on-g-but-not-on-visua
*
* So we solve the explict instantiation of kernel by CMake
*/
#define _PT_KERNEL_INSTANTIATION_1(meta_kernel_fn, cpp_dtype, ...) \ #define _PT_KERNEL_INSTANTIATION_1(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype> template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>
#define _PT_KERNEL_INSTANTIATION_2(meta_kernel_fn, cpp_dtype, ...) \ #define _PT_KERNEL_INSTANTIATION_2(meta_kernel_fn, cpp_dtype, ...) \
...@@ -309,17 +305,10 @@ struct KernelRegistrar { ...@@ -309,17 +305,10 @@ struct KernelRegistrar {
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \ template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_14(meta_kernel_fn, __VA_ARGS__)) PT_EXPAND(_PT_KERNEL_INSTANTIATION_14(meta_kernel_fn, __VA_ARGS__))
#define PT_KERNEL_REGISTRAR_INIT(kernel_name, \ #define PT_KERNEL_REGISTRAR_INIT( \
func_id, \ kernel_name, backend, layout, args_def_fn, meta_kernel_fn, cpp_dtype, ...) \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
_PT_KERNEL_REGISTRAR_INIT(PT_NARGS(cpp_dtype, __VA_ARGS__), \ _PT_KERNEL_REGISTRAR_INIT(PT_NARGS(cpp_dtype, __VA_ARGS__), \
kernel_name, \ kernel_name, \
func_id, \
backend, \ backend, \
layout, \ layout, \
args_def_fn, \ args_def_fn, \
...@@ -333,7 +322,6 @@ struct KernelRegistrar { ...@@ -333,7 +322,6 @@ struct KernelRegistrar {
and multi-line macros cannot be skipped with NOLINT.*/ and multi-line macros cannot be skipped with NOLINT.*/
#define _PT_KERNEL_REGISTRAR_INIT(N, \ #define _PT_KERNEL_REGISTRAR_INIT(N, \
kernel_name, \ kernel_name, \
func_id, \
backend, \ backend, \
layout, \ layout, \
args_def_fn, \ args_def_fn, \
...@@ -342,7 +330,6 @@ struct KernelRegistrar { ...@@ -342,7 +330,6 @@ struct KernelRegistrar {
...) \ ...) \
PT_CONCATENATE(_PT_KERNEL_REGISTRAR_INIT_, N) ( \ PT_CONCATENATE(_PT_KERNEL_REGISTRAR_INIT_, N) ( \
kernel_name, \ kernel_name, \
func_id, \
PT_ID, \ PT_ID, \
backend, \ backend, \
layout, \ layout, \
...@@ -354,7 +341,6 @@ struct KernelRegistrar { ...@@ -354,7 +341,6 @@ struct KernelRegistrar {
// clang-format on // clang-format on
#define _PT_KERNEL_REGISTRAR_INIT_1(kernel_name, \ #define _PT_KERNEL_REGISTRAR_INIT_1(kernel_name, \
func_id, \
registrar_id, \ registrar_id, \
backend, \ backend, \
layout, \ layout, \
...@@ -363,17 +349,17 @@ struct KernelRegistrar { ...@@ -363,17 +349,17 @@ struct KernelRegistrar {
cpp_dtype, \ cpp_dtype, \
...) \ ...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \ static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_op_kernel_##func_id##_, registrar_id)( \ __reg_pt_kernel_##kernel_name##_, registrar_id)( \
kernel_name, \ #kernel_name, \
BACKEND(backend), \ BACKEND(backend), \
DATALAYOUT(layout), \ DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \ ::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype>)>::Parse, \ &meta_kernel_fn<cpp_dtype>)>::Parse, \
args_def_fn, \ args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \
int TouchKernelSymbolFor_##kernel_name##_##backend() { return 0; }
#define _PT_KERNEL_REGISTRAR_INIT_2(kernel_name, \ #define _PT_KERNEL_REGISTRAR_INIT_2(kernel_name, \
func_id, \
registrar_id, \ registrar_id, \
backend, \ backend, \
layout, \ layout, \
...@@ -382,8 +368,8 @@ struct KernelRegistrar { ...@@ -382,8 +368,8 @@ struct KernelRegistrar {
cpp_dtype, \ cpp_dtype, \
...) \ ...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \ static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_op_kernel_##func_id##_, registrar_id)( \ __reg_pt_kernel_##kernel_name##_, registrar_id)( \
kernel_name, \ #kernel_name, \
BACKEND(backend), \ BACKEND(backend), \
DATALAYOUT(layout), \ DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
...@@ -392,7 +378,6 @@ struct KernelRegistrar { ...@@ -392,7 +378,6 @@ struct KernelRegistrar {
args_def_fn, \ args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_1(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_1(kernel_name, \
func_id, \
PT_ID, \ PT_ID, \
backend, \ backend, \
layout, \ layout, \
...@@ -400,7 +385,6 @@ struct KernelRegistrar { ...@@ -400,7 +385,6 @@ struct KernelRegistrar {
meta_kernel_fn, \ meta_kernel_fn, \
__VA_ARGS__)) __VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_3(kernel_name, \ #define _PT_KERNEL_REGISTRAR_INIT_3(kernel_name, \
func_id, \
registrar_id, \ registrar_id, \
backend, \ backend, \
layout, \ layout, \
...@@ -409,8 +393,8 @@ struct KernelRegistrar { ...@@ -409,8 +393,8 @@ struct KernelRegistrar {
cpp_dtype, \ cpp_dtype, \
...) \ ...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \ static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_op_kernel_##func_id##_, registrar_id)( \ __reg_pt_kernel_##kernel_name##_, registrar_id)( \
kernel_name, \ #kernel_name, \
BACKEND(backend), \ BACKEND(backend), \
DATALAYOUT(layout), \ DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
...@@ -419,7 +403,6 @@ struct KernelRegistrar { ...@@ -419,7 +403,6 @@ struct KernelRegistrar {
args_def_fn, \ args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_2(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_2(kernel_name, \
func_id, \
PT_ID, \ PT_ID, \
backend, \ backend, \
layout, \ layout, \
...@@ -427,7 +410,6 @@ struct KernelRegistrar { ...@@ -427,7 +410,6 @@ struct KernelRegistrar {
meta_kernel_fn, \ meta_kernel_fn, \
__VA_ARGS__)) __VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_4(kernel_name, \ #define _PT_KERNEL_REGISTRAR_INIT_4(kernel_name, \
func_id, \
registrar_id, \ registrar_id, \
backend, \ backend, \
layout, \ layout, \
...@@ -436,8 +418,8 @@ struct KernelRegistrar { ...@@ -436,8 +418,8 @@ struct KernelRegistrar {
cpp_dtype, \ cpp_dtype, \
...) \ ...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \ static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_op_kernel_##func_id##_, registrar_id)( \ __reg_pt_kernel_##kernel_name##_, registrar_id)( \
kernel_name, \ #kernel_name, \
BACKEND(backend), \ BACKEND(backend), \
DATALAYOUT(layout), \ DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
...@@ -446,7 +428,6 @@ struct KernelRegistrar { ...@@ -446,7 +428,6 @@ struct KernelRegistrar {
args_def_fn, \ args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_3(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_3(kernel_name, \
func_id, \
PT_ID, \ PT_ID, \
backend, \ backend, \
layout, \ layout, \
...@@ -454,7 +435,6 @@ struct KernelRegistrar { ...@@ -454,7 +435,6 @@ struct KernelRegistrar {
meta_kernel_fn, \ meta_kernel_fn, \
__VA_ARGS__)) __VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_5(kernel_name, \ #define _PT_KERNEL_REGISTRAR_INIT_5(kernel_name, \
func_id, \
registrar_id, \ registrar_id, \
backend, \ backend, \
layout, \ layout, \
...@@ -463,8 +443,8 @@ struct KernelRegistrar { ...@@ -463,8 +443,8 @@ struct KernelRegistrar {
cpp_dtype, \ cpp_dtype, \
...) \ ...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \ static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_op_kernel_##func_id##_, registrar_id)( \ __reg_pt_kernel_##kernel_name##_, registrar_id)( \
kernel_name, \ #kernel_name, \
BACKEND(backend), \ BACKEND(backend), \
DATALAYOUT(layout), \ DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
...@@ -473,7 +453,6 @@ struct KernelRegistrar { ...@@ -473,7 +453,6 @@ struct KernelRegistrar {
args_def_fn, \ args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_4(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_4(kernel_name, \
func_id, \
PT_ID, \ PT_ID, \
backend, \ backend, \
layout, \ layout, \
...@@ -481,7 +460,6 @@ struct KernelRegistrar { ...@@ -481,7 +460,6 @@ struct KernelRegistrar {
meta_kernel_fn, \ meta_kernel_fn, \
__VA_ARGS__)) __VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_6(kernel_name, \ #define _PT_KERNEL_REGISTRAR_INIT_6(kernel_name, \
func_id, \
registrar_id, \ registrar_id, \
backend, \ backend, \
layout, \ layout, \
...@@ -490,8 +468,8 @@ struct KernelRegistrar { ...@@ -490,8 +468,8 @@ struct KernelRegistrar {
cpp_dtype, \ cpp_dtype, \
...) \ ...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \ static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_op_kernel_##func_id##_, registrar_id)( \ __reg_pt_kernel_##kernel_name##_, registrar_id)( \
kernel_name, \ #kernel_name, \
BACKEND(backend), \ BACKEND(backend), \
DATALAYOUT(layout), \ DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
...@@ -500,7 +478,6 @@ struct KernelRegistrar { ...@@ -500,7 +478,6 @@ struct KernelRegistrar {
args_def_fn, \ args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_5(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_5(kernel_name, \
func_id, \
PT_ID, \ PT_ID, \
backend, \ backend, \
layout, \ layout, \
...@@ -508,7 +485,6 @@ struct KernelRegistrar { ...@@ -508,7 +485,6 @@ struct KernelRegistrar {
meta_kernel_fn, \ meta_kernel_fn, \
__VA_ARGS__)) __VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_7(kernel_name, \ #define _PT_KERNEL_REGISTRAR_INIT_7(kernel_name, \
func_id, \
registrar_id, \ registrar_id, \
backend, \ backend, \
layout, \ layout, \
...@@ -517,8 +493,8 @@ struct KernelRegistrar { ...@@ -517,8 +493,8 @@ struct KernelRegistrar {
cpp_dtype, \ cpp_dtype, \
...) \ ...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \ static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_op_kernel_##func_id##_, registrar_id)( \ __reg_pt_kernel_##kernel_name##_, registrar_id)( \
kernel_name, \ #kernel_name, \
BACKEND(backend), \ BACKEND(backend), \
DATALAYOUT(layout), \ DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
...@@ -527,7 +503,6 @@ struct KernelRegistrar { ...@@ -527,7 +503,6 @@ struct KernelRegistrar {
args_def_fn, \ args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_6(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_6(kernel_name, \
func_id, \
PT_ID, \ PT_ID, \
backend, \ backend, \
layout, \ layout, \
...@@ -535,7 +510,6 @@ struct KernelRegistrar { ...@@ -535,7 +510,6 @@ struct KernelRegistrar {
meta_kernel_fn, \ meta_kernel_fn, \
__VA_ARGS__)) __VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_8(kernel_name, \ #define _PT_KERNEL_REGISTRAR_INIT_8(kernel_name, \
func_id, \
registrar_id, \ registrar_id, \
backend, \ backend, \
layout, \ layout, \
...@@ -544,8 +518,8 @@ struct KernelRegistrar { ...@@ -544,8 +518,8 @@ struct KernelRegistrar {
cpp_dtype, \ cpp_dtype, \
...) \ ...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \ static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_op_kernel_##func_id##_, registrar_id)( \ __reg_pt_kernel_##kernel_name##_, registrar_id)( \
kernel_name, \ #kernel_name, \
BACKEND(backend), \ BACKEND(backend), \
DATALAYOUT(layout), \ DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
...@@ -554,7 +528,6 @@ struct KernelRegistrar { ...@@ -554,7 +528,6 @@ struct KernelRegistrar {
args_def_fn, \ args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_7(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_7(kernel_name, \
func_id, \
PT_ID, \ PT_ID, \
backend, \ backend, \
layout, \ layout, \
...@@ -562,7 +535,6 @@ struct KernelRegistrar { ...@@ -562,7 +535,6 @@ struct KernelRegistrar {
meta_kernel_fn, \ meta_kernel_fn, \
__VA_ARGS__)) __VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_9(kernel_name, \ #define _PT_KERNEL_REGISTRAR_INIT_9(kernel_name, \
func_id, \
registrar_id, \ registrar_id, \
backend, \ backend, \
layout, \ layout, \
...@@ -571,8 +543,8 @@ struct KernelRegistrar { ...@@ -571,8 +543,8 @@ struct KernelRegistrar {
cpp_dtype, \ cpp_dtype, \
...) \ ...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \ static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_op_kernel_##func_id##_, registrar_id)( \ __reg_pt_kernel_##kernel_name##_, registrar_id)( \
kernel_name, \ #kernel_name, \
BACKEND(backend), \ BACKEND(backend), \
DATALAYOUT(layout), \ DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
...@@ -581,7 +553,6 @@ struct KernelRegistrar { ...@@ -581,7 +553,6 @@ struct KernelRegistrar {
args_def_fn, \ args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_8(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_8(kernel_name, \
func_id, \
PT_ID, \ PT_ID, \
backend, \ backend, \
layout, \ layout, \
...@@ -589,7 +560,6 @@ struct KernelRegistrar { ...@@ -589,7 +560,6 @@ struct KernelRegistrar {
meta_kernel_fn, \ meta_kernel_fn, \
__VA_ARGS__)) __VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_10(kernel_name, \ #define _PT_KERNEL_REGISTRAR_INIT_10(kernel_name, \
func_id, \
registrar_id, \ registrar_id, \
backend, \ backend, \
layout, \ layout, \
...@@ -598,8 +568,8 @@ struct KernelRegistrar { ...@@ -598,8 +568,8 @@ struct KernelRegistrar {
cpp_dtype, \ cpp_dtype, \
...) \ ...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \ static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_op_kernel_##func_id##_, registrar_id)( \ __reg_pt_kernel_##kernel_name##_, registrar_id)( \
kernel_name, \ #kernel_name, \
BACKEND(backend), \ BACKEND(backend), \
DATALAYOUT(layout), \ DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
...@@ -608,7 +578,6 @@ struct KernelRegistrar { ...@@ -608,7 +578,6 @@ struct KernelRegistrar {
args_def_fn, \ args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_9(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_9(kernel_name, \
func_id, \
PT_ID, \ PT_ID, \
backend, \ backend, \
layout, \ layout, \
...@@ -616,7 +585,6 @@ struct KernelRegistrar { ...@@ -616,7 +585,6 @@ struct KernelRegistrar {
meta_kernel_fn, \ meta_kernel_fn, \
__VA_ARGS__)) __VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_11(kernel_name, \ #define _PT_KERNEL_REGISTRAR_INIT_11(kernel_name, \
func_id, \
registrar_id, \ registrar_id, \
backend, \ backend, \
layout, \ layout, \
...@@ -625,8 +593,8 @@ struct KernelRegistrar { ...@@ -625,8 +593,8 @@ struct KernelRegistrar {
cpp_dtype, \ cpp_dtype, \
...) \ ...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \ static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_op_kernel_##func_id##_, registrar_id)( \ __reg_pt_kernel_##kernel_name##_, registrar_id)( \
kernel_name, \ #kernel_name, \
BACKEND(backend), \ BACKEND(backend), \
DATALAYOUT(layout), \ DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
...@@ -635,7 +603,6 @@ struct KernelRegistrar { ...@@ -635,7 +603,6 @@ struct KernelRegistrar {
args_def_fn, \ args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_10(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_10(kernel_name, \
func_id, \
PT_ID, \ PT_ID, \
backend, \ backend, \
layout, \ layout, \
...@@ -643,7 +610,6 @@ struct KernelRegistrar { ...@@ -643,7 +610,6 @@ struct KernelRegistrar {
meta_kernel_fn, \ meta_kernel_fn, \
__VA_ARGS__)) __VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_12(kernel_name, \ #define _PT_KERNEL_REGISTRAR_INIT_12(kernel_name, \
func_id, \
registrar_id, \ registrar_id, \
backend, \ backend, \
layout, \ layout, \
...@@ -652,8 +618,8 @@ struct KernelRegistrar { ...@@ -652,8 +618,8 @@ struct KernelRegistrar {
cpp_dtype, \ cpp_dtype, \
...) \ ...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \ static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_op_kernel_##func_id##_, registrar_id)( \ __reg_pt_kernel_##kernel_name##_, registrar_id)( \
kernel_name, \ #kernel_name, \
BACKEND(backend), \ BACKEND(backend), \
DATALAYOUT(layout), \ DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
...@@ -662,7 +628,6 @@ struct KernelRegistrar { ...@@ -662,7 +628,6 @@ struct KernelRegistrar {
args_def_fn, \ args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_11(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_11(kernel_name, \
func_id, \
PT_ID, \ PT_ID, \
backend, \ backend, \
layout, \ layout, \
...@@ -670,7 +635,6 @@ struct KernelRegistrar { ...@@ -670,7 +635,6 @@ struct KernelRegistrar {
meta_kernel_fn, \ meta_kernel_fn, \
__VA_ARGS__)) __VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_13(kernel_name, \ #define _PT_KERNEL_REGISTRAR_INIT_13(kernel_name, \
func_id, \
registrar_id, \ registrar_id, \
backend, \ backend, \
layout, \ layout, \
...@@ -679,8 +643,8 @@ struct KernelRegistrar { ...@@ -679,8 +643,8 @@ struct KernelRegistrar {
cpp_dtype, \ cpp_dtype, \
...) \ ...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \ static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_op_kernel_##func_id##_, registrar_id)( \ __reg_pt_kernel_##kernel_name##_, registrar_id)( \
kernel_name, \ #kernel_name, \
BACKEND(backend), \ BACKEND(backend), \
DATALAYOUT(layout), \ DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
...@@ -689,7 +653,6 @@ struct KernelRegistrar { ...@@ -689,7 +653,6 @@ struct KernelRegistrar {
args_def_fn, \ args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_12(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_12(kernel_name, \
func_id, \
PT_ID, \ PT_ID, \
backend, \ backend, \
layout, \ layout, \
...@@ -697,7 +660,6 @@ struct KernelRegistrar { ...@@ -697,7 +660,6 @@ struct KernelRegistrar {
meta_kernel_fn, \ meta_kernel_fn, \
__VA_ARGS__)) __VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_14(kernel_name, \ #define _PT_KERNEL_REGISTRAR_INIT_14(kernel_name, \
func_id, \
registrar_id, \ registrar_id, \
backend, \ backend, \
layout, \ layout, \
...@@ -706,8 +668,8 @@ struct KernelRegistrar { ...@@ -706,8 +668,8 @@ struct KernelRegistrar {
cpp_dtype, \ cpp_dtype, \
...) \ ...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \ static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_op_kernel_##func_id##_, registrar_id)( \ __reg_pt_kernel_##kernel_name##_, registrar_id)( \
kernel_name, \ #kernel_name, \
BACKEND(backend), \ BACKEND(backend), \
DATALAYOUT(layout), \ DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
...@@ -716,7 +678,6 @@ struct KernelRegistrar { ...@@ -716,7 +678,6 @@ struct KernelRegistrar {
args_def_fn, \ args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_13(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_13(kernel_name, \
func_id, \
PT_ID, \ PT_ID, \
backend, \ backend, \
layout, \ layout, \
...@@ -724,7 +685,6 @@ struct KernelRegistrar { ...@@ -724,7 +685,6 @@ struct KernelRegistrar {
meta_kernel_fn, \ meta_kernel_fn, \
__VA_ARGS__)) __VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_15(kernel_name, \ #define _PT_KERNEL_REGISTRAR_INIT_15(kernel_name, \
func_id, \
registrar_id, \ registrar_id, \
backend, \ backend, \
layout, \ layout, \
...@@ -733,8 +693,8 @@ struct KernelRegistrar { ...@@ -733,8 +693,8 @@ struct KernelRegistrar {
cpp_dtype, \ cpp_dtype, \
...) \ ...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \ static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_op_kernel_##func_id##_, registrar_id)( \ __reg_pt_kernel_##kernel_name##_, registrar_id)( \
kernel_name, \ #kernel_name, \
BACKEND(backend), \ BACKEND(backend), \
DATALAYOUT(layout), \ DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
...@@ -743,7 +703,6 @@ struct KernelRegistrar { ...@@ -743,7 +703,6 @@ struct KernelRegistrar {
args_def_fn, \ args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_14(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_14(kernel_name, \
func_id, \
PT_ID, \ PT_ID, \
backend, \ backend, \
layout, \ layout, \
...@@ -751,90 +710,59 @@ struct KernelRegistrar { ...@@ -751,90 +710,59 @@ struct KernelRegistrar {
meta_kernel_fn, \ meta_kernel_fn, \
__VA_ARGS__)) __VA_ARGS__))
#define PT_REGISTER_KERNEL_STANDARD( \ /** PT_REGISTER_SINGLE_KERNEL
*
* Used to register a single kernel, pass in the complete function pointer
* of the kernel, this registration macro will not do automatic template
* instantiation.
*/
#define PT_REGISTER_SINGLE_KERNEL( \
kernel_name, backend, layout, dtype, kernel_fn) \ kernel_name, backend, layout, dtype, kernel_fn) \
_PT_REGISTER_KERNEL_STANDARD( \
kernel_name, PT_ID, backend, layout, dtype, kernel_fn)
#define _PT_REGISTER_KERNEL_STANDARD( \
kernel_name, func_id, backend, layout, dtype, kernel_fn) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
PT_CONCATENATE(pt_op_kernel_ns_check_, func_id), \ pt_register_single_kernel_ns_check_##kernel_name, \
"_PT_REGISTER_KERNEL_STANDARD must be called in global namespace."); \ "PT_REGISTER_SINGLE_KERNEL must be called in global namespace."); \
template decltype(kernel_fn) kernel_fn; \ static void __PT_SINGLE_KERNEL_args_def_FN_##kernel_name(::pten::Kernel*); \
static void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, \ static const ::pten::KernelRegistrar __reg_pt_single_kernel_##kernel_name( \
func_id)(::pten::Kernel*); \ #kernel_name, \
static const ::pten::KernelRegistrar PT_CONCATENATE(__reg_pt_op_kernel_, \
func_id)( \
kernel_name, \
BACKEND(backend), \ BACKEND(backend), \
DATALAYOUT(layout), \ DATALAYOUT(layout), \
DATATYPE(dtype), \ DATATYPE(dtype), \
::pten::KernelArgsParseFunctor<decltype(&kernel_fn)>::Parse, \ ::pten::KernelArgsParseFunctor<decltype(&kernel_fn)>::Parse, \
args_def_fn, \ args_def_fn, \
PT_KERNEL(kernel_fn)); \ PT_KERNEL(kernel_fn)); \
void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, func_id)(::pten::Kernel*) int TouchKernelSymbolFor_##kernel_name##_##backend() { return 0; } \
void __PT_SINGLE_KERNEL_args_def_FN_##kernel_name(::pten::Kernel*)
// use to declare symbol
#define PT_REGISTER_MODULE(name) \
int RegisterSymbolsFor##name() { return 0; }
#define PT_DECLARE_MODULE(name) \ /** PT_REGISTER_KERNEL_ALL_DTYPE
extern int RegisterSymbolsFor##name(); \ *
UNUSED static int use_kernel_module_##name = RegisterSymbolsFor##name() * Used to register a kernel that supports all data types, such as copy and
* reshape that are not sensitive to data types.
// only used in cpp tests */
#define PT_REGISTER_KERNEL_ALL_DTYPE(kernel_name, backend, layout, kernel_fn) \
#define PT_REGISTER_KERNEL_FOR_TEST( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
_PT_REGISTER_KERNEL_FOR_TEST(kernel_name, \
PT_ID, \
backend, \
layout, \
meta_kernel_fn, \
cpp_dtype, \
__VA_ARGS__)
#define _PT_REGISTER_KERNEL_FOR_TEST( \
kernel_name, func_id, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
PT_CONCATENATE(pt_op_kernel_for_test_ns_check_, func_id), \
"PT_REGISTER_KERNEL must be called in global namespace."); \
static void PT_CONCATENATE(__PT_KERNEL_for_test_args_def_FN_, \
func_id)(::pten::Kernel*); \
PT_KERNEL_REGISTRAR_INIT( \
kernel_name, \
func_id, \
backend, \
layout, \
&PT_CONCATENATE(__PT_KERNEL_for_test_args_def_FN_, func_id), \
meta_kernel_fn, \
cpp_dtype, \
__VA_ARGS__); \
void PT_CONCATENATE(__PT_KERNEL_for_test_args_def_FN_, \
func_id)(::pten::Kernel * kernel)
#define PT_REGISTER_KERNEL_WITH_NO_TYPE( \
kernel_name, backend, layout, meta_kernel_fn) \
_PT_REGISTER_KERNEL_WITH_NO_TYPE( \
kernel_name, PT_ID, backend, layout, meta_kernel_fn)
#define _PT_REGISTER_KERNEL_WITH_NO_TYPE( \
kernel_name, func_id, backend, layout, meta_kernel_fn) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
PT_CONCATENATE(pt_op_kernel_ns_check_, func_id), \ pt_register_kernel_all_dtype_ns_check_##kernel_name, \
"PT_REGISTER_KERNEL must be called in global namespace."); \ "PT_REGISTER_KERNEL_ALL_DTYPE must be called in global namespace."); \
decltype(meta_kernel_fn) meta_kernel_fn; \ static void __PT_KERNEL_ALL_DTYPE_args_def_FN_##kernel_name( \
static void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, \ ::pten::Kernel*); \
func_id)(::pten::Kernel*); \ static const ::pten::KernelRegistrar \
static const ::pten::KernelRegistrar PT_CONCATENATE(__reg_pt_op_kernel_, \ __reg_pt_kernel_all_dtype_##kernel_name( \
func_id)( \ #kernel_name, \
kernel_name, \
BACKEND(backend), \ BACKEND(backend), \
DATALAYOUT(layout), \ DATALAYOUT(layout), \
::pten::KernelArgsParseFunctor<decltype(&meta_kernel_fn)>::Parse, \ ::pten::KernelArgsParseFunctor<decltype(&kernel_fn)>::Parse, \
&PT_CONCATENATE(__PT_KERNEL_args_def_FN_, func_id), \ &__PT_KERNEL_ALL_DTYPE_args_def_FN_##kernel_name, \
PT_KERNEL(meta_kernel_fn)); \ PT_KERNEL(kernel_fn)); \
void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, \ int TouchKernelSymbolFor_##kernel_name##_##backend() { return 0; } \
func_id)(::pten::Kernel * kernel) void __PT_KERNEL_ALL_DTYPE_args_def_FN_##kernel_name(::pten::Kernel* kernel)
/** PT_DECLARE_KERNEL
*
* Used to export the symbols of the file where the kernel is located,
* to avoid being removed by linker
*/
#define PT_DECLARE_KERNEL(kernel_name, backend) \
extern int TouchKernelSymbolFor_##kernel_name##_##backend(); \
UNUSED static int __declare_kernel_symbol_for_##kernel_name##_##backend = \
TouchKernelSymbolFor_##kernel_name##_##backend()
} // namespace pten } // namespace pten
...@@ -61,9 +61,7 @@ void FillConstant(const CPUContext& dev_ctx, ...@@ -61,9 +61,7 @@ void FillConstant(const CPUContext& dev_ctx,
} // namespace pten } // namespace pten
PT_REGISTER_MODULE(CreationCPU); PT_REGISTER_KERNEL(full_like,
PT_REGISTER_KERNEL("full_like",
CPU, CPU,
ANY, ANY,
pten::FillAnyLike, pten::FillAnyLike,
...@@ -74,7 +72,7 @@ PT_REGISTER_KERNEL("full_like", ...@@ -74,7 +72,7 @@ PT_REGISTER_KERNEL("full_like",
bool, bool,
paddle::platform::float16) {} paddle::platform::float16) {}
PT_REGISTER_KERNEL("full", PT_REGISTER_KERNEL(full,
CPU, CPU,
ANY, ANY,
pten::FillConstant, pten::FillConstant,
......
...@@ -70,12 +70,10 @@ void Matmul(const CPUContext& dev_ctx, ...@@ -70,12 +70,10 @@ void Matmul(const CPUContext& dev_ctx,
} // namespace pten } // namespace pten
PT_REGISTER_MODULE(LinalgCPU);
using complex64 = ::paddle::platform::complex<float>; using complex64 = ::paddle::platform::complex<float>;
using complex128 = ::paddle::platform::complex<double>; using complex128 = ::paddle::platform::complex<double>;
PT_REGISTER_KERNEL("dot", PT_REGISTER_KERNEL(dot,
CPU, CPU,
ANY, ANY,
pten::Dot, pten::Dot,
...@@ -87,5 +85,4 @@ PT_REGISTER_KERNEL("dot", ...@@ -87,5 +85,4 @@ PT_REGISTER_KERNEL("dot",
complex128) {} complex128) {}
PT_REGISTER_KERNEL( PT_REGISTER_KERNEL(
"matmul_v2", CPU, ANY, pten::Matmul, float, double, complex64, complex128) { matmul_v2, CPU, ANY, pten::Matmul, float, double, complex64, complex128) {}
}
...@@ -130,12 +130,9 @@ void Cast(const CPUContext& dev_ctx, ...@@ -130,12 +130,9 @@ void Cast(const CPUContext& dev_ctx,
} // namespace pten } // namespace pten
// TODO(chenweihang): replace by better impl
PT_REGISTER_MODULE(ManipulationCPU);
// TODO(yuanrisheng): "flatten_contiguous_range" is compatible with old kernel // TODO(yuanrisheng): "flatten_contiguous_range" is compatible with old kernel
// architecture, kernel_name should be "flatten". // architecture, kernel_name should be "flatten".
PT_REGISTER_KERNEL("flatten", PT_REGISTER_KERNEL(flatten,
CPU, CPU,
ANY, ANY,
pten::Flatten, pten::Flatten,
...@@ -145,8 +142,7 @@ PT_REGISTER_KERNEL("flatten", ...@@ -145,8 +142,7 @@ PT_REGISTER_KERNEL("flatten",
int8_t, int8_t,
int, int,
int64_t) {} int64_t) {}
PT_REGISTER_KERNEL(flatten_mid,
PT_REGISTER_KERNEL("flatten.mid",
CPU, CPU,
ANY, ANY,
pten::FlattenWithXShape, pten::FlattenWithXShape,
...@@ -156,7 +152,8 @@ PT_REGISTER_KERNEL("flatten.mid", ...@@ -156,7 +152,8 @@ PT_REGISTER_KERNEL("flatten.mid",
int8_t, int8_t,
int, int,
int64_t) {} int64_t) {}
PT_REGISTER_KERNEL("cast",
PT_REGISTER_KERNEL(cast,
CPU, CPU,
ANY, ANY,
pten::Cast, pten::Cast,
...@@ -174,39 +171,30 @@ PT_REGISTER_KERNEL("cast", ...@@ -174,39 +171,30 @@ PT_REGISTER_KERNEL("cast",
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
} }
// TODO(yuanrisheng): "reshape2" is compatible with old kernel PT_REGISTER_KERNEL_ALL_DTYPE(reshape, CPU, ANY, pten::ReshapeFromVectorVal) {}
// architecture, kernel_name should be "reshape". PT_REGISTER_KERNEL_ALL_DTYPE(reshape_mid,
PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape",
CPU,
ANY,
pten::ReshapeFromVectorVal) {}
PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape.mid",
CPU, CPU,
ANY, ANY,
pten::ReshapeFromVectorValWithXShape) {} pten::ReshapeFromVectorValWithXShape) {}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_host, CPU, ANY, pten::ReshapeFromDT) {
PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape.host", CPU, ANY, pten::ReshapeFromDT) {
kernel->InputAt(1).SetBackend(pten::Backend::CPU); kernel->InputAt(1).SetBackend(pten::Backend::CPU);
kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32); kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32);
} }
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_host_mid,
PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape.host.mid",
CPU, CPU,
ANY, ANY,
pten::ReshapeFromDTWithXShape) { pten::ReshapeFromDTWithXShape) {
kernel->InputAt(1).SetBackend(pten::Backend::CPU); kernel->InputAt(1).SetBackend(pten::Backend::CPU);
kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32); kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32);
} }
PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape.mulhost", PT_REGISTER_KERNEL_ALL_DTYPE(reshape_mulhost,
CPU, CPU,
ANY, ANY,
pten::ReshapeFromVectorDT) { pten::ReshapeFromVectorDT) {
kernel->InputAt(1).SetBackend(pten::Backend::CPU); kernel->InputAt(1).SetBackend(pten::Backend::CPU);
kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32); kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32);
} }
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_mulhost_mid,
PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape.mulhost.mid",
CPU, CPU,
ANY, ANY,
pten::ReshapeFromVectorDTWithXShape) { pten::ReshapeFromVectorDTWithXShape) {
......
...@@ -106,18 +106,14 @@ DEFINE_CPU_ELEMENTWISE_OP(Mul) ...@@ -106,18 +106,14 @@ DEFINE_CPU_ELEMENTWISE_OP(Mul)
} // namespace pten } // namespace pten
// TODO(chenweihang): replace by better impl
PT_REGISTER_MODULE(MathCPU);
using complex64 = ::paddle::platform::complex<float>; using complex64 = ::paddle::platform::complex<float>;
using complex128 = ::paddle::platform::complex<double>; using complex128 = ::paddle::platform::complex<double>;
// NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16 // NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16
// using bfloat16 = ::paddle::platform::bfloat16; // using bfloat16 = ::paddle::platform::bfloat16;
PT_REGISTER_KERNEL(sign, CPU, ANY, pten::Sign, float, double) {}
PT_REGISTER_KERNEL("sign", CPU, ANY, pten::Sign, float, double) {} PT_REGISTER_KERNEL(mean, CPU, ANY, pten::Mean, float, double, bool) {}
PT_REGISTER_KERNEL("mean", CPU, ANY, pten::Mean, float, double, bool) {} PT_REGISTER_KERNEL(scale,
PT_REGISTER_KERNEL("scale",
CPU, CPU,
ANY, ANY,
pten::Scale, pten::Scale,
...@@ -129,8 +125,7 @@ PT_REGISTER_KERNEL("scale", ...@@ -129,8 +125,7 @@ PT_REGISTER_KERNEL("scale",
int16_t, int16_t,
int, int,
int64_t) {} int64_t) {}
PT_REGISTER_KERNEL(add,
PT_REGISTER_KERNEL("add",
CPU, CPU,
ANY, ANY,
pten::ElementwiseAdd, pten::ElementwiseAdd,
...@@ -140,7 +135,7 @@ PT_REGISTER_KERNEL("add", ...@@ -140,7 +135,7 @@ PT_REGISTER_KERNEL("add",
int64_t, int64_t,
complex64, complex64,
complex128) {} complex128) {}
PT_REGISTER_KERNEL("subtract", PT_REGISTER_KERNEL(subtract,
CPU, CPU,
ANY, ANY,
pten::ElementwiseSub, pten::ElementwiseSub,
...@@ -150,7 +145,7 @@ PT_REGISTER_KERNEL("subtract", ...@@ -150,7 +145,7 @@ PT_REGISTER_KERNEL("subtract",
int64_t, int64_t,
complex64, complex64,
complex128) {} complex128) {}
PT_REGISTER_KERNEL("divide", PT_REGISTER_KERNEL(divide,
CPU, CPU,
ANY, ANY,
pten::ElementwiseDiv, pten::ElementwiseDiv,
...@@ -160,7 +155,7 @@ PT_REGISTER_KERNEL("divide", ...@@ -160,7 +155,7 @@ PT_REGISTER_KERNEL("divide",
int64_t, int64_t,
complex64, complex64,
complex128) {} complex128) {}
PT_REGISTER_KERNEL("multiply", PT_REGISTER_KERNEL(multiply,
CPU, CPU,
ANY, ANY,
pten::ElementwiseMul, pten::ElementwiseMul,
...@@ -171,8 +166,7 @@ PT_REGISTER_KERNEL("multiply", ...@@ -171,8 +166,7 @@ PT_REGISTER_KERNEL("multiply",
bool, bool,
complex64, complex64,
complex128) {} complex128) {}
PT_REGISTER_KERNEL(sum,
PT_REGISTER_KERNEL("sum",
CPU, CPU,
ANY, ANY,
pten::Sum, pten::Sum,
......
...@@ -57,7 +57,4 @@ void Copy(const CPUContext& dev_ctx, ...@@ -57,7 +57,4 @@ void Copy(const CPUContext& dev_ctx,
} // namespace pten } // namespace pten
// TODO(chenweihang): replace by better impl PT_REGISTER_KERNEL_ALL_DTYPE(copy, CPU, ANY, pten::Copy) {}
PT_REGISTER_MODULE(UtilsCPU);
PT_REGISTER_KERNEL_WITH_NO_TYPE("copy", CPU, ANY, pten::Copy) {}
...@@ -62,9 +62,7 @@ void FillConstant(const CUDAContext& dev_ctx, ...@@ -62,9 +62,7 @@ void FillConstant(const CUDAContext& dev_ctx,
} // namespace pten } // namespace pten
PT_REGISTER_MODULE(CreationCUDA); PT_REGISTER_KERNEL(full_like,
PT_REGISTER_KERNEL("full_like",
CUDA, CUDA,
ANY, ANY,
pten::FillAnyLike, pten::FillAnyLike,
...@@ -75,7 +73,7 @@ PT_REGISTER_KERNEL("full_like", ...@@ -75,7 +73,7 @@ PT_REGISTER_KERNEL("full_like",
bool, bool,
paddle::platform::float16) {} paddle::platform::float16) {}
PT_REGISTER_KERNEL("full", PT_REGISTER_KERNEL(full,
CUDA, CUDA,
ANY, ANY,
pten::FillConstant, pten::FillConstant,
......
...@@ -54,13 +54,11 @@ void Matmul(const CUDAContext& dev_ctx, ...@@ -54,13 +54,11 @@ void Matmul(const CUDAContext& dev_ctx,
} // namespace pten } // namespace pten
PT_REGISTER_MODULE(LinalgCUDA);
using float16 = paddle::platform::float16; using float16 = paddle::platform::float16;
using complex64 = ::paddle::platform::complex<float>; using complex64 = ::paddle::platform::complex<float>;
using complex128 = ::paddle::platform::complex<double>; using complex128 = ::paddle::platform::complex<double>;
PT_REGISTER_KERNEL("dot", PT_REGISTER_KERNEL(dot,
CUDA, CUDA,
ANY, ANY,
pten::Dot, pten::Dot,
...@@ -71,7 +69,7 @@ PT_REGISTER_KERNEL("dot", ...@@ -71,7 +69,7 @@ PT_REGISTER_KERNEL("dot",
complex64, complex64,
complex128) {} complex128) {}
PT_REGISTER_KERNEL("matmul_v2", PT_REGISTER_KERNEL(matmul_v2,
CUDA, CUDA,
ANY, ANY,
pten::Matmul, pten::Matmul,
......
...@@ -129,13 +129,9 @@ void Cast(const CUDAContext& dev_ctx, ...@@ -129,13 +129,9 @@ void Cast(const CUDAContext& dev_ctx,
} // namespace pten } // namespace pten
// TODO(chenweihang): replace by better impl
PT_REGISTER_MODULE(ManipulationCUDA);
using float16 = paddle::platform::float16; using float16 = paddle::platform::float16;
// TODO(yuanrisheng): "flatten_contiguous_range" is compatible with old kernel
// architecture, kernel_name should be "flatten". PT_REGISTER_KERNEL(flatten,
PT_REGISTER_KERNEL("flatten",
CUDA, CUDA,
ANY, ANY,
pten::Flatten, pten::Flatten,
...@@ -146,8 +142,7 @@ PT_REGISTER_KERNEL("flatten", ...@@ -146,8 +142,7 @@ PT_REGISTER_KERNEL("flatten",
int8_t, int8_t,
int, int,
int64_t) {} int64_t) {}
PT_REGISTER_KERNEL(flatten_mid,
PT_REGISTER_KERNEL("flatten.mid",
CUDA, CUDA,
ANY, ANY,
pten::FlattenWithXShape, pten::FlattenWithXShape,
...@@ -159,7 +154,7 @@ PT_REGISTER_KERNEL("flatten.mid", ...@@ -159,7 +154,7 @@ PT_REGISTER_KERNEL("flatten.mid",
int64_t) {} int64_t) {}
#define PTEN_REGISTER_CAST_CUDA_BASE_TYPE(op_name, ...) \ #define PTEN_REGISTER_CAST_CUDA_BASE_TYPE(op_name, ...) \
PT_REGISTER_KERNEL("cast", \ PT_REGISTER_KERNEL(cast, \
CUDA, \ CUDA, \
ANY, \ ANY, \
pten::Cast, \ pten::Cast, \
...@@ -184,41 +179,30 @@ PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast, paddle::platform::bfloat16) ...@@ -184,41 +179,30 @@ PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast, paddle::platform::bfloat16)
PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast) PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast)
#endif #endif
PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape", PT_REGISTER_KERNEL_ALL_DTYPE(reshape, CUDA, ANY, pten::ReshapeFromVectorVal) {}
CUDA, PT_REGISTER_KERNEL_ALL_DTYPE(reshape_mid,
ANY,
pten::ReshapeFromVectorVal) {}
PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape.mid",
CUDA, CUDA,
ANY, ANY,
pten::ReshapeFromVectorValWithXShape) {} pten::ReshapeFromVectorValWithXShape) {}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_host, CUDA, ANY, pten::ReshapeFromDT) {
PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape.host",
CUDA,
ANY,
pten::ReshapeFromDT) {
kernel->InputAt(1).SetBackend(pten::Backend::CPU); kernel->InputAt(1).SetBackend(pten::Backend::CPU);
kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32); kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32);
} }
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_host_mid,
PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape.host.mid",
CUDA, CUDA,
ANY, ANY,
pten::ReshapeFromDTWithXShape) { pten::ReshapeFromDTWithXShape) {
kernel->InputAt(1).SetBackend(pten::Backend::CPU); kernel->InputAt(1).SetBackend(pten::Backend::CPU);
kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32); kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32);
} }
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_mulhost,
PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape.mulhost",
CUDA, CUDA,
ANY, ANY,
pten::ReshapeFromVectorDT) { pten::ReshapeFromVectorDT) {
kernel->InputAt(1).SetBackend(pten::Backend::CPU); kernel->InputAt(1).SetBackend(pten::Backend::CPU);
kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32); kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32);
} }
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_mulhost_mid,
PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape.mulhost.mid",
CUDA, CUDA,
ANY, ANY,
pten::ReshapeFromVectorDTWithXShape) { pten::ReshapeFromVectorDTWithXShape) {
......
...@@ -111,16 +111,13 @@ void Sum(const CUDAContext& dev_ctx, ...@@ -111,16 +111,13 @@ void Sum(const CUDAContext& dev_ctx,
} // namespace pten } // namespace pten
// TODO(chenweihang): replace by better impl
PT_REGISTER_MODULE(MathCUDA);
using float16 = paddle::platform::float16; using float16 = paddle::platform::float16;
using complex64 = ::paddle::platform::complex<float>; using complex64 = ::paddle::platform::complex<float>;
using complex128 = ::paddle::platform::complex<double>; using complex128 = ::paddle::platform::complex<double>;
PT_REGISTER_KERNEL("sign", CUDA, ANY, pten::Sign, float, double, float16) {} PT_REGISTER_KERNEL(sign, CUDA, ANY, pten::Sign, float, double, float16) {}
PT_REGISTER_KERNEL("mean", CUDA, ANY, pten::Mean, float, double, bool) {} PT_REGISTER_KERNEL(mean, CUDA, ANY, pten::Mean, float, double, bool) {}
PT_REGISTER_KERNEL("scale", PT_REGISTER_KERNEL(scale,
CUDA, CUDA,
ANY, ANY,
pten::Scale, pten::Scale,
...@@ -132,7 +129,7 @@ PT_REGISTER_KERNEL("scale", ...@@ -132,7 +129,7 @@ PT_REGISTER_KERNEL("scale",
int16_t, int16_t,
int, int,
int64_t) {} int64_t) {}
PT_REGISTER_KERNEL("add", PT_REGISTER_KERNEL(add,
CUDA, CUDA,
ANY, ANY,
pten::ElementwiseAdd, pten::ElementwiseAdd,
...@@ -143,7 +140,7 @@ PT_REGISTER_KERNEL("add", ...@@ -143,7 +140,7 @@ PT_REGISTER_KERNEL("add",
float16, float16,
complex64, complex64,
complex128) {} complex128) {}
PT_REGISTER_KERNEL("subtract", PT_REGISTER_KERNEL(subtract,
CUDA, CUDA,
ANY, ANY,
pten::ElementwiseSub, pten::ElementwiseSub,
...@@ -154,7 +151,7 @@ PT_REGISTER_KERNEL("subtract", ...@@ -154,7 +151,7 @@ PT_REGISTER_KERNEL("subtract",
float16, float16,
complex64, complex64,
complex128) {} complex128) {}
PT_REGISTER_KERNEL("divide", PT_REGISTER_KERNEL(divide,
CUDA, CUDA,
ANY, ANY,
pten::ElementwiseDiv, pten::ElementwiseDiv,
...@@ -165,7 +162,7 @@ PT_REGISTER_KERNEL("divide", ...@@ -165,7 +162,7 @@ PT_REGISTER_KERNEL("divide",
float16, float16,
complex64, complex64,
complex128) {} complex128) {}
PT_REGISTER_KERNEL("multiply", PT_REGISTER_KERNEL(multiply,
CUDA, CUDA,
ANY, ANY,
pten::ElementwiseMul, pten::ElementwiseMul,
...@@ -177,7 +174,7 @@ PT_REGISTER_KERNEL("multiply", ...@@ -177,7 +174,7 @@ PT_REGISTER_KERNEL("multiply",
float16, float16,
complex64, complex64,
complex128) {} complex128) {}
PT_REGISTER_KERNEL("sum", PT_REGISTER_KERNEL(sum,
CUDA, CUDA,
ANY, ANY,
pten::Sum, pten::Sum,
......
...@@ -234,7 +234,4 @@ void Copy(const CUDAContext& dev_ctx, ...@@ -234,7 +234,4 @@ void Copy(const CUDAContext& dev_ctx,
} }
} // namespace pten } // namespace pten
// TODO(chenweihang): replace by better impl PT_REGISTER_KERNEL_ALL_DTYPE(copy, CUDA, ANY, pten::Copy) {}
PT_REGISTER_MODULE(UtilsCUDA);
PT_REGISTER_KERNEL_WITH_NO_TYPE("copy", CUDA, ANY, pten::Copy) {}
...@@ -95,12 +95,7 @@ void ReshapeFromVectorDT(const XPUContext& dev_ctx, ...@@ -95,12 +95,7 @@ void ReshapeFromVectorDT(const XPUContext& dev_ctx,
} // namespace pten } // namespace pten
// TODO(chenweihang): replace by better impl PT_REGISTER_KERNEL(flatten,
PT_REGISTER_MODULE(ManipulationXPU);
// TODO(yuanrisheng): "flatten_contiguous_range" is compatible with old kernel
// architecture, kernel_name should be "flatten".
PT_REGISTER_KERNEL("flatten_contiguous_range",
XPU, XPU,
ANY, ANY,
pten::Flatten, pten::Flatten,
...@@ -112,7 +107,7 @@ PT_REGISTER_KERNEL("flatten_contiguous_range", ...@@ -112,7 +107,7 @@ PT_REGISTER_KERNEL("flatten_contiguous_range",
int, int,
int64_t) {} int64_t) {}
PT_REGISTER_KERNEL("flatten_contiguous_range.mid", PT_REGISTER_KERNEL(flatten_mid,
XPU, XPU,
ANY, ANY,
pten::FlattenWithXShape, pten::FlattenWithXShape,
...@@ -124,9 +119,4 @@ PT_REGISTER_KERNEL("flatten_contiguous_range.mid", ...@@ -124,9 +119,4 @@ PT_REGISTER_KERNEL("flatten_contiguous_range.mid",
int, int,
int64_t) {} int64_t) {}
// TODO(yuanrisheng): "reshape2" is compatible with old kernel PT_REGISTER_KERNEL_ALL_DTYPE(reshape, XPU, ANY, pten::ReshapeFromVectorVal) {}
// architecture, kernel_name should be "reshape".
PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape2",
XPU,
ANY,
pten::ReshapeFromVectorVal) {}
...@@ -76,7 +76,4 @@ void Copy(const XPUDeviceContext& dev_ctx, ...@@ -76,7 +76,4 @@ void Copy(const XPUDeviceContext& dev_ctx,
} // namespace pten } // namespace pten
// TODO(chenweihang): replace by better impl PT_REGISTER_KERNEL_ALL_DTYPE(copy, XPU, ANY, pten::Copy) {}
PT_REGISTER_MODULE(UtilsXPU);
PT_REGISTER_KERNEL_WITH_NO_TYPE("copy", XPU, ANY, pten::Copy) {}
...@@ -21,12 +21,6 @@ limitations under the License. */ ...@@ -21,12 +21,6 @@ limitations under the License. */
#include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h" #include "paddle/pten/core/kernel_registry.h"
PT_DECLARE_MODULE(ManipulationCPU);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_DECLARE_MODULE(ManipulationCUDA);
#endif
namespace paddle { namespace paddle {
namespace tests { namespace tests {
......
...@@ -345,6 +345,7 @@ def source_include(header_file_path): ...@@ -345,6 +345,7 @@ def source_include(header_file_path):
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/pten/api/lib/api_registry.h" #include "paddle/pten/api/lib/api_registry.h"
#include "paddle/pten/api/lib/kernel_declare.h"
#include "paddle/pten/api/lib/kernel_dispatch.h" #include "paddle/pten/api/lib/kernel_dispatch.h"
#include "paddle/pten/api/lib/utils/allocator.h" #include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/core/kernel_registry.h" #include "paddle/pten/core/kernel_registry.h"
...@@ -353,22 +354,6 @@ def source_include(header_file_path): ...@@ -353,22 +354,6 @@ def source_include(header_file_path):
""" """
def module_declare():
return """
PT_DECLARE_MODULE(CreationCPU);
PT_DECLARE_MODULE(LinalgCPU);
PT_DECLARE_MODULE(ManipulationCPU);
PT_DECLARE_MODULE(MathCPU);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_DECLARE_MODULE(CreationCUDA);
PT_DECLARE_MODULE(LinalgCUDA);
PT_DECLARE_MODULE(ManipulationCUDA);
PT_DECLARE_MODULE(MathCUDA);
#endif
"""
def api_register(): def api_register():
return """ return """
PT_REGISTER_API(Creation); PT_REGISTER_API(Creation);
...@@ -405,7 +390,6 @@ def generate_api(api_yaml_path, header_file_path, source_file_path): ...@@ -405,7 +390,6 @@ def generate_api(api_yaml_path, header_file_path, source_file_path):
include_header_file = "paddle/pten/api/include/api.h" include_header_file = "paddle/pten/api/include/api.h"
source_file.write(source_include(include_header_file)) source_file.write(source_include(include_header_file))
source_file.write(module_declare())
source_file.write(namespace[0]) source_file.write(namespace[0])
for api in apis: for api in apis:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册