未验证 提交 b02f2aff 编写于 作者: Z zhaoyuchen2018 提交者: GitHub

Add conditional compile for gru opt (#17368)

* improve gru unit performance.
refine code

test=develop
Signed-off-by: Nzhaoyuchen <zhaoyuchen01@baidu.com>

* Add conditional compile for gru opt

Not enable gru opt if compute ability < 700

test=develop
Signed-off-by: Nzhaoyuchen <zhaoyuchen01@baidu.com>

* refine code.

test=develop
Signed-off-by: Nzhaoyuchen <zhaoyuchen01@baidu.com>
上级 6a53fa95
......@@ -30,25 +30,31 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> {
dim3 threads;
dim3 grid;
if (batch_size == 1) {
constexpr int tiled_size = 16;
int frame_blocks = (frame_size * 2 + tiled_size - 1) / tiled_size;
threads = dim3(tiled_size, 1);
grid = dim3(frame_blocks, 1);
detail::KeFastCollectiveGruGate<T,
tiled_size><<<grid, threads, 0, stream>>>(
value.gate_value, value.prev_out_value, value.gate_weight,
value.reset_output_value, frame_size, active_gate);
frame_blocks = (frame_size + tiled_size - 1) / tiled_size;
grid = dim3(frame_blocks, 1);
detail::KeFastCollectiveGruOut<T,
tiled_size><<<grid, threads, 0, stream>>>(
value.state_weight, value.prev_out_value, value.output_value,
value.gate_value, value.reset_output_value, frame_size, active_node,
origin_mode);
return;
if (context.GetComputeCapability() >= 70) {
constexpr int tiled_size = 16;
int frame_blocks = (frame_size * 2 + tiled_size - 1) / tiled_size;
threads = dim3(tiled_size, 1);
grid = dim3(frame_blocks, 1);
detail::KeFastCollectiveGruGate<
T, tiled_size><<<grid, threads, 0, stream>>>(
value.gate_value, value.prev_out_value, value.gate_weight,
value.reset_output_value, frame_size, active_gate);
frame_blocks = (frame_size + tiled_size - 1) / tiled_size;
grid = dim3(frame_blocks, 1);
detail::KeFastCollectiveGruOut<
T, tiled_size><<<grid, threads, 0, stream>>>(
value.state_weight, value.prev_out_value, value.output_value,
value.gate_value, value.reset_output_value, frame_size, active_node,
origin_mode);
return;
} else {
int frame_per_block = frame_size <= 1024 ? frame_size : 1024;
int frame_blocks = (frame_size + 1024 - 1) / 1024;
threads = dim3(frame_per_block, 1);
grid = dim3(frame_blocks, 1);
}
} else {
threads = dim3(32, 32);
grid = dim3((frame_size + 32 - 1) / 32, (batch_size + 32 - 1) / 32);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册