未验证 提交 fd9c555c 编写于 作者: W wz1qqx 提交者: GitHub

[XPU]add fp16 kernels (#54410)

上级 168fac13
......@@ -140,6 +140,7 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"clip",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::INT64,
phi::DataType::INT32})},
{"clip_by_norm", XPUKernelSet({phi::DataType::FLOAT32})},
......@@ -188,7 +189,8 @@ XPUOpMap& get_kl2_ops() {
{"deformable_conv_v1_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"deformable_conv_v1", XPUKernelSet({phi::DataType::FLOAT32})},
{"depthwise_conv2d_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"depthwise_conv2d", XPUKernelSet({phi::DataType::FLOAT32})},
{"depthwise_conv2d",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"depthwise_conv2d_transpose_grad",
XPUKernelSet({phi::DataType::FLOAT32})},
{"depthwise_conv2d_transpose",
......@@ -599,7 +601,7 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT32,
phi::DataType::INT8,
phi::DataType::FLOAT32})},
{"relu6", XPUKernelSet({phi::DataType::FLOAT32})},
{"relu6", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"relu6_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"relu_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
......
......@@ -62,7 +62,8 @@ PD_REGISTER_KERNEL(swish,
#endif
#if defined PADDLE_WITH_XPU
PD_REGISTER_KERNEL(relu6, XPU, ALL_LAYOUT, phi::Relu6Kernel, float) {}
PD_REGISTER_KERNEL(
relu6, XPU, ALL_LAYOUT, phi::Relu6Kernel, float, phi::dtype::float16) {}
PD_REGISTER_KERNEL(
swish, XPU, ALL_LAYOUT, phi::SwishKernel, float, phi::dtype::float16) {}
#endif
......
......@@ -572,6 +572,13 @@ PD_REGISTER_KERNEL(
PD_REGISTER_KERNEL(
log, XPU, ALL_LAYOUT, phi::LogKernel, float, phi::dtype::float16) {}
PD_REGISTER_KERNEL(relu6_raw,
XPU,
ALL_LAYOUT,
phi::Relu6RawKernel,
float,
phi::dtype::float16) {}
#define PD_REGISTER_ACTIVATION_KERNEL(name, func) \
PD_REGISTER_KERNEL(name, XPU, ALL_LAYOUT, phi::func, float) {}
......@@ -581,7 +588,6 @@ PD_REGISTER_ACTIVATION_KERNEL(hardswish, HardSwishKernel)
PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel)
PD_REGISTER_ACTIVATION_KERNEL(pow, PowKernel)
PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel)
PD_REGISTER_ACTIVATION_KERNEL(relu6_raw, Relu6RawKernel)
PD_REGISTER_ACTIVATION_KERNEL(softplus, SoftplusKernel)
PD_REGISTER_ACTIVATION_KERNEL(sin, SinKernel)
PD_REGISTER_ACTIVATION_KERNEL(cos, CosKernel)
......@@ -13,6 +13,9 @@
// limitations under the License.
#include "paddle/phi/kernels/clip_kernel.h"
#include "glog/logging.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/backends/xpu/xpu_header.h"
#include "paddle/phi/core/kernel_registry.h"
......@@ -33,8 +36,8 @@ void ClipKernel(const Context& dev_ctx,
x_data,
out_data,
x.numel(),
min.to<XPUDataType>(),
max.to<XPUDataType>());
static_cast<XPUDataType>(min.to<T>()),
static_cast<XPUDataType>(max.to<T>()));
PADDLE_ENFORCE_EQ(r,
XPU_SUCCESS,
......@@ -46,5 +49,11 @@ void ClipKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(
clip, XPU, ALL_LAYOUT, phi::ClipKernel, float, int64_t, int) {}
PD_REGISTER_KERNEL(clip,
XPU,
ALL_LAYOUT,
phi::ClipKernel,
float,
phi::dtype::float16,
int64_t,
int) {}
......@@ -310,7 +310,11 @@ void Conv3DKernel(const Context& dev_ctx,
PD_REGISTER_KERNEL(
conv2d, XPU, ALL_LAYOUT, phi::ConvKernel, float, phi::dtype::float16) {}
PD_REGISTER_KERNEL(
depthwise_conv2d, XPU, ALL_LAYOUT, phi::DepthwiseConvKernel, float) {}
PD_REGISTER_KERNEL(depthwise_conv2d,
XPU,
ALL_LAYOUT,
phi::DepthwiseConvKernel,
float,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(
conv3d, XPU, ALL_LAYOUT, phi::Conv3DKernel, float, phi::dtype::float16) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册