未验证 提交 9cc5603d 编写于 作者: W whs 提交者: GitHub

Make grid support stopping graients. (#27630)

上级 074a71bd
...@@ -176,8 +176,6 @@ class GridSampleOpGrad : public framework::OperatorWithKernel { ...@@ -176,8 +176,6 @@ class GridSampleOpGrad : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output", OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
framework::GradVarName("X"), "grid_sampler"); framework::GradVarName("X"), "grid_sampler");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Grid")), "Output",
framework::GradVarName("Grid"), "grid_sampler");
auto input_dims = ctx->GetInputDim("X"); auto input_dims = ctx->GetInputDim("X");
auto grid_dims = ctx->GetInputDim("Grid"); auto grid_dims = ctx->GetInputDim("Grid");
if (ctx->HasOutput(framework::GradVarName("X"))) { if (ctx->HasOutput(framework::GradVarName("X"))) {
......
...@@ -397,9 +397,11 @@ __global__ void grid_sampler_cuda_backward_kernel( ...@@ -397,9 +397,11 @@ __global__ void grid_sampler_cuda_backward_kernel(
} }
} }
if (grad_grid != nullptr) {
T* gGrid_ptr_NHW = grad_grid + index * grid_sW; T* gGrid_ptr_NHW = grad_grid + index * grid_sW;
gGrid_ptr_NHW[0] = gix_mult * gix; gGrid_ptr_NHW[0] = gix_mult * gix;
gGrid_ptr_NHW[1] = giy_mult * giy; gGrid_ptr_NHW[1] = giy_mult * giy;
}
} else if (mode == Mode::nearest) { } else if (mode == Mode::nearest) {
int ix_nearest = static_cast<int>(::round(ix)); int ix_nearest = static_cast<int>(::round(ix));
int iy_nearest = static_cast<int>(::round(iy)); int iy_nearest = static_cast<int>(::round(iy));
...@@ -412,11 +414,13 @@ __global__ void grid_sampler_cuda_backward_kernel( ...@@ -412,11 +414,13 @@ __global__ void grid_sampler_cuda_backward_kernel(
in_w, grad_output[gOut_offset]); in_w, grad_output[gOut_offset]);
} }
if (grad_grid != nullptr) {
T* gGrid_ptr_NHW = grad_grid + index * grid_sW; T* gGrid_ptr_NHW = grad_grid + index * grid_sW;
gGrid_ptr_NHW[0] = static_cast<T>(0); gGrid_ptr_NHW[0] = static_cast<T>(0);
gGrid_ptr_NHW[1] = static_cast<T>(0); gGrid_ptr_NHW[1] = static_cast<T>(0);
} }
} }
}
} }
template <typename T> template <typename T>
...@@ -460,11 +464,15 @@ class GridSampleGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -460,11 +464,15 @@ class GridSampleGradOpCUDAKernel : public framework::OpKernel<T> {
math::SetConstant<paddle::platform::CUDADeviceContext, T>()( math::SetConstant<paddle::platform::CUDADeviceContext, T>()(
ctx.template device_context<paddle::platform::CUDADeviceContext>(), ctx.template device_context<paddle::platform::CUDADeviceContext>(),
input_grad, static_cast<T>(0)); input_grad, static_cast<T>(0));
T* grid_grad_data = nullptr;
if (ctx.HasOutput(framework::GradVarName("Grid"))) {
auto* grid_grad = ctx.Output<Tensor>(framework::GradVarName("Grid")); auto* grid_grad = ctx.Output<Tensor>(framework::GradVarName("Grid"));
grid_grad->mutable_data<T>(ctx.GetPlace()); grid_grad_data = grid_grad->mutable_data<T>(ctx.GetPlace());
math::SetConstant<paddle::platform::CUDADeviceContext, T>()( math::SetConstant<paddle::platform::CUDADeviceContext, T>()(
ctx.template device_context<paddle::platform::CUDADeviceContext>(), ctx.template device_context<paddle::platform::CUDADeviceContext>(),
grid_grad, static_cast<T>(0)); grid_grad, static_cast<T>(0));
}
int count = static_cast<int>(n * out_h * out_w); int count = static_cast<int>(n * out_h * out_w);
auto cu_stream = dev_ctx.stream(); auto cu_stream = dev_ctx.stream();
...@@ -472,8 +480,8 @@ class GridSampleGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -472,8 +480,8 @@ class GridSampleGradOpCUDAKernel : public framework::OpKernel<T> {
int grid_size = (count + block - 1) / block; int grid_size = (count + block - 1) / block;
grid_sampler_cuda_backward_kernel<T><<<block, grid_size, 0, cu_stream>>>( grid_sampler_cuda_backward_kernel<T><<<block, grid_size, 0, cu_stream>>>(
count, output_grad->data<T>(), input->data<T>(), grid->data<T>(), n, c, count, output_grad->data<T>(), input->data<T>(), grid->data<T>(), n, c,
out_h, out_w, in_h, in_w, input_grad->data<T>(), grid_grad->data<T>(), out_h, out_w, in_h, in_w, input_grad->data<T>(), grid_grad_data, mode,
mode, padding_mode, align_corners); padding_mode, align_corners);
} }
}; };
......
...@@ -450,6 +450,7 @@ static void gatherBilinearGrad(const platform::CPUDeviceContext& ctx, ...@@ -450,6 +450,7 @@ static void gatherBilinearGrad(const platform::CPUDeviceContext& ctx,
auto output_grad_t = EigenTensor<T, 4>::From(output_grad); auto output_grad_t = EigenTensor<T, 4>::From(output_grad);
if (grid_grad != nullptr) {
Tensor grid_grad_x, grid_grad_y; Tensor grid_grad_x, grid_grad_y;
grid_grad_x.mutable_data<T>({n, out_h, out_w}, ctx.GetPlace()); grid_grad_x.mutable_data<T>({n, out_h, out_w}, ctx.GetPlace());
grid_grad_y.mutable_data<T>({n, out_h, out_w}, ctx.GetPlace()); grid_grad_y.mutable_data<T>({n, out_h, out_w}, ctx.GetPlace());
...@@ -490,6 +491,7 @@ static void gatherBilinearGrad(const platform::CPUDeviceContext& ctx, ...@@ -490,6 +491,7 @@ static void gatherBilinearGrad(const platform::CPUDeviceContext& ctx,
grid_grad_data[2 * i] = grid_grad_x_data[i]; grid_grad_data[2 * i] = grid_grad_x_data[i];
grid_grad_data[2 * i + 1] = grid_grad_y_data[i]; grid_grad_data[2 * i + 1] = grid_grad_y_data[i];
} }
}
} }
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
...@@ -558,11 +560,16 @@ class GridSampleGradOpKernel : public framework::OpKernel<T> { ...@@ -558,11 +560,16 @@ class GridSampleGradOpKernel : public framework::OpKernel<T> {
math::SetConstant<DeviceContext, T>()( math::SetConstant<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), input_grad, ctx.template device_context<DeviceContext>(), input_grad,
static_cast<T>(0)); static_cast<T>(0));
auto* grid_grad = ctx.Output<Tensor>(framework::GradVarName("Grid"));
Tensor* grid_grad = nullptr;
if (ctx.HasOutput(framework::GradVarName("Grid"))) {
grid_grad = ctx.Output<Tensor>(framework::GradVarName("Grid"));
grid_grad->mutable_data<T>({n, out_h, out_w, 2}, ctx.GetPlace()); grid_grad->mutable_data<T>({n, out_h, out_w, 2}, ctx.GetPlace());
math::SetConstant<DeviceContext, T>()( math::SetConstant<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), grid_grad, ctx.template device_context<DeviceContext>(), grid_grad,
static_cast<T>(0)); static_cast<T>(0));
}
Tensor grid_x, grid_y; Tensor grid_x, grid_y;
Tensor grid_x_scale, grid_y_scale; Tensor grid_x_scale, grid_y_scale;
calcGridLocationsWithGrad<T>( calcGridLocationsWithGrad<T>(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册