未验证 提交 9cc5603d 编写于 作者: W whs 提交者: GitHub

Make grid support stopping graients. (#27630)

上级 074a71bd
...@@ -176,8 +176,6 @@ class GridSampleOpGrad : public framework::OperatorWithKernel { ...@@ -176,8 +176,6 @@ class GridSampleOpGrad : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output", OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
framework::GradVarName("X"), "grid_sampler"); framework::GradVarName("X"), "grid_sampler");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Grid")), "Output",
framework::GradVarName("Grid"), "grid_sampler");
auto input_dims = ctx->GetInputDim("X"); auto input_dims = ctx->GetInputDim("X");
auto grid_dims = ctx->GetInputDim("Grid"); auto grid_dims = ctx->GetInputDim("Grid");
if (ctx->HasOutput(framework::GradVarName("X"))) { if (ctx->HasOutput(framework::GradVarName("X"))) {
......
...@@ -397,9 +397,11 @@ __global__ void grid_sampler_cuda_backward_kernel( ...@@ -397,9 +397,11 @@ __global__ void grid_sampler_cuda_backward_kernel(
} }
} }
T* gGrid_ptr_NHW = grad_grid + index * grid_sW; if (grad_grid != nullptr) {
gGrid_ptr_NHW[0] = gix_mult * gix; T* gGrid_ptr_NHW = grad_grid + index * grid_sW;
gGrid_ptr_NHW[1] = giy_mult * giy; gGrid_ptr_NHW[0] = gix_mult * gix;
gGrid_ptr_NHW[1] = giy_mult * giy;
}
} else if (mode == Mode::nearest) { } else if (mode == Mode::nearest) {
int ix_nearest = static_cast<int>(::round(ix)); int ix_nearest = static_cast<int>(::round(ix));
int iy_nearest = static_cast<int>(::round(iy)); int iy_nearest = static_cast<int>(::round(iy));
...@@ -412,9 +414,11 @@ __global__ void grid_sampler_cuda_backward_kernel( ...@@ -412,9 +414,11 @@ __global__ void grid_sampler_cuda_backward_kernel(
in_w, grad_output[gOut_offset]); in_w, grad_output[gOut_offset]);
} }
T* gGrid_ptr_NHW = grad_grid + index * grid_sW; if (grad_grid != nullptr) {
gGrid_ptr_NHW[0] = static_cast<T>(0); T* gGrid_ptr_NHW = grad_grid + index * grid_sW;
gGrid_ptr_NHW[1] = static_cast<T>(0); gGrid_ptr_NHW[0] = static_cast<T>(0);
gGrid_ptr_NHW[1] = static_cast<T>(0);
}
} }
} }
} }
...@@ -460,11 +464,15 @@ class GridSampleGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -460,11 +464,15 @@ class GridSampleGradOpCUDAKernel : public framework::OpKernel<T> {
math::SetConstant<paddle::platform::CUDADeviceContext, T>()( math::SetConstant<paddle::platform::CUDADeviceContext, T>()(
ctx.template device_context<paddle::platform::CUDADeviceContext>(), ctx.template device_context<paddle::platform::CUDADeviceContext>(),
input_grad, static_cast<T>(0)); input_grad, static_cast<T>(0));
auto* grid_grad = ctx.Output<Tensor>(framework::GradVarName("Grid"));
grid_grad->mutable_data<T>(ctx.GetPlace()); T* grid_grad_data = nullptr;
math::SetConstant<paddle::platform::CUDADeviceContext, T>()( if (ctx.HasOutput(framework::GradVarName("Grid"))) {
ctx.template device_context<paddle::platform::CUDADeviceContext>(), auto* grid_grad = ctx.Output<Tensor>(framework::GradVarName("Grid"));
grid_grad, static_cast<T>(0)); grid_grad_data = grid_grad->mutable_data<T>(ctx.GetPlace());
math::SetConstant<paddle::platform::CUDADeviceContext, T>()(
ctx.template device_context<paddle::platform::CUDADeviceContext>(),
grid_grad, static_cast<T>(0));
}
int count = static_cast<int>(n * out_h * out_w); int count = static_cast<int>(n * out_h * out_w);
auto cu_stream = dev_ctx.stream(); auto cu_stream = dev_ctx.stream();
...@@ -472,8 +480,8 @@ class GridSampleGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -472,8 +480,8 @@ class GridSampleGradOpCUDAKernel : public framework::OpKernel<T> {
int grid_size = (count + block - 1) / block; int grid_size = (count + block - 1) / block;
grid_sampler_cuda_backward_kernel<T><<<block, grid_size, 0, cu_stream>>>( grid_sampler_cuda_backward_kernel<T><<<block, grid_size, 0, cu_stream>>>(
count, output_grad->data<T>(), input->data<T>(), grid->data<T>(), n, c, count, output_grad->data<T>(), input->data<T>(), grid->data<T>(), n, c,
out_h, out_w, in_h, in_w, input_grad->data<T>(), grid_grad->data<T>(), out_h, out_w, in_h, in_w, input_grad->data<T>(), grid_grad_data, mode,
mode, padding_mode, align_corners); padding_mode, align_corners);
} }
}; };
......
...@@ -450,45 +450,47 @@ static void gatherBilinearGrad(const platform::CPUDeviceContext& ctx, ...@@ -450,45 +450,47 @@ static void gatherBilinearGrad(const platform::CPUDeviceContext& ctx,
auto output_grad_t = EigenTensor<T, 4>::From(output_grad); auto output_grad_t = EigenTensor<T, 4>::From(output_grad);
Tensor grid_grad_x, grid_grad_y; if (grid_grad != nullptr) {
grid_grad_x.mutable_data<T>({n, out_h, out_w}, ctx.GetPlace()); Tensor grid_grad_x, grid_grad_y;
grid_grad_y.mutable_data<T>({n, out_h, out_w}, ctx.GetPlace()); grid_grad_x.mutable_data<T>({n, out_h, out_w}, ctx.GetPlace());
auto grid_grad_x_t = grid_grad_y.mutable_data<T>({n, out_h, out_w}, ctx.GetPlace());
EigenTensor<T, 3>::From(grid_grad_x).setConstant(static_cast<T>(0.0)); auto grid_grad_x_t =
auto grid_grad_y_t = EigenTensor<T, 3>::From(grid_grad_x).setConstant(static_cast<T>(0.0));
EigenTensor<T, 3>::From(grid_grad_y).setConstant(static_cast<T>(0.0)); auto grid_grad_y_t =
for (int i = 0; i < n; i++) { EigenTensor<T, 3>::From(grid_grad_y).setConstant(static_cast<T>(0.0));
for (int j = 0; j < c; j++) { for (int i = 0; i < n; i++) {
for (int k = 0; k < out_h; k++) { for (int j = 0; j < c; j++) {
for (int l = 0; l < out_w; l++) { for (int k = 0; k < out_h; k++) {
grid_grad_x_t(i, k, l) += for (int l = 0; l < out_w; l++) {
((v_en_t(i, j, k, l) - v_wn_t(i, j, k, l)) * d_s_t(i, k, l) + grid_grad_x_t(i, k, l) +=
(v_es_t(i, j, k, l) - v_ws_t(i, j, k, l)) * d_n_t(i, k, l)) * ((v_en_t(i, j, k, l) - v_wn_t(i, j, k, l)) * d_s_t(i, k, l) +
output_grad_t(i, j, k, l); (v_es_t(i, j, k, l) - v_ws_t(i, j, k, l)) * d_n_t(i, k, l)) *
grid_grad_y_t(i, k, l) += output_grad_t(i, j, k, l);
((v_ws_t(i, j, k, l) - v_wn_t(i, j, k, l)) * d_e_t(i, k, l) + grid_grad_y_t(i, k, l) +=
(v_es_t(i, j, k, l) - v_en_t(i, j, k, l)) * d_w_t(i, k, l)) * ((v_ws_t(i, j, k, l) - v_wn_t(i, j, k, l)) * d_e_t(i, k, l) +
output_grad_t(i, j, k, l); (v_es_t(i, j, k, l) - v_en_t(i, j, k, l)) * d_w_t(i, k, l)) *
output_grad_t(i, j, k, l);
}
} }
} }
} }
}
// const T x_max = static_cast<T>(in_w - 1);
// const T y_max = static_cast<T>(in_h - 1);
auto grid_x_scale_t = EigenTensor<T, 3>::From(*grid_x_scale);
auto grid_y_scale_t = EigenTensor<T, 3>::From(*grid_y_scale);
grid_grad_x_t = grid_grad_x_t * grid_x_scale_t;
grid_grad_y_t = grid_grad_y_t * grid_y_scale_t;
// gather grid_grad [x, y] in 3rd Dim // const T x_max = static_cast<T>(in_w - 1);
T* grid_grad_data = grid_grad->data<T>(); // const T y_max = static_cast<T>(in_h - 1);
T* grid_grad_x_data = grid_grad_x.data<T>();
T* grid_grad_y_data = grid_grad_y.data<T>(); auto grid_x_scale_t = EigenTensor<T, 3>::From(*grid_x_scale);
for (int i = 0; i < n * out_h * out_w; i++) { auto grid_y_scale_t = EigenTensor<T, 3>::From(*grid_y_scale);
grid_grad_data[2 * i] = grid_grad_x_data[i]; grid_grad_x_t = grid_grad_x_t * grid_x_scale_t;
grid_grad_data[2 * i + 1] = grid_grad_y_data[i]; grid_grad_y_t = grid_grad_y_t * grid_y_scale_t;
// gather grid_grad [x, y] in 3rd Dim
T* grid_grad_data = grid_grad->data<T>();
T* grid_grad_x_data = grid_grad_x.data<T>();
T* grid_grad_y_data = grid_grad_y.data<T>();
for (int i = 0; i < n * out_h * out_w; i++) {
grid_grad_data[2 * i] = grid_grad_x_data[i];
grid_grad_data[2 * i + 1] = grid_grad_y_data[i];
}
} }
} }
...@@ -558,11 +560,16 @@ class GridSampleGradOpKernel : public framework::OpKernel<T> { ...@@ -558,11 +560,16 @@ class GridSampleGradOpKernel : public framework::OpKernel<T> {
math::SetConstant<DeviceContext, T>()( math::SetConstant<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), input_grad, ctx.template device_context<DeviceContext>(), input_grad,
static_cast<T>(0)); static_cast<T>(0));
auto* grid_grad = ctx.Output<Tensor>(framework::GradVarName("Grid"));
grid_grad->mutable_data<T>({n, out_h, out_w, 2}, ctx.GetPlace()); Tensor* grid_grad = nullptr;
math::SetConstant<DeviceContext, T>()( if (ctx.HasOutput(framework::GradVarName("Grid"))) {
ctx.template device_context<DeviceContext>(), grid_grad, grid_grad = ctx.Output<Tensor>(framework::GradVarName("Grid"));
static_cast<T>(0)); grid_grad->mutable_data<T>({n, out_h, out_w, 2}, ctx.GetPlace());
math::SetConstant<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), grid_grad,
static_cast<T>(0));
}
Tensor grid_x, grid_y; Tensor grid_x, grid_y;
Tensor grid_x_scale, grid_y_scale; Tensor grid_x_scale, grid_y_scale;
calcGridLocationsWithGrad<T>( calcGridLocationsWithGrad<T>(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册