未验证 提交 c9da845f 编写于 作者: C Chen Weihang 提交者: GitHub

[PTen] Polish kernel register marco design (#38078)

* polish register marco

* resolve compile failed

* revert needless change

* revert eager related change

* revert eager related change

* change register marco name

* polish deetails
上级 206a33b3
......@@ -20,18 +20,18 @@ limitations under the License. */
// the kernel declare statement is automatically generated according to the
// file name of the kernel, and this header file will be removed
PT_DECLARE_KERNEL(full_like, CPU);
PT_DECLARE_KERNEL(dot, CPU);
PT_DECLARE_KERNEL(flatten, CPU);
PT_DECLARE_KERNEL(sign, CPU);
PT_DECLARE_KERNEL(full_like, CPU, ALL_LAYOUT);
PT_DECLARE_KERNEL(dot, CPU, ALL_LAYOUT);
PT_DECLARE_KERNEL(flatten, CPU, ALL_LAYOUT);
PT_DECLARE_KERNEL(sign, CPU, ALL_LAYOUT);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_DECLARE_KERNEL(full_like, CUDA);
PT_DECLARE_KERNEL(dot, CUDA);
PT_DECLARE_KERNEL(flatten, CUDA);
PT_DECLARE_KERNEL(sign, CUDA);
PT_DECLARE_KERNEL(full_like, CUDA, ALL_LAYOUT);
PT_DECLARE_KERNEL(dot, CUDA, ALL_LAYOUT);
PT_DECLARE_KERNEL(flatten, CUDA, ALL_LAYOUT);
PT_DECLARE_KERNEL(sign, CUDA, ALL_LAYOUT);
#endif
#ifdef PADDLE_WITH_XPU
PT_DECLARE_KERNEL(flatten, XPU);
PT_DECLARE_KERNEL(flatten, XPU, ALL_LAYOUT);
#endif
......@@ -25,14 +25,14 @@ limitations under the License. */
#include "paddle/pten/include/core.h"
#include "paddle/pten/include/infermeta.h"
PT_DECLARE_KERNEL(copy, CPU);
PT_DECLARE_KERNEL(copy, CPU, ALL_LAYOUT);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_DECLARE_KERNEL(copy, CUDA);
PT_DECLARE_KERNEL(copy, CUDA, ALL_LAYOUT);
#endif
#ifdef PADDLE_WITH_XPU
PT_DECLARE_KERNEL(copy, XPU);
PT_DECLARE_KERNEL(copy, XPU, ALL_LAYOUT);
#endif
namespace paddle {
......
......@@ -37,7 +37,6 @@ namespace experimental {
* in the future
*/
enum class Backend : uint8_t {
// kernel backend cannot be undefined
UNDEFINED = 0,
// basic kernel backend
......@@ -54,6 +53,42 @@ enum class Backend : uint8_t {
// end of backend types
NUM_BACKENDS,
/**
* [ Why we need ALL in baisc kernel key member? ]
*
* For Tensor, ALL represents an illegal Backend, but for Kernel, some
* kernels may be device-independent by nature, such as reshape; and when
* and some kernels are also device-independent when implemented based on
* primitive API.
*
* In this case, we need to provide a more concise registration method,
* instead of registering the kernels for each device with almost
* repetitive code, we need one registration covers all situations,
* so if we provide the ALL field with Register the kernel in this statement.
*
* Of course, we have also considered solving this problem through different
* named macros, for example, if we define
*
* PT_REGISTER_KERNEL_FOR_ALL_BACKEND
*
* Based on this design pattern, the dtype and layout also have the same
* requirements, this cause we need to define a series of macros
*
* PT_REGISTER_KERNEL_FOR_ALL_DTYPE
* PT_REGISTER_KERNEL_FOR_ALL_LAYOUT
* PT_REGISTER_KERNEL_FOR_ALL_BACKEND_AND_LAYOUT
* PT_REGISTER_KERNEL_FOR_ALL_BACKEND_AND_DTYPE
* PT_REGISTER_KERNEL_FOR_ALL_LAYOUT_AND_DTYPE
* PT_REGISTER_KERNEL_FOR_ALL_BACKEND_AND_LAYOUT_AND_DTYPE
*
* It makes the system of registering macros more complicated, we think
* this is not a simple design, so we still adopt the design of providing
* the ALL field.
*
* Note: ALL_BACKEND only used for Kernel registration and selection
*/
ALL_BACKEND = UNDEFINED,
};
inline std::ostream& operator<<(std::ostream& os, Backend backend) {
......
......@@ -45,7 +45,9 @@ enum class DataType {
FLOAT64,
COMPLEX64,
COMPLEX128,
NUM_DATA_TYPES
NUM_DATA_TYPES,
// See Note [ Why we need ALL in baisc kernel key member? ]
ALL_DTYPE = UNDEFINED,
};
inline size_t SizeOf(DataType data_type) {
......
......@@ -20,11 +20,14 @@ namespace experimental {
enum class DataLayout {
UNDEFINED = 0,
ANY,
// TODO(chenweihang): keep ANY for compatibility, remove it later
ANY = UNDEFINED,
NHWC,
NCHW,
MKLDNN,
NUM_DATA_LAYOUTS,
// See Note [ Why we need ALL in baisc kernel key member? ]
ALL_LAYOUT = UNDEFINED,
};
inline std::ostream& operator<<(std::ostream& os, DataLayout layout) {
......@@ -32,9 +35,6 @@ inline std::ostream& operator<<(std::ostream& os, DataLayout layout) {
case DataLayout::UNDEFINED:
os << "Undefined";
break;
case DataLayout::ANY:
os << "Any";
break;
case DataLayout::NHWC:
os << "NHWC";
break;
......
此差异已折叠。
......@@ -63,7 +63,7 @@ void FillConstant(const CPUContext& dev_ctx,
PT_REGISTER_KERNEL(full_like,
CPU,
ANY,
ALL_LAYOUT,
pten::FillAnyLike,
float,
double,
......@@ -74,7 +74,7 @@ PT_REGISTER_KERNEL(full_like,
PT_REGISTER_KERNEL(full,
CPU,
ANY,
ALL_LAYOUT,
pten::FillConstant,
float,
double,
......
......@@ -75,7 +75,7 @@ using complex128 = ::paddle::platform::complex<double>;
PT_REGISTER_KERNEL(dot,
CPU,
ANY,
ALL_LAYOUT,
pten::Dot,
float,
double,
......@@ -84,5 +84,11 @@ PT_REGISTER_KERNEL(dot,
complex64,
complex128) {}
PT_REGISTER_KERNEL(
matmul, CPU, ANY, pten::Matmul, float, double, complex64, complex128) {}
PT_REGISTER_KERNEL(matmul,
CPU,
ALL_LAYOUT,
pten::Matmul,
float,
double,
complex64,
complex128) {}
......@@ -85,7 +85,7 @@ void Cast(const CPUContext& dev_ctx,
PT_REGISTER_KERNEL(flatten,
CPU,
ANY,
ALL_LAYOUT,
pten::Flatten,
float,
double,
......@@ -95,7 +95,7 @@ PT_REGISTER_KERNEL(flatten,
int64_t) {}
PT_REGISTER_KERNEL(flatten_with_xshape,
CPU,
ANY,
ALL_LAYOUT,
pten::FlattenWithXShape,
float,
double,
......@@ -106,7 +106,7 @@ PT_REGISTER_KERNEL(flatten_with_xshape,
PT_REGISTER_KERNEL(cast,
CPU,
ANY,
ALL_LAYOUT,
pten::Cast,
float,
double,
......@@ -122,8 +122,7 @@ PT_REGISTER_KERNEL(cast,
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape, CPU, ANY, pten::Reshape) {}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_with_xshape,
CPU,
ANY,
pten::ReshapeWithXShape) {}
PT_REGISTER_NO_TEMPLATE_KERNEL(
reshape, CPU, ALL_LAYOUT, pten::Reshape, ALL_DTYPE) {}
PT_REGISTER_NO_TEMPLATE_KERNEL(
reshape_with_xshape, CPU, ALL_LAYOUT, pten::ReshapeWithXShape, ALL_DTYPE) {}
......@@ -111,11 +111,11 @@ using complex128 = ::paddle::platform::complex<double>;
// NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16
// using bfloat16 = ::paddle::platform::bfloat16;
PT_REGISTER_KERNEL(sign, CPU, ANY, pten::Sign, float, double) {}
PT_REGISTER_KERNEL(mean, CPU, ANY, pten::Mean, float, double, bool) {}
PT_REGISTER_KERNEL(sign, CPU, ALL_LAYOUT, pten::Sign, float, double) {}
PT_REGISTER_KERNEL(mean, CPU, ALL_LAYOUT, pten::Mean, float, double, bool) {}
PT_REGISTER_KERNEL(scale,
CPU,
ANY,
ALL_LAYOUT,
pten::Scale,
float,
double,
......@@ -127,7 +127,7 @@ PT_REGISTER_KERNEL(scale,
int64_t) {}
PT_REGISTER_KERNEL(add,
CPU,
ANY,
ALL_LAYOUT,
pten::ElementwiseAdd,
float,
double,
......@@ -137,7 +137,7 @@ PT_REGISTER_KERNEL(add,
complex128) {}
PT_REGISTER_KERNEL(subtract,
CPU,
ANY,
ALL_LAYOUT,
pten::ElementwiseSub,
float,
double,
......@@ -147,7 +147,7 @@ PT_REGISTER_KERNEL(subtract,
complex128) {}
PT_REGISTER_KERNEL(divide,
CPU,
ANY,
ALL_LAYOUT,
pten::ElementwiseDiv,
float,
double,
......@@ -157,7 +157,7 @@ PT_REGISTER_KERNEL(divide,
complex128) {}
PT_REGISTER_KERNEL(multiply,
CPU,
ANY,
ALL_LAYOUT,
pten::ElementwiseMul,
float,
double,
......@@ -168,7 +168,7 @@ PT_REGISTER_KERNEL(multiply,
complex128) {}
PT_REGISTER_KERNEL(sum,
CPU,
ANY,
ALL_LAYOUT,
pten::Sum,
bool,
float,
......
......@@ -57,4 +57,4 @@ void Copy(const CPUContext& dev_ctx,
} // namespace pten
PT_REGISTER_KERNEL_ALL_DTYPE(copy, CPU, ANY, pten::Copy) {}
PT_REGISTER_NO_TEMPLATE_KERNEL(copy, CPU, ALL_LAYOUT, pten::Copy, ALL_DTYPE) {}
......@@ -64,7 +64,7 @@ void FillConstant(const CUDAContext& dev_ctx,
PT_REGISTER_KERNEL(full_like,
CUDA,
ANY,
ALL_LAYOUT,
pten::FillAnyLike,
float,
double,
......@@ -75,7 +75,7 @@ PT_REGISTER_KERNEL(full_like,
PT_REGISTER_KERNEL(full,
CUDA,
ANY,
ALL_LAYOUT,
pten::FillConstant,
float,
double,
......
......@@ -60,7 +60,7 @@ using complex128 = ::paddle::platform::complex<double>;
PT_REGISTER_KERNEL(dot,
CUDA,
ANY,
ALL_LAYOUT,
pten::Dot,
float,
double,
......@@ -71,7 +71,7 @@ PT_REGISTER_KERNEL(dot,
PT_REGISTER_KERNEL(matmul,
CUDA,
ANY,
ALL_LAYOUT,
pten::Matmul,
float,
double,
......
......@@ -86,7 +86,7 @@ using float16 = paddle::platform::float16;
PT_REGISTER_KERNEL(flatten,
CUDA,
ANY,
ALL_LAYOUT,
pten::Flatten,
float,
float16,
......@@ -97,7 +97,7 @@ PT_REGISTER_KERNEL(flatten,
int64_t) {}
PT_REGISTER_KERNEL(flatten_with_xshape,
CUDA,
ANY,
ALL_LAYOUT,
pten::FlattenWithXShape,
float,
double,
......@@ -109,7 +109,7 @@ PT_REGISTER_KERNEL(flatten_with_xshape,
#define PTEN_REGISTER_CAST_CUDA_BASE_TYPE(op_name, ...) \
PT_REGISTER_KERNEL(cast, \
CUDA, \
ANY, \
ALL_LAYOUT, \
pten::Cast, \
float, \
double, \
......@@ -132,8 +132,6 @@ PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast, paddle::platform::bfloat16)
PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast)
#endif
PT_REGISTER_KERNEL_ALL_DTYPE(reshape, CUDA, ANY, pten::Reshape) {}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_with_xshape,
CUDA,
ANY,
pten::ReshapeWithXShape) {}
PT_REGISTER_NO_TEMPLATE_KERNEL(reshape, CUDA, ANY, pten::Reshape, ALL_DTYPE) {}
PT_REGISTER_NO_TEMPLATE_KERNEL(
reshape_with_xshape, CUDA, ANY, pten::ReshapeWithXShape, ALL_DTYPE) {}
......@@ -115,11 +115,12 @@ using float16 = paddle::platform::float16;
using complex64 = ::paddle::platform::complex<float>;
using complex128 = ::paddle::platform::complex<double>;
PT_REGISTER_KERNEL(sign, CUDA, ANY, pten::Sign, float, double, float16) {}
PT_REGISTER_KERNEL(mean, CUDA, ANY, pten::Mean, float, double, bool) {}
PT_REGISTER_KERNEL(sign, CUDA, ALL_LAYOUT, pten::Sign, float, double, float16) {
}
PT_REGISTER_KERNEL(mean, CUDA, ALL_LAYOUT, pten::Mean, float, double, bool) {}
PT_REGISTER_KERNEL(scale,
CUDA,
ANY,
ALL_LAYOUT,
pten::Scale,
float,
double,
......@@ -131,7 +132,7 @@ PT_REGISTER_KERNEL(scale,
int64_t) {}
PT_REGISTER_KERNEL(add,
CUDA,
ANY,
ALL_LAYOUT,
pten::ElementwiseAdd,
float,
double,
......@@ -142,7 +143,7 @@ PT_REGISTER_KERNEL(add,
complex128) {}
PT_REGISTER_KERNEL(subtract,
CUDA,
ANY,
ALL_LAYOUT,
pten::ElementwiseSub,
float,
double,
......@@ -153,7 +154,7 @@ PT_REGISTER_KERNEL(subtract,
complex128) {}
PT_REGISTER_KERNEL(divide,
CUDA,
ANY,
ALL_LAYOUT,
pten::ElementwiseDiv,
float,
double,
......@@ -164,7 +165,7 @@ PT_REGISTER_KERNEL(divide,
complex128) {}
PT_REGISTER_KERNEL(multiply,
CUDA,
ANY,
ALL_LAYOUT,
pten::ElementwiseMul,
float,
double,
......@@ -176,7 +177,7 @@ PT_REGISTER_KERNEL(multiply,
complex128) {}
PT_REGISTER_KERNEL(sum,
CUDA,
ANY,
ALL_LAYOUT,
pten::Sum,
bool,
float,
......
......@@ -234,4 +234,4 @@ void Copy(const CUDAContext& dev_ctx,
}
} // namespace pten
PT_REGISTER_KERNEL_ALL_DTYPE(copy, CUDA, ANY, pten::Copy) {}
PT_REGISTER_NO_TEMPLATE_KERNEL(copy, CUDA, ALL_LAYOUT, pten::Copy, ALL_DTYPE) {}
......@@ -78,7 +78,7 @@ void ReshapeWithXShape(const XPUContext& dev_ctx,
PT_REGISTER_KERNEL(flatten,
XPU,
ANY,
ALL_LAYOUT,
pten::Flatten,
float,
paddle::platform::float16,
......@@ -90,7 +90,7 @@ PT_REGISTER_KERNEL(flatten,
PT_REGISTER_KERNEL(flatten_with_xshape,
XPU,
ANY,
ALL_LAYOUT,
pten::FlattenWithXShape,
float,
paddle::platform::float16,
......@@ -100,4 +100,5 @@ PT_REGISTER_KERNEL(flatten_with_xshape,
int,
int64_t) {}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape, XPU, ANY, pten::Reshape) {}
PT_REGISTER_NO_TEMPLATE_KERNEL(
reshape, XPU, ALL_LAYOUT, pten::Reshape, ALL_DTYPE) {}
......@@ -76,4 +76,4 @@ void Copy(const XPUDeviceContext& dev_ctx,
} // namespace pten
PT_REGISTER_KERNEL_ALL_DTYPE(copy, XPU, ANY, pten::Copy) {}
PT_REGISTER_NO_TEMPLATE_KERNEL(copy, XPU, ALL_LAYOUT, pten::Copy, ALL_DTYPE) {}
......@@ -28,7 +28,7 @@ TEST(DataLayout, OStream) {
EXPECT_EQ(oss.str(), "Undefined");
oss.str("");
oss << pten::DataLayout::ANY;
EXPECT_EQ(oss.str(), "Any");
EXPECT_EQ(oss.str(), "Undefined");
oss.str("");
oss << pten::DataLayout::NHWC;
EXPECT_EQ(oss.str(), "NHWC");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册