diff --git a/paddle/fluid/operators/math/detail/gru_gpu_kernel.h b/paddle/fluid/operators/math/detail/gru_gpu_kernel.h index 8d133f5327d28abd57356ab5d874cf57368ca1e2..6b57da1046a05b15b9c3302104d9f4d12c52227f 100644 --- a/paddle/fluid/operators/math/detail/gru_gpu_kernel.h +++ b/paddle/fluid/operators/math/detail/gru_gpu_kernel.h @@ -110,7 +110,8 @@ __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value, T *gate_grad, T *prev_out_value, T *prev_out_grad, T *output_grad, int frame_size, int batch_size, - ActivationType active_node) { + ActivationType active_node, + bool origin_mode) { const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; if (frame_idx >= frame_size) return; int batch_idx = 0; @@ -140,7 +141,7 @@ __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value, op_state_grad(&r_update_gate_value, &r_update_gate_grad, &r_frame_state_value, &r_frame_state_grad, &r_prev_out_value, &r_prev_out_grad, - &r_out_grad, active_node); + &r_out_grad, active_node, origin_mode); gate_grad[frame_idx + frame_size * 0] = r_update_gate_grad; gate_grad[frame_idx + frame_size * 2] = r_frame_state_grad; diff --git a/paddle/fluid/operators/math/detail/gru_kernel.h b/paddle/fluid/operators/math/detail/gru_kernel.h index fa4f94b6b0ad725b0edcb56abe7d2a0e08d0b28e..c464d9cec4b3cb12902b8233c5b07d55175317ce 100644 --- a/paddle/fluid/operators/math/detail/gru_kernel.h +++ b/paddle/fluid/operators/math/detail/gru_kernel.h @@ -132,16 +132,26 @@ class gru_stateGrad { __m256 *grad_frame_state, __m256 *value_prev_out, __m256 *grad_prev_out, __m256 *grad_output, ActivationType act_input, bool origin_mode) { - *grad_update_gate = _mm256_mul_ps(*grad_output, *value_frame_state); - *grad_update_gate = _mm256_sub_ps( - *grad_update_gate, _mm256_mul_ps(*grad_output, *value_prev_out)); - *grad_prev_out = _mm256_add_ps( - _mm256_sub_ps(*grad_prev_out, - _mm256_mul_ps(*grad_output, *value_update_gate)), - *grad_output); - *grad_frame_state = - activation(_mm256_mul_ps(*grad_output, *value_update_gate), - *value_frame_state, act_input); + if (origin_mode) { + *grad_update_gate = _mm256_mul_ps( + *grad_output, _mm256_sub_ps(*value_prev_out, *value_frame_state)); + *grad_prev_out = _mm256_add_ps( + *grad_prev_out, _mm256_mul_ps(*grad_output, *value_update_gate)); + *grad_frame_state = activation( + _mm256_mul_ps(*grad_output, _mm256_sub_ps(_mm256_set1_ps(1.0f), + *value_update_gate)), + *value_frame_state, act_input); + } else { + *grad_update_gate = _mm256_mul_ps( + *grad_output, _mm256_sub_ps(*value_frame_state, *value_prev_out)); + *grad_prev_out = _mm256_add_ps( + *grad_prev_out, + _mm256_mul_ps(*grad_output, _mm256_sub_ps(_mm256_set1_ps(1.0f), + *value_update_gate))); + *grad_frame_state = + activation(_mm256_mul_ps(*grad_output, *value_update_gate), + *value_frame_state, act_input); + } } #endif #endif diff --git a/paddle/fluid/operators/math/gru_compute.cu b/paddle/fluid/operators/math/gru_compute.cu index e2c40b739542941bf9016b6414f359eede845bde..ec7e4d2228c38161bb1f3f97ec21b91db454adb4 100644 --- a/paddle/fluid/operators/math/gru_compute.cu +++ b/paddle/fluid/operators/math/gru_compute.cu @@ -92,7 +92,8 @@ struct GRUUnitGradFunctor { GRUMetaValue value, GRUMetaGrad grad, int frame_size, int batch_size, const detail::ActivationType active_node, - const detail::ActivationType active_gate) { + const detail::ActivationType active_gate, + bool origin_mode) { auto stream = context.stream(); dim3 threads; dim3 grid; @@ -112,14 +113,14 @@ struct GRUUnitGradFunctor { /* is_batch= */ false><<>>( detail::backward::gru_stateGrad(), value.gate_value, grad.gate_grad, value.prev_out_value, grad.prev_out_grad, - grad.output_grad, frame_size, batch_size, active_node); + grad.output_grad, frame_size, batch_size, active_node, origin_mode); } else { detail::KeGruBackwardStateGrad< detail::backward::gru_stateGrad, /* is_batch= */ true><<>>( detail::backward::gru_stateGrad(), value.gate_value, grad.gate_grad, value.prev_out_value, grad.prev_out_grad, - grad.output_grad, frame_size, batch_size, active_node); + grad.output_grad, frame_size, batch_size, active_node, origin_mode); } auto blas = math::GetBlas(context);