未验证 提交 fbe2c311 编写于 作者: L Lijunhui 提交者: GitHub

[KP] Add registry for elementwise_add/max/min/sub/div/mul/floordiv on XPU2 with KP lib (#41494)

* regist elementwise_xxx
上级 4733fe60
...@@ -69,7 +69,11 @@ PD_DECLARE_KERNEL(split, GPU, ALL_LAYOUT); ...@@ -69,7 +69,11 @@ PD_DECLARE_KERNEL(split, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(concat, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(concat, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(concat_grad, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(concat_grad, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(matmul, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(matmul, GPU, ALL_LAYOUT);
#ifdef PADDLE_WITH_XPU_KP
PD_DECLARE_KERNEL(add_raw, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(add_raw, GPU, ALL_LAYOUT);
#else
PD_DECLARE_KERNEL(add_raw, KPS, ALL_LAYOUT);
#endif
PD_DECLARE_KERNEL(add, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(add, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(mean, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(mean, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(sigmoid, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(sigmoid, GPU, ALL_LAYOUT);
......
...@@ -30,6 +30,18 @@ XPUOpMap& get_kp_ops() { ...@@ -30,6 +30,18 @@ XPUOpMap& get_kp_ops() {
static XPUOpMap s_xpu_kp_kernels{ static XPUOpMap s_xpu_kp_kernels{
{"elementwise_add", {"elementwise_add",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"elementwise_div",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"elementwise_sub",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"elementwise_max",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"elementwise_min",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"elementwise_mul",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"elementwise_floordiv",
XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})},
// activation op // activation op
{"exp", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"exp", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"hard_swish", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"hard_swish", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
......
...@@ -542,7 +542,9 @@ struct InverseModuloFunctor< ...@@ -542,7 +542,9 @@ struct InverseModuloFunctor<
template <typename T> template <typename T>
struct FloorDivideFunctor { struct FloorDivideFunctor {
inline HOSTDEVICE T operator()(const T a, const T b) const { inline HOSTDEVICE T operator()(const T a, const T b) const {
#ifndef PADDLE_WITH_XPU_KP
PADDLE_ENFORCE(b != 0, DIV_ERROR_INFO); PADDLE_ENFORCE(b != 0, DIV_ERROR_INFO);
#endif
return static_cast<T>(std::trunc(a / b)); return static_cast<T>(std::trunc(a / b));
} }
}; };
...@@ -550,7 +552,9 @@ struct FloorDivideFunctor { ...@@ -550,7 +552,9 @@ struct FloorDivideFunctor {
template <typename T> template <typename T>
struct InverseFloorDivideFunctor { struct InverseFloorDivideFunctor {
inline HOSTDEVICE T operator()(const T a, const T b) const { inline HOSTDEVICE T operator()(const T a, const T b) const {
#ifndef PADDLE_WITH_XPU_KP
PADDLE_ENFORCE(a != 0, DIV_ERROR_INFO); PADDLE_ENFORCE(a != 0, DIV_ERROR_INFO);
#endif
return static_cast<T>(std::trunc(b / a)); return static_cast<T>(std::trunc(b / a));
} }
}; };
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include "paddle/phi/kernels/elementwise_kernel.h" #include "paddle/phi/kernels/elementwise_kernel.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h"
#if defined(__NVCC__) || defined(__HIPCC__) #if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__)
#include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/broadcast_function.h"
#endif #endif
......
...@@ -13,8 +13,10 @@ ...@@ -13,8 +13,10 @@
// limitations under the License. // limitations under the License.
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#ifndef PADDLE_WITH_XPU_KP
#include "paddle/phi/common/complex.h" #include "paddle/phi/common/complex.h"
#include "paddle/phi/common/float16.h" #include "paddle/phi/common/float16.h"
#endif
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/elementwise_kernel_impl.h" #include "paddle/phi/kernels/impl/elementwise_kernel_impl.h"
...@@ -40,7 +42,6 @@ namespace phi { ...@@ -40,7 +42,6 @@ namespace phi {
/** /**
* Kernels * Kernels
*/ */
// Create the definition of Add // Create the definition of Add
DEFINE_CUDA_ELEMENTWISE_OP(Add) DEFINE_CUDA_ELEMENTWISE_OP(Add)
// Create the definition of Subtract // Create the definition of Subtract
...@@ -62,19 +63,34 @@ DEFINE_CUDA_ELEMENTWISE_OP(ElementwisePow) ...@@ -62,19 +63,34 @@ DEFINE_CUDA_ELEMENTWISE_OP(ElementwisePow)
} // namespace phi } // namespace phi
#ifdef PADDLE_WITH_XPU_KP
PD_REGISTER_KERNEL(add_raw, KPS, ALL_LAYOUT, phi::AddRawKernel, float) {}
PD_REGISTER_KERNEL(
subtract_raw, KPS, ALL_LAYOUT, phi::SubtractRawKernel, float) {}
PD_REGISTER_KERNEL(divide_raw, KPS, ALL_LAYOUT, phi::DivideRawKernel, float) {}
PD_REGISTER_KERNEL(
multiply_raw, KPS, ALL_LAYOUT, phi::MultiplyRawKernel, float) {}
PD_REGISTER_KERNEL(maximum_raw, KPS, ALL_LAYOUT, phi::MaximumRawKernel, float) {
}
PD_REGISTER_KERNEL(minimum_raw, KPS, ALL_LAYOUT, phi::MinimumRawKernel, float) {
}
PD_REGISTER_KERNEL(
floor_divide_raw, KPS, ALL_LAYOUT, phi::FloorDivideRawKernel, int) {}
#else
using float16 = phi::dtype::float16; using float16 = phi::dtype::float16;
using bfloat16 = phi::dtype::bfloat16; using bfloat16 = phi::dtype::bfloat16;
using complex64 = ::phi::dtype::complex<float>; using complex64 = ::phi::dtype::complex<float>;
using complex128 = ::phi::dtype::complex<double>; using complex128 = ::phi::dtype::complex<double>;
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
fmax, GPU, ALL_LAYOUT, phi::FMaxKernel, float, double, int, int64_t) {} fmax, KPS, ALL_LAYOUT, phi::FMaxKernel, float, double, int, int64_t) {}
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
fmin, GPU, ALL_LAYOUT, phi::FMinKernel, float, double, int, int64_t) {} fmin, KPS, ALL_LAYOUT, phi::FMinKernel, float, double, int, int64_t) {}
PD_REGISTER_KERNEL(add_raw, PD_REGISTER_KERNEL(add_raw,
GPU, KPS,
ALL_LAYOUT, ALL_LAYOUT,
phi::AddRawKernel, phi::AddRawKernel,
float, float,
...@@ -87,7 +103,7 @@ PD_REGISTER_KERNEL(add_raw, ...@@ -87,7 +103,7 @@ PD_REGISTER_KERNEL(add_raw,
complex64, complex64,
complex128) {} complex128) {}
PD_REGISTER_KERNEL(subtract_raw, PD_REGISTER_KERNEL(subtract_raw,
GPU, KPS,
ALL_LAYOUT, ALL_LAYOUT,
phi::SubtractRawKernel, phi::SubtractRawKernel,
float, float,
...@@ -100,7 +116,7 @@ PD_REGISTER_KERNEL(subtract_raw, ...@@ -100,7 +116,7 @@ PD_REGISTER_KERNEL(subtract_raw,
complex64, complex64,
complex128) {} complex128) {}
PD_REGISTER_KERNEL(divide_raw, PD_REGISTER_KERNEL(divide_raw,
GPU, KPS,
ALL_LAYOUT, ALL_LAYOUT,
phi::DivideRawKernel, phi::DivideRawKernel,
float, float,
...@@ -112,7 +128,7 @@ PD_REGISTER_KERNEL(divide_raw, ...@@ -112,7 +128,7 @@ PD_REGISTER_KERNEL(divide_raw,
complex64, complex64,
complex128) {} complex128) {}
PD_REGISTER_KERNEL(multiply_raw, PD_REGISTER_KERNEL(multiply_raw,
GPU, KPS,
ALL_LAYOUT, ALL_LAYOUT,
phi::MultiplyRawKernel, phi::MultiplyRawKernel,
float, float,
...@@ -125,7 +141,7 @@ PD_REGISTER_KERNEL(multiply_raw, ...@@ -125,7 +141,7 @@ PD_REGISTER_KERNEL(multiply_raw,
complex128, complex128,
bfloat16) {} bfloat16) {}
PD_REGISTER_KERNEL(maximum_raw, PD_REGISTER_KERNEL(maximum_raw,
GPU, KPS,
ALL_LAYOUT, ALL_LAYOUT,
phi::MaximumRawKernel, phi::MaximumRawKernel,
float, float,
...@@ -135,7 +151,7 @@ PD_REGISTER_KERNEL(maximum_raw, ...@@ -135,7 +151,7 @@ PD_REGISTER_KERNEL(maximum_raw,
float16, float16,
bfloat16) {} bfloat16) {}
PD_REGISTER_KERNEL(minimum_raw, PD_REGISTER_KERNEL(minimum_raw,
GPU, KPS,
ALL_LAYOUT, ALL_LAYOUT,
phi::MinimumRawKernel, phi::MinimumRawKernel,
float, float,
...@@ -145,7 +161,7 @@ PD_REGISTER_KERNEL(minimum_raw, ...@@ -145,7 +161,7 @@ PD_REGISTER_KERNEL(minimum_raw,
float16, float16,
bfloat16) {} bfloat16) {}
PD_REGISTER_KERNEL(modulo_raw, PD_REGISTER_KERNEL(modulo_raw,
GPU, KPS,
ALL_LAYOUT, ALL_LAYOUT,
phi::ModuloRawKernel, phi::ModuloRawKernel,
float, float,
...@@ -153,16 +169,17 @@ PD_REGISTER_KERNEL(modulo_raw, ...@@ -153,16 +169,17 @@ PD_REGISTER_KERNEL(modulo_raw,
int, int,
int64_t) {} int64_t) {}
PD_REGISTER_KERNEL(floor_divide_raw, PD_REGISTER_KERNEL(floor_divide_raw,
GPU, KPS,
ALL_LAYOUT, ALL_LAYOUT,
phi::FloorDivideRawKernel, phi::FloorDivideRawKernel,
int, int,
int64_t) {} int64_t) {}
PD_REGISTER_KERNEL(elementwise_pow_raw, PD_REGISTER_KERNEL(elementwise_pow_raw,
GPU, KPS,
ALL_LAYOUT, ALL_LAYOUT,
phi::ElementwisePowRawKernel, phi::ElementwisePowRawKernel,
float, float,
double, double,
int, int,
int64_t) {} int64_t) {}
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册