未验证 提交 c9da845f 编写于 作者: C Chen Weihang 提交者: GitHub

[PTen] Polish kernel register marco design (#38078)

* polish register marco

* resolve compile failed

* revert needless change

* revert eager related change

* revert eager related change

* change register marco name

* polish deetails
上级 206a33b3
...@@ -20,18 +20,18 @@ limitations under the License. */ ...@@ -20,18 +20,18 @@ limitations under the License. */
// the kernel declare statement is automatically generated according to the // the kernel declare statement is automatically generated according to the
// file name of the kernel, and this header file will be removed // file name of the kernel, and this header file will be removed
PT_DECLARE_KERNEL(full_like, CPU); PT_DECLARE_KERNEL(full_like, CPU, ALL_LAYOUT);
PT_DECLARE_KERNEL(dot, CPU); PT_DECLARE_KERNEL(dot, CPU, ALL_LAYOUT);
PT_DECLARE_KERNEL(flatten, CPU); PT_DECLARE_KERNEL(flatten, CPU, ALL_LAYOUT);
PT_DECLARE_KERNEL(sign, CPU); PT_DECLARE_KERNEL(sign, CPU, ALL_LAYOUT);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_DECLARE_KERNEL(full_like, CUDA); PT_DECLARE_KERNEL(full_like, CUDA, ALL_LAYOUT);
PT_DECLARE_KERNEL(dot, CUDA); PT_DECLARE_KERNEL(dot, CUDA, ALL_LAYOUT);
PT_DECLARE_KERNEL(flatten, CUDA); PT_DECLARE_KERNEL(flatten, CUDA, ALL_LAYOUT);
PT_DECLARE_KERNEL(sign, CUDA); PT_DECLARE_KERNEL(sign, CUDA, ALL_LAYOUT);
#endif #endif
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
PT_DECLARE_KERNEL(flatten, XPU); PT_DECLARE_KERNEL(flatten, XPU, ALL_LAYOUT);
#endif #endif
...@@ -25,14 +25,14 @@ limitations under the License. */ ...@@ -25,14 +25,14 @@ limitations under the License. */
#include "paddle/pten/include/core.h" #include "paddle/pten/include/core.h"
#include "paddle/pten/include/infermeta.h" #include "paddle/pten/include/infermeta.h"
PT_DECLARE_KERNEL(copy, CPU); PT_DECLARE_KERNEL(copy, CPU, ALL_LAYOUT);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_DECLARE_KERNEL(copy, CUDA); PT_DECLARE_KERNEL(copy, CUDA, ALL_LAYOUT);
#endif #endif
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
PT_DECLARE_KERNEL(copy, XPU); PT_DECLARE_KERNEL(copy, XPU, ALL_LAYOUT);
#endif #endif
namespace paddle { namespace paddle {
......
...@@ -37,7 +37,6 @@ namespace experimental { ...@@ -37,7 +37,6 @@ namespace experimental {
* in the future * in the future
*/ */
enum class Backend : uint8_t { enum class Backend : uint8_t {
// kernel backend cannot be undefined
UNDEFINED = 0, UNDEFINED = 0,
// basic kernel backend // basic kernel backend
...@@ -54,6 +53,42 @@ enum class Backend : uint8_t { ...@@ -54,6 +53,42 @@ enum class Backend : uint8_t {
// end of backend types // end of backend types
NUM_BACKENDS, NUM_BACKENDS,
/**
* [ Why we need ALL in baisc kernel key member? ]
*
* For Tensor, ALL represents an illegal Backend, but for Kernel, some
* kernels may be device-independent by nature, such as reshape; and when
* and some kernels are also device-independent when implemented based on
* primitive API.
*
* In this case, we need to provide a more concise registration method,
* instead of registering the kernels for each device with almost
* repetitive code, we need one registration covers all situations,
* so if we provide the ALL field with Register the kernel in this statement.
*
* Of course, we have also considered solving this problem through different
* named macros, for example, if we define
*
* PT_REGISTER_KERNEL_FOR_ALL_BACKEND
*
* Based on this design pattern, the dtype and layout also have the same
* requirements, this cause we need to define a series of macros
*
* PT_REGISTER_KERNEL_FOR_ALL_DTYPE
* PT_REGISTER_KERNEL_FOR_ALL_LAYOUT
* PT_REGISTER_KERNEL_FOR_ALL_BACKEND_AND_LAYOUT
* PT_REGISTER_KERNEL_FOR_ALL_BACKEND_AND_DTYPE
* PT_REGISTER_KERNEL_FOR_ALL_LAYOUT_AND_DTYPE
* PT_REGISTER_KERNEL_FOR_ALL_BACKEND_AND_LAYOUT_AND_DTYPE
*
* It makes the system of registering macros more complicated, we think
* this is not a simple design, so we still adopt the design of providing
* the ALL field.
*
* Note: ALL_BACKEND only used for Kernel registration and selection
*/
ALL_BACKEND = UNDEFINED,
}; };
inline std::ostream& operator<<(std::ostream& os, Backend backend) { inline std::ostream& operator<<(std::ostream& os, Backend backend) {
......
...@@ -45,7 +45,9 @@ enum class DataType { ...@@ -45,7 +45,9 @@ enum class DataType {
FLOAT64, FLOAT64,
COMPLEX64, COMPLEX64,
COMPLEX128, COMPLEX128,
NUM_DATA_TYPES NUM_DATA_TYPES,
// See Note [ Why we need ALL in baisc kernel key member? ]
ALL_DTYPE = UNDEFINED,
}; };
inline size_t SizeOf(DataType data_type) { inline size_t SizeOf(DataType data_type) {
......
...@@ -20,11 +20,14 @@ namespace experimental { ...@@ -20,11 +20,14 @@ namespace experimental {
enum class DataLayout { enum class DataLayout {
UNDEFINED = 0, UNDEFINED = 0,
ANY, // TODO(chenweihang): keep ANY for compatibility, remove it later
ANY = UNDEFINED,
NHWC, NHWC,
NCHW, NCHW,
MKLDNN, MKLDNN,
NUM_DATA_LAYOUTS, NUM_DATA_LAYOUTS,
// See Note [ Why we need ALL in baisc kernel key member? ]
ALL_LAYOUT = UNDEFINED,
}; };
inline std::ostream& operator<<(std::ostream& os, DataLayout layout) { inline std::ostream& operator<<(std::ostream& os, DataLayout layout) {
...@@ -32,9 +35,6 @@ inline std::ostream& operator<<(std::ostream& os, DataLayout layout) { ...@@ -32,9 +35,6 @@ inline std::ostream& operator<<(std::ostream& os, DataLayout layout) {
case DataLayout::UNDEFINED: case DataLayout::UNDEFINED:
os << "Undefined"; os << "Undefined";
break; break;
case DataLayout::ANY:
os << "Any";
break;
case DataLayout::NHWC: case DataLayout::NHWC:
os << "NHWC"; os << "NHWC";
break; break;
......
此差异已折叠。
...@@ -63,7 +63,7 @@ void FillConstant(const CPUContext& dev_ctx, ...@@ -63,7 +63,7 @@ void FillConstant(const CPUContext& dev_ctx,
PT_REGISTER_KERNEL(full_like, PT_REGISTER_KERNEL(full_like,
CPU, CPU,
ANY, ALL_LAYOUT,
pten::FillAnyLike, pten::FillAnyLike,
float, float,
double, double,
...@@ -74,7 +74,7 @@ PT_REGISTER_KERNEL(full_like, ...@@ -74,7 +74,7 @@ PT_REGISTER_KERNEL(full_like,
PT_REGISTER_KERNEL(full, PT_REGISTER_KERNEL(full,
CPU, CPU,
ANY, ALL_LAYOUT,
pten::FillConstant, pten::FillConstant,
float, float,
double, double,
......
...@@ -75,7 +75,7 @@ using complex128 = ::paddle::platform::complex<double>; ...@@ -75,7 +75,7 @@ using complex128 = ::paddle::platform::complex<double>;
PT_REGISTER_KERNEL(dot, PT_REGISTER_KERNEL(dot,
CPU, CPU,
ANY, ALL_LAYOUT,
pten::Dot, pten::Dot,
float, float,
double, double,
...@@ -84,5 +84,11 @@ PT_REGISTER_KERNEL(dot, ...@@ -84,5 +84,11 @@ PT_REGISTER_KERNEL(dot,
complex64, complex64,
complex128) {} complex128) {}
PT_REGISTER_KERNEL( PT_REGISTER_KERNEL(matmul,
matmul, CPU, ANY, pten::Matmul, float, double, complex64, complex128) {} CPU,
ALL_LAYOUT,
pten::Matmul,
float,
double,
complex64,
complex128) {}
...@@ -85,7 +85,7 @@ void Cast(const CPUContext& dev_ctx, ...@@ -85,7 +85,7 @@ void Cast(const CPUContext& dev_ctx,
PT_REGISTER_KERNEL(flatten, PT_REGISTER_KERNEL(flatten,
CPU, CPU,
ANY, ALL_LAYOUT,
pten::Flatten, pten::Flatten,
float, float,
double, double,
...@@ -95,7 +95,7 @@ PT_REGISTER_KERNEL(flatten, ...@@ -95,7 +95,7 @@ PT_REGISTER_KERNEL(flatten,
int64_t) {} int64_t) {}
PT_REGISTER_KERNEL(flatten_with_xshape, PT_REGISTER_KERNEL(flatten_with_xshape,
CPU, CPU,
ANY, ALL_LAYOUT,
pten::FlattenWithXShape, pten::FlattenWithXShape,
float, float,
double, double,
...@@ -106,7 +106,7 @@ PT_REGISTER_KERNEL(flatten_with_xshape, ...@@ -106,7 +106,7 @@ PT_REGISTER_KERNEL(flatten_with_xshape,
PT_REGISTER_KERNEL(cast, PT_REGISTER_KERNEL(cast,
CPU, CPU,
ANY, ALL_LAYOUT,
pten::Cast, pten::Cast,
float, float,
double, double,
...@@ -122,8 +122,7 @@ PT_REGISTER_KERNEL(cast, ...@@ -122,8 +122,7 @@ PT_REGISTER_KERNEL(cast,
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
} }
PT_REGISTER_KERNEL_ALL_DTYPE(reshape, CPU, ANY, pten::Reshape) {} PT_REGISTER_NO_TEMPLATE_KERNEL(
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_with_xshape, reshape, CPU, ALL_LAYOUT, pten::Reshape, ALL_DTYPE) {}
CPU, PT_REGISTER_NO_TEMPLATE_KERNEL(
ANY, reshape_with_xshape, CPU, ALL_LAYOUT, pten::ReshapeWithXShape, ALL_DTYPE) {}
pten::ReshapeWithXShape) {}
...@@ -111,11 +111,11 @@ using complex128 = ::paddle::platform::complex<double>; ...@@ -111,11 +111,11 @@ using complex128 = ::paddle::platform::complex<double>;
// NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16 // NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16
// using bfloat16 = ::paddle::platform::bfloat16; // using bfloat16 = ::paddle::platform::bfloat16;
PT_REGISTER_KERNEL(sign, CPU, ANY, pten::Sign, float, double) {} PT_REGISTER_KERNEL(sign, CPU, ALL_LAYOUT, pten::Sign, float, double) {}
PT_REGISTER_KERNEL(mean, CPU, ANY, pten::Mean, float, double, bool) {} PT_REGISTER_KERNEL(mean, CPU, ALL_LAYOUT, pten::Mean, float, double, bool) {}
PT_REGISTER_KERNEL(scale, PT_REGISTER_KERNEL(scale,
CPU, CPU,
ANY, ALL_LAYOUT,
pten::Scale, pten::Scale,
float, float,
double, double,
...@@ -127,7 +127,7 @@ PT_REGISTER_KERNEL(scale, ...@@ -127,7 +127,7 @@ PT_REGISTER_KERNEL(scale,
int64_t) {} int64_t) {}
PT_REGISTER_KERNEL(add, PT_REGISTER_KERNEL(add,
CPU, CPU,
ANY, ALL_LAYOUT,
pten::ElementwiseAdd, pten::ElementwiseAdd,
float, float,
double, double,
...@@ -137,7 +137,7 @@ PT_REGISTER_KERNEL(add, ...@@ -137,7 +137,7 @@ PT_REGISTER_KERNEL(add,
complex128) {} complex128) {}
PT_REGISTER_KERNEL(subtract, PT_REGISTER_KERNEL(subtract,
CPU, CPU,
ANY, ALL_LAYOUT,
pten::ElementwiseSub, pten::ElementwiseSub,
float, float,
double, double,
...@@ -147,7 +147,7 @@ PT_REGISTER_KERNEL(subtract, ...@@ -147,7 +147,7 @@ PT_REGISTER_KERNEL(subtract,
complex128) {} complex128) {}
PT_REGISTER_KERNEL(divide, PT_REGISTER_KERNEL(divide,
CPU, CPU,
ANY, ALL_LAYOUT,
pten::ElementwiseDiv, pten::ElementwiseDiv,
float, float,
double, double,
...@@ -157,7 +157,7 @@ PT_REGISTER_KERNEL(divide, ...@@ -157,7 +157,7 @@ PT_REGISTER_KERNEL(divide,
complex128) {} complex128) {}
PT_REGISTER_KERNEL(multiply, PT_REGISTER_KERNEL(multiply,
CPU, CPU,
ANY, ALL_LAYOUT,
pten::ElementwiseMul, pten::ElementwiseMul,
float, float,
double, double,
...@@ -168,7 +168,7 @@ PT_REGISTER_KERNEL(multiply, ...@@ -168,7 +168,7 @@ PT_REGISTER_KERNEL(multiply,
complex128) {} complex128) {}
PT_REGISTER_KERNEL(sum, PT_REGISTER_KERNEL(sum,
CPU, CPU,
ANY, ALL_LAYOUT,
pten::Sum, pten::Sum,
bool, bool,
float, float,
......
...@@ -57,4 +57,4 @@ void Copy(const CPUContext& dev_ctx, ...@@ -57,4 +57,4 @@ void Copy(const CPUContext& dev_ctx,
} // namespace pten } // namespace pten
PT_REGISTER_KERNEL_ALL_DTYPE(copy, CPU, ANY, pten::Copy) {} PT_REGISTER_NO_TEMPLATE_KERNEL(copy, CPU, ALL_LAYOUT, pten::Copy, ALL_DTYPE) {}
...@@ -64,7 +64,7 @@ void FillConstant(const CUDAContext& dev_ctx, ...@@ -64,7 +64,7 @@ void FillConstant(const CUDAContext& dev_ctx,
PT_REGISTER_KERNEL(full_like, PT_REGISTER_KERNEL(full_like,
CUDA, CUDA,
ANY, ALL_LAYOUT,
pten::FillAnyLike, pten::FillAnyLike,
float, float,
double, double,
...@@ -75,7 +75,7 @@ PT_REGISTER_KERNEL(full_like, ...@@ -75,7 +75,7 @@ PT_REGISTER_KERNEL(full_like,
PT_REGISTER_KERNEL(full, PT_REGISTER_KERNEL(full,
CUDA, CUDA,
ANY, ALL_LAYOUT,
pten::FillConstant, pten::FillConstant,
float, float,
double, double,
......
...@@ -60,7 +60,7 @@ using complex128 = ::paddle::platform::complex<double>; ...@@ -60,7 +60,7 @@ using complex128 = ::paddle::platform::complex<double>;
PT_REGISTER_KERNEL(dot, PT_REGISTER_KERNEL(dot,
CUDA, CUDA,
ANY, ALL_LAYOUT,
pten::Dot, pten::Dot,
float, float,
double, double,
...@@ -71,7 +71,7 @@ PT_REGISTER_KERNEL(dot, ...@@ -71,7 +71,7 @@ PT_REGISTER_KERNEL(dot,
PT_REGISTER_KERNEL(matmul, PT_REGISTER_KERNEL(matmul,
CUDA, CUDA,
ANY, ALL_LAYOUT,
pten::Matmul, pten::Matmul,
float, float,
double, double,
......
...@@ -86,7 +86,7 @@ using float16 = paddle::platform::float16; ...@@ -86,7 +86,7 @@ using float16 = paddle::platform::float16;
PT_REGISTER_KERNEL(flatten, PT_REGISTER_KERNEL(flatten,
CUDA, CUDA,
ANY, ALL_LAYOUT,
pten::Flatten, pten::Flatten,
float, float,
float16, float16,
...@@ -97,7 +97,7 @@ PT_REGISTER_KERNEL(flatten, ...@@ -97,7 +97,7 @@ PT_REGISTER_KERNEL(flatten,
int64_t) {} int64_t) {}
PT_REGISTER_KERNEL(flatten_with_xshape, PT_REGISTER_KERNEL(flatten_with_xshape,
CUDA, CUDA,
ANY, ALL_LAYOUT,
pten::FlattenWithXShape, pten::FlattenWithXShape,
float, float,
double, double,
...@@ -109,7 +109,7 @@ PT_REGISTER_KERNEL(flatten_with_xshape, ...@@ -109,7 +109,7 @@ PT_REGISTER_KERNEL(flatten_with_xshape,
#define PTEN_REGISTER_CAST_CUDA_BASE_TYPE(op_name, ...) \ #define PTEN_REGISTER_CAST_CUDA_BASE_TYPE(op_name, ...) \
PT_REGISTER_KERNEL(cast, \ PT_REGISTER_KERNEL(cast, \
CUDA, \ CUDA, \
ANY, \ ALL_LAYOUT, \
pten::Cast, \ pten::Cast, \
float, \ float, \
double, \ double, \
...@@ -132,8 +132,6 @@ PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast, paddle::platform::bfloat16) ...@@ -132,8 +132,6 @@ PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast, paddle::platform::bfloat16)
PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast) PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast)
#endif #endif
PT_REGISTER_KERNEL_ALL_DTYPE(reshape, CUDA, ANY, pten::Reshape) {} PT_REGISTER_NO_TEMPLATE_KERNEL(reshape, CUDA, ANY, pten::Reshape, ALL_DTYPE) {}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_with_xshape, PT_REGISTER_NO_TEMPLATE_KERNEL(
CUDA, reshape_with_xshape, CUDA, ANY, pten::ReshapeWithXShape, ALL_DTYPE) {}
ANY,
pten::ReshapeWithXShape) {}
...@@ -115,11 +115,12 @@ using float16 = paddle::platform::float16; ...@@ -115,11 +115,12 @@ using float16 = paddle::platform::float16;
using complex64 = ::paddle::platform::complex<float>; using complex64 = ::paddle::platform::complex<float>;
using complex128 = ::paddle::platform::complex<double>; using complex128 = ::paddle::platform::complex<double>;
PT_REGISTER_KERNEL(sign, CUDA, ANY, pten::Sign, float, double, float16) {} PT_REGISTER_KERNEL(sign, CUDA, ALL_LAYOUT, pten::Sign, float, double, float16) {
PT_REGISTER_KERNEL(mean, CUDA, ANY, pten::Mean, float, double, bool) {} }
PT_REGISTER_KERNEL(mean, CUDA, ALL_LAYOUT, pten::Mean, float, double, bool) {}
PT_REGISTER_KERNEL(scale, PT_REGISTER_KERNEL(scale,
CUDA, CUDA,
ANY, ALL_LAYOUT,
pten::Scale, pten::Scale,
float, float,
double, double,
...@@ -131,7 +132,7 @@ PT_REGISTER_KERNEL(scale, ...@@ -131,7 +132,7 @@ PT_REGISTER_KERNEL(scale,
int64_t) {} int64_t) {}
PT_REGISTER_KERNEL(add, PT_REGISTER_KERNEL(add,
CUDA, CUDA,
ANY, ALL_LAYOUT,
pten::ElementwiseAdd, pten::ElementwiseAdd,
float, float,
double, double,
...@@ -142,7 +143,7 @@ PT_REGISTER_KERNEL(add, ...@@ -142,7 +143,7 @@ PT_REGISTER_KERNEL(add,
complex128) {} complex128) {}
PT_REGISTER_KERNEL(subtract, PT_REGISTER_KERNEL(subtract,
CUDA, CUDA,
ANY, ALL_LAYOUT,
pten::ElementwiseSub, pten::ElementwiseSub,
float, float,
double, double,
...@@ -153,7 +154,7 @@ PT_REGISTER_KERNEL(subtract, ...@@ -153,7 +154,7 @@ PT_REGISTER_KERNEL(subtract,
complex128) {} complex128) {}
PT_REGISTER_KERNEL(divide, PT_REGISTER_KERNEL(divide,
CUDA, CUDA,
ANY, ALL_LAYOUT,
pten::ElementwiseDiv, pten::ElementwiseDiv,
float, float,
double, double,
...@@ -164,7 +165,7 @@ PT_REGISTER_KERNEL(divide, ...@@ -164,7 +165,7 @@ PT_REGISTER_KERNEL(divide,
complex128) {} complex128) {}
PT_REGISTER_KERNEL(multiply, PT_REGISTER_KERNEL(multiply,
CUDA, CUDA,
ANY, ALL_LAYOUT,
pten::ElementwiseMul, pten::ElementwiseMul,
float, float,
double, double,
...@@ -176,7 +177,7 @@ PT_REGISTER_KERNEL(multiply, ...@@ -176,7 +177,7 @@ PT_REGISTER_KERNEL(multiply,
complex128) {} complex128) {}
PT_REGISTER_KERNEL(sum, PT_REGISTER_KERNEL(sum,
CUDA, CUDA,
ANY, ALL_LAYOUT,
pten::Sum, pten::Sum,
bool, bool,
float, float,
......
...@@ -234,4 +234,4 @@ void Copy(const CUDAContext& dev_ctx, ...@@ -234,4 +234,4 @@ void Copy(const CUDAContext& dev_ctx,
} }
} // namespace pten } // namespace pten
PT_REGISTER_KERNEL_ALL_DTYPE(copy, CUDA, ANY, pten::Copy) {} PT_REGISTER_NO_TEMPLATE_KERNEL(copy, CUDA, ALL_LAYOUT, pten::Copy, ALL_DTYPE) {}
...@@ -78,7 +78,7 @@ void ReshapeWithXShape(const XPUContext& dev_ctx, ...@@ -78,7 +78,7 @@ void ReshapeWithXShape(const XPUContext& dev_ctx,
PT_REGISTER_KERNEL(flatten, PT_REGISTER_KERNEL(flatten,
XPU, XPU,
ANY, ALL_LAYOUT,
pten::Flatten, pten::Flatten,
float, float,
paddle::platform::float16, paddle::platform::float16,
...@@ -90,7 +90,7 @@ PT_REGISTER_KERNEL(flatten, ...@@ -90,7 +90,7 @@ PT_REGISTER_KERNEL(flatten,
PT_REGISTER_KERNEL(flatten_with_xshape, PT_REGISTER_KERNEL(flatten_with_xshape,
XPU, XPU,
ANY, ALL_LAYOUT,
pten::FlattenWithXShape, pten::FlattenWithXShape,
float, float,
paddle::platform::float16, paddle::platform::float16,
...@@ -100,4 +100,5 @@ PT_REGISTER_KERNEL(flatten_with_xshape, ...@@ -100,4 +100,5 @@ PT_REGISTER_KERNEL(flatten_with_xshape,
int, int,
int64_t) {} int64_t) {}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape, XPU, ANY, pten::Reshape) {} PT_REGISTER_NO_TEMPLATE_KERNEL(
reshape, XPU, ALL_LAYOUT, pten::Reshape, ALL_DTYPE) {}
...@@ -76,4 +76,4 @@ void Copy(const XPUDeviceContext& dev_ctx, ...@@ -76,4 +76,4 @@ void Copy(const XPUDeviceContext& dev_ctx,
} // namespace pten } // namespace pten
PT_REGISTER_KERNEL_ALL_DTYPE(copy, XPU, ANY, pten::Copy) {} PT_REGISTER_NO_TEMPLATE_KERNEL(copy, XPU, ALL_LAYOUT, pten::Copy, ALL_DTYPE) {}
...@@ -28,7 +28,7 @@ TEST(DataLayout, OStream) { ...@@ -28,7 +28,7 @@ TEST(DataLayout, OStream) {
EXPECT_EQ(oss.str(), "Undefined"); EXPECT_EQ(oss.str(), "Undefined");
oss.str(""); oss.str("");
oss << pten::DataLayout::ANY; oss << pten::DataLayout::ANY;
EXPECT_EQ(oss.str(), "Any"); EXPECT_EQ(oss.str(), "Undefined");
oss.str(""); oss.str("");
oss << pten::DataLayout::NHWC; oss << pten::DataLayout::NHWC;
EXPECT_EQ(oss.str(), "NHWC"); EXPECT_EQ(oss.str(), "NHWC");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册