diff --git a/paddle/fluid/operators/grid_sampler_op.cc b/paddle/fluid/operators/grid_sampler_op.cc index f5224239eb2ded9a156aadc9185eca89f4e3396f..3d34a3d15c1ddd944dd205def278beeeef3efdeb 100644 --- a/paddle/fluid/operators/grid_sampler_op.cc +++ b/paddle/fluid/operators/grid_sampler_op.cc @@ -176,8 +176,6 @@ class GridSampleOpGrad : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override { OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output", framework::GradVarName("X"), "grid_sampler"); - OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Grid")), "Output", - framework::GradVarName("Grid"), "grid_sampler"); auto input_dims = ctx->GetInputDim("X"); auto grid_dims = ctx->GetInputDim("Grid"); if (ctx->HasOutput(framework::GradVarName("X"))) { diff --git a/paddle/fluid/operators/grid_sampler_op.cu b/paddle/fluid/operators/grid_sampler_op.cu index 4e61d0c2ea7f91e4199c3e9daa3e93ac45bc0eb8..b250407fdb38199e155787a9d1c313587a0c0379 100644 --- a/paddle/fluid/operators/grid_sampler_op.cu +++ b/paddle/fluid/operators/grid_sampler_op.cu @@ -397,9 +397,11 @@ __global__ void grid_sampler_cuda_backward_kernel( } } - T* gGrid_ptr_NHW = grad_grid + index * grid_sW; - gGrid_ptr_NHW[0] = gix_mult * gix; - gGrid_ptr_NHW[1] = giy_mult * giy; + if (grad_grid != nullptr) { + T* gGrid_ptr_NHW = grad_grid + index * grid_sW; + gGrid_ptr_NHW[0] = gix_mult * gix; + gGrid_ptr_NHW[1] = giy_mult * giy; + } } else if (mode == Mode::nearest) { int ix_nearest = static_cast(::round(ix)); int iy_nearest = static_cast(::round(iy)); @@ -412,9 +414,11 @@ __global__ void grid_sampler_cuda_backward_kernel( in_w, grad_output[gOut_offset]); } - T* gGrid_ptr_NHW = grad_grid + index * grid_sW; - gGrid_ptr_NHW[0] = static_cast(0); - gGrid_ptr_NHW[1] = static_cast(0); + if (grad_grid != nullptr) { + T* gGrid_ptr_NHW = grad_grid + index * grid_sW; + gGrid_ptr_NHW[0] = static_cast(0); + gGrid_ptr_NHW[1] = static_cast(0); + } } } } @@ -460,11 +464,15 @@ class GridSampleGradOpCUDAKernel : public framework::OpKernel { math::SetConstant()( ctx.template device_context(), input_grad, static_cast(0)); - auto* grid_grad = ctx.Output(framework::GradVarName("Grid")); - grid_grad->mutable_data(ctx.GetPlace()); - math::SetConstant()( - ctx.template device_context(), - grid_grad, static_cast(0)); + + T* grid_grad_data = nullptr; + if (ctx.HasOutput(framework::GradVarName("Grid"))) { + auto* grid_grad = ctx.Output(framework::GradVarName("Grid")); + grid_grad_data = grid_grad->mutable_data(ctx.GetPlace()); + math::SetConstant()( + ctx.template device_context(), + grid_grad, static_cast(0)); + } int count = static_cast(n * out_h * out_w); auto cu_stream = dev_ctx.stream(); @@ -472,8 +480,8 @@ class GridSampleGradOpCUDAKernel : public framework::OpKernel { int grid_size = (count + block - 1) / block; grid_sampler_cuda_backward_kernel<<>>( count, output_grad->data(), input->data(), grid->data(), n, c, - out_h, out_w, in_h, in_w, input_grad->data(), grid_grad->data(), - mode, padding_mode, align_corners); + out_h, out_w, in_h, in_w, input_grad->data(), grid_grad_data, mode, + padding_mode, align_corners); } }; diff --git a/paddle/fluid/operators/grid_sampler_op.h b/paddle/fluid/operators/grid_sampler_op.h index b8faef759ae90e14d1e83b66130bfe957b51907b..b1857b49eede0db931479fdf75e3a407558df1c0 100644 --- a/paddle/fluid/operators/grid_sampler_op.h +++ b/paddle/fluid/operators/grid_sampler_op.h @@ -450,45 +450,47 @@ static void gatherBilinearGrad(const platform::CPUDeviceContext& ctx, auto output_grad_t = EigenTensor::From(output_grad); - Tensor grid_grad_x, grid_grad_y; - grid_grad_x.mutable_data({n, out_h, out_w}, ctx.GetPlace()); - grid_grad_y.mutable_data({n, out_h, out_w}, ctx.GetPlace()); - auto grid_grad_x_t = - EigenTensor::From(grid_grad_x).setConstant(static_cast(0.0)); - auto grid_grad_y_t = - EigenTensor::From(grid_grad_y).setConstant(static_cast(0.0)); - for (int i = 0; i < n; i++) { - for (int j = 0; j < c; j++) { - for (int k = 0; k < out_h; k++) { - for (int l = 0; l < out_w; l++) { - grid_grad_x_t(i, k, l) += - ((v_en_t(i, j, k, l) - v_wn_t(i, j, k, l)) * d_s_t(i, k, l) + - (v_es_t(i, j, k, l) - v_ws_t(i, j, k, l)) * d_n_t(i, k, l)) * - output_grad_t(i, j, k, l); - grid_grad_y_t(i, k, l) += - ((v_ws_t(i, j, k, l) - v_wn_t(i, j, k, l)) * d_e_t(i, k, l) + - (v_es_t(i, j, k, l) - v_en_t(i, j, k, l)) * d_w_t(i, k, l)) * - output_grad_t(i, j, k, l); + if (grid_grad != nullptr) { + Tensor grid_grad_x, grid_grad_y; + grid_grad_x.mutable_data({n, out_h, out_w}, ctx.GetPlace()); + grid_grad_y.mutable_data({n, out_h, out_w}, ctx.GetPlace()); + auto grid_grad_x_t = + EigenTensor::From(grid_grad_x).setConstant(static_cast(0.0)); + auto grid_grad_y_t = + EigenTensor::From(grid_grad_y).setConstant(static_cast(0.0)); + for (int i = 0; i < n; i++) { + for (int j = 0; j < c; j++) { + for (int k = 0; k < out_h; k++) { + for (int l = 0; l < out_w; l++) { + grid_grad_x_t(i, k, l) += + ((v_en_t(i, j, k, l) - v_wn_t(i, j, k, l)) * d_s_t(i, k, l) + + (v_es_t(i, j, k, l) - v_ws_t(i, j, k, l)) * d_n_t(i, k, l)) * + output_grad_t(i, j, k, l); + grid_grad_y_t(i, k, l) += + ((v_ws_t(i, j, k, l) - v_wn_t(i, j, k, l)) * d_e_t(i, k, l) + + (v_es_t(i, j, k, l) - v_en_t(i, j, k, l)) * d_w_t(i, k, l)) * + output_grad_t(i, j, k, l); + } } } } - } - - // const T x_max = static_cast(in_w - 1); - // const T y_max = static_cast(in_h - 1); - - auto grid_x_scale_t = EigenTensor::From(*grid_x_scale); - auto grid_y_scale_t = EigenTensor::From(*grid_y_scale); - grid_grad_x_t = grid_grad_x_t * grid_x_scale_t; - grid_grad_y_t = grid_grad_y_t * grid_y_scale_t; - // gather grid_grad [x, y] in 3rd Dim - T* grid_grad_data = grid_grad->data(); - T* grid_grad_x_data = grid_grad_x.data(); - T* grid_grad_y_data = grid_grad_y.data(); - for (int i = 0; i < n * out_h * out_w; i++) { - grid_grad_data[2 * i] = grid_grad_x_data[i]; - grid_grad_data[2 * i + 1] = grid_grad_y_data[i]; + // const T x_max = static_cast(in_w - 1); + // const T y_max = static_cast(in_h - 1); + + auto grid_x_scale_t = EigenTensor::From(*grid_x_scale); + auto grid_y_scale_t = EigenTensor::From(*grid_y_scale); + grid_grad_x_t = grid_grad_x_t * grid_x_scale_t; + grid_grad_y_t = grid_grad_y_t * grid_y_scale_t; + + // gather grid_grad [x, y] in 3rd Dim + T* grid_grad_data = grid_grad->data(); + T* grid_grad_x_data = grid_grad_x.data(); + T* grid_grad_y_data = grid_grad_y.data(); + for (int i = 0; i < n * out_h * out_w; i++) { + grid_grad_data[2 * i] = grid_grad_x_data[i]; + grid_grad_data[2 * i + 1] = grid_grad_y_data[i]; + } } } @@ -558,11 +560,16 @@ class GridSampleGradOpKernel : public framework::OpKernel { math::SetConstant()( ctx.template device_context(), input_grad, static_cast(0)); - auto* grid_grad = ctx.Output(framework::GradVarName("Grid")); - grid_grad->mutable_data({n, out_h, out_w, 2}, ctx.GetPlace()); - math::SetConstant()( - ctx.template device_context(), grid_grad, - static_cast(0)); + + Tensor* grid_grad = nullptr; + if (ctx.HasOutput(framework::GradVarName("Grid"))) { + grid_grad = ctx.Output(framework::GradVarName("Grid")); + grid_grad->mutable_data({n, out_h, out_w, 2}, ctx.GetPlace()); + math::SetConstant()( + ctx.template device_context(), grid_grad, + static_cast(0)); + } + Tensor grid_x, grid_y; Tensor grid_x_scale, grid_y_scale; calcGridLocationsWithGrad(