未验证 提交 520f48d6 编写于 作者: Z zhangyikun02 提交者: GitHub

support grid_sampler_grad op for XPU (#49857)

上级 f71796b6
...@@ -313,6 +313,7 @@ XPUOpMap& get_kl2_ops() { ...@@ -313,6 +313,7 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({phi::DataType::INT64, XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32, phi::DataType::INT32,
phi::DataType::FLOAT32})}, phi::DataType::FLOAT32})},
{"grid_sampler_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"grid_sampler", XPUKernelSet({phi::DataType::FLOAT32})}, {"grid_sampler", XPUKernelSet({phi::DataType::FLOAT32})},
{"hard_sigmoid_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"hard_sigmoid_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"hard_sigmoid", XPUKernelSet({phi::DataType::FLOAT32})}, {"hard_sigmoid", XPUKernelSet({phi::DataType::FLOAT32})},
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/grid_sample_grad_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void GridSampleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& grid,
const DenseTensor& out_grid,
const std::string& mode,
const std::string& padding_mode,
bool align_corners,
DenseTensor* x_grad,
DenseTensor* grid_grad) {
PADDLE_ENFORCE_EQ(
x.dims().size(),
4,
phi::errors::InvalidArgument(
("XPU is only support input_dims == 4 in grid_sample_grad op.")));
const int64_t n = grid.dims()[0];
const int64_t out_h = grid.dims()[1];
const int64_t out_w = grid.dims()[2];
const int64_t c = x.dims()[1];
const int64_t in_h = x.dims()[2];
const int64_t in_w = x.dims()[3];
x_grad->Resize({n, c, in_h, in_w});
T* x_grad_ptr = dev_ctx.template Alloc<T>(x_grad);
T* grid_grad_ptr = nullptr;
if (grid_grad != nullptr) {
grid_grad->Resize({n, out_h, out_w, 2});
grid_grad_ptr = dev_ctx.template Alloc<T>(grid_grad);
}
bool is_nearest = false;
if (mode == "nearest") {
is_nearest = true;
}
int64_t padding_mode_type = 0;
if (padding_mode == "border") {
padding_mode_type = 1;
} else if (padding_mode == "reflection") {
padding_mode_type = 2;
}
int r = xpu::grid_sample_grad<T>(dev_ctx.x_context(),
x.data<T>(),
grid.data<T>(),
out_grid.data<T>(),
x_grad_ptr,
grid_grad_ptr,
n,
c,
in_h,
in_w,
out_h,
out_w,
is_nearest,
align_corners,
padding_mode_type,
true);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "grid_sample_grad");
}
} // namespace phi
PD_REGISTER_KERNEL(
grid_sample_grad, XPU, ALL_LAYOUT, phi::GridSampleGradKernel, float) {}
...@@ -35,6 +35,7 @@ class XPUOpTest(OpTest): ...@@ -35,6 +35,7 @@ class XPUOpTest(OpTest):
'''Fix random seeds to remove randomness from tests''' '''Fix random seeds to remove randomness from tests'''
cls.use_xpu = True cls.use_xpu = True
cls.use_mkldnn = False cls.use_mkldnn = False
cls.epsilon_xpu2xpu = 0.00000001
super().setUpClass() super().setUpClass()
@classmethod @classmethod
...@@ -212,7 +213,11 @@ class XPUOpTest(OpTest): ...@@ -212,7 +213,11 @@ class XPUOpTest(OpTest):
user_defined_grad_outputs=user_defined_grad_outputs, user_defined_grad_outputs=user_defined_grad_outputs,
) )
self._assert_is_close( self._assert_is_close(
a1, a2, inputs_to_check, 0.00000001, "Gradient Check On two xpu" a1,
a2,
inputs_to_check,
self.epsilon_xpu2xpu,
"Gradient Check On two xpu",
) )
self._assert_is_close( self._assert_is_close(
a1, a1,
......
...@@ -170,6 +170,7 @@ class XPUTestGridSamplerOP(XPUOpTestWrapper): ...@@ -170,6 +170,7 @@ class XPUTestGridSamplerOP(XPUOpTestWrapper):
self.place = paddle.XPUPlace(0) self.place = paddle.XPUPlace(0)
self.init_dtype() self.init_dtype()
self.op_type = 'grid_sampler' self.op_type = 'grid_sampler'
self.epsilon_xpu2xpu = 0.000001
self.use_cudnn = False self.use_cudnn = False
self.align_corners = True self.align_corners = True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册