未验证 提交 fd9c555c 编写于 作者: W wz1qqx 提交者: GitHub

[XPU]add fp16 kernels (#54410)

上级 168fac13
...@@ -140,6 +140,7 @@ XPUOpMap& get_kl2_ops() { ...@@ -140,6 +140,7 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"clip", {"clip",
XPUKernelSet({phi::DataType::FLOAT32, XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::INT64, phi::DataType::INT64,
phi::DataType::INT32})}, phi::DataType::INT32})},
{"clip_by_norm", XPUKernelSet({phi::DataType::FLOAT32})}, {"clip_by_norm", XPUKernelSet({phi::DataType::FLOAT32})},
...@@ -188,7 +189,8 @@ XPUOpMap& get_kl2_ops() { ...@@ -188,7 +189,8 @@ XPUOpMap& get_kl2_ops() {
{"deformable_conv_v1_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"deformable_conv_v1_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"deformable_conv_v1", XPUKernelSet({phi::DataType::FLOAT32})}, {"deformable_conv_v1", XPUKernelSet({phi::DataType::FLOAT32})},
{"depthwise_conv2d_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"depthwise_conv2d_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"depthwise_conv2d", XPUKernelSet({phi::DataType::FLOAT32})}, {"depthwise_conv2d",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"depthwise_conv2d_transpose_grad", {"depthwise_conv2d_transpose_grad",
XPUKernelSet({phi::DataType::FLOAT32})}, XPUKernelSet({phi::DataType::FLOAT32})},
{"depthwise_conv2d_transpose", {"depthwise_conv2d_transpose",
...@@ -599,7 +601,7 @@ XPUOpMap& get_kl2_ops() { ...@@ -599,7 +601,7 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT32, phi::DataType::INT32,
phi::DataType::INT8, phi::DataType::INT8,
phi::DataType::FLOAT32})}, phi::DataType::FLOAT32})},
{"relu6", XPUKernelSet({phi::DataType::FLOAT32})}, {"relu6", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"relu6_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"relu6_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"relu_grad", {"relu_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
......
...@@ -62,7 +62,8 @@ PD_REGISTER_KERNEL(swish, ...@@ -62,7 +62,8 @@ PD_REGISTER_KERNEL(swish,
#endif #endif
#if defined PADDLE_WITH_XPU #if defined PADDLE_WITH_XPU
PD_REGISTER_KERNEL(relu6, XPU, ALL_LAYOUT, phi::Relu6Kernel, float) {} PD_REGISTER_KERNEL(
relu6, XPU, ALL_LAYOUT, phi::Relu6Kernel, float, phi::dtype::float16) {}
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
swish, XPU, ALL_LAYOUT, phi::SwishKernel, float, phi::dtype::float16) {} swish, XPU, ALL_LAYOUT, phi::SwishKernel, float, phi::dtype::float16) {}
#endif #endif
......
...@@ -572,6 +572,13 @@ PD_REGISTER_KERNEL( ...@@ -572,6 +572,13 @@ PD_REGISTER_KERNEL(
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
log, XPU, ALL_LAYOUT, phi::LogKernel, float, phi::dtype::float16) {} log, XPU, ALL_LAYOUT, phi::LogKernel, float, phi::dtype::float16) {}
PD_REGISTER_KERNEL(relu6_raw,
XPU,
ALL_LAYOUT,
phi::Relu6RawKernel,
float,
phi::dtype::float16) {}
#define PD_REGISTER_ACTIVATION_KERNEL(name, func) \ #define PD_REGISTER_ACTIVATION_KERNEL(name, func) \
PD_REGISTER_KERNEL(name, XPU, ALL_LAYOUT, phi::func, float) {} PD_REGISTER_KERNEL(name, XPU, ALL_LAYOUT, phi::func, float) {}
...@@ -581,7 +588,6 @@ PD_REGISTER_ACTIVATION_KERNEL(hardswish, HardSwishKernel) ...@@ -581,7 +588,6 @@ PD_REGISTER_ACTIVATION_KERNEL(hardswish, HardSwishKernel)
PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel) PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel)
PD_REGISTER_ACTIVATION_KERNEL(pow, PowKernel) PD_REGISTER_ACTIVATION_KERNEL(pow, PowKernel)
PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel) PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel)
PD_REGISTER_ACTIVATION_KERNEL(relu6_raw, Relu6RawKernel)
PD_REGISTER_ACTIVATION_KERNEL(softplus, SoftplusKernel) PD_REGISTER_ACTIVATION_KERNEL(softplus, SoftplusKernel)
PD_REGISTER_ACTIVATION_KERNEL(sin, SinKernel) PD_REGISTER_ACTIVATION_KERNEL(sin, SinKernel)
PD_REGISTER_ACTIVATION_KERNEL(cos, CosKernel) PD_REGISTER_ACTIVATION_KERNEL(cos, CosKernel)
...@@ -13,6 +13,9 @@ ...@@ -13,6 +13,9 @@
// limitations under the License. // limitations under the License.
#include "paddle/phi/kernels/clip_kernel.h" #include "paddle/phi/kernels/clip_kernel.h"
#include "glog/logging.h"
#include "paddle/phi/backends/xpu/xpu_context.h" #include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/backends/xpu/xpu_header.h" #include "paddle/phi/backends/xpu/xpu_header.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
...@@ -33,8 +36,8 @@ void ClipKernel(const Context& dev_ctx, ...@@ -33,8 +36,8 @@ void ClipKernel(const Context& dev_ctx,
x_data, x_data,
out_data, out_data,
x.numel(), x.numel(),
min.to<XPUDataType>(), static_cast<XPUDataType>(min.to<T>()),
max.to<XPUDataType>()); static_cast<XPUDataType>(max.to<T>()));
PADDLE_ENFORCE_EQ(r, PADDLE_ENFORCE_EQ(r,
XPU_SUCCESS, XPU_SUCCESS,
...@@ -46,5 +49,11 @@ void ClipKernel(const Context& dev_ctx, ...@@ -46,5 +49,11 @@ void ClipKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(clip,
clip, XPU, ALL_LAYOUT, phi::ClipKernel, float, int64_t, int) {} XPU,
ALL_LAYOUT,
phi::ClipKernel,
float,
phi::dtype::float16,
int64_t,
int) {}
...@@ -310,7 +310,11 @@ void Conv3DKernel(const Context& dev_ctx, ...@@ -310,7 +310,11 @@ void Conv3DKernel(const Context& dev_ctx,
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
conv2d, XPU, ALL_LAYOUT, phi::ConvKernel, float, phi::dtype::float16) {} conv2d, XPU, ALL_LAYOUT, phi::ConvKernel, float, phi::dtype::float16) {}
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(depthwise_conv2d,
depthwise_conv2d, XPU, ALL_LAYOUT, phi::DepthwiseConvKernel, float) {} XPU,
ALL_LAYOUT,
phi::DepthwiseConvKernel,
float,
phi::dtype::float16) {}
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
conv3d, XPU, ALL_LAYOUT, phi::Conv3DKernel, float, phi::dtype::float16) {} conv3d, XPU, ALL_LAYOUT, phi::Conv3DKernel, float, phi::dtype::float16) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册