未验证 提交 b199ba85 编写于 作者: C Chen Weihang 提交者: GitHub

[PTen] Refine Kernel Registrar Writing (#37977)

* refine the kernel register impl

* fix cmake and symbol error

* remove overload marco

* polish details
上级 dfed4a63
......@@ -555,10 +555,10 @@ class Reshape2Op : public ReshapeOp {
const framework::ExecutionContext &ctx) const override {
auto multi_inputs = ctx.MultiInput<framework::Tensor>("ShapeTensor");
if (multi_inputs.size() > 0) {
return framework::KernelSignature("reshape.mulhost", {"X", "ShapeTensor"},
return framework::KernelSignature("reshape_mulhost", {"X", "ShapeTensor"},
{}, {"Out"});
} else if (ctx.HasInput("Shape")) {
return framework::KernelSignature("reshape.host", {"X", "Shape"}, {},
return framework::KernelSignature("reshape_host", {"X", "Shape"}, {},
{"Out"});
} else {
return framework::KernelSignature("reshape", {"X"}, {"shape"}, {"Out"});
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/pten/core/kernel_registry.h"
// TODO(chenweihang) After the kernel is split into a single file,
// the kernel declare statement is automatically generated according to the
// file name of the kernel, and this header file will be removed
PT_DECLARE_KERNEL(full_like, CPU);
PT_DECLARE_KERNEL(dot, CPU);
PT_DECLARE_KERNEL(flatten, CPU);
PT_DECLARE_KERNEL(sign, CPU);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_DECLARE_KERNEL(full_like, CUDA);
PT_DECLARE_KERNEL(dot, CUDA);
PT_DECLARE_KERNEL(flatten, CUDA);
PT_DECLARE_KERNEL(sign, CUDA);
#endif
#ifdef PADDLE_WITH_XPU
PT_DECLARE_KERNEL(flatten, XPU);
#endif
......@@ -25,10 +25,14 @@ limitations under the License. */
#include "paddle/pten/include/core.h"
#include "paddle/pten/include/infermeta.h"
PT_DECLARE_MODULE(UtilsCPU);
PT_DECLARE_KERNEL(copy, CPU);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_DECLARE_MODULE(UtilsCUDA);
PT_DECLARE_KERNEL(copy, CUDA);
#endif
#ifdef PADDLE_WITH_XPU
PT_DECLARE_KERNEL(copy, XPU);
#endif
namespace paddle {
......
此差异已折叠。
......@@ -61,9 +61,7 @@ void FillConstant(const CPUContext& dev_ctx,
} // namespace pten
PT_REGISTER_MODULE(CreationCPU);
PT_REGISTER_KERNEL("full_like",
PT_REGISTER_KERNEL(full_like,
CPU,
ANY,
pten::FillAnyLike,
......@@ -74,7 +72,7 @@ PT_REGISTER_KERNEL("full_like",
bool,
paddle::platform::float16) {}
PT_REGISTER_KERNEL("full",
PT_REGISTER_KERNEL(full,
CPU,
ANY,
pten::FillConstant,
......
......@@ -70,12 +70,10 @@ void Matmul(const CPUContext& dev_ctx,
} // namespace pten
PT_REGISTER_MODULE(LinalgCPU);
using complex64 = ::paddle::platform::complex<float>;
using complex128 = ::paddle::platform::complex<double>;
PT_REGISTER_KERNEL("dot",
PT_REGISTER_KERNEL(dot,
CPU,
ANY,
pten::Dot,
......@@ -87,5 +85,4 @@ PT_REGISTER_KERNEL("dot",
complex128) {}
PT_REGISTER_KERNEL(
"matmul_v2", CPU, ANY, pten::Matmul, float, double, complex64, complex128) {
}
matmul_v2, CPU, ANY, pten::Matmul, float, double, complex64, complex128) {}
......@@ -130,12 +130,9 @@ void Cast(const CPUContext& dev_ctx,
} // namespace pten
// TODO(chenweihang): replace by better impl
PT_REGISTER_MODULE(ManipulationCPU);
// TODO(yuanrisheng): "flatten_contiguous_range" is compatible with old kernel
// architecture, kernel_name should be "flatten".
PT_REGISTER_KERNEL("flatten",
PT_REGISTER_KERNEL(flatten,
CPU,
ANY,
pten::Flatten,
......@@ -145,8 +142,7 @@ PT_REGISTER_KERNEL("flatten",
int8_t,
int,
int64_t) {}
PT_REGISTER_KERNEL("flatten.mid",
PT_REGISTER_KERNEL(flatten_mid,
CPU,
ANY,
pten::FlattenWithXShape,
......@@ -156,7 +152,8 @@ PT_REGISTER_KERNEL("flatten.mid",
int8_t,
int,
int64_t) {}
PT_REGISTER_KERNEL("cast",
PT_REGISTER_KERNEL(cast,
CPU,
ANY,
pten::Cast,
......@@ -174,39 +171,30 @@ PT_REGISTER_KERNEL("cast",
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
}
// TODO(yuanrisheng): "reshape2" is compatible with old kernel
// architecture, kernel_name should be "reshape".
PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape",
CPU,
ANY,
pten::ReshapeFromVectorVal) {}
PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape.mid",
PT_REGISTER_KERNEL_ALL_DTYPE(reshape, CPU, ANY, pten::ReshapeFromVectorVal) {}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_mid,
CPU,
ANY,
pten::ReshapeFromVectorValWithXShape) {}
PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape.host", CPU, ANY, pten::ReshapeFromDT) {
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_host, CPU, ANY, pten::ReshapeFromDT) {
kernel->InputAt(1).SetBackend(pten::Backend::CPU);
kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32);
}
PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape.host.mid",
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_host_mid,
CPU,
ANY,
pten::ReshapeFromDTWithXShape) {
kernel->InputAt(1).SetBackend(pten::Backend::CPU);
kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32);
}
PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape.mulhost",
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_mulhost,
CPU,
ANY,
pten::ReshapeFromVectorDT) {
kernel->InputAt(1).SetBackend(pten::Backend::CPU);
kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32);
}
PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape.mulhost.mid",
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_mulhost_mid,
CPU,
ANY,
pten::ReshapeFromVectorDTWithXShape) {
......
......@@ -106,18 +106,14 @@ DEFINE_CPU_ELEMENTWISE_OP(Mul)
} // namespace pten
// TODO(chenweihang): replace by better impl
PT_REGISTER_MODULE(MathCPU);
using complex64 = ::paddle::platform::complex<float>;
using complex128 = ::paddle::platform::complex<double>;
// NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16
// using bfloat16 = ::paddle::platform::bfloat16;
PT_REGISTER_KERNEL("sign", CPU, ANY, pten::Sign, float, double) {}
PT_REGISTER_KERNEL("mean", CPU, ANY, pten::Mean, float, double, bool) {}
PT_REGISTER_KERNEL("scale",
PT_REGISTER_KERNEL(sign, CPU, ANY, pten::Sign, float, double) {}
PT_REGISTER_KERNEL(mean, CPU, ANY, pten::Mean, float, double, bool) {}
PT_REGISTER_KERNEL(scale,
CPU,
ANY,
pten::Scale,
......@@ -129,8 +125,7 @@ PT_REGISTER_KERNEL("scale",
int16_t,
int,
int64_t) {}
PT_REGISTER_KERNEL("add",
PT_REGISTER_KERNEL(add,
CPU,
ANY,
pten::ElementwiseAdd,
......@@ -140,7 +135,7 @@ PT_REGISTER_KERNEL("add",
int64_t,
complex64,
complex128) {}
PT_REGISTER_KERNEL("subtract",
PT_REGISTER_KERNEL(subtract,
CPU,
ANY,
pten::ElementwiseSub,
......@@ -150,7 +145,7 @@ PT_REGISTER_KERNEL("subtract",
int64_t,
complex64,
complex128) {}
PT_REGISTER_KERNEL("divide",
PT_REGISTER_KERNEL(divide,
CPU,
ANY,
pten::ElementwiseDiv,
......@@ -160,7 +155,7 @@ PT_REGISTER_KERNEL("divide",
int64_t,
complex64,
complex128) {}
PT_REGISTER_KERNEL("multiply",
PT_REGISTER_KERNEL(multiply,
CPU,
ANY,
pten::ElementwiseMul,
......@@ -171,8 +166,7 @@ PT_REGISTER_KERNEL("multiply",
bool,
complex64,
complex128) {}
PT_REGISTER_KERNEL("sum",
PT_REGISTER_KERNEL(sum,
CPU,
ANY,
pten::Sum,
......
......@@ -57,7 +57,4 @@ void Copy(const CPUContext& dev_ctx,
} // namespace pten
// TODO(chenweihang): replace by better impl
PT_REGISTER_MODULE(UtilsCPU);
PT_REGISTER_KERNEL_WITH_NO_TYPE("copy", CPU, ANY, pten::Copy) {}
PT_REGISTER_KERNEL_ALL_DTYPE(copy, CPU, ANY, pten::Copy) {}
......@@ -62,9 +62,7 @@ void FillConstant(const CUDAContext& dev_ctx,
} // namespace pten
PT_REGISTER_MODULE(CreationCUDA);
PT_REGISTER_KERNEL("full_like",
PT_REGISTER_KERNEL(full_like,
CUDA,
ANY,
pten::FillAnyLike,
......@@ -75,7 +73,7 @@ PT_REGISTER_KERNEL("full_like",
bool,
paddle::platform::float16) {}
PT_REGISTER_KERNEL("full",
PT_REGISTER_KERNEL(full,
CUDA,
ANY,
pten::FillConstant,
......
......@@ -54,13 +54,11 @@ void Matmul(const CUDAContext& dev_ctx,
} // namespace pten
PT_REGISTER_MODULE(LinalgCUDA);
using float16 = paddle::platform::float16;
using complex64 = ::paddle::platform::complex<float>;
using complex128 = ::paddle::platform::complex<double>;
PT_REGISTER_KERNEL("dot",
PT_REGISTER_KERNEL(dot,
CUDA,
ANY,
pten::Dot,
......@@ -71,7 +69,7 @@ PT_REGISTER_KERNEL("dot",
complex64,
complex128) {}
PT_REGISTER_KERNEL("matmul_v2",
PT_REGISTER_KERNEL(matmul_v2,
CUDA,
ANY,
pten::Matmul,
......
......@@ -129,13 +129,9 @@ void Cast(const CUDAContext& dev_ctx,
} // namespace pten
// TODO(chenweihang): replace by better impl
PT_REGISTER_MODULE(ManipulationCUDA);
using float16 = paddle::platform::float16;
// TODO(yuanrisheng): "flatten_contiguous_range" is compatible with old kernel
// architecture, kernel_name should be "flatten".
PT_REGISTER_KERNEL("flatten",
PT_REGISTER_KERNEL(flatten,
CUDA,
ANY,
pten::Flatten,
......@@ -146,8 +142,7 @@ PT_REGISTER_KERNEL("flatten",
int8_t,
int,
int64_t) {}
PT_REGISTER_KERNEL("flatten.mid",
PT_REGISTER_KERNEL(flatten_mid,
CUDA,
ANY,
pten::FlattenWithXShape,
......@@ -159,7 +154,7 @@ PT_REGISTER_KERNEL("flatten.mid",
int64_t) {}
#define PTEN_REGISTER_CAST_CUDA_BASE_TYPE(op_name, ...) \
PT_REGISTER_KERNEL("cast", \
PT_REGISTER_KERNEL(cast, \
CUDA, \
ANY, \
pten::Cast, \
......@@ -184,41 +179,30 @@ PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast, paddle::platform::bfloat16)
PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast)
#endif
PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape",
CUDA,
ANY,
pten::ReshapeFromVectorVal) {}
PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape.mid",
PT_REGISTER_KERNEL_ALL_DTYPE(reshape, CUDA, ANY, pten::ReshapeFromVectorVal) {}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_mid,
CUDA,
ANY,
pten::ReshapeFromVectorValWithXShape) {}
PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape.host",
CUDA,
ANY,
pten::ReshapeFromDT) {
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_host, CUDA, ANY, pten::ReshapeFromDT) {
kernel->InputAt(1).SetBackend(pten::Backend::CPU);
kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32);
}
PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape.host.mid",
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_host_mid,
CUDA,
ANY,
pten::ReshapeFromDTWithXShape) {
kernel->InputAt(1).SetBackend(pten::Backend::CPU);
kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32);
}
PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape.mulhost",
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_mulhost,
CUDA,
ANY,
pten::ReshapeFromVectorDT) {
kernel->InputAt(1).SetBackend(pten::Backend::CPU);
kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32);
}
PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape.mulhost.mid",
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_mulhost_mid,
CUDA,
ANY,
pten::ReshapeFromVectorDTWithXShape) {
......
......@@ -111,16 +111,13 @@ void Sum(const CUDAContext& dev_ctx,
} // namespace pten
// TODO(chenweihang): replace by better impl
PT_REGISTER_MODULE(MathCUDA);
using float16 = paddle::platform::float16;
using complex64 = ::paddle::platform::complex<float>;
using complex128 = ::paddle::platform::complex<double>;
PT_REGISTER_KERNEL("sign", CUDA, ANY, pten::Sign, float, double, float16) {}
PT_REGISTER_KERNEL("mean", CUDA, ANY, pten::Mean, float, double, bool) {}
PT_REGISTER_KERNEL("scale",
PT_REGISTER_KERNEL(sign, CUDA, ANY, pten::Sign, float, double, float16) {}
PT_REGISTER_KERNEL(mean, CUDA, ANY, pten::Mean, float, double, bool) {}
PT_REGISTER_KERNEL(scale,
CUDA,
ANY,
pten::Scale,
......@@ -132,7 +129,7 @@ PT_REGISTER_KERNEL("scale",
int16_t,
int,
int64_t) {}
PT_REGISTER_KERNEL("add",
PT_REGISTER_KERNEL(add,
CUDA,
ANY,
pten::ElementwiseAdd,
......@@ -143,7 +140,7 @@ PT_REGISTER_KERNEL("add",
float16,
complex64,
complex128) {}
PT_REGISTER_KERNEL("subtract",
PT_REGISTER_KERNEL(subtract,
CUDA,
ANY,
pten::ElementwiseSub,
......@@ -154,7 +151,7 @@ PT_REGISTER_KERNEL("subtract",
float16,
complex64,
complex128) {}
PT_REGISTER_KERNEL("divide",
PT_REGISTER_KERNEL(divide,
CUDA,
ANY,
pten::ElementwiseDiv,
......@@ -165,7 +162,7 @@ PT_REGISTER_KERNEL("divide",
float16,
complex64,
complex128) {}
PT_REGISTER_KERNEL("multiply",
PT_REGISTER_KERNEL(multiply,
CUDA,
ANY,
pten::ElementwiseMul,
......@@ -177,7 +174,7 @@ PT_REGISTER_KERNEL("multiply",
float16,
complex64,
complex128) {}
PT_REGISTER_KERNEL("sum",
PT_REGISTER_KERNEL(sum,
CUDA,
ANY,
pten::Sum,
......
......@@ -234,7 +234,4 @@ void Copy(const CUDAContext& dev_ctx,
}
} // namespace pten
// TODO(chenweihang): replace by better impl
PT_REGISTER_MODULE(UtilsCUDA);
PT_REGISTER_KERNEL_WITH_NO_TYPE("copy", CUDA, ANY, pten::Copy) {}
PT_REGISTER_KERNEL_ALL_DTYPE(copy, CUDA, ANY, pten::Copy) {}
......@@ -95,12 +95,7 @@ void ReshapeFromVectorDT(const XPUContext& dev_ctx,
} // namespace pten
// TODO(chenweihang): replace by better impl
PT_REGISTER_MODULE(ManipulationXPU);
// TODO(yuanrisheng): "flatten_contiguous_range" is compatible with old kernel
// architecture, kernel_name should be "flatten".
PT_REGISTER_KERNEL("flatten_contiguous_range",
PT_REGISTER_KERNEL(flatten,
XPU,
ANY,
pten::Flatten,
......@@ -112,7 +107,7 @@ PT_REGISTER_KERNEL("flatten_contiguous_range",
int,
int64_t) {}
PT_REGISTER_KERNEL("flatten_contiguous_range.mid",
PT_REGISTER_KERNEL(flatten_mid,
XPU,
ANY,
pten::FlattenWithXShape,
......@@ -124,9 +119,4 @@ PT_REGISTER_KERNEL("flatten_contiguous_range.mid",
int,
int64_t) {}
// TODO(yuanrisheng): "reshape2" is compatible with old kernel
// architecture, kernel_name should be "reshape".
PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape2",
XPU,
ANY,
pten::ReshapeFromVectorVal) {}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape, XPU, ANY, pten::ReshapeFromVectorVal) {}
......@@ -76,7 +76,4 @@ void Copy(const XPUDeviceContext& dev_ctx,
} // namespace pten
// TODO(chenweihang): replace by better impl
PT_REGISTER_MODULE(UtilsXPU);
PT_REGISTER_KERNEL_WITH_NO_TYPE("copy", XPU, ANY, pten::Copy) {}
PT_REGISTER_KERNEL_ALL_DTYPE(copy, XPU, ANY, pten::Copy) {}
......@@ -21,12 +21,6 @@ limitations under the License. */
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
PT_DECLARE_MODULE(ManipulationCPU);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_DECLARE_MODULE(ManipulationCUDA);
#endif
namespace paddle {
namespace tests {
......
......@@ -345,6 +345,7 @@ def source_include(header_file_path):
#include "glog/logging.h"
#include "paddle/pten/api/lib/api_registry.h"
#include "paddle/pten/api/lib/kernel_declare.h"
#include "paddle/pten/api/lib/kernel_dispatch.h"
#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/core/kernel_registry.h"
......@@ -353,22 +354,6 @@ def source_include(header_file_path):
"""
def module_declare():
return """
PT_DECLARE_MODULE(CreationCPU);
PT_DECLARE_MODULE(LinalgCPU);
PT_DECLARE_MODULE(ManipulationCPU);
PT_DECLARE_MODULE(MathCPU);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_DECLARE_MODULE(CreationCUDA);
PT_DECLARE_MODULE(LinalgCUDA);
PT_DECLARE_MODULE(ManipulationCUDA);
PT_DECLARE_MODULE(MathCUDA);
#endif
"""
def api_register():
return """
PT_REGISTER_API(Creation);
......@@ -405,7 +390,6 @@ def generate_api(api_yaml_path, header_file_path, source_file_path):
include_header_file = "paddle/pten/api/include/api.h"
source_file.write(source_include(include_header_file))
source_file.write(module_declare())
source_file.write(namespace[0])
for api in apis:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册