未验证 提交 b199ba85 编写于 作者: C Chen Weihang 提交者: GitHub

[PTen] Refine Kernel Registrar Writing (#37977)

* refine the kernel register impl

* fix cmake and symbol error

* remove overload marco

* polish details
上级 dfed4a63
......@@ -555,10 +555,10 @@ class Reshape2Op : public ReshapeOp {
const framework::ExecutionContext &ctx) const override {
auto multi_inputs = ctx.MultiInput<framework::Tensor>("ShapeTensor");
if (multi_inputs.size() > 0) {
return framework::KernelSignature("reshape.mulhost", {"X", "ShapeTensor"},
return framework::KernelSignature("reshape_mulhost", {"X", "ShapeTensor"},
{}, {"Out"});
} else if (ctx.HasInput("Shape")) {
return framework::KernelSignature("reshape.host", {"X", "Shape"}, {},
return framework::KernelSignature("reshape_host", {"X", "Shape"}, {},
{"Out"});
} else {
return framework::KernelSignature("reshape", {"X"}, {"shape"}, {"Out"});
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/pten/core/kernel_registry.h"
// TODO(chenweihang) After the kernel is split into a single file,
// the kernel declare statement is automatically generated according to the
// file name of the kernel, and this header file will be removed
PT_DECLARE_KERNEL(full_like, CPU);
PT_DECLARE_KERNEL(dot, CPU);
PT_DECLARE_KERNEL(flatten, CPU);
PT_DECLARE_KERNEL(sign, CPU);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_DECLARE_KERNEL(full_like, CUDA);
PT_DECLARE_KERNEL(dot, CUDA);
PT_DECLARE_KERNEL(flatten, CUDA);
PT_DECLARE_KERNEL(sign, CUDA);
#endif
#ifdef PADDLE_WITH_XPU
PT_DECLARE_KERNEL(flatten, XPU);
#endif
......@@ -25,10 +25,14 @@ limitations under the License. */
#include "paddle/pten/include/core.h"
#include "paddle/pten/include/infermeta.h"
PT_DECLARE_MODULE(UtilsCPU);
PT_DECLARE_KERNEL(copy, CPU);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_DECLARE_MODULE(UtilsCUDA);
PT_DECLARE_KERNEL(copy, CUDA);
#endif
#ifdef PADDLE_WITH_XPU
PT_DECLARE_KERNEL(copy, XPU);
#endif
namespace paddle {
......
......@@ -15,6 +15,7 @@
#pragma once
#include <cstring>
#include <string>
#include <type_traits>
#include <typeindex>
#include <typeinfo>
......@@ -24,6 +25,8 @@
#include "paddle/pten/core/kernel_factory.h"
#include "paddle/pten/core/kernel_utils.h"
#include "paddle/fluid/platform/enforce.h"
namespace pten {
#define BACKEND(arg__) pten::Backend::arg__
......@@ -193,64 +196,35 @@ struct KernelRegistrar {
#define _PT_ARG_N(args) _PT_ARG_N_EXPAND args
#define _PT_RESQ_N() 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0
/** PT_REGISTER_KERNEL
*
* The most frequently used kernel registration macro, used for kernel
* registration with only data type as template parameter, and the function
* pointer of the corresponding data type is automatically instantiated
* during registration.
*/
#define PT_REGISTER_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
_PT_REGISTER_KERNEL(kernel_name, \
PT_ID, \
backend, \
layout, \
meta_kernel_fn, \
cpp_dtype, \
__VA_ARGS__)
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
pt_register_kernel_ns_check_##kernel_name, \
"PT_REGISTER_KERNEL must be called in global namespace."); \
_PT_REGISTER_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, __VA_ARGS__)
#ifndef _WIN32
#define _PT_REGISTER_KERNEL( \
kernel_name, func_id, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
PT_CONCATENATE(pt_op_kernel_ns_check_, func_id), \
"PT_REGISTER_KERNEL must be called in global namespace."); \
PT_KERNEL_INSTANTIATION(meta_kernel_fn, cpp_dtype, __VA_ARGS__); \
static void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, \
func_id)(::pten::Kernel*); \
PT_KERNEL_REGISTRAR_INIT(kernel_name, \
func_id, \
backend, \
layout, \
&PT_CONCATENATE(__PT_KERNEL_args_def_FN_, func_id), \
meta_kernel_fn, \
cpp_dtype, \
__VA_ARGS__); \
void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, \
func_id)(::pten::Kernel * kernel)
#define _PT_REGISTER_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
PT_KERNEL_INSTANTIATION(meta_kernel_fn, cpp_dtype, __VA_ARGS__); \
static void __PT_KERNEL_args_def_FN_##kernel_name(::pten::Kernel*); \
PT_KERNEL_REGISTRAR_INIT(kernel_name, \
backend, \
layout, \
&__PT_KERNEL_args_def_FN_##kernel_name, \
meta_kernel_fn, \
cpp_dtype, \
__VA_ARGS__); \
void __PT_KERNEL_args_def_FN_##kernel_name(::pten::Kernel* kernel)
#else
#define _PT_REGISTER_KERNEL( \
kernel_name, func_id, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
PT_CONCATENATE(pt_op_kernel_ns_check_, func_id), \
"PT_REGISTER_KERNEL must be called in global namespace."); \
static void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, \
func_id)(::pten::Kernel*); \
PT_KERNEL_REGISTRAR_INIT(kernel_name, \
func_id, \
backend, \
layout, \
&PT_CONCATENATE(__PT_KERNEL_args_def_FN_, func_id), \
meta_kernel_fn, \
cpp_dtype, \
__VA_ARGS__); \
void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, \
func_id)(::pten::Kernel * kernel)
#endif
#define PT_KERNEL_INSTANTIATION(meta_kernel_fn, cpp_dtype, ...) \
_PT_KERNEL_INSTANTIATION(PT_NARGS(cpp_dtype, __VA_ARGS__), \
meta_kernel_fn, \
cpp_dtype, \
__VA_ARGS__)
#define _PT_KERNEL_INSTANTIATION(N, meta_kernel_fn, cpp_dtype, ...) \
PT_CONCATENATE(_PT_KERNEL_INSTANTIATION_, N) \
(meta_kernel_fn, cpp_dtype, __VA_ARGS__)
/**
* `template decltype(fn) fn` can work on gcc and clang,
* but msvc will failed, error like:
......@@ -261,8 +235,30 @@ struct KernelRegistrar {
*
* https://stackoverflow.com/questions/63989585/explicit-instantiation-of-function-using-decltype-work-on-g-but-not-on-visua
*
* So we solve the explict instantiation of kernel by CMake
* And msvc can work without template instantiation
*/
#define _PT_REGISTER_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
static void __PT_KERNEL_args_def_FN_##kernel_name(::pten::Kernel*); \
PT_KERNEL_REGISTRAR_INIT(kernel_name, \
backend, \
layout, \
&__PT_KERNEL_args_def_FN_##kernel_name, \
meta_kernel_fn, \
cpp_dtype, \
__VA_ARGS__); \
void __PT_KERNEL_args_def_FN_##kernel_name(::pten::Kernel* kernel)
#endif
#define PT_KERNEL_INSTANTIATION(meta_kernel_fn, cpp_dtype, ...) \
_PT_KERNEL_INSTANTIATION(PT_NARGS(cpp_dtype, __VA_ARGS__), \
meta_kernel_fn, \
cpp_dtype, \
__VA_ARGS__)
#define _PT_KERNEL_INSTANTIATION(N, meta_kernel_fn, cpp_dtype, ...) \
PT_CONCATENATE(_PT_KERNEL_INSTANTIATION_, N) \
(meta_kernel_fn, cpp_dtype, __VA_ARGS__)
#define _PT_KERNEL_INSTANTIATION_1(meta_kernel_fn, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>
......@@ -309,22 +305,15 @@ struct KernelRegistrar {
template decltype(meta_kernel_fn<cpp_dtype>) meta_kernel_fn<cpp_dtype>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_14(meta_kernel_fn, __VA_ARGS__))
#define PT_KERNEL_REGISTRAR_INIT(kernel_name, \
func_id, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
_PT_KERNEL_REGISTRAR_INIT(PT_NARGS(cpp_dtype, __VA_ARGS__), \
kernel_name, \
func_id, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
#define PT_KERNEL_REGISTRAR_INIT( \
kernel_name, backend, layout, args_def_fn, meta_kernel_fn, cpp_dtype, ...) \
_PT_KERNEL_REGISTRAR_INIT(PT_NARGS(cpp_dtype, __VA_ARGS__), \
kernel_name, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
__VA_ARGS__)
// clang-format off
......@@ -333,7 +322,6 @@ struct KernelRegistrar {
and multi-line macros cannot be skipped with NOLINT.*/
#define _PT_KERNEL_REGISTRAR_INIT(N, \
kernel_name, \
func_id, \
backend, \
layout, \
args_def_fn, \
......@@ -342,7 +330,6 @@ struct KernelRegistrar {
...) \
PT_CONCATENATE(_PT_KERNEL_REGISTRAR_INIT_, N) ( \
kernel_name, \
func_id, \
PT_ID, \
backend, \
layout, \
......@@ -354,7 +341,6 @@ struct KernelRegistrar {
// clang-format on
#define _PT_KERNEL_REGISTRAR_INIT_1(kernel_name, \
func_id, \
registrar_id, \
backend, \
layout, \
......@@ -363,17 +349,17 @@ struct KernelRegistrar {
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_op_kernel_##func_id##_, registrar_id)( \
kernel_name, \
__reg_pt_kernel_##kernel_name##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>));
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \
int TouchKernelSymbolFor_##kernel_name##_##backend() { return 0; }
#define _PT_KERNEL_REGISTRAR_INIT_2(kernel_name, \
func_id, \
registrar_id, \
backend, \
layout, \
......@@ -382,8 +368,8 @@ struct KernelRegistrar {
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_op_kernel_##func_id##_, registrar_id)( \
kernel_name, \
__reg_pt_kernel_##kernel_name##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
......@@ -392,7 +378,6 @@ struct KernelRegistrar {
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_1(kernel_name, \
func_id, \
PT_ID, \
backend, \
layout, \
......@@ -400,7 +385,6 @@ struct KernelRegistrar {
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_3(kernel_name, \
func_id, \
registrar_id, \
backend, \
layout, \
......@@ -409,8 +393,8 @@ struct KernelRegistrar {
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_op_kernel_##func_id##_, registrar_id)( \
kernel_name, \
__reg_pt_kernel_##kernel_name##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
......@@ -419,7 +403,6 @@ struct KernelRegistrar {
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_2(kernel_name, \
func_id, \
PT_ID, \
backend, \
layout, \
......@@ -427,7 +410,6 @@ struct KernelRegistrar {
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_4(kernel_name, \
func_id, \
registrar_id, \
backend, \
layout, \
......@@ -436,8 +418,8 @@ struct KernelRegistrar {
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_op_kernel_##func_id##_, registrar_id)( \
kernel_name, \
__reg_pt_kernel_##kernel_name##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
......@@ -446,7 +428,6 @@ struct KernelRegistrar {
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_3(kernel_name, \
func_id, \
PT_ID, \
backend, \
layout, \
......@@ -454,7 +435,6 @@ struct KernelRegistrar {
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_5(kernel_name, \
func_id, \
registrar_id, \
backend, \
layout, \
......@@ -463,8 +443,8 @@ struct KernelRegistrar {
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_op_kernel_##func_id##_, registrar_id)( \
kernel_name, \
__reg_pt_kernel_##kernel_name##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
......@@ -473,7 +453,6 @@ struct KernelRegistrar {
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_4(kernel_name, \
func_id, \
PT_ID, \
backend, \
layout, \
......@@ -481,7 +460,6 @@ struct KernelRegistrar {
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_6(kernel_name, \
func_id, \
registrar_id, \
backend, \
layout, \
......@@ -490,8 +468,8 @@ struct KernelRegistrar {
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_op_kernel_##func_id##_, registrar_id)( \
kernel_name, \
__reg_pt_kernel_##kernel_name##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
......@@ -500,7 +478,6 @@ struct KernelRegistrar {
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_5(kernel_name, \
func_id, \
PT_ID, \
backend, \
layout, \
......@@ -508,7 +485,6 @@ struct KernelRegistrar {
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_7(kernel_name, \
func_id, \
registrar_id, \
backend, \
layout, \
......@@ -517,8 +493,8 @@ struct KernelRegistrar {
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_op_kernel_##func_id##_, registrar_id)( \
kernel_name, \
__reg_pt_kernel_##kernel_name##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
......@@ -527,7 +503,6 @@ struct KernelRegistrar {
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_6(kernel_name, \
func_id, \
PT_ID, \
backend, \
layout, \
......@@ -535,7 +510,6 @@ struct KernelRegistrar {
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_8(kernel_name, \
func_id, \
registrar_id, \
backend, \
layout, \
......@@ -544,8 +518,8 @@ struct KernelRegistrar {
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_op_kernel_##func_id##_, registrar_id)( \
kernel_name, \
__reg_pt_kernel_##kernel_name##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
......@@ -554,7 +528,6 @@ struct KernelRegistrar {
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_7(kernel_name, \
func_id, \
PT_ID, \
backend, \
layout, \
......@@ -562,7 +535,6 @@ struct KernelRegistrar {
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_9(kernel_name, \
func_id, \
registrar_id, \
backend, \
layout, \
......@@ -571,8 +543,8 @@ struct KernelRegistrar {
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_op_kernel_##func_id##_, registrar_id)( \
kernel_name, \
__reg_pt_kernel_##kernel_name##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
......@@ -581,7 +553,6 @@ struct KernelRegistrar {
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_8(kernel_name, \
func_id, \
PT_ID, \
backend, \
layout, \
......@@ -589,7 +560,6 @@ struct KernelRegistrar {
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_10(kernel_name, \
func_id, \
registrar_id, \
backend, \
layout, \
......@@ -598,8 +568,8 @@ struct KernelRegistrar {
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_op_kernel_##func_id##_, registrar_id)( \
kernel_name, \
__reg_pt_kernel_##kernel_name##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
......@@ -608,7 +578,6 @@ struct KernelRegistrar {
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_9(kernel_name, \
func_id, \
PT_ID, \
backend, \
layout, \
......@@ -616,7 +585,6 @@ struct KernelRegistrar {
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_11(kernel_name, \
func_id, \
registrar_id, \
backend, \
layout, \
......@@ -625,8 +593,8 @@ struct KernelRegistrar {
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_op_kernel_##func_id##_, registrar_id)( \
kernel_name, \
__reg_pt_kernel_##kernel_name##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
......@@ -635,7 +603,6 @@ struct KernelRegistrar {
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_10(kernel_name, \
func_id, \
PT_ID, \
backend, \
layout, \
......@@ -643,7 +610,6 @@ struct KernelRegistrar {
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_12(kernel_name, \
func_id, \
registrar_id, \
backend, \
layout, \
......@@ -652,8 +618,8 @@ struct KernelRegistrar {
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_op_kernel_##func_id##_, registrar_id)( \
kernel_name, \
__reg_pt_kernel_##kernel_name##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
......@@ -662,7 +628,6 @@ struct KernelRegistrar {
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_11(kernel_name, \
func_id, \
PT_ID, \
backend, \
layout, \
......@@ -670,7 +635,6 @@ struct KernelRegistrar {
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_13(kernel_name, \
func_id, \
registrar_id, \
backend, \
layout, \
......@@ -679,8 +643,8 @@ struct KernelRegistrar {
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_op_kernel_##func_id##_, registrar_id)( \
kernel_name, \
__reg_pt_kernel_##kernel_name##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
......@@ -689,7 +653,6 @@ struct KernelRegistrar {
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_12(kernel_name, \
func_id, \
PT_ID, \
backend, \
layout, \
......@@ -697,7 +660,6 @@ struct KernelRegistrar {
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_14(kernel_name, \
func_id, \
registrar_id, \
backend, \
layout, \
......@@ -706,8 +668,8 @@ struct KernelRegistrar {
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_op_kernel_##func_id##_, registrar_id)( \
kernel_name, \
__reg_pt_kernel_##kernel_name##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
......@@ -716,7 +678,6 @@ struct KernelRegistrar {
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_13(kernel_name, \
func_id, \
PT_ID, \
backend, \
layout, \
......@@ -724,7 +685,6 @@ struct KernelRegistrar {
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_15(kernel_name, \
func_id, \
registrar_id, \
backend, \
layout, \
......@@ -733,8 +693,8 @@ struct KernelRegistrar {
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_op_kernel_##func_id##_, registrar_id)( \
kernel_name, \
__reg_pt_kernel_##kernel_name##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
......@@ -743,7 +703,6 @@ struct KernelRegistrar {
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_14(kernel_name, \
func_id, \
PT_ID, \
backend, \
layout, \
......@@ -751,90 +710,59 @@ struct KernelRegistrar {
meta_kernel_fn, \
__VA_ARGS__))
#define PT_REGISTER_KERNEL_STANDARD( \
kernel_name, backend, layout, dtype, kernel_fn) \
_PT_REGISTER_KERNEL_STANDARD( \
kernel_name, PT_ID, backend, layout, dtype, kernel_fn)
#define _PT_REGISTER_KERNEL_STANDARD( \
kernel_name, func_id, backend, layout, dtype, kernel_fn) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
PT_CONCATENATE(pt_op_kernel_ns_check_, func_id), \
"_PT_REGISTER_KERNEL_STANDARD must be called in global namespace."); \
template decltype(kernel_fn) kernel_fn; \
static void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, \
func_id)(::pten::Kernel*); \
static const ::pten::KernelRegistrar PT_CONCATENATE(__reg_pt_op_kernel_, \
func_id)( \
kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
DATATYPE(dtype), \
::pten::KernelArgsParseFunctor<decltype(&kernel_fn)>::Parse, \
args_def_fn, \
PT_KERNEL(kernel_fn)); \
void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, func_id)(::pten::Kernel*)
// use to declare symbol
#define PT_REGISTER_MODULE(name) \
int RegisterSymbolsFor##name() { return 0; }
#define PT_DECLARE_MODULE(name) \
extern int RegisterSymbolsFor##name(); \
UNUSED static int use_kernel_module_##name = RegisterSymbolsFor##name()
// only used in cpp tests
#define PT_REGISTER_KERNEL_FOR_TEST( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
_PT_REGISTER_KERNEL_FOR_TEST(kernel_name, \
PT_ID, \
backend, \
layout, \
meta_kernel_fn, \
cpp_dtype, \
__VA_ARGS__)
#define _PT_REGISTER_KERNEL_FOR_TEST( \
kernel_name, func_id, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
PT_CONCATENATE(pt_op_kernel_for_test_ns_check_, func_id), \
"PT_REGISTER_KERNEL must be called in global namespace."); \
static void PT_CONCATENATE(__PT_KERNEL_for_test_args_def_FN_, \
func_id)(::pten::Kernel*); \
PT_KERNEL_REGISTRAR_INIT( \
kernel_name, \
func_id, \
backend, \
layout, \
&PT_CONCATENATE(__PT_KERNEL_for_test_args_def_FN_, func_id), \
meta_kernel_fn, \
cpp_dtype, \
__VA_ARGS__); \
void PT_CONCATENATE(__PT_KERNEL_for_test_args_def_FN_, \
func_id)(::pten::Kernel * kernel)
#define PT_REGISTER_KERNEL_WITH_NO_TYPE( \
kernel_name, backend, layout, meta_kernel_fn) \
_PT_REGISTER_KERNEL_WITH_NO_TYPE( \
kernel_name, PT_ID, backend, layout, meta_kernel_fn)
#define _PT_REGISTER_KERNEL_WITH_NO_TYPE( \
kernel_name, func_id, backend, layout, meta_kernel_fn) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
PT_CONCATENATE(pt_op_kernel_ns_check_, func_id), \
"PT_REGISTER_KERNEL must be called in global namespace."); \
decltype(meta_kernel_fn) meta_kernel_fn; \
static void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, \
func_id)(::pten::Kernel*); \
static const ::pten::KernelRegistrar PT_CONCATENATE(__reg_pt_op_kernel_, \
func_id)( \
kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::pten::KernelArgsParseFunctor<decltype(&meta_kernel_fn)>::Parse, \
&PT_CONCATENATE(__PT_KERNEL_args_def_FN_, func_id), \
PT_KERNEL(meta_kernel_fn)); \
void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, \
func_id)(::pten::Kernel * kernel)
/** PT_REGISTER_SINGLE_KERNEL
*
* Used to register a single kernel, pass in the complete function pointer
* of the kernel, this registration macro will not do automatic template
* instantiation.
*/
#define PT_REGISTER_SINGLE_KERNEL( \
kernel_name, backend, layout, dtype, kernel_fn) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
pt_register_single_kernel_ns_check_##kernel_name, \
"PT_REGISTER_SINGLE_KERNEL must be called in global namespace."); \
static void __PT_SINGLE_KERNEL_args_def_FN_##kernel_name(::pten::Kernel*); \
static const ::pten::KernelRegistrar __reg_pt_single_kernel_##kernel_name( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
DATATYPE(dtype), \
::pten::KernelArgsParseFunctor<decltype(&kernel_fn)>::Parse, \
args_def_fn, \
PT_KERNEL(kernel_fn)); \
int TouchKernelSymbolFor_##kernel_name##_##backend() { return 0; } \
void __PT_SINGLE_KERNEL_args_def_FN_##kernel_name(::pten::Kernel*)
/** PT_REGISTER_KERNEL_ALL_DTYPE
*
* Used to register a kernel that supports all data types, such as copy and
* reshape that are not sensitive to data types.
*/
#define PT_REGISTER_KERNEL_ALL_DTYPE(kernel_name, backend, layout, kernel_fn) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
pt_register_kernel_all_dtype_ns_check_##kernel_name, \
"PT_REGISTER_KERNEL_ALL_DTYPE must be called in global namespace."); \
static void __PT_KERNEL_ALL_DTYPE_args_def_FN_##kernel_name( \
::pten::Kernel*); \
static const ::pten::KernelRegistrar \
__reg_pt_kernel_all_dtype_##kernel_name( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::pten::KernelArgsParseFunctor<decltype(&kernel_fn)>::Parse, \
&__PT_KERNEL_ALL_DTYPE_args_def_FN_##kernel_name, \
PT_KERNEL(kernel_fn)); \
int TouchKernelSymbolFor_##kernel_name##_##backend() { return 0; } \
void __PT_KERNEL_ALL_DTYPE_args_def_FN_##kernel_name(::pten::Kernel* kernel)
/** PT_DECLARE_KERNEL
*
* Used to export the symbols of the file where the kernel is located,
* to avoid being removed by linker
*/
#define PT_DECLARE_KERNEL(kernel_name, backend) \
extern int TouchKernelSymbolFor_##kernel_name##_##backend(); \
UNUSED static int __declare_kernel_symbol_for_##kernel_name##_##backend = \
TouchKernelSymbolFor_##kernel_name##_##backend()
} // namespace pten
......@@ -61,9 +61,7 @@ void FillConstant(const CPUContext& dev_ctx,
} // namespace pten
PT_REGISTER_MODULE(CreationCPU);
PT_REGISTER_KERNEL("full_like",
PT_REGISTER_KERNEL(full_like,
CPU,
ANY,
pten::FillAnyLike,
......@@ -74,7 +72,7 @@ PT_REGISTER_KERNEL("full_like",
bool,
paddle::platform::float16) {}
PT_REGISTER_KERNEL("full",
PT_REGISTER_KERNEL(full,
CPU,
ANY,
pten::FillConstant,
......
......@@ -70,12 +70,10 @@ void Matmul(const CPUContext& dev_ctx,
} // namespace pten
PT_REGISTER_MODULE(LinalgCPU);
using complex64 = ::paddle::platform::complex<float>;
using complex128 = ::paddle::platform::complex<double>;
PT_REGISTER_KERNEL("dot",
PT_REGISTER_KERNEL(dot,
CPU,
ANY,
pten::Dot,
......@@ -87,5 +85,4 @@ PT_REGISTER_KERNEL("dot",
complex128) {}
PT_REGISTER_KERNEL(
"matmul_v2", CPU, ANY, pten::Matmul, float, double, complex64, complex128) {
}
matmul_v2, CPU, ANY, pten::Matmul, float, double, complex64, complex128) {}
......@@ -130,12 +130,9 @@ void Cast(const CPUContext& dev_ctx,
} // namespace pten
// TODO(chenweihang): replace by better impl
PT_REGISTER_MODULE(ManipulationCPU);
// TODO(yuanrisheng): "flatten_contiguous_range" is compatible with old kernel
// architecture, kernel_name should be "flatten".
PT_REGISTER_KERNEL("flatten",
PT_REGISTER_KERNEL(flatten,
CPU,
ANY,
pten::Flatten,
......@@ -145,8 +142,7 @@ PT_REGISTER_KERNEL("flatten",
int8_t,
int,
int64_t) {}
PT_REGISTER_KERNEL("flatten.mid",
PT_REGISTER_KERNEL(flatten_mid,
CPU,
ANY,
pten::FlattenWithXShape,
......@@ -156,7 +152,8 @@ PT_REGISTER_KERNEL("flatten.mid",
int8_t,
int,
int64_t) {}
PT_REGISTER_KERNEL("cast",
PT_REGISTER_KERNEL(cast,
CPU,
ANY,
pten::Cast,
......@@ -174,42 +171,33 @@ PT_REGISTER_KERNEL("cast",
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
}
// TODO(yuanrisheng): "reshape2" is compatible with old kernel
// architecture, kernel_name should be "reshape".
PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape",
CPU,
ANY,
pten::ReshapeFromVectorVal) {}
PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape.mid",
CPU,
ANY,
pten::ReshapeFromVectorValWithXShape) {}
PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape.host", CPU, ANY, pten::ReshapeFromDT) {
PT_REGISTER_KERNEL_ALL_DTYPE(reshape, CPU, ANY, pten::ReshapeFromVectorVal) {}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_mid,
CPU,
ANY,
pten::ReshapeFromVectorValWithXShape) {}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_host, CPU, ANY, pten::ReshapeFromDT) {
kernel->InputAt(1).SetBackend(pten::Backend::CPU);
kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32);
}
PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape.host.mid",
CPU,
ANY,
pten::ReshapeFromDTWithXShape) {
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_host_mid,
CPU,
ANY,
pten::ReshapeFromDTWithXShape) {
kernel->InputAt(1).SetBackend(pten::Backend::CPU);
kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32);
}
PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape.mulhost",
CPU,
ANY,
pten::ReshapeFromVectorDT) {
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_mulhost,
CPU,
ANY,
pten::ReshapeFromVectorDT) {
kernel->InputAt(1).SetBackend(pten::Backend::CPU);
kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32);
}
PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape.mulhost.mid",
CPU,
ANY,
pten::ReshapeFromVectorDTWithXShape) {
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_mulhost_mid,
CPU,
ANY,
pten::ReshapeFromVectorDTWithXShape) {
kernel->InputAt(1).SetBackend(pten::Backend::CPU);
kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32);
}
......@@ -106,18 +106,14 @@ DEFINE_CPU_ELEMENTWISE_OP(Mul)
} // namespace pten
// TODO(chenweihang): replace by better impl
PT_REGISTER_MODULE(MathCPU);
using complex64 = ::paddle::platform::complex<float>;
using complex128 = ::paddle::platform::complex<double>;
// NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16
// using bfloat16 = ::paddle::platform::bfloat16;
PT_REGISTER_KERNEL("sign", CPU, ANY, pten::Sign, float, double) {}
PT_REGISTER_KERNEL("mean", CPU, ANY, pten::Mean, float, double, bool) {}
PT_REGISTER_KERNEL("scale",
PT_REGISTER_KERNEL(sign, CPU, ANY, pten::Sign, float, double) {}
PT_REGISTER_KERNEL(mean, CPU, ANY, pten::Mean, float, double, bool) {}
PT_REGISTER_KERNEL(scale,
CPU,
ANY,
pten::Scale,
......@@ -129,8 +125,7 @@ PT_REGISTER_KERNEL("scale",
int16_t,
int,
int64_t) {}
PT_REGISTER_KERNEL("add",
PT_REGISTER_KERNEL(add,
CPU,
ANY,
pten::ElementwiseAdd,
......@@ -140,7 +135,7 @@ PT_REGISTER_KERNEL("add",
int64_t,
complex64,
complex128) {}
PT_REGISTER_KERNEL("subtract",
PT_REGISTER_KERNEL(subtract,
CPU,
ANY,
pten::ElementwiseSub,
......@@ -150,7 +145,7 @@ PT_REGISTER_KERNEL("subtract",
int64_t,
complex64,
complex128) {}
PT_REGISTER_KERNEL("divide",
PT_REGISTER_KERNEL(divide,
CPU,
ANY,
pten::ElementwiseDiv,
......@@ -160,7 +155,7 @@ PT_REGISTER_KERNEL("divide",
int64_t,
complex64,
complex128) {}
PT_REGISTER_KERNEL("multiply",
PT_REGISTER_KERNEL(multiply,
CPU,
ANY,
pten::ElementwiseMul,
......@@ -171,8 +166,7 @@ PT_REGISTER_KERNEL("multiply",
bool,
complex64,
complex128) {}
PT_REGISTER_KERNEL("sum",
PT_REGISTER_KERNEL(sum,
CPU,
ANY,
pten::Sum,
......
......@@ -57,7 +57,4 @@ void Copy(const CPUContext& dev_ctx,
} // namespace pten
// TODO(chenweihang): replace by better impl
PT_REGISTER_MODULE(UtilsCPU);
PT_REGISTER_KERNEL_WITH_NO_TYPE("copy", CPU, ANY, pten::Copy) {}
PT_REGISTER_KERNEL_ALL_DTYPE(copy, CPU, ANY, pten::Copy) {}
......@@ -62,9 +62,7 @@ void FillConstant(const CUDAContext& dev_ctx,
} // namespace pten
PT_REGISTER_MODULE(CreationCUDA);
PT_REGISTER_KERNEL("full_like",
PT_REGISTER_KERNEL(full_like,
CUDA,
ANY,
pten::FillAnyLike,
......@@ -75,7 +73,7 @@ PT_REGISTER_KERNEL("full_like",
bool,
paddle::platform::float16) {}
PT_REGISTER_KERNEL("full",
PT_REGISTER_KERNEL(full,
CUDA,
ANY,
pten::FillConstant,
......
......@@ -54,13 +54,11 @@ void Matmul(const CUDAContext& dev_ctx,
} // namespace pten
PT_REGISTER_MODULE(LinalgCUDA);
using float16 = paddle::platform::float16;
using complex64 = ::paddle::platform::complex<float>;
using complex128 = ::paddle::platform::complex<double>;
PT_REGISTER_KERNEL("dot",
PT_REGISTER_KERNEL(dot,
CUDA,
ANY,
pten::Dot,
......@@ -71,7 +69,7 @@ PT_REGISTER_KERNEL("dot",
complex64,
complex128) {}
PT_REGISTER_KERNEL("matmul_v2",
PT_REGISTER_KERNEL(matmul_v2,
CUDA,
ANY,
pten::Matmul,
......
......@@ -129,13 +129,9 @@ void Cast(const CUDAContext& dev_ctx,
} // namespace pten
// TODO(chenweihang): replace by better impl
PT_REGISTER_MODULE(ManipulationCUDA);
using float16 = paddle::platform::float16;
// TODO(yuanrisheng): "flatten_contiguous_range" is compatible with old kernel
// architecture, kernel_name should be "flatten".
PT_REGISTER_KERNEL("flatten",
PT_REGISTER_KERNEL(flatten,
CUDA,
ANY,
pten::Flatten,
......@@ -146,8 +142,7 @@ PT_REGISTER_KERNEL("flatten",
int8_t,
int,
int64_t) {}
PT_REGISTER_KERNEL("flatten.mid",
PT_REGISTER_KERNEL(flatten_mid,
CUDA,
ANY,
pten::FlattenWithXShape,
......@@ -159,7 +154,7 @@ PT_REGISTER_KERNEL("flatten.mid",
int64_t) {}
#define PTEN_REGISTER_CAST_CUDA_BASE_TYPE(op_name, ...) \
PT_REGISTER_KERNEL("cast", \
PT_REGISTER_KERNEL(cast, \
CUDA, \
ANY, \
pten::Cast, \
......@@ -184,44 +179,33 @@ PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast, paddle::platform::bfloat16)
PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast)
#endif
PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape",
CUDA,
ANY,
pten::ReshapeFromVectorVal) {}
PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape.mid",
CUDA,
ANY,
pten::ReshapeFromVectorValWithXShape) {}
PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape.host",
CUDA,
ANY,
pten::ReshapeFromDT) {
PT_REGISTER_KERNEL_ALL_DTYPE(reshape, CUDA, ANY, pten::ReshapeFromVectorVal) {}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_mid,
CUDA,
ANY,
pten::ReshapeFromVectorValWithXShape) {}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_host, CUDA, ANY, pten::ReshapeFromDT) {
kernel->InputAt(1).SetBackend(pten::Backend::CPU);
kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32);
}
PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape.host.mid",
CUDA,
ANY,
pten::ReshapeFromDTWithXShape) {
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_host_mid,
CUDA,
ANY,
pten::ReshapeFromDTWithXShape) {
kernel->InputAt(1).SetBackend(pten::Backend::CPU);
kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32);
}
PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape.mulhost",
CUDA,
ANY,
pten::ReshapeFromVectorDT) {
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_mulhost,
CUDA,
ANY,
pten::ReshapeFromVectorDT) {
kernel->InputAt(1).SetBackend(pten::Backend::CPU);
kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32);
}
PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape.mulhost.mid",
CUDA,
ANY,
pten::ReshapeFromVectorDTWithXShape) {
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_mulhost_mid,
CUDA,
ANY,
pten::ReshapeFromVectorDTWithXShape) {
kernel->InputAt(1).SetBackend(pten::Backend::CPU);
kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32);
}
......@@ -111,16 +111,13 @@ void Sum(const CUDAContext& dev_ctx,
} // namespace pten
// TODO(chenweihang): replace by better impl
PT_REGISTER_MODULE(MathCUDA);
using float16 = paddle::platform::float16;
using complex64 = ::paddle::platform::complex<float>;
using complex128 = ::paddle::platform::complex<double>;
PT_REGISTER_KERNEL("sign", CUDA, ANY, pten::Sign, float, double, float16) {}
PT_REGISTER_KERNEL("mean", CUDA, ANY, pten::Mean, float, double, bool) {}
PT_REGISTER_KERNEL("scale",
PT_REGISTER_KERNEL(sign, CUDA, ANY, pten::Sign, float, double, float16) {}
PT_REGISTER_KERNEL(mean, CUDA, ANY, pten::Mean, float, double, bool) {}
PT_REGISTER_KERNEL(scale,
CUDA,
ANY,
pten::Scale,
......@@ -132,7 +129,7 @@ PT_REGISTER_KERNEL("scale",
int16_t,
int,
int64_t) {}
PT_REGISTER_KERNEL("add",
PT_REGISTER_KERNEL(add,
CUDA,
ANY,
pten::ElementwiseAdd,
......@@ -143,7 +140,7 @@ PT_REGISTER_KERNEL("add",
float16,
complex64,
complex128) {}
PT_REGISTER_KERNEL("subtract",
PT_REGISTER_KERNEL(subtract,
CUDA,
ANY,
pten::ElementwiseSub,
......@@ -154,7 +151,7 @@ PT_REGISTER_KERNEL("subtract",
float16,
complex64,
complex128) {}
PT_REGISTER_KERNEL("divide",
PT_REGISTER_KERNEL(divide,
CUDA,
ANY,
pten::ElementwiseDiv,
......@@ -165,7 +162,7 @@ PT_REGISTER_KERNEL("divide",
float16,
complex64,
complex128) {}
PT_REGISTER_KERNEL("multiply",
PT_REGISTER_KERNEL(multiply,
CUDA,
ANY,
pten::ElementwiseMul,
......@@ -177,7 +174,7 @@ PT_REGISTER_KERNEL("multiply",
float16,
complex64,
complex128) {}
PT_REGISTER_KERNEL("sum",
PT_REGISTER_KERNEL(sum,
CUDA,
ANY,
pten::Sum,
......
......@@ -234,7 +234,4 @@ void Copy(const CUDAContext& dev_ctx,
}
} // namespace pten
// TODO(chenweihang): replace by better impl
PT_REGISTER_MODULE(UtilsCUDA);
PT_REGISTER_KERNEL_WITH_NO_TYPE("copy", CUDA, ANY, pten::Copy) {}
PT_REGISTER_KERNEL_ALL_DTYPE(copy, CUDA, ANY, pten::Copy) {}
......@@ -95,12 +95,7 @@ void ReshapeFromVectorDT(const XPUContext& dev_ctx,
} // namespace pten
// TODO(chenweihang): replace by better impl
PT_REGISTER_MODULE(ManipulationXPU);
// TODO(yuanrisheng): "flatten_contiguous_range" is compatible with old kernel
// architecture, kernel_name should be "flatten".
PT_REGISTER_KERNEL("flatten_contiguous_range",
PT_REGISTER_KERNEL(flatten,
XPU,
ANY,
pten::Flatten,
......@@ -112,7 +107,7 @@ PT_REGISTER_KERNEL("flatten_contiguous_range",
int,
int64_t) {}
PT_REGISTER_KERNEL("flatten_contiguous_range.mid",
PT_REGISTER_KERNEL(flatten_mid,
XPU,
ANY,
pten::FlattenWithXShape,
......@@ -124,9 +119,4 @@ PT_REGISTER_KERNEL("flatten_contiguous_range.mid",
int,
int64_t) {}
// TODO(yuanrisheng): "reshape2" is compatible with old kernel
// architecture, kernel_name should be "reshape".
PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape2",
XPU,
ANY,
pten::ReshapeFromVectorVal) {}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape, XPU, ANY, pten::ReshapeFromVectorVal) {}
......@@ -76,7 +76,4 @@ void Copy(const XPUDeviceContext& dev_ctx,
} // namespace pten
// TODO(chenweihang): replace by better impl
PT_REGISTER_MODULE(UtilsXPU);
PT_REGISTER_KERNEL_WITH_NO_TYPE("copy", XPU, ANY, pten::Copy) {}
PT_REGISTER_KERNEL_ALL_DTYPE(copy, XPU, ANY, pten::Copy) {}
......@@ -21,12 +21,6 @@ limitations under the License. */
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
PT_DECLARE_MODULE(ManipulationCPU);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_DECLARE_MODULE(ManipulationCUDA);
#endif
namespace paddle {
namespace tests {
......
......@@ -345,6 +345,7 @@ def source_include(header_file_path):
#include "glog/logging.h"
#include "paddle/pten/api/lib/api_registry.h"
#include "paddle/pten/api/lib/kernel_declare.h"
#include "paddle/pten/api/lib/kernel_dispatch.h"
#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/core/kernel_registry.h"
......@@ -353,22 +354,6 @@ def source_include(header_file_path):
"""
def module_declare():
return """
PT_DECLARE_MODULE(CreationCPU);
PT_DECLARE_MODULE(LinalgCPU);
PT_DECLARE_MODULE(ManipulationCPU);
PT_DECLARE_MODULE(MathCPU);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_DECLARE_MODULE(CreationCUDA);
PT_DECLARE_MODULE(LinalgCUDA);
PT_DECLARE_MODULE(ManipulationCUDA);
PT_DECLARE_MODULE(MathCUDA);
#endif
"""
def api_register():
return """
PT_REGISTER_API(Creation);
......@@ -405,7 +390,6 @@ def generate_api(api_yaml_path, header_file_path, source_file_path):
include_header_file = "paddle/pten/api/include/api.h"
source_file.write(source_include(include_header_file))
source_file.write(module_declare())
source_file.write(namespace[0])
for api in apis:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册