未验证 提交 8a2caacd 编写于 作者: Z zhaoyuchen2018 提交者: GitHub

improve gru unit performance. (#16338)

refine code

fuse cublas  calling and kernels into one cuda kernel.

test=develop
Signed-off-by: Nzhaoyuchen <zhaoyuchen01@baidu.com>
上级 ddb24d48
......@@ -101,6 +101,122 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output,
output_value[frame_idx] = r_output;
}
/*
* threads(tile_size, 1)
* grid(frame_blocks, 1)
*/
template <class T, int Tiled_size>
__global__ void KeFastCollectiveGruGate(T *gate_value, T *prev_output_value,
T *gate_weight, T *reset_output,
int frame_size,
ActivationType active_node) {
T xt_0 = 0.0f;
T a0 = 0.0f;
T c0 = 0.0f;
T b0[Tiled_size];
int COL = blockIdx.x * blockDim.x + threadIdx.x;
int Tiled_mask = ((1 << Tiled_size) - 1);
// Tiled matrix multiply using register shift, faster than sm.
if (prev_output_value) {
for (int k = 0; k < (((frame_size - 1) / Tiled_size) + 1); ++k) {
a0 = 0;
if ((threadIdx.x + k * Tiled_size) < frame_size) {
a0 = prev_output_value[threadIdx.x + (k * Tiled_size)];
}
for (int i = 0; i < Tiled_size; i++) {
if (COL < frame_size * 2 && (i + k * Tiled_size) < frame_size) {
b0[i] = gate_weight[(i + k * Tiled_size) * frame_size * 2 + COL];
}
}
for (int i = 0; i < Tiled_size; ++i) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
c0 = c0 + __shfl_sync(Tiled_mask, a0, i, Tiled_size) * b0[i];
#else
c0 = c0 + __shfl(a0, i, Tiled_size) * b0[i];
#endif
}
}
}
__syncthreads();
if (COL < frame_size * 2) {
xt_0 = gate_value[COL];
c0 += xt_0;
c0 = forward::activation(c0, active_node);
gate_value[COL] = c0;
if (frame_size <= COL && COL < frame_size * 2) {
T htp_0 = 0.0;
if (prev_output_value) {
htp_0 = prev_output_value[COL - frame_size];
}
reset_output[COL - frame_size] = c0 * htp_0;
} else if (COL < frame_size) {
gate_value[COL] = c0;
}
}
}
/*
* threads(tile_size, 1)
* grid(frame_blocks, 1)
*/
template <class T, int Tiled_size>
__global__ void KeFastCollectiveGruOut(T *gate_weight, T *prev_out_value,
T *output_value, T *gate_value,
T *reset_value, int frame_size,
ActivationType act_node,
bool origin_mode) {
int COL = blockIdx.x * blockDim.x + threadIdx.x;
T a0 = 0.0f;
T b0[Tiled_size];
T c0 = 0.0f;
int Tiled_mask = ((1 << Tiled_size) - 1);
//- Tiled matrix multiply with register shift
if (prev_out_value) {
for (int k = 0; k < (((frame_size - 1) / Tiled_size) + 1); ++k) {
a0 = 0;
if ((threadIdx.x + k * Tiled_size) < frame_size) {
a0 = reset_value[threadIdx.x + (k * Tiled_size)];
}
for (int i = 0; i < Tiled_size; i++) {
if (COL < frame_size && (i + k * Tiled_size) < frame_size) {
b0[i] = gate_weight[(i + k * Tiled_size) * frame_size + COL];
}
}
for (int i = 0; i < Tiled_size; ++i) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
c0 = c0 + __shfl_sync(Tiled_mask, a0, i, Tiled_size) * b0[i];
#else
c0 = c0 + __shfl(a0, i, Tiled_size) * b0[i];
#endif
}
}
}
__syncthreads();
if (COL < frame_size) {
T xt_0 = gate_value[COL + 2 * frame_size];
T gta_0 = gate_value[COL];
T htp_0 = 0;
if (prev_out_value) htp_0 = prev_out_value[COL];
c0 += xt_0;
c0 = forward::activation(c0, act_node);
gate_value[COL + 2 * frame_size] = c0;
if (origin_mode) {
output_value[COL] = htp_0 * gta_0 + (1 - gta_0) * c0;
} else {
output_value[COL] = c0 * gta_0 + (1 - gta_0) * htp_0;
}
}
}
/*
* threads(frame_per_block, batch_per_block)
* grid(frame_blocks, batch_blocks)
......
......@@ -30,10 +30,25 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> {
dim3 threads;
dim3 grid;
if (batch_size == 1) {
int frame_per_block = frame_size <= 1024 ? frame_size : 1024;
int frame_blocks = (frame_size + 1024 - 1) / 1024;
threads = dim3(frame_per_block, 1);
constexpr int tiled_size = 16;
int frame_blocks = (frame_size * 2 + tiled_size - 1) / tiled_size;
threads = dim3(tiled_size, 1);
grid = dim3(frame_blocks, 1);
detail::KeFastCollectiveGruGate<T,
tiled_size><<<grid, threads, 0, stream>>>(
value.gate_value, value.prev_out_value, value.gate_weight,
value.reset_output_value, frame_size, active_gate);
frame_blocks = (frame_size + tiled_size - 1) / tiled_size;
grid = dim3(frame_blocks, 1);
detail::KeFastCollectiveGruOut<T,
tiled_size><<<grid, threads, 0, stream>>>(
value.state_weight, value.prev_out_value, value.output_value,
value.gate_value, value.reset_output_value, frame_size, active_node,
origin_mode);
return;
} else {
threads = dim3(32, 32);
grid = dim3((frame_size + 32 - 1) / 32, (batch_size + 32 - 1) / 32);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册