You need to sign in or sign up before continuing.
未验证 提交 663ebd5f 编写于 作者: D duanyanhui 提交者: GitHub

enhance grid_sampler cpu kernel to 5D input (#45578)

* enhance grid_sampler cpu kernel to 5D input

* fix bug when 5D input tensor running on the cudnn kernel
上级 6f2bac7c
...@@ -82,6 +82,67 @@ static inline void ClipWithMask(const CPUContext& ctx, ...@@ -82,6 +82,67 @@ static inline void ClipWithMask(const CPUContext& ctx,
} }
} }
template <typename T>
static inline void ClipWithMask3D(const CPUContext& ctx,
const int max_val, // height-1 or width-1
bool align_corners,
std::string padding_mode,
DenseTensor* grid_slice,
DenseTensor* grid_scale) {
auto& place = *ctx.eigen_device();
grid_scale->Resize(grid_slice->dims());
ctx.Alloc<T>(grid_scale);
auto grid_slice_t = EigenTensor<T, 4>::From(*grid_slice);
auto factor = static_cast<T>(max_val * 0.5);
if (!align_corners) {
factor = static_cast<T>((max_val + 1) * 0.5);
}
auto grid_scale_t = EigenTensor<T, 4>::From(*grid_scale).setConstant(factor);
if (padding_mode == "border") {
// auto bounded_lo = grid_slice_t.cwiseMax(static_cast<T>(0));
auto res = grid_slice_t.cwiseMax(static_cast<T>(0))
.cwiseMin(static_cast<T>(max_val));
auto in_bound = (res == grid_slice_t);
grid_scale_t.device(place) = grid_scale_t * in_bound.template cast<T>();
grid_slice_t.device(place) = res;
} else if (padding_mode == "reflection") {
if (align_corners) {
auto double_range = static_cast<T>(max_val * 2);
auto is_neg = (grid_slice_t < static_cast<T>(0));
auto grid_abs = grid_slice_t.abs();
auto extra = grid_abs - (grid_abs / double_range).floor() * double_range;
auto one_more_flip = (extra > (double_range - extra));
grid_scale_t.device(place) =
grid_scale_t * ((is_neg == one_more_flip).template cast<T>() -
(is_neg != one_more_flip).template cast<T>());
grid_slice_t.device(place) = extra.cwiseMin(double_range - extra);
if (max_val == 0) {
grid_slice_t.device(place) = grid_slice_t.constant(static_cast<T>(0));
}
} else {
auto double_range = static_cast<T>((max_val + 1) * 2);
auto grid_abs = (grid_slice_t + static_cast<T>(0.5)).abs();
auto is_neg = ((grid_slice_t + static_cast<T>(0.5)) < static_cast<T>(0));
auto extra = grid_abs - (grid_abs / double_range).floor() * double_range;
auto one_more_flip = (extra > (double_range - extra));
auto reflected =
extra.cwiseMin(double_range - extra) - static_cast<T>(0.5);
auto clipped = reflected.cwiseMax(static_cast<T>(0))
.cwiseMin(static_cast<T>(max_val));
auto in_bound = (clipped == reflected).template cast<T>();
grid_scale_t.device(place) =
grid_scale_t *
((is_neg == one_more_flip).template cast<T>() -
(is_neg != one_more_flip).template cast<T>()) *
in_bound;
grid_slice_t.device(place) = clipped;
}
}
}
template <typename T> template <typename T>
static void CalcGridLocationsWithGrad(const CPUContext& ctx, static void CalcGridLocationsWithGrad(const CPUContext& ctx,
const DenseTensor& grid, const DenseTensor& grid,
...@@ -118,6 +179,52 @@ static void CalcGridLocationsWithGrad(const CPUContext& ctx, ...@@ -118,6 +179,52 @@ static void CalcGridLocationsWithGrad(const CPUContext& ctx,
ctx, in_h - 1, align_corners, padding_mode, grid_y, grid_y_scale); ctx, in_h - 1, align_corners, padding_mode, grid_y, grid_y_scale);
} }
template <typename T>
static void Calc3DGridLocationsWithGrad(const CPUContext& ctx,
const DenseTensor& grid,
const int in_d,
const int in_h,
const int in_w,
bool align_corners,
std::string padding_mode,
DenseTensor* grid_x,
DenseTensor* grid_y,
DenseTensor* grid_z,
DenseTensor* grid_x_scale,
DenseTensor* grid_y_scale,
DenseTensor* grid_z_scale) {
const int n = grid.dims()[0];
const int out_d = grid.dims()[1];
const int out_h = grid.dims()[2];
const int out_w = grid.dims()[3];
// split grid with shape (n, d, h, w, 3) into (x, y, z) by the 3rd Dim
grid_x->Resize({n, out_d, out_h, out_w});
grid_y->Resize({n, out_d, out_h, out_w});
grid_z->Resize({n, out_d, out_h, out_w});
T* grid_x_data = ctx.Alloc<T>(grid_x);
T* grid_y_data = ctx.Alloc<T>(grid_y);
T* grid_z_data = ctx.Alloc<T>(grid_z);
const T* grid_data = grid.data<T>();
for (int i = 0; i < n * out_d * out_h * out_w; i++) {
grid_x_data[i] = grid_data[3 * i];
grid_y_data[i] = grid_data[(3 * i) + 1];
grid_z_data[i] = grid_data[(3 * i) + 2];
}
Unnormalize3D<T>(ctx, grid_x, in_w - 1, align_corners);
Unnormalize3D<T>(ctx, grid_y, in_h - 1, align_corners);
Unnormalize3D<T>(ctx, grid_z, in_d - 1, align_corners);
ClipWithMask3D<T>(
ctx, in_w - 1, align_corners, padding_mode, grid_x, grid_x_scale);
ClipWithMask3D<T>(
ctx, in_h - 1, align_corners, padding_mode, grid_y, grid_y_scale);
ClipWithMask3D<T>(
ctx, in_d - 1, align_corners, padding_mode, grid_z, grid_z_scale);
}
template <typename T> template <typename T>
static void GatherOutputGradToInputGrad(const DenseTensor& output_grad, static void GatherOutputGradToInputGrad(const DenseTensor& output_grad,
DenseTensor* input_grad, DenseTensor* input_grad,
...@@ -156,6 +263,58 @@ static void GatherOutputGradToInputGrad(const DenseTensor& output_grad, ...@@ -156,6 +263,58 @@ static void GatherOutputGradToInputGrad(const DenseTensor& output_grad,
} }
} }
template <typename T>
static void Gather3DOutputGradToInputGrad(const DenseTensor& output_grad,
DenseTensor* input_grad,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& z,
const DenseTensor& d1,
const DenseTensor& d2,
const DenseTensor& d3) {
const int n = output_grad.dims()[0];
const int c = output_grad.dims()[1];
const int out_d = output_grad.dims()[2];
const int out_h = output_grad.dims()[3];
const int out_w = output_grad.dims()[4];
const int in_d = input_grad->dims()[2];
const int in_h = input_grad->dims()[3];
const int in_w = input_grad->dims()[4];
auto x_t = EigenTensor<T, 4>::From(x);
auto y_t = EigenTensor<T, 4>::From(y);
auto z_t = EigenTensor<T, 4>::From(z);
auto d1_t = EigenTensor<T, 4>::From(d1);
auto d2_t = EigenTensor<T, 4>::From(d2);
auto d3_t = EigenTensor<T, 4>::From(d3);
auto input_grad_t = EigenTensor<T, 5>::From(*input_grad);
auto output_grad_t = EigenTensor<T, 5>::From(output_grad);
for (int i = 0; i < n; i++) {
for (int m = 0; m < out_d; m++) {
for (int k = 0; k < out_h; k++) {
for (int l = 0; l < out_w; l++) {
if (IsInBound3D(x_t(i, m, k, l),
y_t(i, m, k, l),
z_t(i, m, k, l),
(T)(in_w - 1),
(T)(in_h - 1),
(T)(in_d - 1))) {
for (int j = 0; j < c; j++) {
input_grad_t(i,
j,
static_cast<int>(round(z_t(i, m, k, l))),
static_cast<int>(round(y_t(i, m, k, l))),
static_cast<int>(round(x_t(i, m, k, l)))) +=
output_grad_t(i, j, m, k, l) * d1_t(i, m, k, l) *
d2_t(i, m, k, l) * d3_t(i, m, k, l);
}
}
}
}
}
}
}
template <typename T> template <typename T>
static void GatherBilinearGrad(const CPUContext& ctx, static void GatherBilinearGrad(const CPUContext& ctx,
const DenseTensor& input, const DenseTensor& input,
...@@ -256,6 +415,163 @@ static void GatherBilinearGrad(const CPUContext& ctx, ...@@ -256,6 +415,163 @@ static void GatherBilinearGrad(const CPUContext& ctx,
} }
} }
template <typename T>
static void Gather3DBilinearGrad(const CPUContext& ctx,
const DenseTensor& input,
const DenseTensor& output_grad,
DenseTensor* grid_x,
DenseTensor* grid_y,
DenseTensor* grid_z,
DenseTensor* grid_x_scale,
DenseTensor* grid_y_scale,
DenseTensor* grid_z_scale,
DenseTensor* input_grad,
DenseTensor* grid_grad) {
const int n = grid_x->dims()[0];
const int out_d = grid_x->dims()[1];
const int out_h = grid_x->dims()[2];
const int out_w = grid_x->dims()[3];
const int c = input.dims()[1];
DenseTensor x_w, x_e, y_n, y_s, z_t, z_b;
DenseTensor d_w, d_e, d_n, d_s, d_t, d_b;
DenseTensor v_twn, v_ten, v_tws, v_tes, v_bwn, v_ben, v_bws, v_bes;
All3DNeigbors<T>(ctx,
input,
grid_x,
grid_y,
grid_z,
&x_w,
&x_e,
&y_n,
&y_s,
&z_t,
&z_b,
&d_w,
&d_e,
&d_n,
&d_s,
&d_t,
&d_b,
&v_twn,
&v_ten,
&v_tws,
&v_tes,
&v_bwn,
&v_ben,
&v_bws,
&v_bes);
// gather output grad value to input grad by corner point coords and weight
Gather3DOutputGradToInputGrad<T>(
output_grad, input_grad, x_w, y_n, z_t, d_e, d_s, d_b);
Gather3DOutputGradToInputGrad<T>(
output_grad, input_grad, x_w, y_s, z_t, d_e, d_n, d_b);
Gather3DOutputGradToInputGrad<T>(
output_grad, input_grad, x_e, y_n, z_t, d_w, d_s, d_b);
Gather3DOutputGradToInputGrad<T>(
output_grad, input_grad, x_e, y_s, z_t, d_w, d_n, d_b);
Gather3DOutputGradToInputGrad<T>(
output_grad, input_grad, x_w, y_n, z_b, d_e, d_s, d_t);
Gather3DOutputGradToInputGrad<T>(
output_grad, input_grad, x_w, y_s, z_b, d_e, d_n, d_t);
Gather3DOutputGradToInputGrad<T>(
output_grad, input_grad, x_e, y_n, z_b, d_w, d_s, d_t);
Gather3DOutputGradToInputGrad<T>(
output_grad, input_grad, x_e, y_s, z_b, d_w, d_n, d_t);
auto v_twn_t = EigenTensor<T, 5>::From(v_twn);
auto v_ten_t = EigenTensor<T, 5>::From(v_ten);
auto v_tws_t = EigenTensor<T, 5>::From(v_tws);
auto v_tes_t = EigenTensor<T, 5>::From(v_tes);
auto v_bwn_t = EigenTensor<T, 5>::From(v_bwn);
auto v_ben_t = EigenTensor<T, 5>::From(v_ben);
auto v_bws_t = EigenTensor<T, 5>::From(v_bws);
auto v_bes_t = EigenTensor<T, 5>::From(v_bes);
auto d_w_t = EigenTensor<T, 4>::From(d_w);
auto d_e_t = EigenTensor<T, 4>::From(d_e);
auto d_n_t = EigenTensor<T, 4>::From(d_n);
auto d_s_t = EigenTensor<T, 4>::From(d_s);
auto d_t_t = EigenTensor<T, 4>::From(d_t);
auto d_b_t = EigenTensor<T, 4>::From(d_b);
auto output_grad_t = EigenTensor<T, 5>::From(output_grad);
if (grid_grad != nullptr) {
DenseTensor grid_grad_x, grid_grad_y, grid_grad_z;
grid_grad_x.Resize({n, out_d, out_h, out_w});
grid_grad_y.Resize({n, out_d, out_h, out_w});
grid_grad_z.Resize({n, out_d, out_h, out_w});
ctx.Alloc<T>(&grid_grad_x);
ctx.Alloc<T>(&grid_grad_y);
ctx.Alloc<T>(&grid_grad_z);
auto grid_grad_x_t =
EigenTensor<T, 4>::From(grid_grad_x).setConstant(static_cast<T>(0.0));
auto grid_grad_y_t =
EigenTensor<T, 4>::From(grid_grad_y).setConstant(static_cast<T>(0.0));
auto grid_grad_z_t =
EigenTensor<T, 4>::From(grid_grad_z).setConstant(static_cast<T>(0.0));
for (int i = 0; i < n; i++) {
for (int j = 0; j < c; j++) {
for (int m = 0; m < out_d; m++) {
for (int k = 0; k < out_h; k++) {
for (int l = 0; l < out_w; l++) {
grid_grad_x_t(i, m, k, l) +=
((v_ten_t(i, j, m, k, l) - v_twn_t(i, j, m, k, l)) *
d_s_t(i, m, k, l) * d_b_t(i, m, k, l) +
(v_tes_t(i, j, m, k, l) - v_tws_t(i, j, m, k, l)) *
d_n_t(i, m, k, l) * d_b_t(i, m, k, l) +
(v_ben_t(i, j, m, k, l) - v_bwn_t(i, j, m, k, l)) *
d_s_t(i, m, k, l) * d_t_t(i, m, k, l) +
(v_bes_t(i, j, m, k, l) - v_bws_t(i, j, m, k, l)) *
d_n_t(i, m, k, l) * d_t_t(i, m, k, l)) *
output_grad_t(i, j, m, k, l);
grid_grad_y_t(i, m, k, l) +=
((v_tws_t(i, j, m, k, l) - v_twn_t(i, j, m, k, l)) *
d_e_t(i, m, k, l) * d_b_t(i, m, k, l) +
(v_tes_t(i, j, m, k, l) - v_ten_t(i, j, m, k, l)) *
d_w_t(i, m, k, l) * d_b_t(i, m, k, l) +
(v_bws_t(i, j, m, k, l) - v_bwn_t(i, j, m, k, l)) *
d_e_t(i, m, k, l) * d_t_t(i, m, k, l) +
(v_bes_t(i, j, m, k, l) - v_ben_t(i, j, m, k, l)) *
d_w_t(i, m, k, l) * d_t_t(i, m, k, l)) *
output_grad_t(i, j, m, k, l);
grid_grad_z_t(i, m, k, l) +=
((v_bws_t(i, j, m, k, l) - v_tws_t(i, j, m, k, l)) *
d_e_t(i, m, k, l) * d_n_t(i, m, k, l) +
(v_bes_t(i, j, m, k, l) - v_tes_t(i, j, m, k, l)) *
d_w_t(i, m, k, l) * d_n_t(i, m, k, l) +
(v_bwn_t(i, j, m, k, l) - v_twn_t(i, j, m, k, l)) *
d_e_t(i, m, k, l) * d_s_t(i, m, k, l) +
(v_ben_t(i, j, m, k, l) - v_ten_t(i, j, m, k, l)) *
d_w_t(i, m, k, l) * d_s_t(i, m, k, l)) *
output_grad_t(i, j, m, k, l);
}
}
}
}
}
auto grid_x_scale_t = EigenTensor<T, 4>::From(*grid_x_scale);
auto grid_y_scale_t = EigenTensor<T, 4>::From(*grid_y_scale);
auto grid_z_scale_t = EigenTensor<T, 4>::From(*grid_z_scale);
grid_grad_x_t = grid_grad_x_t * grid_x_scale_t;
grid_grad_y_t = grid_grad_y_t * grid_y_scale_t;
grid_grad_z_t = grid_grad_z_t * grid_z_scale_t;
// gather grid_grad [x, y, z] in 4rd Dim
T* grid_grad_data = grid_grad->data<T>();
T* grid_grad_x_data = grid_grad_x.data<T>();
T* grid_grad_y_data = grid_grad_y.data<T>();
T* grid_grad_z_data = grid_grad_z.data<T>();
for (int i = 0; i < n * out_d * out_h * out_w; i++) {
grid_grad_data[3 * i] = grid_grad_x_data[i];
grid_grad_data[3 * i + 1] = grid_grad_y_data[i];
grid_grad_data[3 * i + 2] = grid_grad_z_data[i];
}
}
}
template <typename T> template <typename T>
static void GatherOutputGradToInputGrad(const DenseTensor& output_grad, static void GatherOutputGradToInputGrad(const DenseTensor& output_grad,
DenseTensor* input_grad, DenseTensor* input_grad,
...@@ -289,6 +605,50 @@ static void GatherOutputGradToInputGrad(const DenseTensor& output_grad, ...@@ -289,6 +605,50 @@ static void GatherOutputGradToInputGrad(const DenseTensor& output_grad,
} }
} }
template <typename T>
static void Gather3DOutputGradToInputGrad(const DenseTensor& output_grad,
DenseTensor* input_grad,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& z) {
const int n = output_grad.dims()[0];
const int c = output_grad.dims()[1];
const int out_d = output_grad.dims()[2];
const int out_h = output_grad.dims()[3];
const int out_w = output_grad.dims()[4];
const int in_d = input_grad->dims()[2];
const int in_h = input_grad->dims()[3];
const int in_w = input_grad->dims()[4];
auto x_t = EigenTensor<T, 4>::From(x);
auto y_t = EigenTensor<T, 4>::From(y);
auto z_t = EigenTensor<T, 4>::From(z);
auto input_grad_t = EigenTensor<T, 5>::From(*input_grad);
auto output_grad_t = EigenTensor<T, 5>::From(output_grad);
for (int i = 0; i < n; i++) {
for (int m = 0; m < out_d; m++) {
for (int k = 0; k < out_h; k++) {
for (int l = 0; l < out_w; l++) {
if (IsInBound3D(x_t(i, m, k, l),
y_t(i, m, k, l),
z_t(i, m, k, l),
(T)(in_w - 1),
(T)(in_h - 1),
(T)(in_d - 1))) {
for (int j = 0; j < c; j++) {
input_grad_t(i,
j,
static_cast<int>(round(z_t(i, m, k, l))),
static_cast<int>(round(y_t(i, m, k, l))),
static_cast<int>(round(x_t(i, m, k, l)))) +=
output_grad_t(i, j, m, k, l);
}
}
}
}
}
}
}
template <typename T, typename Context> template <typename T, typename Context>
void GridSampleGradKernel(const Context& dev_ctx, void GridSampleGradKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
...@@ -299,6 +659,7 @@ void GridSampleGradKernel(const Context& dev_ctx, ...@@ -299,6 +659,7 @@ void GridSampleGradKernel(const Context& dev_ctx,
bool align_corners, bool align_corners,
DenseTensor* x_grad, DenseTensor* x_grad,
DenseTensor* grid_grad) { DenseTensor* grid_grad) {
if (x.dims().size() == 4) {
const int n = grid.dims()[0]; const int n = grid.dims()[0];
const int out_h = grid.dims()[1]; const int out_h = grid.dims()[1];
const int out_w = grid.dims()[2]; const int out_w = grid.dims()[2];
...@@ -346,6 +707,59 @@ void GridSampleGradKernel(const Context& dev_ctx, ...@@ -346,6 +707,59 @@ void GridSampleGradKernel(const Context& dev_ctx,
grid_y_t = grid_y_t.round(); grid_y_t = grid_y_t.round();
GatherOutputGradToInputGrad<T>(out_grid, x_grad, grid_x, grid_y); GatherOutputGradToInputGrad<T>(out_grid, x_grad, grid_x, grid_y);
} }
} else {
const int n = grid.dims()[0];
const int out_d = grid.dims()[1];
const int out_h = grid.dims()[2];
const int out_w = grid.dims()[3];
const int c = x.dims()[1];
const int in_d = x.dims()[2];
const int in_h = x.dims()[3];
const int in_w = x.dims()[4];
x_grad->Resize({n, c, in_d, in_h, in_w});
dev_ctx.template Alloc<T>(x_grad);
phi::funcs::SetConstant<Context, T>()(dev_ctx, x_grad, static_cast<T>(0));
if (grid_grad != nullptr) {
grid_grad->Resize({n, out_d, out_h, out_w, 3});
dev_ctx.template Alloc<T>(grid_grad);
phi::funcs::SetConstant<Context, T>()(
dev_ctx, grid_grad, static_cast<T>(0));
}
DenseTensor grid_x, grid_y, grid_z;
DenseTensor grid_x_scale, grid_y_scale, grid_z_scale;
Calc3DGridLocationsWithGrad<T>(dev_ctx,
grid,
in_d,
in_h,
in_w,
align_corners,
padding_mode,
&grid_x,
&grid_y,
&grid_z,
&grid_x_scale,
&grid_y_scale,
&grid_z_scale);
if (mode == "bilinear") {
Gather3DBilinearGrad<T>(dev_ctx,
x,
out_grid,
&grid_x,
&grid_y,
&grid_z,
&grid_x_scale,
&grid_y_scale,
&grid_z_scale,
x_grad,
grid_grad);
} else {
Gather3DOutputGradToInputGrad<T>(
out_grid, x_grad, grid_x, grid_y, grid_z);
}
}
} }
} // namespace phi } // namespace phi
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
namespace phi { namespace phi {
using Array4 = Eigen::DSizes<int64_t, 4>; using Array4 = Eigen::DSizes<int64_t, 4>;
using Array5 = Eigen::DSizes<int64_t, 5>;
template <typename T> template <typename T>
static inline void Clip(const CPUContext& ctx, static inline void Clip(const CPUContext& ctx,
...@@ -55,6 +56,38 @@ static inline void Clip(const CPUContext& ctx, ...@@ -55,6 +56,38 @@ static inline void Clip(const CPUContext& ctx,
} }
} }
template <typename T>
static inline void Clip3D(const CPUContext& ctx,
DenseTensor* grid_slice,
const int max_val, // height-1 or width-1
bool align_corners,
std::string padding_mode) {
auto& place = *ctx.eigen_device();
auto grid_slice_t = EigenTensor<T, 4>::From(*grid_slice);
if (padding_mode == "border") {
grid_slice_t.device(place) = grid_slice_t.cwiseMax(static_cast<T>(0))
.cwiseMin(static_cast<T>(max_val));
} else if (padding_mode == "reflection") {
if (align_corners) {
auto double_range = static_cast<T>(max_val * 2);
auto grid_abs = grid_slice_t.abs();
auto extra = grid_abs - (grid_abs / double_range).floor() * double_range;
grid_slice_t.device(place) = extra.cwiseMin(double_range - extra);
if (max_val == 0) {
grid_slice_t.device(place) = grid_slice_t.constant(static_cast<T>(0));
}
} else {
auto double_range = static_cast<T>((max_val + 1) * 2);
auto grid_abs = (grid_slice_t + static_cast<T>(0.5)).abs();
auto extra = grid_abs - (grid_abs / double_range).floor() * double_range;
grid_slice_t.device(place) =
extra.cwiseMin(double_range - extra) - static_cast<T>(0.5);
grid_slice_t.device(place) = grid_slice_t.cwiseMax(static_cast<T>(0))
.cwiseMin(static_cast<T>(max_val));
}
}
}
template <typename T> template <typename T>
static void CalcGridLocations(const CPUContext& ctx, static void CalcGridLocations(const CPUContext& ctx,
const DenseTensor& grid, const DenseTensor& grid,
...@@ -86,6 +119,45 @@ static void CalcGridLocations(const CPUContext& ctx, ...@@ -86,6 +119,45 @@ static void CalcGridLocations(const CPUContext& ctx,
Clip<T>(ctx, grid_y, in_h - 1, align_corners, padding_mode); Clip<T>(ctx, grid_y, in_h - 1, align_corners, padding_mode);
} }
template <typename T>
static void Calc3DGridLocations(const CPUContext& ctx,
const DenseTensor& grid,
const int in_d,
const int in_h,
const int in_w,
bool align_corners,
std::string padding_mode,
DenseTensor* grid_x,
DenseTensor* grid_y,
DenseTensor* grid_z) {
const int n = grid.dims()[0];
const int out_d = grid.dims()[1];
const int out_h = grid.dims()[2];
const int out_w = grid.dims()[3];
// split grid with shape (n, d, h, w, 3) into (x, y, z) by the 3rd Dim
grid_x->Resize({n, out_d, out_h, out_w});
grid_y->Resize({n, out_d, out_h, out_w});
grid_z->Resize({n, out_d, out_h, out_w});
T* grid_x_data = ctx.Alloc<T>(grid_x);
T* grid_y_data = ctx.Alloc<T>(grid_y);
T* grid_z_data = ctx.Alloc<T>(grid_z);
const T* grid_data = grid.data<T>();
for (int i = 0; i < n * out_d * out_h * out_w; i++) {
grid_x_data[i] = grid_data[3 * i];
grid_y_data[i] = grid_data[(3 * i) + 1];
grid_z_data[i] = grid_data[(3 * i) + 2];
}
Unnormalize3D<T>(ctx, grid_x, in_w - 1, align_corners);
Unnormalize3D<T>(ctx, grid_y, in_h - 1, align_corners);
Unnormalize3D<T>(ctx, grid_z, in_d - 1, align_corners);
Clip3D<T>(ctx, grid_x, in_w - 1, align_corners, padding_mode);
Clip3D<T>(ctx, grid_y, in_h - 1, align_corners, padding_mode);
Clip3D<T>(ctx, grid_z, in_d - 1, align_corners, padding_mode);
}
template <typename T> template <typename T>
static void BilinearInter(const CPUContext& ctx, static void BilinearInter(const CPUContext& ctx,
const DenseTensor& input, const DenseTensor& input,
...@@ -144,6 +216,94 @@ static void BilinearInter(const CPUContext& ctx, ...@@ -144,6 +216,94 @@ static void BilinearInter(const CPUContext& ctx,
v_es_t * d_w_scaled_t * d_n_scaled_t; v_es_t * d_w_scaled_t * d_n_scaled_t;
} }
template <typename T>
static void Bilinear3DInter(const CPUContext& ctx,
const DenseTensor& input,
DenseTensor* grid_x,
DenseTensor* grid_y,
DenseTensor* grid_z,
DenseTensor* out) {
auto& place = *ctx.eigen_device();
const int n = grid_x->dims()[0];
const int out_d = grid_x->dims()[1];
const int out_h = grid_x->dims()[2];
const int out_w = grid_x->dims()[3];
const int c = input.dims()[1];
// get corner pixel values from (x, y, z)
// for 4d, we used north-east-south-west
// for 5d, we add top-bottom
DenseTensor x_w, x_e, y_n, y_s, z_t, z_b;
DenseTensor d_w, d_e, d_n, d_s, d_t, d_b;
DenseTensor v_twn, v_ten, v_tws, v_tes, v_bwn, v_ben, v_bws, v_bes;
All3DNeigbors<T>(ctx,
input,
grid_x,
grid_y,
grid_z,
&x_w,
&x_e,
&y_n,
&y_s,
&z_t,
&z_b,
&d_w,
&d_e,
&d_n,
&d_s,
&d_t,
&d_b,
&v_twn,
&v_ten,
&v_tws,
&v_tes,
&v_bwn,
&v_ben,
&v_bws,
&v_bes);
auto d_w_t = EigenTensor<T, 4>::From(d_w);
auto d_e_t = EigenTensor<T, 4>::From(d_e);
auto d_n_t = EigenTensor<T, 4>::From(d_n);
auto d_s_t = EigenTensor<T, 4>::From(d_s);
auto d_t_t = EigenTensor<T, 4>::From(d_t);
auto d_b_t = EigenTensor<T, 4>::From(d_b);
auto d_w_scaled_t = d_w_t.reshape(Array5(n, 1, out_d, out_h, out_w))
.broadcast(Array5(1, c, 1, 1, 1));
auto d_e_scaled_t = d_e_t.reshape(Array5(n, 1, out_d, out_h, out_w))
.broadcast(Array5(1, c, 1, 1, 1));
auto d_n_scaled_t = d_n_t.reshape(Array5(n, 1, out_d, out_h, out_w))
.broadcast(Array5(1, c, 1, 1, 1));
auto d_s_scaled_t = d_s_t.reshape(Array5(n, 1, out_d, out_h, out_w))
.broadcast(Array5(1, c, 1, 1, 1));
auto d_t_scaled_t = d_t_t.reshape(Array5(n, 1, out_d, out_h, out_w))
.broadcast(Array5(1, c, 1, 1, 1));
auto d_b_scaled_t = d_b_t.reshape(Array5(n, 1, out_d, out_h, out_w))
.broadcast(Array5(1, c, 1, 1, 1));
auto v_twn_t = EigenTensor<T, 5>::From(v_twn);
auto v_ten_t = EigenTensor<T, 5>::From(v_ten);
auto v_tws_t = EigenTensor<T, 5>::From(v_tws);
auto v_tes_t = EigenTensor<T, 5>::From(v_tes);
auto v_bwn_t = EigenTensor<T, 5>::From(v_bwn);
auto v_ben_t = EigenTensor<T, 5>::From(v_ben);
auto v_bws_t = EigenTensor<T, 5>::From(v_bws);
auto v_bes_t = EigenTensor<T, 5>::From(v_bes);
auto output_t = EigenTensor<T, 5>::From(*out);
// bilinear interpolaetion by 4 corner points
output_t.device(place) =
v_twn_t * d_e_scaled_t * d_s_scaled_t * d_b_scaled_t +
v_ten_t * d_w_scaled_t * d_s_scaled_t * d_b_scaled_t +
v_tws_t * d_e_scaled_t * d_n_scaled_t * d_b_scaled_t +
v_tes_t * d_w_scaled_t * d_n_scaled_t * d_b_scaled_t +
v_bwn_t * d_e_scaled_t * d_s_scaled_t * d_t_scaled_t +
v_ben_t * d_w_scaled_t * d_s_scaled_t * d_t_scaled_t +
v_bws_t * d_e_scaled_t * d_n_scaled_t * d_t_scaled_t +
v_bes_t * d_w_scaled_t * d_n_scaled_t * d_t_scaled_t;
}
template <typename T, typename Context> template <typename T, typename Context>
void GridSampleKernel(const Context& dev_ctx, void GridSampleKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
...@@ -152,6 +312,7 @@ void GridSampleKernel(const Context& dev_ctx, ...@@ -152,6 +312,7 @@ void GridSampleKernel(const Context& dev_ctx,
const std::string& padding_mode, const std::string& padding_mode,
bool align_corners, bool align_corners,
DenseTensor* out) { DenseTensor* out) {
if (x.dims().size() == 4) {
const int n = grid.dims()[0]; const int n = grid.dims()[0];
const int out_h = grid.dims()[1]; const int out_h = grid.dims()[1];
const int out_w = grid.dims()[2]; const int out_w = grid.dims()[2];
...@@ -164,8 +325,14 @@ void GridSampleKernel(const Context& dev_ctx, ...@@ -164,8 +325,14 @@ void GridSampleKernel(const Context& dev_ctx,
phi::funcs::SetConstant<Context, T>()(dev_ctx, out, static_cast<T>(0)); phi::funcs::SetConstant<Context, T>()(dev_ctx, out, static_cast<T>(0));
DenseTensor grid_x, grid_y; DenseTensor grid_x, grid_y;
CalcGridLocations<T>( CalcGridLocations<T>(dev_ctx,
dev_ctx, grid, in_h, in_w, align_corners, padding_mode, &grid_x, &grid_y); grid,
in_h,
in_w,
align_corners,
padding_mode,
&grid_x,
&grid_y);
if (mode == "bilinear") { if (mode == "bilinear") {
BilinearInter<T>(dev_ctx, x, &grid_x, &grid_y, out); BilinearInter<T>(dev_ctx, x, &grid_x, &grid_y, out);
...@@ -176,6 +343,37 @@ void GridSampleKernel(const Context& dev_ctx, ...@@ -176,6 +343,37 @@ void GridSampleKernel(const Context& dev_ctx,
grid_y_t = grid_y_t.round(); grid_y_t = grid_y_t.round();
GetGridPointValue<T>(x, out, grid_x, grid_y); GetGridPointValue<T>(x, out, grid_x, grid_y);
} }
} else {
const int n = grid.dims()[0];
const int out_d = grid.dims()[1];
const int out_h = grid.dims()[2];
const int out_w = grid.dims()[3];
const int c = x.dims()[1];
const int in_d = x.dims()[2];
const int in_h = x.dims()[3];
const int in_w = x.dims()[4];
out->Resize(phi::make_ddim({n, c, out_d, out_h, out_w}));
dev_ctx.template Alloc<T>(out);
phi::funcs::SetConstant<Context, T>()(dev_ctx, out, static_cast<T>(0));
DenseTensor grid_x, grid_y, grid_z;
Calc3DGridLocations<T>(dev_ctx,
grid,
in_d,
in_h,
in_w,
align_corners,
padding_mode,
&grid_x,
&grid_y,
&grid_z);
if (mode == "bilinear") {
Bilinear3DInter<T>(dev_ctx, x, &grid_x, &grid_y, &grid_z, out);
} else if (mode == "nearest") {
Get3DGridPointValue<T>(x, out, grid_x, grid_y, grid_z);
}
}
} }
} // namespace phi } // namespace phi
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/common.h"
...@@ -37,6 +36,24 @@ void Unnormalize(const CPUContext& ctx, ...@@ -37,6 +36,24 @@ void Unnormalize(const CPUContext& ctx,
} }
} }
template <typename T>
void Unnormalize3D(const CPUContext& ctx,
DenseTensor* grid_slice,
const int max_val, // height-1 or width-1
bool align_corners) {
auto& place = *ctx.eigen_device();
auto grid_slice_t = EigenTensor<T, 4>::From(*grid_slice);
if (!align_corners) {
auto factor = static_cast<T>((max_val + 1) * 0.5);
grid_slice_t.device(place) =
(grid_slice_t + static_cast<T>(1)) * factor - static_cast<T>(0.5);
} else {
auto factor = static_cast<T>(max_val * 0.5);
grid_slice_t.device(place) = (grid_slice_t + static_cast<T>(1)) * factor;
}
}
template <typename T> template <typename T>
inline bool IsInBound(T x, T y, T x_max, T y_max) { inline bool IsInBound(T x, T y, T x_max, T y_max) {
if (x < 0 || x > x_max || y < 0 || y > y_max) { if (x < 0 || x > x_max || y < 0 || y > y_max) {
...@@ -45,6 +62,14 @@ inline bool IsInBound(T x, T y, T x_max, T y_max) { ...@@ -45,6 +62,14 @@ inline bool IsInBound(T x, T y, T x_max, T y_max) {
return true; return true;
} }
template <typename T>
inline bool IsInBound3D(T x, T y, T z, T x_max, T y_max, T z_max) {
if (x < 0 || x > x_max || y < 0 || y > y_max || z < 0 || z > z_max) {
return false;
}
return true;
}
template <typename T> template <typename T>
void GetGridPointValue(const DenseTensor& input, void GetGridPointValue(const DenseTensor& input,
DenseTensor* output, DenseTensor* output,
...@@ -157,4 +182,167 @@ void AllNeigbors(const CPUContext& ctx, ...@@ -157,4 +182,167 @@ void AllNeigbors(const CPUContext& ctx,
GetGridPointValue<T>(input, v_es, *x_e, *y_s); GetGridPointValue<T>(input, v_es, *x_e, *y_s);
} }
template <typename T>
void Get3DGridPointValue(const DenseTensor& input,
DenseTensor* output,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& z) {
const int n = input.dims()[0];
const int c = input.dims()[1];
const int in_d = input.dims()[2];
const int in_h = input.dims()[3];
const int in_w = input.dims()[4];
const int out_d = x.dims()[1];
const int out_h = x.dims()[2];
const int out_w = x.dims()[3];
auto x_t = EigenTensor<T, 4>::From(x);
auto y_t = EigenTensor<T, 4>::From(y);
auto z_t = EigenTensor<T, 4>::From(z);
auto output_t =
EigenTensor<T, 5>::From(*output).setConstant(static_cast<T>(0.0));
auto input_t = EigenTensor<T, 5>::From(input);
for (int i = 0; i < n; i++) {
for (int m = 0; m < out_d; m++) {
for (int k = 0; k < out_h; k++) {
for (int l = 0; l < out_w; l++) {
if (IsInBound3D(x_t(i, m, k, l),
y_t(i, m, k, l),
z_t(i, m, k, l),
(T)(in_w - 1),
(T)(in_h - 1),
(T)(in_d - 1))) {
for (int j = 0; j < c; j++) {
output_t(i, j, m, k, l) =
input_t(i,
j,
static_cast<int>(round(z_t(i, m, k, l))),
static_cast<int>(round(y_t(i, m, k, l))),
static_cast<int>(round(x_t(i, m, k, l))));
}
}
}
}
}
}
}
template <typename T>
void All3DNeigbors(const CPUContext& ctx,
const DenseTensor& input,
DenseTensor* grid_x,
DenseTensor* grid_y,
DenseTensor* grid_z,
DenseTensor* x_w,
DenseTensor* x_e,
DenseTensor* y_n,
DenseTensor* y_s,
DenseTensor* z_t,
DenseTensor* z_b, // positions
DenseTensor* d_w,
DenseTensor* d_e,
DenseTensor* d_n,
DenseTensor* d_s,
DenseTensor* d_t,
DenseTensor* d_b, // distance
DenseTensor* v_twn,
DenseTensor* v_ten,
DenseTensor* v_tws,
DenseTensor* v_tes,
DenseTensor* v_bwn,
DenseTensor* v_ben,
DenseTensor* v_bws,
DenseTensor* v_bes) { // values
auto& place = *ctx.eigen_device();
const int c = input.dims()[1];
const int n = grid_x->dims()[0];
const int out_d = grid_x->dims()[1];
const int out_h = grid_x->dims()[2];
const int out_w = grid_x->dims()[3];
// calculate coords of 6 corner points
x_w->Resize({n, out_d, out_h, out_w});
x_e->Resize({n, out_d, out_h, out_w});
y_n->Resize({n, out_d, out_h, out_w});
y_s->Resize({n, out_d, out_h, out_w});
z_t->Resize({n, out_d, out_h, out_w});
z_b->Resize({n, out_d, out_h, out_w});
ctx.Alloc<T>(x_w);
ctx.Alloc<T>(x_e);
ctx.Alloc<T>(y_n);
ctx.Alloc<T>(y_s);
ctx.Alloc<T>(z_t);
ctx.Alloc<T>(z_b);
auto x_w_t = EigenTensor<T, 4>::From(*x_w);
auto x_e_t = EigenTensor<T, 4>::From(*x_e);
auto y_n_t = EigenTensor<T, 4>::From(*y_n);
auto y_s_t = EigenTensor<T, 4>::From(*y_s);
auto z_t_t = EigenTensor<T, 4>::From(*z_t);
auto z_b_t = EigenTensor<T, 4>::From(*z_b);
auto grid_x_t = EigenTensor<T, 4>::From(*grid_x);
auto grid_y_t = EigenTensor<T, 4>::From(*grid_y);
auto grid_z_t = EigenTensor<T, 4>::From(*grid_z);
x_w_t.device(place) = grid_x_t.floor();
x_e_t.device(place) = x_w_t + static_cast<T>(1);
y_n_t.device(place) = grid_y_t.floor();
y_s_t.device(place) = y_n_t + static_cast<T>(1);
z_t_t.device(place) = grid_z_t.floor();
z_b_t.device(place) = z_t_t + static_cast<T>(1);
// calculate distances to 6 sides
d_w->Resize({n, out_d, out_h, out_w});
d_e->Resize({n, out_d, out_h, out_w});
d_n->Resize({n, out_d, out_h, out_w});
d_s->Resize({n, out_d, out_h, out_w});
d_t->Resize({n, out_d, out_h, out_w});
d_b->Resize({n, out_d, out_h, out_w});
ctx.Alloc<T>(d_w);
ctx.Alloc<T>(d_e);
ctx.Alloc<T>(d_n);
ctx.Alloc<T>(d_s);
ctx.Alloc<T>(d_t);
ctx.Alloc<T>(d_b);
auto d_w_t = EigenTensor<T, 4>::From(*d_w);
auto d_e_t = EigenTensor<T, 4>::From(*d_e);
auto d_n_t = EigenTensor<T, 4>::From(*d_n);
auto d_s_t = EigenTensor<T, 4>::From(*d_s);
auto d_t_t = EigenTensor<T, 4>::From(*d_t);
auto d_b_t = EigenTensor<T, 4>::From(*d_b);
d_w_t.device(place) = grid_x_t - x_w_t;
d_e_t.device(place) = x_e_t - grid_x_t;
d_n_t.device(place) = grid_y_t - y_n_t;
d_s_t.device(place) = y_s_t - grid_y_t;
d_t_t.device(place) = grid_z_t - z_t_t;
d_b_t.device(place) = z_b_t - grid_z_t;
// calc 8 corner points value
v_twn->Resize({n, c, out_d, out_h, out_w});
v_ten->Resize({n, c, out_d, out_h, out_w});
v_tws->Resize({n, c, out_d, out_h, out_w});
v_tes->Resize({n, c, out_d, out_h, out_w});
v_bwn->Resize({n, c, out_d, out_h, out_w});
v_ben->Resize({n, c, out_d, out_h, out_w});
v_bws->Resize({n, c, out_d, out_h, out_w});
v_bes->Resize({n, c, out_d, out_h, out_w});
ctx.Alloc<T>(v_twn);
ctx.Alloc<T>(v_ten);
ctx.Alloc<T>(v_tws);
ctx.Alloc<T>(v_tes);
ctx.Alloc<T>(v_bwn);
ctx.Alloc<T>(v_ben);
ctx.Alloc<T>(v_bws);
ctx.Alloc<T>(v_bes);
Get3DGridPointValue<T>(input, v_twn, *x_w, *y_n, *z_t);
Get3DGridPointValue<T>(input, v_ten, *x_e, *y_n, *z_t);
Get3DGridPointValue<T>(input, v_tws, *x_w, *y_s, *z_t);
Get3DGridPointValue<T>(input, v_tes, *x_e, *y_s, *z_t);
Get3DGridPointValue<T>(input, v_bwn, *x_w, *y_n, *z_b);
Get3DGridPointValue<T>(input, v_ben, *x_e, *y_n, *z_b);
Get3DGridPointValue<T>(input, v_bws, *x_w, *y_s, *z_b);
Get3DGridPointValue<T>(input, v_bes, *x_e, *y_s, *z_b);
}
} // namespace phi } // namespace phi
...@@ -20,15 +20,6 @@ from op_test import OpTest, skip_check_grad_ci ...@@ -20,15 +20,6 @@ from op_test import OpTest, skip_check_grad_ci
paddle.enable_static() paddle.enable_static()
from white_list import (
op_accuracy_white_list,
check_shape_white_list,
compile_vs_runtime_white_list,
no_check_set_white_list,
op_threshold_white_list,
no_grad_set_white_list,
)
def AffineGrid(theta, grid_shape): def AffineGrid(theta, grid_shape):
n = grid_shape[0] n = grid_shape[0]
...@@ -118,7 +109,6 @@ def getGridPointValue3D(data, x, y, z): ...@@ -118,7 +109,6 @@ def getGridPointValue3D(data, x, y, z):
out_H = x.shape[2] out_H = x.shape[2]
out_W = x.shape[3] out_W = x.shape[3]
#out = np.zeros(data_shape, dtype='float64')
out = np.zeros([N, C, out_D, out_H, out_W], dtype='float64') out = np.zeros([N, C, out_D, out_H, out_W], dtype='float64')
for i in range(N): for i in range(N):
for j in range(out_D): for j in range(out_D):
...@@ -334,51 +324,15 @@ class TestGridSamplerOp(OpTest): ...@@ -334,51 +324,15 @@ class TestGridSamplerOp(OpTest):
self.padding_mode) self.padding_mode)
} }
def get_places(self):
places = []
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
return places
def test_check_output(self): def test_check_output(self):
if len(self.grid_shape) == 4:
self.check_output(check_eager=True) self.check_output(check_eager=True)
else:
check_eager_flag = True
check_dygraph_flag = False
for place in self.get_places():
res = self.check_output_with_place(
place,
atol=1e-5,
check_dygraph=check_dygraph_flag,
check_eager=check_eager_flag)
if check_eager_flag:
assert check_dygraph_flag == False
outs, eager_dygraph_outs, fetch_list = res
elif check_dygraph_flag:
uts, dygraph_outs, fetch_list = res
else:
outs, fetch_list = res
if self.op_type not in compile_vs_runtime_white_list.COMPILE_RUN_OP_WHITE_LIST:
self.check_compile_vs_runtime(fetch_list, outs)
def test_check_grad_normal(self): def test_check_grad_normal(self):
if len(self.grid_shape) == 4:
self.check_grad(['X', 'Grid'], self.check_grad(['X', 'Grid'],
'Output', 'Output',
max_relative_error=0.01, max_relative_error=0.01,
numeric_grad_delta=self.numeric_grad_delta, numeric_grad_delta=self.numeric_grad_delta,
check_eager=True) check_eager=True)
else:
self._check_grad_helper()
for place in self.get_places():
self.check_grad_with_place(
place, ['X'],
'Output',
numeric_grad_delta=self.numeric_grad_delta,
max_relative_error=0.01,
check_eager=True,
check_dygraph=False)
def initTestCase(self): def initTestCase(self):
self.x_shape = (2, 3, 8, 8) self.x_shape = (2, 3, 8, 8)
...@@ -493,63 +447,67 @@ class Case6(TestGridSamplerOp): ...@@ -493,63 +447,67 @@ class Case6(TestGridSamplerOp):
self.align_corners = False self.align_corners = False
self.padding_mode = "zeros" self.padding_mode = "zeros"
self.mode = "bilinear" self.mode = "bilinear"
self.numeric_grad_delta = 0.000001
class Case6_(TestGridSamplerOp): class Case6_(TestGridSamplerOp):
def get_places(self):
places = []
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
return places
def initTestCase(self): def initTestCase(self):
self.x_shape = (2, 3, 5, 6, 7) self.x_shape = (2, 3, 4, 5, 6)
self.grid_shape = (2, 8, 9, 10, 3) self.grid_shape = (2, 7, 8, 9, 3)
self.theta_shape = (2, 3, 4) self.theta_shape = (2, 3, 4)
self.align_corners = False self.align_corners = False
self.padding_mode = "border" self.padding_mode = "border"
self.mode = "bilinear" self.mode = "bilinear"
self.numeric_grad_delta = 0.000001
class Case7(TestGridSamplerOp): class Case7(TestGridSamplerOp):
def initTestCase(self): def initTestCase(self):
self.x_shape = (2, 3, 5, 6, 7) self.x_shape = (2, 3, 4, 5, 6)
self.grid_shape = (2, 8, 9, 10, 3) self.grid_shape = (2, 7, 8, 9, 3)
self.theta_shape = (2, 3, 4) self.theta_shape = (2, 3, 4)
self.align_corners = False self.align_corners = False
self.padding_mode = "reflection" self.padding_mode = "reflection"
self.mode = "bilinear" self.mode = "bilinear"
self.numeric_grad_delta = 0.000001
class Case8(TestGridSamplerOp): class Case8(TestGridSamplerOp):
def initTestCase(self): def initTestCase(self):
self.x_shape = (2, 3, 5, 6, 7) self.x_shape = (2, 3, 4, 5, 6)
self.grid_shape = (2, 8, 9, 10, 3) self.grid_shape = (2, 7, 8, 9, 3)
self.theta_shape = (2, 3, 4) self.theta_shape = (2, 3, 4)
self.align_corners = True self.align_corners = True
self.padding_mode = "reflection" self.padding_mode = "reflection"
self.mode = "bilinear" self.mode = "bilinear"
self.numeric_grad_delta = 0.000001
class Case9(TestGridSamplerOp): class Case9(TestGridSamplerOp):
def initTestCase(self): def initTestCase(self):
self.x_shape = (2, 3, 5, 6, 7) self.x_shape = (2, 3, 4, 5, 6)
self.grid_shape = (2, 8, 9, 10, 3) self.grid_shape = (2, 7, 8, 9, 3)
self.theta_shape = (2, 3, 4) self.theta_shape = (2, 3, 4)
self.align_corners = False self.align_corners = False
self.padding_mode = "reflection" self.padding_mode = "reflection"
self.mode = "nearest" self.mode = "nearest"
self.numeric_grad_delta = 0.0001 self.numeric_grad_delta = 0.000001
@skip_check_grad_ci(reason="'check_grad' on large inputs is too slow, " + @skip_check_grad_ci(reason="'check_grad' on large inputs is too slow, " +
"however it is desirable to cover the forward pass") "however it is desirable to cover the forward pass")
class LargeInput3DCase(TestGridSamplerOp): class LargeInput3DCase(TestGridSamplerOp):
def get_places(self):
places = []
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
return places
def initTestCase(self): def initTestCase(self):
self.no_need_check_grad = True self.no_need_check_grad = True
self.x_shape = (2, 3, 24, 24, 12) self.x_shape = (2, 3, 24, 24, 12)
...@@ -558,8 +516,8 @@ class LargeInput3DCase(TestGridSamplerOp): ...@@ -558,8 +516,8 @@ class LargeInput3DCase(TestGridSamplerOp):
self.align_corners = False self.align_corners = False
self.padding_mode = "reflection" self.padding_mode = "reflection"
self.mode = "bilinear" self.mode = "bilinear"
self.numeric_grad_delta = 0.000001
self.use_cudnn = False self.use_cudnn = False
self.__class__.op_type = 'grid_sampler'
def test_check_grad_normal(self): def test_check_grad_normal(self):
pass pass
...@@ -577,8 +535,7 @@ class Case10(LargeInput3DCase): ...@@ -577,8 +535,7 @@ class Case10(LargeInput3DCase):
self.align_corners = True self.align_corners = True
self.padding_mode = "zeros" self.padding_mode = "zeros"
self.mode = "bilinear" self.mode = "bilinear"
self.use_cudnn = False self.numeric_grad_delta = 0.000001
self.__class__.op_type = 'grid_sampler'
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -275,6 +275,9 @@ def grid_sample(x, ...@@ -275,6 +275,9 @@ def grid_sample(x,
x.stop_gradient = False x.stop_gradient = False
grid.stop_gradient = False grid.stop_gradient = False
if len(grid.shape) == 5:
use_cudnn = False
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.grid_sample(x, grid, mode, padding_mode, align_corners) return _C_ops.grid_sample(x, grid, mode, padding_mode, align_corners)
elif in_dynamic_mode(): elif in_dynamic_mode():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册