未验证 提交 158bf13f 编写于 作者: C Chen Weihang 提交者: GitHub

[PTen] Rename kernel register marco (#38861)

* rename register marco

* fix error changing

* fix format error
上级 dccdc719
...@@ -16,12 +16,12 @@ ...@@ -16,12 +16,12 @@
function(kernel_declare TARGET_LIST) function(kernel_declare TARGET_LIST)
foreach(kernel_path ${TARGET_LIST}) foreach(kernel_path ${TARGET_LIST})
file(READ ${kernel_path} kernel_impl) file(READ ${kernel_path} kernel_impl)
# TODO(chenweihang): rename PT_REGISTER_CTX_KERNEL to PT_REGISTER_KERNEL # TODO(chenweihang): rename PT_REGISTER_KERNEL to PT_REGISTER_KERNEL
# NOTE(chenweihang): now we don't recommend to use digit in kernel name # NOTE(chenweihang): now we don't recommend to use digit in kernel name
string(REGEX MATCH "(PT_REGISTER_CTX_KERNEL|PT_REGISTER_GENERAL_KERNEL)\\([ \t\r\n]*[a-z0-9_]*," first_registry "${kernel_impl}") string(REGEX MATCH "(PT_REGISTER_KERNEL|PT_REGISTER_GENERAL_KERNEL)\\([ \t\r\n]*[a-z0-9_]*," first_registry "${kernel_impl}")
if (NOT first_registry STREQUAL "") if (NOT first_registry STREQUAL "")
# parse the first kernel name # parse the first kernel name
string(REPLACE "PT_REGISTER_CTX_KERNEL(" "" kernel_name "${first_registry}") string(REPLACE "PT_REGISTER_KERNEL(" "" kernel_name "${first_registry}")
string(REPLACE "PT_REGISTER_GENERAL_KERNEL(" "" kernel_name "${kernel_name}") string(REPLACE "PT_REGISTER_GENERAL_KERNEL(" "" kernel_name "${kernel_name}")
string(REPLACE "," "" kernel_name "${kernel_name}") string(REPLACE "," "" kernel_name "${kernel_name}")
string(REGEX REPLACE "[ \t\r\n]+" "" kernel_name "${kernel_name}") string(REGEX REPLACE "[ \t\r\n]+" "" kernel_name "${kernel_name}")
......
此差异已折叠。
...@@ -58,7 +58,7 @@ void CastKernel(const Context& dev_ctx, ...@@ -58,7 +58,7 @@ void CastKernel(const Context& dev_ctx,
} // namespace pten } // namespace pten
PT_REGISTER_CTX_KERNEL(cast, PT_REGISTER_KERNEL(cast,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::CastKernel, pten::CastKernel,
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
// See Note [ Why still include the fluid headers? ] // See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/complex.h"
PT_REGISTER_CTX_KERNEL(conj, PT_REGISTER_KERNEL(conj,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::ConjKernel, pten::ConjKernel,
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/complex.h"
PT_REGISTER_CTX_KERNEL(dot_grad, PT_REGISTER_KERNEL(dot_grad,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::DotGradKernel, pten::DotGradKernel,
......
...@@ -49,7 +49,7 @@ void DotKernel(const Context& dev_ctx, ...@@ -49,7 +49,7 @@ void DotKernel(const Context& dev_ctx,
using complex64 = ::paddle::platform::complex<float>; using complex64 = ::paddle::platform::complex<float>;
using complex128 = ::paddle::platform::complex<double>; using complex128 = ::paddle::platform::complex<double>;
PT_REGISTER_CTX_KERNEL(dot, PT_REGISTER_KERNEL(dot,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::DotKernel, pten::DotKernel,
......
...@@ -18,7 +18,7 @@ limitations under the License. */ ...@@ -18,7 +18,7 @@ limitations under the License. */
#include "paddle/pten/core/kernel_registry.h" #include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/impl/full_kernel_impl.h" #include "paddle/pten/kernels/impl/full_kernel_impl.h"
PT_REGISTER_CTX_KERNEL(full, PT_REGISTER_KERNEL(full,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::FullKernel, pten::FullKernel,
...@@ -34,7 +34,7 @@ PT_REGISTER_CTX_KERNEL(full, ...@@ -34,7 +34,7 @@ PT_REGISTER_CTX_KERNEL(full,
paddle::platform::complex<float>, paddle::platform::complex<float>,
paddle::platform::complex<double>) {} paddle::platform::complex<double>) {}
PT_REGISTER_CTX_KERNEL(full_like, PT_REGISTER_KERNEL(full_like,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::FullLikeKernel, pten::FullLikeKernel,
......
...@@ -118,9 +118,9 @@ using complex128 = ::paddle::platform::complex<double>; ...@@ -118,9 +118,9 @@ using complex128 = ::paddle::platform::complex<double>;
// NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16 // NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16
// using bfloat16 = ::paddle::platform::bfloat16; // using bfloat16 = ::paddle::platform::bfloat16;
PT_REGISTER_CTX_KERNEL( PT_REGISTER_KERNEL(
mean, CPU, ALL_LAYOUT, pten::MeanKernel, float, double, bool) {} mean, CPU, ALL_LAYOUT, pten::MeanKernel, float, double, bool) {}
PT_REGISTER_CTX_KERNEL(add, PT_REGISTER_KERNEL(add,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::AddKernel, pten::AddKernel,
...@@ -130,7 +130,7 @@ PT_REGISTER_CTX_KERNEL(add, ...@@ -130,7 +130,7 @@ PT_REGISTER_CTX_KERNEL(add,
int64_t, int64_t,
complex64, complex64,
complex128) {} complex128) {}
PT_REGISTER_CTX_KERNEL(subtract, PT_REGISTER_KERNEL(subtract,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::SubtractKernel, pten::SubtractKernel,
...@@ -140,7 +140,7 @@ PT_REGISTER_CTX_KERNEL(subtract, ...@@ -140,7 +140,7 @@ PT_REGISTER_CTX_KERNEL(subtract,
int64_t, int64_t,
complex64, complex64,
complex128) {} complex128) {}
PT_REGISTER_CTX_KERNEL(divide, PT_REGISTER_KERNEL(divide,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::DivideKernel, pten::DivideKernel,
...@@ -150,7 +150,7 @@ PT_REGISTER_CTX_KERNEL(divide, ...@@ -150,7 +150,7 @@ PT_REGISTER_CTX_KERNEL(divide,
int64_t, int64_t,
complex64, complex64,
complex128) {} complex128) {}
PT_REGISTER_CTX_KERNEL(multiply, PT_REGISTER_KERNEL(multiply,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::MultiplyKernel, pten::MultiplyKernel,
...@@ -161,7 +161,7 @@ PT_REGISTER_CTX_KERNEL(multiply, ...@@ -161,7 +161,7 @@ PT_REGISTER_CTX_KERNEL(multiply,
bool, bool,
complex64, complex64,
complex128) {} complex128) {}
PT_REGISTER_CTX_KERNEL(sum, PT_REGISTER_KERNEL(sum,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::SumKernel, pten::SumKernel,
......
...@@ -19,7 +19,7 @@ limitations under the License. */ ...@@ -19,7 +19,7 @@ limitations under the License. */
#include "paddle/pten/kernels/impl/matmul_grad_kernel_impl.h" #include "paddle/pten/kernels/impl/matmul_grad_kernel_impl.h"
PT_REGISTER_CTX_KERNEL(matmul_grad, PT_REGISTER_KERNEL(matmul_grad,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::MatmulGradKernel, pten::MatmulGradKernel,
...@@ -28,7 +28,7 @@ PT_REGISTER_CTX_KERNEL(matmul_grad, ...@@ -28,7 +28,7 @@ PT_REGISTER_CTX_KERNEL(matmul_grad,
paddle::platform::complex<float>, paddle::platform::complex<float>,
paddle::platform::complex<double>) {} paddle::platform::complex<double>) {}
PT_REGISTER_CTX_KERNEL(matmul_double_grad, PT_REGISTER_KERNEL(matmul_double_grad,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::MatmulDoubleGradKernel, pten::MatmulDoubleGradKernel,
...@@ -37,7 +37,7 @@ PT_REGISTER_CTX_KERNEL(matmul_double_grad, ...@@ -37,7 +37,7 @@ PT_REGISTER_CTX_KERNEL(matmul_double_grad,
paddle::platform::complex<float>, paddle::platform::complex<float>,
paddle::platform::complex<double>) {} paddle::platform::complex<double>) {}
PT_REGISTER_CTX_KERNEL(matmul_triple_grad, PT_REGISTER_KERNEL(matmul_triple_grad,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::MatmulTripleGradKernel, pten::MatmulTripleGradKernel,
......
...@@ -20,7 +20,7 @@ limitations under the License. */ ...@@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/complex.h"
#include "paddle/pten/kernels/impl/matmul_kernel_impl.h" #include "paddle/pten/kernels/impl/matmul_kernel_impl.h"
PT_REGISTER_CTX_KERNEL(matmul, PT_REGISTER_KERNEL(matmul,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::MatmulKernel, pten::MatmulKernel,
......
...@@ -51,7 +51,7 @@ void ScaleKernel(const Context& dev_ctx, ...@@ -51,7 +51,7 @@ void ScaleKernel(const Context& dev_ctx,
} // namespace pten } // namespace pten
PT_REGISTER_CTX_KERNEL(scale, PT_REGISTER_KERNEL(scale,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::ScaleKernel, pten::ScaleKernel,
......
...@@ -21,5 +21,4 @@ limitations under the License. */ ...@@ -21,5 +21,4 @@ limitations under the License. */
// See Note [ Why still include the fluid headers? ] // See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/bfloat16.h"
PT_REGISTER_CTX_KERNEL(sign, CPU, ALL_LAYOUT, pten::SignKernel, float, double) { PT_REGISTER_KERNEL(sign, CPU, ALL_LAYOUT, pten::SignKernel, float, double) {}
}
...@@ -34,7 +34,7 @@ void EmptyLikeKernel(const Context& dev_ctx, DenseTensor* out) { ...@@ -34,7 +34,7 @@ void EmptyLikeKernel(const Context& dev_ctx, DenseTensor* out) {
} // namespace pten } // namespace pten
PT_REGISTER_CTX_KERNEL(empty, PT_REGISTER_KERNEL(empty,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::EmptyKernel, pten::EmptyKernel,
...@@ -50,7 +50,7 @@ PT_REGISTER_CTX_KERNEL(empty, ...@@ -50,7 +50,7 @@ PT_REGISTER_CTX_KERNEL(empty,
paddle::platform::complex<float>, paddle::platform::complex<float>,
paddle::platform::complex<double>) {} paddle::platform::complex<double>) {}
PT_REGISTER_CTX_KERNEL(empty_like, PT_REGISTER_KERNEL(empty_like,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::EmptyLikeKernel, pten::EmptyLikeKernel,
...@@ -67,7 +67,7 @@ PT_REGISTER_CTX_KERNEL(empty_like, ...@@ -67,7 +67,7 @@ PT_REGISTER_CTX_KERNEL(empty_like,
paddle::platform::complex<double>) {} paddle::platform::complex<double>) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_REGISTER_CTX_KERNEL(empty, PT_REGISTER_KERNEL(empty,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::EmptyKernel, pten::EmptyKernel,
...@@ -82,7 +82,7 @@ PT_REGISTER_CTX_KERNEL(empty, ...@@ -82,7 +82,7 @@ PT_REGISTER_CTX_KERNEL(empty,
paddle::platform::complex<float>, paddle::platform::complex<float>,
paddle::platform::complex<double>) {} paddle::platform::complex<double>) {}
PT_REGISTER_CTX_KERNEL(empty_like, PT_REGISTER_KERNEL(empty_like,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::EmptyLikeKernel, pten::EmptyLikeKernel,
......
...@@ -33,7 +33,7 @@ void FlattenGradKernel(const Context& dev_ctx, ...@@ -33,7 +33,7 @@ void FlattenGradKernel(const Context& dev_ctx,
} // namespace pten } // namespace pten
PT_REGISTER_CTX_KERNEL(flatten_grad, PT_REGISTER_KERNEL(flatten_grad,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::FlattenGradKernel, pten::FlattenGradKernel,
...@@ -45,7 +45,7 @@ PT_REGISTER_CTX_KERNEL(flatten_grad, ...@@ -45,7 +45,7 @@ PT_REGISTER_CTX_KERNEL(flatten_grad,
int64_t) {} int64_t) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_REGISTER_CTX_KERNEL(flatten_grad, PT_REGISTER_KERNEL(flatten_grad,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::FlattenGradKernel, pten::FlattenGradKernel,
...@@ -60,7 +60,7 @@ PT_REGISTER_CTX_KERNEL(flatten_grad, ...@@ -60,7 +60,7 @@ PT_REGISTER_CTX_KERNEL(flatten_grad,
#endif #endif
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
PT_REGISTER_CTX_KERNEL(flatten_grad, PT_REGISTER_KERNEL(flatten_grad,
XPU, XPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::FlattenGradKernel, pten::FlattenGradKernel,
......
...@@ -48,7 +48,7 @@ void FlattenWithXShape(const Context& dev_ctx, ...@@ -48,7 +48,7 @@ void FlattenWithXShape(const Context& dev_ctx,
} // namespace pten } // namespace pten
PT_REGISTER_CTX_KERNEL(flatten, PT_REGISTER_KERNEL(flatten,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::FlattenKernel, pten::FlattenKernel,
...@@ -59,7 +59,7 @@ PT_REGISTER_CTX_KERNEL(flatten, ...@@ -59,7 +59,7 @@ PT_REGISTER_CTX_KERNEL(flatten,
int, int,
int64_t) {} int64_t) {}
PT_REGISTER_CTX_KERNEL(flatten_with_xshape, PT_REGISTER_KERNEL(flatten_with_xshape,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::FlattenWithXShape, pten::FlattenWithXShape,
...@@ -71,7 +71,7 @@ PT_REGISTER_CTX_KERNEL(flatten_with_xshape, ...@@ -71,7 +71,7 @@ PT_REGISTER_CTX_KERNEL(flatten_with_xshape,
int64_t) {} int64_t) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_REGISTER_CTX_KERNEL(flatten, PT_REGISTER_KERNEL(flatten,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::FlattenKernel, pten::FlattenKernel,
...@@ -83,7 +83,7 @@ PT_REGISTER_CTX_KERNEL(flatten, ...@@ -83,7 +83,7 @@ PT_REGISTER_CTX_KERNEL(flatten,
int, int,
int64_t) {} int64_t) {}
PT_REGISTER_CTX_KERNEL(flatten_with_xshape, PT_REGISTER_KERNEL(flatten_with_xshape,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::FlattenWithXShape, pten::FlattenWithXShape,
...@@ -97,7 +97,7 @@ PT_REGISTER_CTX_KERNEL(flatten_with_xshape, ...@@ -97,7 +97,7 @@ PT_REGISTER_CTX_KERNEL(flatten_with_xshape,
#endif #endif
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
PT_REGISTER_CTX_KERNEL(flatten, PT_REGISTER_KERNEL(flatten,
XPU, XPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::FlattenKernel, pten::FlattenKernel,
...@@ -107,7 +107,7 @@ PT_REGISTER_CTX_KERNEL(flatten, ...@@ -107,7 +107,7 @@ PT_REGISTER_CTX_KERNEL(flatten,
int, int,
int64_t) {} int64_t) {}
PT_REGISTER_CTX_KERNEL(flatten_with_xshape, PT_REGISTER_KERNEL(flatten_with_xshape,
XPU, XPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::FlattenWithXShape, pten::FlattenWithXShape,
......
...@@ -61,7 +61,7 @@ void CastKernel(const Context& dev_ctx, ...@@ -61,7 +61,7 @@ void CastKernel(const Context& dev_ctx,
} // namespace pten } // namespace pten
#define PTEN_REGISTER_CAST_CUDA_BASE_TYPE(op_name, ...) \ #define PTEN_REGISTER_CAST_CUDA_BASE_TYPE(op_name, ...) \
PT_REGISTER_CTX_KERNEL(cast, \ PT_REGISTER_KERNEL(cast, \
GPU, \ GPU, \
ALL_LAYOUT, \ ALL_LAYOUT, \
pten::CastKernel, \ pten::CastKernel, \
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
// See Note [ Why still include the fluid headers? ] // See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/complex.h"
PT_REGISTER_CTX_KERNEL(conj, PT_REGISTER_KERNEL(conj,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::ConjKernel, pten::ConjKernel,
......
...@@ -20,7 +20,7 @@ limitations under the License. */ ...@@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/complex.h"
PT_REGISTER_CTX_KERNEL(dot_grad, PT_REGISTER_KERNEL(dot_grad,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::DotGradKernel, pten::DotGradKernel,
......
...@@ -52,7 +52,7 @@ void DotKernel(const Context& dev_ctx, ...@@ -52,7 +52,7 @@ void DotKernel(const Context& dev_ctx,
using complex64 = ::paddle::platform::complex<float>; using complex64 = ::paddle::platform::complex<float>;
using complex128 = ::paddle::platform::complex<double>; using complex128 = ::paddle::platform::complex<double>;
PT_REGISTER_CTX_KERNEL(dot, PT_REGISTER_KERNEL(dot,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::DotKernel, pten::DotKernel,
......
...@@ -18,7 +18,7 @@ limitations under the License. */ ...@@ -18,7 +18,7 @@ limitations under the License. */
#include "paddle/pten/core/kernel_registry.h" #include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/impl/full_kernel_impl.h" #include "paddle/pten/kernels/impl/full_kernel_impl.h"
PT_REGISTER_CTX_KERNEL(full, PT_REGISTER_KERNEL(full,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::FullKernel, pten::FullKernel,
...@@ -33,7 +33,7 @@ PT_REGISTER_CTX_KERNEL(full, ...@@ -33,7 +33,7 @@ PT_REGISTER_CTX_KERNEL(full,
paddle::platform::complex<float>, paddle::platform::complex<float>,
paddle::platform::complex<double>) {} paddle::platform::complex<double>) {}
PT_REGISTER_CTX_KERNEL(full_like, PT_REGISTER_KERNEL(full_like,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::FullLikeKernel, pten::FullLikeKernel,
......
...@@ -110,9 +110,9 @@ using float16 = paddle::platform::float16; ...@@ -110,9 +110,9 @@ using float16 = paddle::platform::float16;
using complex64 = ::paddle::platform::complex<float>; using complex64 = ::paddle::platform::complex<float>;
using complex128 = ::paddle::platform::complex<double>; using complex128 = ::paddle::platform::complex<double>;
PT_REGISTER_CTX_KERNEL( PT_REGISTER_KERNEL(
mean, GPU, ALL_LAYOUT, pten::MeanKernel, float, double, bool, float16) {} mean, GPU, ALL_LAYOUT, pten::MeanKernel, float, double, bool, float16) {}
PT_REGISTER_CTX_KERNEL(add, PT_REGISTER_KERNEL(add,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::AddKernel, pten::AddKernel,
...@@ -123,7 +123,7 @@ PT_REGISTER_CTX_KERNEL(add, ...@@ -123,7 +123,7 @@ PT_REGISTER_CTX_KERNEL(add,
float16, float16,
complex64, complex64,
complex128) {} complex128) {}
PT_REGISTER_CTX_KERNEL(subtract, PT_REGISTER_KERNEL(subtract,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::SubtractKernel, pten::SubtractKernel,
...@@ -134,7 +134,7 @@ PT_REGISTER_CTX_KERNEL(subtract, ...@@ -134,7 +134,7 @@ PT_REGISTER_CTX_KERNEL(subtract,
float16, float16,
complex64, complex64,
complex128) {} complex128) {}
PT_REGISTER_CTX_KERNEL(divide, PT_REGISTER_KERNEL(divide,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::DivideKernel, pten::DivideKernel,
...@@ -145,7 +145,7 @@ PT_REGISTER_CTX_KERNEL(divide, ...@@ -145,7 +145,7 @@ PT_REGISTER_CTX_KERNEL(divide,
float16, float16,
complex64, complex64,
complex128) {} complex128) {}
PT_REGISTER_CTX_KERNEL(multiply, PT_REGISTER_KERNEL(multiply,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::MultiplyKernel, pten::MultiplyKernel,
...@@ -157,7 +157,7 @@ PT_REGISTER_CTX_KERNEL(multiply, ...@@ -157,7 +157,7 @@ PT_REGISTER_CTX_KERNEL(multiply,
float16, float16,
complex64, complex64,
complex128) {} complex128) {}
PT_REGISTER_CTX_KERNEL(sum, PT_REGISTER_KERNEL(sum,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::SumKernel, pten::SumKernel,
......
...@@ -19,7 +19,7 @@ limitations under the License. */ ...@@ -19,7 +19,7 @@ limitations under the License. */
#include "paddle/pten/kernels/impl/matmul_grad_kernel_impl.h" #include "paddle/pten/kernels/impl/matmul_grad_kernel_impl.h"
PT_REGISTER_CTX_KERNEL(matmul_grad, PT_REGISTER_KERNEL(matmul_grad,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::MatmulGradKernel, pten::MatmulGradKernel,
...@@ -29,7 +29,7 @@ PT_REGISTER_CTX_KERNEL(matmul_grad, ...@@ -29,7 +29,7 @@ PT_REGISTER_CTX_KERNEL(matmul_grad,
paddle::platform::complex<float>, paddle::platform::complex<float>,
paddle::platform::complex<double>) {} paddle::platform::complex<double>) {}
PT_REGISTER_CTX_KERNEL(matmul_double_grad, PT_REGISTER_KERNEL(matmul_double_grad,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::MatmulDoubleGradKernel, pten::MatmulDoubleGradKernel,
...@@ -39,7 +39,7 @@ PT_REGISTER_CTX_KERNEL(matmul_double_grad, ...@@ -39,7 +39,7 @@ PT_REGISTER_CTX_KERNEL(matmul_double_grad,
paddle::platform::complex<float>, paddle::platform::complex<float>,
paddle::platform::complex<double>) {} paddle::platform::complex<double>) {}
PT_REGISTER_CTX_KERNEL(matmul_triple_grad, PT_REGISTER_KERNEL(matmul_triple_grad,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::MatmulTripleGradKernel, pten::MatmulTripleGradKernel,
......
...@@ -20,7 +20,7 @@ limitations under the License. */ ...@@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/complex.h"
#include "paddle/pten/kernels/impl/matmul_kernel_impl.h" #include "paddle/pten/kernels/impl/matmul_kernel_impl.h"
PT_REGISTER_CTX_KERNEL(matmul, PT_REGISTER_KERNEL(matmul,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::MatmulKernel, pten::MatmulKernel,
......
...@@ -64,7 +64,7 @@ void ScaleKernel(const ContextT& dev_ctx, ...@@ -64,7 +64,7 @@ void ScaleKernel(const ContextT& dev_ctx,
} // namespace pten } // namespace pten
PT_REGISTER_CTX_KERNEL(scale, PT_REGISTER_KERNEL(scale,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::ScaleKernel, pten::ScaleKernel,
......
...@@ -23,5 +23,5 @@ limitations under the License. */ ...@@ -23,5 +23,5 @@ limitations under the License. */
using float16 = paddle::platform::float16; using float16 = paddle::platform::float16;
PT_REGISTER_CTX_KERNEL( PT_REGISTER_KERNEL(
sign, GPU, ALL_LAYOUT, pten::SignKernel, float, double, float16) {} sign, GPU, ALL_LAYOUT, pten::SignKernel, float, double, float16) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册