未验证 提交 cdadc8f0 编写于 作者: W wangchaochaohu 提交者: GitHub

refine temporal_shift_op for performance optimization using gpu kernel config (#28114)

上级 b1eb28d7
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "paddle/fluid/operators/temporal_shift_op.h" #include "paddle/fluid/operators/temporal_shift_op.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/gpu_launch_config.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -112,11 +113,11 @@ class TemporalShiftOpCUDAKernel : public framework::OpKernel<T> { ...@@ -112,11 +113,11 @@ class TemporalShiftOpCUDAKernel : public framework::OpKernel<T> {
T* output_data = output->mutable_data<T>({nt, c, h, w}, ctx.GetPlace()); T* output_data = output->mutable_data<T>({nt, c, h, w}, ctx.GetPlace());
int pixelNum = nt * chw; int pixelNum = nt * chw;
int grid_dim = (pixelNum + 512 - 1) / 512; platform::GpuLaunchConfig config =
grid_dim = grid_dim > 8 ? 8 : grid_dim; platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum);
KeTemporalShiftFw< KeTemporalShiftFw<T><<<config.block_per_grid, config.thread_per_block, 0,
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>( ctx.cuda_device_context().stream()>>>(
input_data, output_data, ntchw, tchw, chw, hw, w, t, c, shift_ratio); input_data, output_data, ntchw, tchw, chw, hw, w, t, c, shift_ratio);
} }
}; };
...@@ -148,11 +149,11 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -148,11 +149,11 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> {
static_cast<T>(0)); static_cast<T>(0));
int pixelNum = nt * chw; int pixelNum = nt * chw;
int grid_dim = (pixelNum + 512 - 1) / 512; platform::GpuLaunchConfig config =
grid_dim = grid_dim > 8 ? 8 : grid_dim; platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum);
KeTemporalShiftBw< KeTemporalShiftBw<T><<<config.block_per_grid, config.thread_per_block, 0,
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>( ctx.cuda_device_context().stream()>>>(
output_grad_data, input_grad_data, ntchw, tchw, chw, hw, w, t, c, output_grad_data, input_grad_data, ntchw, tchw, chw, hw, w, t, c,
shift_ratio); shift_ratio);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册