未验证 提交 873b32de 编写于 作者: Z zhaoyuchen2018 提交者: GitHub

[cherry-pick] Fix gru as small frame_size has error. (#20922) (#21440)

seems shuffle_sync cannot handle small size

test=develop
Signed-off-by: Nzhaoyuchen <zhaoyuchen01@baidu.com>
上级 0473cdb8
...@@ -31,23 +31,41 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> { ...@@ -31,23 +31,41 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> {
dim3 grid; dim3 grid;
if (batch_size == 1) { if (batch_size == 1) {
if (context.GetComputeCapability() >= 70) { if (context.GetComputeCapability() >= 70) {
constexpr int tiled_size = 16; if (frame_size < 16) {
int frame_blocks = (frame_size * 2 + tiled_size - 1) / tiled_size; constexpr int tiled_size = 8;
threads = dim3(tiled_size, 1); int frame_blocks = (frame_size * 2 + tiled_size - 1) / tiled_size;
grid = dim3(frame_blocks, 1); threads = dim3(tiled_size, 1);
detail::KeFastCollectiveGruGate< grid = dim3(frame_blocks, 1);
T, tiled_size><<<grid, threads, 0, stream>>>( detail::KeFastCollectiveGruGate<
value.gate_value, value.prev_out_value, value.gate_weight, T, tiled_size><<<grid, threads, 0, stream>>>(
value.reset_output_value, frame_size, active_gate); value.gate_value, value.prev_out_value, value.gate_weight,
value.reset_output_value, frame_size, active_gate);
frame_blocks = (frame_size + tiled_size - 1) / tiled_size;
grid = dim3(frame_blocks, 1); frame_blocks = (frame_size + tiled_size - 1) / tiled_size;
detail::KeFastCollectiveGruOut< grid = dim3(frame_blocks, 1);
T, tiled_size><<<grid, threads, 0, stream>>>( detail::KeFastCollectiveGruOut<
value.state_weight, value.prev_out_value, value.output_value, T, tiled_size><<<grid, threads, 0, stream>>>(
value.gate_value, value.reset_output_value, frame_size, active_node, value.state_weight, value.prev_out_value, value.output_value,
origin_mode); value.gate_value, value.reset_output_value, frame_size,
active_node, origin_mode);
} else {
constexpr int tiled_size = 16;
int frame_blocks = (frame_size * 2 + tiled_size - 1) / tiled_size;
threads = dim3(tiled_size, 1);
grid = dim3(frame_blocks, 1);
detail::KeFastCollectiveGruGate<
T, tiled_size><<<grid, threads, 0, stream>>>(
value.gate_value, value.prev_out_value, value.gate_weight,
value.reset_output_value, frame_size, active_gate);
frame_blocks = (frame_size + tiled_size - 1) / tiled_size;
grid = dim3(frame_blocks, 1);
detail::KeFastCollectiveGruOut<
T, tiled_size><<<grid, threads, 0, stream>>>(
value.state_weight, value.prev_out_value, value.output_value,
value.gate_value, value.reset_output_value, frame_size,
active_node, origin_mode);
}
return; return;
} else { } else {
int frame_per_block = frame_size <= 1024 ? frame_size : 1024; int frame_per_block = frame_size <= 1024 ? frame_size : 1024;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册