未验证 提交 3996f0de 编写于 作者: C csy0225 提交者: GitHub

[XPU] interpolate support fp16 (#52358)

上级 d83d89ed
......@@ -75,7 +75,8 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::FLOAT16,
phi::DataType::INT32,
phi::DataType::INT64})},
{"bilinear_interp_v2", XPUKernelSet({phi::DataType::FLOAT32})},
{"bilinear_interp_v2",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"bilinear_interp_v2_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"bitwise_not", XPUKernelSet({phi::DataType::BOOL})},
{"broadcast", XPUKernelSet({phi::DataType::FLOAT32})},
......@@ -496,7 +497,8 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT64})},
{"multi_encoder_xpu",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"nearest_interp_v2", XPUKernelSet({phi::DataType::FLOAT32})},
{"nearest_interp_v2",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"nearest_interp_v2_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"not_equal",
XPUKernelSet({phi::DataType::INT64,
......
......@@ -38,6 +38,7 @@ void InterpolateKernel(
bool align_corners,
int align_mode,
DenseTensor* output) {
using XPUType = typename XPUTypeTrait<T>::Type;
const DataLayout data_layout = phi::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w;
phi::funcs::ExtractNCDWH(x.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
......@@ -140,18 +141,19 @@ void InterpolateKernel(
errors::InvalidArgument("XPU nearest is only support NCHW"));
}
int r = xpu::interpolate2d<T>(ctx.x_context(),
x.data<T>(),
output->data<T>(),
n,
c,
in_h,
in_w,
out_h,
out_w,
nearest,
trans_mode,
(data_layout == DataLayout::kNCHW));
int r =
xpu::interpolate2d<XPUType>(ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<XPUType*>(output->data<T>()),
n,
c,
in_h,
in_w,
out_h,
out_w,
nearest,
trans_mode,
(data_layout == DataLayout::kNCHW));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "interpolate2d");
}
......@@ -221,14 +223,22 @@ void NearestInterpKernel(
} // namespace phi
PD_REGISTER_KERNEL(
bilinear_interp, XPU, ALL_LAYOUT, phi::BilinearInterpKernel, float) {
PD_REGISTER_KERNEL(bilinear_interp,
XPU,
ALL_LAYOUT,
phi::BilinearInterpKernel,
phi::dtype::float16,
float) {
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(
nearest_interp, XPU, ALL_LAYOUT, phi::NearestInterpKernel, float) {
PD_REGISTER_KERNEL(nearest_interp,
XPU,
ALL_LAYOUT,
phi::NearestInterpKernel,
phi::dtype::float16,
float) {
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册