未验证 提交 2c66775b 编写于 作者: L Lijunhui 提交者: GitHub

Grid_sampler optimization (#39751)

* init grid_sampler with mode=bilinear

* solve error

* rm fill constant

* rm head

* change block size

* change block size

* optimize

* apply existing config
上级 f335d9e1
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/operators/grid_sampler_op.h" #include "paddle/fluid/operators/grid_sampler_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
namespace paddle { namespace paddle {
...@@ -292,15 +293,12 @@ class GridSampleOpCUDAKernel : public framework::OpKernel<T> { ...@@ -292,15 +293,12 @@ class GridSampleOpCUDAKernel : public framework::OpKernel<T> {
auto* output_data = output->mutable_data<T>(ctx.GetPlace()); auto* output_data = output->mutable_data<T>(ctx.GetPlace());
VLOG(3) << "out dims: " << output->dims()[0] << "; " << output->dims()[1] VLOG(3) << "out dims: " << output->dims()[0] << "; " << output->dims()[1]
<< "; " << output->dims()[2] << "; " << output->dims()[3]; << "; " << output->dims()[2] << "; " << output->dims()[3];
phi::funcs::SetConstant<paddle::platform::CUDADeviceContext, T>()(
dev_ctx, output, static_cast<T>(0));
int count = static_cast<int>(n * out_h * out_w); int count = static_cast<int>(n * out_h * out_w);
auto cu_stream = dev_ctx.stream(); auto cu_stream = dev_ctx.stream();
int block_size = 512; platform::GpuLaunchConfig config =
int grid_size = (count + block_size - 1) / block_size; platform::GetGpuLaunchConfig1D(dev_ctx, count);
VLOG(3) << "cuda launch - grid dims: " << grid_size << "; block dims" grid_sample_cuda_kernel<
<< block_size; T><<<config.block_per_grid, config.thread_per_block, 0, cu_stream>>>(
grid_sample_cuda_kernel<T><<<grid_size, block_size, 0, cu_stream>>>(
count, n, c, out_h, out_w, in_h, in_w, input->data<T>(), count, n, c, out_h, out_w, in_h, in_w, input->data<T>(),
grid->data<T>(), output_data, mode, padding_mode, align_corners); grid->data<T>(), output_data, mode, padding_mode, align_corners);
} }
...@@ -467,19 +465,14 @@ class GridSampleGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -467,19 +465,14 @@ class GridSampleGradOpCUDAKernel : public framework::OpKernel<T> {
if (ctx.HasOutput(framework::GradVarName("Grid"))) { if (ctx.HasOutput(framework::GradVarName("Grid"))) {
auto* grid_grad = ctx.Output<Tensor>(framework::GradVarName("Grid")); auto* grid_grad = ctx.Output<Tensor>(framework::GradVarName("Grid"));
grid_grad_data = grid_grad->mutable_data<T>(ctx.GetPlace()); grid_grad_data = grid_grad->mutable_data<T>(ctx.GetPlace());
phi::funcs::SetConstant<paddle::platform::CUDADeviceContext, T>()(
ctx.template device_context<paddle::platform::CUDADeviceContext>(),
grid_grad, static_cast<T>(0));
} }
int count = static_cast<int>(n * out_h * out_w); int count = static_cast<int>(n * out_h * out_w);
auto cu_stream = dev_ctx.stream(); auto cu_stream = dev_ctx.stream();
int block_size = 512; platform::GpuLaunchConfig config =
int grid_size = (count + block_size - 1) / block_size; platform::GetGpuLaunchConfig1D(dev_ctx, count);
VLOG(3) << "cuda launch grad kernel - grid dims: " << grid_size
<< "; block dims" << block_size << "; count: " << count;
grid_sampler_cuda_backward_kernel< grid_sampler_cuda_backward_kernel<
T><<<grid_size, block_size, 0, cu_stream>>>( T><<<config.block_per_grid, config.thread_per_block, 0, cu_stream>>>(
count, output_grad->data<T>(), input->data<T>(), grid->data<T>(), n, c, count, output_grad->data<T>(), input->data<T>(), grid->data<T>(), n, c,
out_h, out_w, in_h, in_w, input_grad->data<T>(), grid_grad_data, mode, out_h, out_w, in_h, in_w, input_grad->data<T>(), grid_grad_data, mode,
padding_mode, align_corners); padding_mode, align_corners);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册