diff --git a/paddle/operators/math/detail/lstm_gpu_kernel.h b/paddle/operators/math/detail/lstm_gpu_kernel.h index d3e5e381a50df78c7d7533aaaafe6f6ddfde4644..e07655eaac2bc2731a3e52cdd7405828bbf9ac3c 100644 --- a/paddle/operators/math/detail/lstm_gpu_kernel.h +++ b/paddle/operators/math/detail/lstm_gpu_kernel.h @@ -227,7 +227,7 @@ void gpu_lstm_backward(const platform::DeviceContext& context, Op op, grid = dim3(frameBlocks, 1); } else { /* framePerBlock = 32 batchPerBlock = 32 */ - threads = dim3(32, 32); + threads = dim3(32, 16); grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32); } @@ -244,6 +244,11 @@ void gpu_lstm_backward(const platform::DeviceContext& context, Op op, op, value, grad, frameSize, batchSize, active_node, active_gate, active_state); } + + cudaStreamSynchronize(stream); + // TODO(qingqing): Add cuda error check for each kernel. + cudaError_t err = cudaGetLastError(); + PADDLE_ENFORCE_EQ(err, cudaGetErrorString(err)); } } // namespace detail