未验证 提交 9cc5603d 编写于 作者: W whs 提交者: GitHub

Make grid support stopping graients. (#27630)

上级 074a71bd
......@@ -176,8 +176,6 @@ class GridSampleOpGrad : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
framework::GradVarName("X"), "grid_sampler");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Grid")), "Output",
framework::GradVarName("Grid"), "grid_sampler");
auto input_dims = ctx->GetInputDim("X");
auto grid_dims = ctx->GetInputDim("Grid");
if (ctx->HasOutput(framework::GradVarName("X"))) {
......
......@@ -397,9 +397,11 @@ __global__ void grid_sampler_cuda_backward_kernel(
}
}
if (grad_grid != nullptr) {
T* gGrid_ptr_NHW = grad_grid + index * grid_sW;
gGrid_ptr_NHW[0] = gix_mult * gix;
gGrid_ptr_NHW[1] = giy_mult * giy;
}
} else if (mode == Mode::nearest) {
int ix_nearest = static_cast<int>(::round(ix));
int iy_nearest = static_cast<int>(::round(iy));
......@@ -412,11 +414,13 @@ __global__ void grid_sampler_cuda_backward_kernel(
in_w, grad_output[gOut_offset]);
}
if (grad_grid != nullptr) {
T* gGrid_ptr_NHW = grad_grid + index * grid_sW;
gGrid_ptr_NHW[0] = static_cast<T>(0);
gGrid_ptr_NHW[1] = static_cast<T>(0);
}
}
}
}
template <typename T>
......@@ -460,11 +464,15 @@ class GridSampleGradOpCUDAKernel : public framework::OpKernel<T> {
math::SetConstant<paddle::platform::CUDADeviceContext, T>()(
ctx.template device_context<paddle::platform::CUDADeviceContext>(),
input_grad, static_cast<T>(0));
T* grid_grad_data = nullptr;
if (ctx.HasOutput(framework::GradVarName("Grid"))) {
auto* grid_grad = ctx.Output<Tensor>(framework::GradVarName("Grid"));
grid_grad->mutable_data<T>(ctx.GetPlace());
grid_grad_data = grid_grad->mutable_data<T>(ctx.GetPlace());
math::SetConstant<paddle::platform::CUDADeviceContext, T>()(
ctx.template device_context<paddle::platform::CUDADeviceContext>(),
grid_grad, static_cast<T>(0));
}
int count = static_cast<int>(n * out_h * out_w);
auto cu_stream = dev_ctx.stream();
......@@ -472,8 +480,8 @@ class GridSampleGradOpCUDAKernel : public framework::OpKernel<T> {
int grid_size = (count + block - 1) / block;
grid_sampler_cuda_backward_kernel<T><<<block, grid_size, 0, cu_stream>>>(
count, output_grad->data<T>(), input->data<T>(), grid->data<T>(), n, c,
out_h, out_w, in_h, in_w, input_grad->data<T>(), grid_grad->data<T>(),
mode, padding_mode, align_corners);
out_h, out_w, in_h, in_w, input_grad->data<T>(), grid_grad_data, mode,
padding_mode, align_corners);
}
};
......
......@@ -450,6 +450,7 @@ static void gatherBilinearGrad(const platform::CPUDeviceContext& ctx,
auto output_grad_t = EigenTensor<T, 4>::From(output_grad);
if (grid_grad != nullptr) {
Tensor grid_grad_x, grid_grad_y;
grid_grad_x.mutable_data<T>({n, out_h, out_w}, ctx.GetPlace());
grid_grad_y.mutable_data<T>({n, out_h, out_w}, ctx.GetPlace());
......@@ -490,6 +491,7 @@ static void gatherBilinearGrad(const platform::CPUDeviceContext& ctx,
grid_grad_data[2 * i] = grid_grad_x_data[i];
grid_grad_data[2 * i + 1] = grid_grad_y_data[i];
}
}
}
template <typename DeviceContext, typename T>
......@@ -558,11 +560,16 @@ class GridSampleGradOpKernel : public framework::OpKernel<T> {
math::SetConstant<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), input_grad,
static_cast<T>(0));
auto* grid_grad = ctx.Output<Tensor>(framework::GradVarName("Grid"));
Tensor* grid_grad = nullptr;
if (ctx.HasOutput(framework::GradVarName("Grid"))) {
grid_grad = ctx.Output<Tensor>(framework::GradVarName("Grid"));
grid_grad->mutable_data<T>({n, out_h, out_w, 2}, ctx.GetPlace());
math::SetConstant<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), grid_grad,
static_cast<T>(0));
}
Tensor grid_x, grid_y;
Tensor grid_x_scale, grid_y_scale;
calcGridLocationsWithGrad<T>(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册