From 4c7be265d339dec75d7076c8018b222384ecc436 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Sun, 13 Jan 2019 18:10:05 +0800 Subject: [PATCH] update avx gru grad kernel test=develop --- .../operators/math/detail/gru_gpu_kernel.h | 5 ++-- .../fluid/operators/math/detail/gru_kernel.h | 30 ++++++++++++------- paddle/fluid/operators/math/gru_compute.cu | 7 +++-- 3 files changed, 27 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/operators/math/detail/gru_gpu_kernel.h b/paddle/fluid/operators/math/detail/gru_gpu_kernel.h index 8d133f532..6b57da104 100644 --- a/paddle/fluid/operators/math/detail/gru_gpu_kernel.h +++ b/paddle/fluid/operators/math/detail/gru_gpu_kernel.h @@ -110,7 +110,8 @@ __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value, T *gate_grad, T *prev_out_value, T *prev_out_grad, T *output_grad, int frame_size, int batch_size, - ActivationType active_node) { + ActivationType active_node, + bool origin_mode) { const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; if (frame_idx >= frame_size) return; int batch_idx = 0; @@ -140,7 +141,7 @@ __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value, op_state_grad(&r_update_gate_value, &r_update_gate_grad, &r_frame_state_value, &r_frame_state_grad, &r_prev_out_value, &r_prev_out_grad, - &r_out_grad, active_node); + &r_out_grad, active_node, origin_mode); gate_grad[frame_idx + frame_size * 0] = r_update_gate_grad; gate_grad[frame_idx + frame_size * 2] = r_frame_state_grad; diff --git a/paddle/fluid/operators/math/detail/gru_kernel.h b/paddle/fluid/operators/math/detail/gru_kernel.h index fa4f94b6b..c464d9cec 100644 --- a/paddle/fluid/operators/math/detail/gru_kernel.h +++ b/paddle/fluid/operators/math/detail/gru_kernel.h @@ -132,16 +132,26 @@ class gru_stateGrad { __m256 *grad_frame_state, __m256 *value_prev_out, __m256 *grad_prev_out, __m256 *grad_output, ActivationType act_input, bool origin_mode) { - *grad_update_gate = _mm256_mul_ps(*grad_output, *value_frame_state); - *grad_update_gate = _mm256_sub_ps( - *grad_update_gate, _mm256_mul_ps(*grad_output, *value_prev_out)); - *grad_prev_out = _mm256_add_ps( - _mm256_sub_ps(*grad_prev_out, - _mm256_mul_ps(*grad_output, *value_update_gate)), - *grad_output); - *grad_frame_state = - activation(_mm256_mul_ps(*grad_output, *value_update_gate), - *value_frame_state, act_input); + if (origin_mode) { + *grad_update_gate = _mm256_mul_ps( + *grad_output, _mm256_sub_ps(*value_prev_out, *value_frame_state)); + *grad_prev_out = _mm256_add_ps( + *grad_prev_out, _mm256_mul_ps(*grad_output, *value_update_gate)); + *grad_frame_state = activation( + _mm256_mul_ps(*grad_output, _mm256_sub_ps(_mm256_set1_ps(1.0f), + *value_update_gate)), + *value_frame_state, act_input); + } else { + *grad_update_gate = _mm256_mul_ps( + *grad_output, _mm256_sub_ps(*value_frame_state, *value_prev_out)); + *grad_prev_out = _mm256_add_ps( + *grad_prev_out, + _mm256_mul_ps(*grad_output, _mm256_sub_ps(_mm256_set1_ps(1.0f), + *value_update_gate))); + *grad_frame_state = + activation(_mm256_mul_ps(*grad_output, *value_update_gate), + *value_frame_state, act_input); + } } #endif #endif diff --git a/paddle/fluid/operators/math/gru_compute.cu b/paddle/fluid/operators/math/gru_compute.cu index e2c40b739..ec7e4d222 100644 --- a/paddle/fluid/operators/math/gru_compute.cu +++ b/paddle/fluid/operators/math/gru_compute.cu @@ -92,7 +92,8 @@ struct GRUUnitGradFunctor { GRUMetaValue value, GRUMetaGrad grad, int frame_size, int batch_size, const detail::ActivationType active_node, - const detail::ActivationType active_gate) { + const detail::ActivationType active_gate, + bool origin_mode) { auto stream = context.stream(); dim3 threads; dim3 grid; @@ -112,14 +113,14 @@ struct GRUUnitGradFunctor { /* is_batch= */ false><<>>( detail::backward::gru_stateGrad(), value.gate_value, grad.gate_grad, value.prev_out_value, grad.prev_out_grad, - grad.output_grad, frame_size, batch_size, active_node); + grad.output_grad, frame_size, batch_size, active_node, origin_mode); } else { detail::KeGruBackwardStateGrad< detail::backward::gru_stateGrad, /* is_batch= */ true><<>>( detail::backward::gru_stateGrad(), value.gate_value, grad.gate_grad, value.prev_out_value, grad.prev_out_grad, - grad.output_grad, frame_size, batch_size, active_node); + grad.output_grad, frame_size, batch_size, active_node, origin_mode); } auto blas = math::GetBlas(context); -- GitLab