未验证 提交 fbe2c311 编写于 作者: L Lijunhui 提交者: GitHub

[KP] Add registry for elementwise_add/max/min/sub/div/mul/floordiv on XPU2 with KP lib (#41494)

* regist elementwise_xxx
上级 4733fe60
......@@ -69,7 +69,11 @@ PD_DECLARE_KERNEL(split, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(concat, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(concat_grad, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(matmul, GPU, ALL_LAYOUT);
#ifdef PADDLE_WITH_XPU_KP
PD_DECLARE_KERNEL(add_raw, GPU, ALL_LAYOUT);
#else
PD_DECLARE_KERNEL(add_raw, KPS, ALL_LAYOUT);
#endif
PD_DECLARE_KERNEL(add, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(mean, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(sigmoid, GPU, ALL_LAYOUT);
......
......@@ -30,6 +30,18 @@ XPUOpMap& get_kp_ops() {
static XPUOpMap s_xpu_kp_kernels{
{"elementwise_add",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"elementwise_div",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"elementwise_sub",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"elementwise_max",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"elementwise_min",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"elementwise_mul",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"elementwise_floordiv",
XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})},
// activation op
{"exp", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"hard_swish", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
......
......@@ -542,7 +542,9 @@ struct InverseModuloFunctor<
template <typename T>
struct FloorDivideFunctor {
inline HOSTDEVICE T operator()(const T a, const T b) const {
#ifndef PADDLE_WITH_XPU_KP
PADDLE_ENFORCE(b != 0, DIV_ERROR_INFO);
#endif
return static_cast<T>(std::trunc(a / b));
}
};
......@@ -550,7 +552,9 @@ struct FloorDivideFunctor {
template <typename T>
struct InverseFloorDivideFunctor {
inline HOSTDEVICE T operator()(const T a, const T b) const {
#ifndef PADDLE_WITH_XPU_KP
PADDLE_ENFORCE(a != 0, DIV_ERROR_INFO);
#endif
return static_cast<T>(std::trunc(b / a));
}
};
......
......@@ -17,7 +17,7 @@
#include "paddle/phi/kernels/elementwise_kernel.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__)
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#endif
......
......@@ -13,8 +13,10 @@
// limitations under the License.
#include "paddle/phi/backends/gpu/gpu_context.h"
#ifndef PADDLE_WITH_XPU_KP
#include "paddle/phi/common/complex.h"
#include "paddle/phi/common/float16.h"
#endif
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/elementwise_kernel_impl.h"
......@@ -40,7 +42,6 @@ namespace phi {
/**
* Kernels
*/
// Create the definition of Add
DEFINE_CUDA_ELEMENTWISE_OP(Add)
// Create the definition of Subtract
......@@ -62,19 +63,34 @@ DEFINE_CUDA_ELEMENTWISE_OP(ElementwisePow)
} // namespace phi
#ifdef PADDLE_WITH_XPU_KP
PD_REGISTER_KERNEL(add_raw, KPS, ALL_LAYOUT, phi::AddRawKernel, float) {}
PD_REGISTER_KERNEL(
subtract_raw, KPS, ALL_LAYOUT, phi::SubtractRawKernel, float) {}
PD_REGISTER_KERNEL(divide_raw, KPS, ALL_LAYOUT, phi::DivideRawKernel, float) {}
PD_REGISTER_KERNEL(
multiply_raw, KPS, ALL_LAYOUT, phi::MultiplyRawKernel, float) {}
PD_REGISTER_KERNEL(maximum_raw, KPS, ALL_LAYOUT, phi::MaximumRawKernel, float) {
}
PD_REGISTER_KERNEL(minimum_raw, KPS, ALL_LAYOUT, phi::MinimumRawKernel, float) {
}
PD_REGISTER_KERNEL(
floor_divide_raw, KPS, ALL_LAYOUT, phi::FloorDivideRawKernel, int) {}
#else
using float16 = phi::dtype::float16;
using bfloat16 = phi::dtype::bfloat16;
using complex64 = ::phi::dtype::complex<float>;
using complex128 = ::phi::dtype::complex<double>;
PD_REGISTER_KERNEL(
fmax, GPU, ALL_LAYOUT, phi::FMaxKernel, float, double, int, int64_t) {}
fmax, KPS, ALL_LAYOUT, phi::FMaxKernel, float, double, int, int64_t) {}
PD_REGISTER_KERNEL(
fmin, GPU, ALL_LAYOUT, phi::FMinKernel, float, double, int, int64_t) {}
fmin, KPS, ALL_LAYOUT, phi::FMinKernel, float, double, int, int64_t) {}
PD_REGISTER_KERNEL(add_raw,
GPU,
KPS,
ALL_LAYOUT,
phi::AddRawKernel,
float,
......@@ -87,7 +103,7 @@ PD_REGISTER_KERNEL(add_raw,
complex64,
complex128) {}
PD_REGISTER_KERNEL(subtract_raw,
GPU,
KPS,
ALL_LAYOUT,
phi::SubtractRawKernel,
float,
......@@ -100,7 +116,7 @@ PD_REGISTER_KERNEL(subtract_raw,
complex64,
complex128) {}
PD_REGISTER_KERNEL(divide_raw,
GPU,
KPS,
ALL_LAYOUT,
phi::DivideRawKernel,
float,
......@@ -112,7 +128,7 @@ PD_REGISTER_KERNEL(divide_raw,
complex64,
complex128) {}
PD_REGISTER_KERNEL(multiply_raw,
GPU,
KPS,
ALL_LAYOUT,
phi::MultiplyRawKernel,
float,
......@@ -125,7 +141,7 @@ PD_REGISTER_KERNEL(multiply_raw,
complex128,
bfloat16) {}
PD_REGISTER_KERNEL(maximum_raw,
GPU,
KPS,
ALL_LAYOUT,
phi::MaximumRawKernel,
float,
......@@ -135,7 +151,7 @@ PD_REGISTER_KERNEL(maximum_raw,
float16,
bfloat16) {}
PD_REGISTER_KERNEL(minimum_raw,
GPU,
KPS,
ALL_LAYOUT,
phi::MinimumRawKernel,
float,
......@@ -145,7 +161,7 @@ PD_REGISTER_KERNEL(minimum_raw,
float16,
bfloat16) {}
PD_REGISTER_KERNEL(modulo_raw,
GPU,
KPS,
ALL_LAYOUT,
phi::ModuloRawKernel,
float,
......@@ -153,16 +169,17 @@ PD_REGISTER_KERNEL(modulo_raw,
int,
int64_t) {}
PD_REGISTER_KERNEL(floor_divide_raw,
GPU,
KPS,
ALL_LAYOUT,
phi::FloorDivideRawKernel,
int,
int64_t) {}
PD_REGISTER_KERNEL(elementwise_pow_raw,
GPU,
KPS,
ALL_LAYOUT,
phi::ElementwisePowRawKernel,
float,
double,
int,
int64_t) {}
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册