未验证 提交 0059404e 编写于 作者: Z zhaoyuchen2018 提交者: GitHub

Fix ce ocr_recognition test fails (#20987)

ocr_recognition fails, so add a path to handle small frame_size.

test=develop
上级 f56967c4
...@@ -105,7 +105,7 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output, ...@@ -105,7 +105,7 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output,
* threads(tile_size, 1) * threads(tile_size, 1)
* grid(frame_blocks, 1) * grid(frame_blocks, 1)
*/ */
template <class T> template <class T, int Tiled_size>
__global__ void KeFastCollectiveGruGate(T *gate_value, T *prev_output_value, __global__ void KeFastCollectiveGruGate(T *gate_value, T *prev_output_value,
T *gate_weight, T *reset_output, T *gate_weight, T *reset_output,
int frame_size, int frame_size,
...@@ -113,9 +113,7 @@ __global__ void KeFastCollectiveGruGate(T *gate_value, T *prev_output_value, ...@@ -113,9 +113,7 @@ __global__ void KeFastCollectiveGruGate(T *gate_value, T *prev_output_value,
T xt_0 = 0.0f; T xt_0 = 0.0f;
T a0 = 0.0f; T a0 = 0.0f;
T c0 = 0.0f; T c0 = 0.0f;
T b0[Tiled_size];
int Tiled_size = blockDim.x;
T b0[16];
int COL = blockIdx.x * blockDim.x + threadIdx.x; int COL = blockIdx.x * blockDim.x + threadIdx.x;
int Tiled_mask = ((1 << Tiled_size) - 1); int Tiled_mask = ((1 << Tiled_size) - 1);
...@@ -165,7 +163,7 @@ __global__ void KeFastCollectiveGruGate(T *gate_value, T *prev_output_value, ...@@ -165,7 +163,7 @@ __global__ void KeFastCollectiveGruGate(T *gate_value, T *prev_output_value,
* threads(tile_size, 1) * threads(tile_size, 1)
* grid(frame_blocks, 1) * grid(frame_blocks, 1)
*/ */
template <class T> template <class T, int Tiled_size>
__global__ void KeFastCollectiveGruOut(T *gate_weight, T *prev_out_value, __global__ void KeFastCollectiveGruOut(T *gate_weight, T *prev_out_value,
T *output_value, T *gate_value, T *output_value, T *gate_value,
T *reset_value, int frame_size, T *reset_value, int frame_size,
...@@ -174,10 +172,9 @@ __global__ void KeFastCollectiveGruOut(T *gate_weight, T *prev_out_value, ...@@ -174,10 +172,9 @@ __global__ void KeFastCollectiveGruOut(T *gate_weight, T *prev_out_value,
int COL = blockIdx.x * blockDim.x + threadIdx.x; int COL = blockIdx.x * blockDim.x + threadIdx.x;
T a0 = 0.0f; T a0 = 0.0f;
T b0[16]; T b0[Tiled_size];
T c0 = 0.0f; T c0 = 0.0f;
int Tiled_size = blockDim.x;
int Tiled_mask = ((1 << Tiled_size) - 1); int Tiled_mask = ((1 << Tiled_size) - 1);
//- Tiled matrix multiply with register shift //- Tiled matrix multiply with register shift
if (prev_out_value) { if (prev_out_value) {
......
...@@ -31,29 +31,41 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> { ...@@ -31,29 +31,41 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> {
dim3 grid; dim3 grid;
if (batch_size == 1) { if (batch_size == 1) {
if (context.GetComputeCapability() >= 70) { if (context.GetComputeCapability() >= 70) {
auto ComputeTiledSize = [](int frame_size) { if (frame_size < 16) {
if (frame_size >= 16) constexpr int tiled_size = 8;
return 16;
else if (frame_size < 16)
return 8;
};
auto tiled_size = ComputeTiledSize(frame_size);
int frame_blocks = (frame_size * 2 + tiled_size - 1) / tiled_size; int frame_blocks = (frame_size * 2 + tiled_size - 1) / tiled_size;
threads = dim3(tiled_size, 1); threads = dim3(tiled_size, 1);
grid = dim3(frame_blocks, 1); grid = dim3(frame_blocks, 1);
detail::KeFastCollectiveGruGate<
detail::KeFastCollectiveGruGate<T><<<grid, threads, 0, stream>>>( T, tiled_size><<<grid, threads, 0, stream>>>(
value.gate_value, value.prev_out_value, value.gate_weight, value.gate_value, value.prev_out_value, value.gate_weight,
value.reset_output_value, frame_size, active_gate); value.reset_output_value, frame_size, active_gate);
frame_blocks = (frame_size + tiled_size - 1) / tiled_size; frame_blocks = (frame_size + tiled_size - 1) / tiled_size;
grid = dim3(frame_blocks, 1); grid = dim3(frame_blocks, 1);
detail::KeFastCollectiveGruOut<T><<<grid, threads, 0, stream>>>( detail::KeFastCollectiveGruOut<
T, tiled_size><<<grid, threads, 0, stream>>>(
value.state_weight, value.prev_out_value, value.output_value, value.state_weight, value.prev_out_value, value.output_value,
value.gate_value, value.reset_output_value, frame_size, active_node, value.gate_value, value.reset_output_value, frame_size,
origin_mode); active_node, origin_mode);
} else {
constexpr int tiled_size = 16;
int frame_blocks = (frame_size * 2 + tiled_size - 1) / tiled_size;
threads = dim3(tiled_size, 1);
grid = dim3(frame_blocks, 1);
detail::KeFastCollectiveGruGate<
T, tiled_size><<<grid, threads, 0, stream>>>(
value.gate_value, value.prev_out_value, value.gate_weight,
value.reset_output_value, frame_size, active_gate);
frame_blocks = (frame_size + tiled_size - 1) / tiled_size;
grid = dim3(frame_blocks, 1);
detail::KeFastCollectiveGruOut<
T, tiled_size><<<grid, threads, 0, stream>>>(
value.state_weight, value.prev_out_value, value.output_value,
value.gate_value, value.reset_output_value, frame_size,
active_node, origin_mode);
}
return; return;
} else { } else {
int frame_per_block = frame_size <= 1024 ? frame_size : 1024; int frame_per_block = frame_size <= 1024 ? frame_size : 1024;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册