未验证 提交 158bf13f 编写于 作者: C Chen Weihang 提交者: GitHub

[PTen] Rename kernel register marco (#38861)

* rename register marco

* fix error changing

* fix format error
上级 dccdc719
......@@ -16,12 +16,12 @@
function(kernel_declare TARGET_LIST)
foreach(kernel_path ${TARGET_LIST})
file(READ ${kernel_path} kernel_impl)
# TODO(chenweihang): rename PT_REGISTER_CTX_KERNEL to PT_REGISTER_KERNEL
# TODO(chenweihang): rename PT_REGISTER_KERNEL to PT_REGISTER_KERNEL
# NOTE(chenweihang): now we don't recommend to use digit in kernel name
string(REGEX MATCH "(PT_REGISTER_CTX_KERNEL|PT_REGISTER_GENERAL_KERNEL)\\([ \t\r\n]*[a-z0-9_]*," first_registry "${kernel_impl}")
string(REGEX MATCH "(PT_REGISTER_KERNEL|PT_REGISTER_GENERAL_KERNEL)\\([ \t\r\n]*[a-z0-9_]*," first_registry "${kernel_impl}")
if (NOT first_registry STREQUAL "")
# parse the first kernel name
string(REPLACE "PT_REGISTER_CTX_KERNEL(" "" kernel_name "${first_registry}")
string(REPLACE "PT_REGISTER_KERNEL(" "" kernel_name "${first_registry}")
string(REPLACE "PT_REGISTER_GENERAL_KERNEL(" "" kernel_name "${kernel_name}")
string(REPLACE "," "" kernel_name "${kernel_name}")
string(REGEX REPLACE "[ \t\r\n]+" "" kernel_name "${kernel_name}")
......
此差异已折叠。
......@@ -58,20 +58,20 @@ void CastKernel(const Context& dev_ctx,
} // namespace pten
PT_REGISTER_CTX_KERNEL(cast,
CPU,
ALL_LAYOUT,
pten::CastKernel,
float,
double,
int,
int64_t,
int16_t,
bool,
uint8_t,
paddle::platform::float16,
paddle::platform::bfloat16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {
PT_REGISTER_KERNEL(cast,
CPU,
ALL_LAYOUT,
pten::CastKernel,
float,
double,
int,
int64_t,
int16_t,
bool,
uint8_t,
paddle::platform::float16,
paddle::platform::bfloat16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
}
......@@ -21,13 +21,13 @@
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/complex.h"
PT_REGISTER_CTX_KERNEL(conj,
CPU,
ALL_LAYOUT,
pten::ConjKernel,
paddle::platform::complex<float>,
paddle::platform::complex<double>,
float,
double,
int,
int64_t) {}
PT_REGISTER_KERNEL(conj,
CPU,
ALL_LAYOUT,
pten::ConjKernel,
paddle::platform::complex<float>,
paddle::platform::complex<double>,
float,
double,
int,
int64_t) {}
......@@ -20,13 +20,13 @@
#include "paddle/fluid/platform/complex.h"
PT_REGISTER_CTX_KERNEL(dot_grad,
CPU,
ALL_LAYOUT,
pten::DotGradKernel,
float,
double,
int,
int64_t,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_KERNEL(dot_grad,
CPU,
ALL_LAYOUT,
pten::DotGradKernel,
float,
double,
int,
int64_t,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
......@@ -49,13 +49,13 @@ void DotKernel(const Context& dev_ctx,
using complex64 = ::paddle::platform::complex<float>;
using complex128 = ::paddle::platform::complex<double>;
PT_REGISTER_CTX_KERNEL(dot,
CPU,
ALL_LAYOUT,
pten::DotKernel,
float,
double,
int,
int64_t,
complex64,
complex128) {}
PT_REGISTER_KERNEL(dot,
CPU,
ALL_LAYOUT,
pten::DotKernel,
float,
double,
int,
int64_t,
complex64,
complex128) {}
......@@ -18,29 +18,29 @@ limitations under the License. */
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/impl/full_kernel_impl.h"
PT_REGISTER_CTX_KERNEL(full,
CPU,
ALL_LAYOUT,
pten::FullKernel,
float,
double,
uint8_t,
int16_t,
int,
int64_t,
bool,
paddle::platform::float16,
paddle::platform::bfloat16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_KERNEL(full,
CPU,
ALL_LAYOUT,
pten::FullKernel,
float,
double,
uint8_t,
int16_t,
int,
int64_t,
bool,
paddle::platform::float16,
paddle::platform::bfloat16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_CTX_KERNEL(full_like,
CPU,
ALL_LAYOUT,
pten::FullLikeKernel,
float,
double,
int,
int64_t,
bool,
paddle::platform::float16) {}
PT_REGISTER_KERNEL(full_like,
CPU,
ALL_LAYOUT,
pten::FullLikeKernel,
float,
double,
int,
int64_t,
bool,
paddle::platform::float16) {}
......@@ -118,60 +118,60 @@ using complex128 = ::paddle::platform::complex<double>;
// NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16
// using bfloat16 = ::paddle::platform::bfloat16;
PT_REGISTER_CTX_KERNEL(
PT_REGISTER_KERNEL(
mean, CPU, ALL_LAYOUT, pten::MeanKernel, float, double, bool) {}
PT_REGISTER_CTX_KERNEL(add,
CPU,
ALL_LAYOUT,
pten::AddKernel,
float,
double,
int,
int64_t,
complex64,
complex128) {}
PT_REGISTER_CTX_KERNEL(subtract,
CPU,
ALL_LAYOUT,
pten::SubtractKernel,
float,
double,
int,
int64_t,
complex64,
complex128) {}
PT_REGISTER_CTX_KERNEL(divide,
CPU,
ALL_LAYOUT,
pten::DivideKernel,
float,
double,
int,
int64_t,
complex64,
complex128) {}
PT_REGISTER_CTX_KERNEL(multiply,
CPU,
ALL_LAYOUT,
pten::MultiplyKernel,
float,
double,
int,
int64_t,
bool,
complex64,
complex128) {}
PT_REGISTER_CTX_KERNEL(sum,
CPU,
ALL_LAYOUT,
pten::SumKernel,
bool,
float,
double,
paddle::platform::float16,
int,
int64_t,
complex64,
complex128) {
PT_REGISTER_KERNEL(add,
CPU,
ALL_LAYOUT,
pten::AddKernel,
float,
double,
int,
int64_t,
complex64,
complex128) {}
PT_REGISTER_KERNEL(subtract,
CPU,
ALL_LAYOUT,
pten::SubtractKernel,
float,
double,
int,
int64_t,
complex64,
complex128) {}
PT_REGISTER_KERNEL(divide,
CPU,
ALL_LAYOUT,
pten::DivideKernel,
float,
double,
int,
int64_t,
complex64,
complex128) {}
PT_REGISTER_KERNEL(multiply,
CPU,
ALL_LAYOUT,
pten::MultiplyKernel,
float,
double,
int,
int64_t,
bool,
complex64,
complex128) {}
PT_REGISTER_KERNEL(sum,
CPU,
ALL_LAYOUT,
pten::SumKernel,
bool,
float,
double,
paddle::platform::float16,
int,
int64_t,
complex64,
complex128) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
}
......@@ -19,29 +19,29 @@ limitations under the License. */
#include "paddle/pten/kernels/impl/matmul_grad_kernel_impl.h"
PT_REGISTER_CTX_KERNEL(matmul_grad,
CPU,
ALL_LAYOUT,
pten::MatmulGradKernel,
float,
double,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_CTX_KERNEL(matmul_double_grad,
CPU,
ALL_LAYOUT,
pten::MatmulDoubleGradKernel,
float,
double,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_CTX_KERNEL(matmul_triple_grad,
CPU,
ALL_LAYOUT,
pten::MatmulTripleGradKernel,
float,
double,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_KERNEL(matmul_grad,
CPU,
ALL_LAYOUT,
pten::MatmulGradKernel,
float,
double,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_KERNEL(matmul_double_grad,
CPU,
ALL_LAYOUT,
pten::MatmulDoubleGradKernel,
float,
double,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_KERNEL(matmul_triple_grad,
CPU,
ALL_LAYOUT,
pten::MatmulTripleGradKernel,
float,
double,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
......@@ -20,11 +20,11 @@ limitations under the License. */
#include "paddle/fluid/platform/complex.h"
#include "paddle/pten/kernels/impl/matmul_kernel_impl.h"
PT_REGISTER_CTX_KERNEL(matmul,
CPU,
ALL_LAYOUT,
pten::MatmulKernel,
float,
double,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_KERNEL(matmul,
CPU,
ALL_LAYOUT,
pten::MatmulKernel,
float,
double,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
......@@ -51,15 +51,15 @@ void ScaleKernel(const Context& dev_ctx,
} // namespace pten
PT_REGISTER_CTX_KERNEL(scale,
CPU,
ALL_LAYOUT,
pten::ScaleKernel,
float,
double,
paddle::platform::bfloat16,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}
PT_REGISTER_KERNEL(scale,
CPU,
ALL_LAYOUT,
pten::ScaleKernel,
float,
double,
paddle::platform::bfloat16,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}
......@@ -21,5 +21,4 @@ limitations under the License. */
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/bfloat16.h"
PT_REGISTER_CTX_KERNEL(sign, CPU, ALL_LAYOUT, pten::SignKernel, float, double) {
}
PT_REGISTER_KERNEL(sign, CPU, ALL_LAYOUT, pten::SignKernel, float, double) {}
......@@ -34,66 +34,66 @@ void EmptyLikeKernel(const Context& dev_ctx, DenseTensor* out) {
} // namespace pten
PT_REGISTER_CTX_KERNEL(empty,
CPU,
ALL_LAYOUT,
pten::EmptyKernel,
float,
double,
uint8_t,
int16_t,
int,
int64_t,
bool,
paddle::platform::float16,
paddle::platform::bfloat16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_KERNEL(empty,
CPU,
ALL_LAYOUT,
pten::EmptyKernel,
float,
double,
uint8_t,
int16_t,
int,
int64_t,
bool,
paddle::platform::float16,
paddle::platform::bfloat16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_CTX_KERNEL(empty_like,
CPU,
ALL_LAYOUT,
pten::EmptyLikeKernel,
float,
double,
uint8_t,
int16_t,
int,
int64_t,
bool,
paddle::platform::float16,
paddle::platform::bfloat16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_KERNEL(empty_like,
CPU,
ALL_LAYOUT,
pten::EmptyLikeKernel,
float,
double,
uint8_t,
int16_t,
int,
int64_t,
bool,
paddle::platform::float16,
paddle::platform::bfloat16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_REGISTER_CTX_KERNEL(empty,
GPU,
ALL_LAYOUT,
pten::EmptyKernel,
float,
double,
uint8_t,
int16_t,
int,
int64_t,
bool,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_KERNEL(empty,
GPU,
ALL_LAYOUT,
pten::EmptyKernel,
float,
double,
uint8_t,
int16_t,
int,
int64_t,
bool,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_CTX_KERNEL(empty_like,
GPU,
ALL_LAYOUT,
pten::EmptyLikeKernel,
float,
double,
uint8_t,
int16_t,
int,
int64_t,
bool,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_KERNEL(empty_like,
GPU,
ALL_LAYOUT,
pten::EmptyLikeKernel,
float,
double,
uint8_t,
int16_t,
int,
int64_t,
bool,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
#endif
......@@ -33,41 +33,41 @@ void FlattenGradKernel(const Context& dev_ctx,
} // namespace pten
PT_REGISTER_CTX_KERNEL(flatten_grad,
CPU,
ALL_LAYOUT,
pten::FlattenGradKernel,
float,
double,
uint8_t,
int8_t,
int,
int64_t) {}
PT_REGISTER_KERNEL(flatten_grad,
CPU,
ALL_LAYOUT,
pten::FlattenGradKernel,
float,
double,
uint8_t,
int8_t,
int,
int64_t) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_REGISTER_CTX_KERNEL(flatten_grad,
GPU,
ALL_LAYOUT,
pten::FlattenGradKernel,
float,
paddle::platform::float16,
double,
uint8_t,
int8_t,
int,
int64_t) {}
PT_REGISTER_KERNEL(flatten_grad,
GPU,
ALL_LAYOUT,
pten::FlattenGradKernel,
float,
paddle::platform::float16,
double,
uint8_t,
int8_t,
int,
int64_t) {}
#endif
#ifdef PADDLE_WITH_XPU
PT_REGISTER_CTX_KERNEL(flatten_grad,
XPU,
ALL_LAYOUT,
pten::FlattenGradKernel,
float,
paddle::platform::float16,
int8_t,
int,
int64_t) {}
PT_REGISTER_KERNEL(flatten_grad,
XPU,
ALL_LAYOUT,
pten::FlattenGradKernel,
float,
paddle::platform::float16,
int8_t,
int,
int64_t) {}
#endif
......@@ -48,72 +48,72 @@ void FlattenWithXShape(const Context& dev_ctx,
} // namespace pten
PT_REGISTER_CTX_KERNEL(flatten,
CPU,
ALL_LAYOUT,
pten::FlattenKernel,
float,
double,
uint8_t,
int8_t,
int,
int64_t) {}
PT_REGISTER_KERNEL(flatten,
CPU,
ALL_LAYOUT,
pten::FlattenKernel,
float,
double,
uint8_t,
int8_t,
int,
int64_t) {}
PT_REGISTER_CTX_KERNEL(flatten_with_xshape,
CPU,
ALL_LAYOUT,
pten::FlattenWithXShape,
float,
double,
uint8_t,
int8_t,
int,
int64_t) {}
PT_REGISTER_KERNEL(flatten_with_xshape,
CPU,
ALL_LAYOUT,
pten::FlattenWithXShape,
float,
double,
uint8_t,
int8_t,
int,
int64_t) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_REGISTER_CTX_KERNEL(flatten,
GPU,
ALL_LAYOUT,
pten::FlattenKernel,
float,
paddle::platform::float16,
double,
uint8_t,
int8_t,
int,
int64_t) {}
PT_REGISTER_KERNEL(flatten,
GPU,
ALL_LAYOUT,
pten::FlattenKernel,
float,
paddle::platform::float16,
double,
uint8_t,
int8_t,
int,
int64_t) {}
PT_REGISTER_CTX_KERNEL(flatten_with_xshape,
GPU,
ALL_LAYOUT,
pten::FlattenWithXShape,
float,
paddle::platform::float16,
double,
uint8_t,
int8_t,
int,
int64_t) {}
PT_REGISTER_KERNEL(flatten_with_xshape,
GPU,
ALL_LAYOUT,
pten::FlattenWithXShape,
float,
paddle::platform::float16,
double,
uint8_t,
int8_t,
int,
int64_t) {}
#endif
#ifdef PADDLE_WITH_XPU
PT_REGISTER_CTX_KERNEL(flatten,
XPU,
ALL_LAYOUT,
pten::FlattenKernel,
float,
paddle::platform::float16,
int8_t,
int,
int64_t) {}
PT_REGISTER_KERNEL(flatten,
XPU,
ALL_LAYOUT,
pten::FlattenKernel,
float,
paddle::platform::float16,
int8_t,
int,
int64_t) {}
PT_REGISTER_CTX_KERNEL(flatten_with_xshape,
XPU,
ALL_LAYOUT,
pten::FlattenWithXShape,
float,
paddle::platform::float16,
int8_t,
int,
int64_t) {}
PT_REGISTER_KERNEL(flatten_with_xshape,
XPU,
ALL_LAYOUT,
pten::FlattenWithXShape,
float,
paddle::platform::float16,
int8_t,
int,
int64_t) {}
#endif
......@@ -60,24 +60,24 @@ void CastKernel(const Context& dev_ctx,
} // namespace pten
#define PTEN_REGISTER_CAST_CUDA_BASE_TYPE(op_name, ...) \
PT_REGISTER_CTX_KERNEL(cast, \
GPU, \
ALL_LAYOUT, \
pten::CastKernel, \
float, \
double, \
int, \
int64_t, \
int16_t, \
bool, \
uint8_t, \
paddle::platform::float16, \
paddle::platform::complex<float>, \
paddle::platform::complex<double>, \
##__VA_ARGS__) { \
kernel->OutputAt(0).SetDataType( \
paddle::experimental::DataType::UNDEFINED); \
#define PTEN_REGISTER_CAST_CUDA_BASE_TYPE(op_name, ...) \
PT_REGISTER_KERNEL(cast, \
GPU, \
ALL_LAYOUT, \
pten::CastKernel, \
float, \
double, \
int, \
int64_t, \
int16_t, \
bool, \
uint8_t, \
paddle::platform::float16, \
paddle::platform::complex<float>, \
paddle::platform::complex<double>, \
##__VA_ARGS__) { \
kernel->OutputAt(0).SetDataType( \
paddle::experimental::DataType::UNDEFINED); \
}
#if !defined(PADDLE_WITH_HIP)
......
......@@ -21,14 +21,14 @@
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/complex.h"
PT_REGISTER_CTX_KERNEL(conj,
GPU,
ALL_LAYOUT,
pten::ConjKernel,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>,
float,
double,
int,
int64_t) {}
PT_REGISTER_KERNEL(conj,
GPU,
ALL_LAYOUT,
pten::ConjKernel,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>,
float,
double,
int,
int64_t) {}
......@@ -20,13 +20,13 @@ limitations under the License. */
#include "paddle/fluid/platform/complex.h"
PT_REGISTER_CTX_KERNEL(dot_grad,
GPU,
ALL_LAYOUT,
pten::DotGradKernel,
float,
double,
int,
int64_t,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_KERNEL(dot_grad,
GPU,
ALL_LAYOUT,
pten::DotGradKernel,
float,
double,
int,
int64_t,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
......@@ -52,13 +52,13 @@ void DotKernel(const Context& dev_ctx,
using complex64 = ::paddle::platform::complex<float>;
using complex128 = ::paddle::platform::complex<double>;
PT_REGISTER_CTX_KERNEL(dot,
GPU,
ALL_LAYOUT,
pten::DotKernel,
float,
double,
int,
int64_t,
complex64,
complex128) {}
PT_REGISTER_KERNEL(dot,
GPU,
ALL_LAYOUT,
pten::DotKernel,
float,
double,
int,
int64_t,
complex64,
complex128) {}
......@@ -18,28 +18,28 @@ limitations under the License. */
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/impl/full_kernel_impl.h"
PT_REGISTER_CTX_KERNEL(full,
GPU,
ALL_LAYOUT,
pten::FullKernel,
float,
double,
uint8_t,
int16_t,
int,
int64_t,
bool,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_KERNEL(full,
GPU,
ALL_LAYOUT,
pten::FullKernel,
float,
double,
uint8_t,
int16_t,
int,
int64_t,
bool,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_CTX_KERNEL(full_like,
GPU,
ALL_LAYOUT,
pten::FullLikeKernel,
float,
double,
int,
int64_t,
bool,
paddle::platform::float16) {}
PT_REGISTER_KERNEL(full_like,
GPU,
ALL_LAYOUT,
pten::FullLikeKernel,
float,
double,
int,
int64_t,
bool,
paddle::platform::float16) {}
......@@ -110,64 +110,64 @@ using float16 = paddle::platform::float16;
using complex64 = ::paddle::platform::complex<float>;
using complex128 = ::paddle::platform::complex<double>;
PT_REGISTER_CTX_KERNEL(
PT_REGISTER_KERNEL(
mean, GPU, ALL_LAYOUT, pten::MeanKernel, float, double, bool, float16) {}
PT_REGISTER_CTX_KERNEL(add,
GPU,
ALL_LAYOUT,
pten::AddKernel,
float,
double,
int,
int64_t,
float16,
complex64,
complex128) {}
PT_REGISTER_CTX_KERNEL(subtract,
GPU,
ALL_LAYOUT,
pten::SubtractKernel,
float,
double,
int,
int64_t,
float16,
complex64,
complex128) {}
PT_REGISTER_CTX_KERNEL(divide,
GPU,
ALL_LAYOUT,
pten::DivideKernel,
float,
double,
int,
int64_t,
float16,
complex64,
complex128) {}
PT_REGISTER_CTX_KERNEL(multiply,
GPU,
ALL_LAYOUT,
pten::MultiplyKernel,
float,
double,
int,
int64_t,
bool,
float16,
complex64,
complex128) {}
PT_REGISTER_CTX_KERNEL(sum,
GPU,
ALL_LAYOUT,
pten::SumKernel,
bool,
float,
double,
float16,
int,
int64_t,
complex64,
complex128) {
PT_REGISTER_KERNEL(add,
GPU,
ALL_LAYOUT,
pten::AddKernel,
float,
double,
int,
int64_t,
float16,
complex64,
complex128) {}
PT_REGISTER_KERNEL(subtract,
GPU,
ALL_LAYOUT,
pten::SubtractKernel,
float,
double,
int,
int64_t,
float16,
complex64,
complex128) {}
PT_REGISTER_KERNEL(divide,
GPU,
ALL_LAYOUT,
pten::DivideKernel,
float,
double,
int,
int64_t,
float16,
complex64,
complex128) {}
PT_REGISTER_KERNEL(multiply,
GPU,
ALL_LAYOUT,
pten::MultiplyKernel,
float,
double,
int,
int64_t,
bool,
float16,
complex64,
complex128) {}
PT_REGISTER_KERNEL(sum,
GPU,
ALL_LAYOUT,
pten::SumKernel,
bool,
float,
double,
float16,
int,
int64_t,
complex64,
complex128) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
}
......@@ -19,32 +19,32 @@ limitations under the License. */
#include "paddle/pten/kernels/impl/matmul_grad_kernel_impl.h"
PT_REGISTER_CTX_KERNEL(matmul_grad,
GPU,
ALL_LAYOUT,
pten::MatmulGradKernel,
float,
double,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_CTX_KERNEL(matmul_double_grad,
GPU,
ALL_LAYOUT,
pten::MatmulDoubleGradKernel,
float,
double,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_CTX_KERNEL(matmul_triple_grad,
GPU,
ALL_LAYOUT,
pten::MatmulTripleGradKernel,
float,
double,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_KERNEL(matmul_grad,
GPU,
ALL_LAYOUT,
pten::MatmulGradKernel,
float,
double,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_KERNEL(matmul_double_grad,
GPU,
ALL_LAYOUT,
pten::MatmulDoubleGradKernel,
float,
double,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_KERNEL(matmul_triple_grad,
GPU,
ALL_LAYOUT,
pten::MatmulTripleGradKernel,
float,
double,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
......@@ -20,12 +20,12 @@ limitations under the License. */
#include "paddle/fluid/platform/complex.h"
#include "paddle/pten/kernels/impl/matmul_kernel_impl.h"
PT_REGISTER_CTX_KERNEL(matmul,
GPU,
ALL_LAYOUT,
pten::MatmulKernel,
float,
double,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_KERNEL(matmul,
GPU,
ALL_LAYOUT,
pten::MatmulKernel,
float,
double,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
......@@ -64,15 +64,15 @@ void ScaleKernel(const ContextT& dev_ctx,
} // namespace pten
PT_REGISTER_CTX_KERNEL(scale,
GPU,
ALL_LAYOUT,
pten::ScaleKernel,
float,
double,
paddle::platform::float16,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}
PT_REGISTER_KERNEL(scale,
GPU,
ALL_LAYOUT,
pten::ScaleKernel,
float,
double,
paddle::platform::float16,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}
......@@ -23,5 +23,5 @@ limitations under the License. */
using float16 = paddle::platform::float16;
PT_REGISTER_CTX_KERNEL(
PT_REGISTER_KERNEL(
sign, GPU, ALL_LAYOUT, pten::SignKernel, float, double, float16) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册