未验证 提交 3996f0de 编写于 作者: C csy0225 提交者: GitHub

[XPU] interpolate support fp16 (#52358)

上级 d83d89ed
...@@ -75,7 +75,8 @@ XPUOpMap& get_kl2_ops() { ...@@ -75,7 +75,8 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::FLOAT16, phi::DataType::FLOAT16,
phi::DataType::INT32, phi::DataType::INT32,
phi::DataType::INT64})}, phi::DataType::INT64})},
{"bilinear_interp_v2", XPUKernelSet({phi::DataType::FLOAT32})}, {"bilinear_interp_v2",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"bilinear_interp_v2_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"bilinear_interp_v2_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"bitwise_not", XPUKernelSet({phi::DataType::BOOL})}, {"bitwise_not", XPUKernelSet({phi::DataType::BOOL})},
{"broadcast", XPUKernelSet({phi::DataType::FLOAT32})}, {"broadcast", XPUKernelSet({phi::DataType::FLOAT32})},
...@@ -496,7 +497,8 @@ XPUOpMap& get_kl2_ops() { ...@@ -496,7 +497,8 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT64})}, phi::DataType::INT64})},
{"multi_encoder_xpu", {"multi_encoder_xpu",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"nearest_interp_v2", XPUKernelSet({phi::DataType::FLOAT32})}, {"nearest_interp_v2",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"nearest_interp_v2_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"nearest_interp_v2_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"not_equal", {"not_equal",
XPUKernelSet({phi::DataType::INT64, XPUKernelSet({phi::DataType::INT64,
......
...@@ -38,6 +38,7 @@ void InterpolateKernel( ...@@ -38,6 +38,7 @@ void InterpolateKernel(
bool align_corners, bool align_corners,
int align_mode, int align_mode,
DenseTensor* output) { DenseTensor* output) {
using XPUType = typename XPUTypeTrait<T>::Type;
const DataLayout data_layout = phi::StringToDataLayout(data_layout_str); const DataLayout data_layout = phi::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w; int n, c, in_d, in_h, in_w;
phi::funcs::ExtractNCDWH(x.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w); phi::funcs::ExtractNCDWH(x.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
...@@ -140,18 +141,19 @@ void InterpolateKernel( ...@@ -140,18 +141,19 @@ void InterpolateKernel(
errors::InvalidArgument("XPU nearest is only support NCHW")); errors::InvalidArgument("XPU nearest is only support NCHW"));
} }
int r = xpu::interpolate2d<T>(ctx.x_context(), int r =
x.data<T>(), xpu::interpolate2d<XPUType>(ctx.x_context(),
output->data<T>(), reinterpret_cast<const XPUType*>(x.data<T>()),
n, reinterpret_cast<XPUType*>(output->data<T>()),
c, n,
in_h, c,
in_w, in_h,
out_h, in_w,
out_w, out_h,
nearest, out_w,
trans_mode, nearest,
(data_layout == DataLayout::kNCHW)); trans_mode,
(data_layout == DataLayout::kNCHW));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "interpolate2d"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "interpolate2d");
} }
...@@ -221,14 +223,22 @@ void NearestInterpKernel( ...@@ -221,14 +223,22 @@ void NearestInterpKernel(
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(bilinear_interp,
bilinear_interp, XPU, ALL_LAYOUT, phi::BilinearInterpKernel, float) { XPU,
ALL_LAYOUT,
phi::BilinearInterpKernel,
phi::dtype::float16,
float) {
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
} }
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(nearest_interp,
nearest_interp, XPU, ALL_LAYOUT, phi::NearestInterpKernel, float) { XPU,
ALL_LAYOUT,
phi::NearestInterpKernel,
phi::dtype::float16,
float) {
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册