未验证 提交 cdadc8f0 编写于 作者: W wangchaochaohu 提交者: GitHub

refine temporal_shift_op for performance optimization using gpu kernel config (#28114)

上级 b1eb28d7
......@@ -11,6 +11,7 @@
#include "paddle/fluid/operators/temporal_shift_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/gpu_launch_config.h"
namespace paddle {
namespace operators {
......@@ -112,11 +113,11 @@ class TemporalShiftOpCUDAKernel : public framework::OpKernel<T> {
T* output_data = output->mutable_data<T>({nt, c, h, w}, ctx.GetPlace());
int pixelNum = nt * chw;
int grid_dim = (pixelNum + 512 - 1) / 512;
grid_dim = grid_dim > 8 ? 8 : grid_dim;
platform::GpuLaunchConfig config =
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum);
KeTemporalShiftFw<
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
KeTemporalShiftFw<T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_data, output_data, ntchw, tchw, chw, hw, w, t, c, shift_ratio);
}
};
......@@ -148,11 +149,11 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> {
static_cast<T>(0));
int pixelNum = nt * chw;
int grid_dim = (pixelNum + 512 - 1) / 512;
grid_dim = grid_dim > 8 ? 8 : grid_dim;
platform::GpuLaunchConfig config =
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum);
KeTemporalShiftBw<
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
KeTemporalShiftBw<T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
output_grad_data, input_grad_data, ntchw, tchw, chw, hw, w, t, c,
shift_ratio);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册