diff --git a/paddle/fluid/operators/math/detail/gru_cpu_kernel.h b/paddle/fluid/operators/math/detail/gru_cpu_kernel.h index 6e74e124fc2b2bef4e5128e02bcc2beb27b7db23..f3f5cad0cb454e79be0a5567760ea6352d8c1fa4 100644 --- a/paddle/fluid/operators/math/detail/gru_cpu_kernel.h +++ b/paddle/fluid/operators/math/detail/gru_cpu_kernel.h @@ -256,7 +256,8 @@ void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value, T *gate_grad, T *prev_out_value, T *prev_out_grad, T *output_grad, int frame_size, - ActivationType active_node) { + ActivationType active_node, + bool origin_mode) { T r_update_gate_value; T r_update_gate_grad; T r_frame_state_value; @@ -282,7 +283,7 @@ void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value, op_state_grad(&r_update_gate_value, &r_update_gate_grad, &r_frame_state_value, &r_frame_state_grad, &r_prev_out_value, - &r_prev_out_grad, &r_out_grad, active_node); + &r_prev_out_grad, &r_out_grad, active_node, origin_mode); update_gate_grad[i] = r_update_gate_grad; frame_state_grad[i] = r_frame_state_grad; @@ -297,7 +298,8 @@ void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value, T *gate_grad, T *prev_out_value, T *prev_out_grad, T *reset_output_grad, int frame_size, - ActivationType active_gate) { + ActivationType active_gate, + bool origin_mode) { T r_update_gate_value; T r_update_gate_grad; T r_reset_gate_value; @@ -327,7 +329,8 @@ void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value, op_reset_grad(&r_update_gate_value, &r_update_gate_grad, &r_reset_gate_value, &r_reset_gate_grad, &r_prev_out_value, - &r_prev_out_grad, &r_reset_output_grad, active_gate); + &r_prev_out_grad, &r_reset_output_grad, active_gate, + origin_mode); update_gate_grad[i] = r_update_gate_grad; reset_gate_grad[i] = r_reset_gate_grad; @@ -341,8 +344,8 @@ template void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value, T *gate_grad, T *prev_out_value, T *prev_out_grad, T *output_grad, - int frame_size, - ActivationType active_node) { + int frame_size, ActivationType active_node, + bool origin_mode) { #ifdef __AVX__ __m256 r_update_gate_value; __m256 r_update_gate_grad; @@ -371,7 +374,7 @@ void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value, op_state_grad(&r_update_gate_value, &r_update_gate_grad, &r_frame_state_value, &r_frame_state_grad, &r_prev_out_value, - &r_prev_out_grad, &r_out_grad, active_node); + &r_prev_out_grad, &r_out_grad, active_node, origin_mode); update_gate_grad[i] = r_update_gate_grad; frame_state_grad[i] = r_frame_state_grad; @@ -386,8 +389,8 @@ template void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value, T *gate_grad, T *prev_out_value, T *prev_out_grad, T *reset_output_grad, - int frame_size, - ActivationType active_gate) { + int frame_size, ActivationType active_gate, + bool origin_mode) { #ifdef __AVX__ __m256 r_update_gate_value; __m256 r_update_gate_grad; @@ -419,7 +422,8 @@ void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value, op_reset_grad(&r_update_gate_value, &r_update_gate_grad, &r_reset_gate_value, &r_reset_gate_grad, &r_prev_out_value, - &r_prev_out_grad, &r_reset_output_grad, active_gate); + &r_prev_out_grad, &r_reset_output_grad, active_gate, + origin_mode); update_gate_grad[i] = r_update_gate_grad; reset_gate_grad[i] = r_reset_gate_grad; @@ -434,16 +438,18 @@ template inline void backward_state_grad(OpStateGrad op_state_grad, GRUMetaValue value, GRUMetaGrad grad, int frame_size, int batch_size, - ActivationType active_node) { + ActivationType active_node, bool origin_mode) { for (int b = 0; b < batch_size; b++) { if (OpStateGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) { - hl_avx_gru_backward_state_grad( - op_state_grad, value.gate_value, grad.gate_grad, value.prev_out_value, - grad.prev_out_grad, grad.output_grad, frame_size, active_node); + hl_avx_gru_backward_state_grad(op_state_grad, value.gate_value, + grad.gate_grad, value.prev_out_value, + grad.prev_out_grad, grad.output_grad, + frame_size, active_node, origin_mode); } else { - hl_naive_gru_backward_state_grad( - op_state_grad, value.gate_value, grad.gate_grad, value.prev_out_value, - grad.prev_out_grad, grad.output_grad, frame_size, active_node); + hl_naive_gru_backward_state_grad(op_state_grad, value.gate_value, + grad.gate_grad, value.prev_out_value, + grad.prev_out_grad, grad.output_grad, + frame_size, active_node, origin_mode); } value.gate_value += frame_size * 3; @@ -463,16 +469,18 @@ template inline void backward_reset_grad(OpResetGrad op_reset_grad, GRUMetaValue value, GRUMetaGrad grad, int frame_size, int batch_size, - ActivationType active_gate) { + ActivationType active_gate, bool origin_mode) { for (int b = 0; b < batch_size; b++) { if (OpResetGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) { - hl_avx_gru_backward_reset_grad( - op_reset_grad, value.gate_value, grad.gate_grad, value.prev_out_value, - grad.prev_out_grad, grad.reset_output_grad, frame_size, active_gate); + hl_avx_gru_backward_reset_grad(op_reset_grad, value.gate_value, + grad.gate_grad, value.prev_out_value, + grad.prev_out_grad, grad.reset_output_grad, + frame_size, active_gate, origin_mode); } else { hl_naive_gru_backward_reset_grad( op_reset_grad, value.gate_value, grad.gate_grad, value.prev_out_value, - grad.prev_out_grad, grad.reset_output_grad, frame_size, active_gate); + grad.prev_out_grad, grad.reset_output_grad, frame_size, active_gate, + origin_mode); } value.gate_value += frame_size * 3; diff --git a/paddle/fluid/operators/math/detail/gru_kernel.h b/paddle/fluid/operators/math/detail/gru_kernel.h index d978bd95c87446524748c3d45b2726a27821d04b..fa4f94b6b0ad725b0edcb56abe7d2a0e08d0b28e 100644 --- a/paddle/fluid/operators/math/detail/gru_kernel.h +++ b/paddle/fluid/operators/math/detail/gru_kernel.h @@ -103,13 +103,23 @@ class gru_stateGrad { HOSTDEVICE void operator()(T *value_update_gate, T *grad_update_gate, T *value_frame_state, T *grad_frame_state, T *value_prev_out, T *grad_prev_out, - T *grad_output, ActivationType act_input) { - *grad_update_gate = (*grad_output * (*value_frame_state)); - *grad_update_gate -= (*grad_output * (*value_prev_out)); - *grad_prev_out -= (*grad_output * (*value_update_gate)); - *grad_prev_out += *grad_output; - *grad_frame_state = activation(*grad_output * (*value_update_gate), - *value_frame_state, act_input); + T *grad_output, ActivationType act_input, + bool origin_mode) { + if (origin_mode) { + *grad_update_gate = + (*grad_output) * ((*value_prev_out) - (*value_frame_state)); + *grad_prev_out += (*grad_output * (*value_update_gate)); + *grad_frame_state = activation( + *grad_output * (static_cast(1.0) - (*value_update_gate)), + *value_frame_state, act_input); + } else { + *grad_update_gate = + (*grad_output) * ((*value_frame_state) - (*value_prev_out)); + *grad_prev_out += + (*grad_output * (static_cast(1.0) - *value_update_gate)); + *grad_frame_state = activation(*grad_output * (*value_update_gate), + *value_frame_state, act_input); + } } #ifndef __NVCC__ #ifndef __AVX__ @@ -121,7 +131,7 @@ class gru_stateGrad { __m256 *value_frame_state, __m256 *grad_frame_state, __m256 *value_prev_out, __m256 *grad_prev_out, __m256 *grad_output, - ActivationType act_input) { + ActivationType act_input, bool origin_mode) { *grad_update_gate = _mm256_mul_ps(*grad_output, *value_frame_state); *grad_update_gate = _mm256_sub_ps( *grad_update_gate, _mm256_mul_ps(*grad_output, *value_prev_out)); @@ -143,7 +153,8 @@ class gru_resetGrad { HOSTDEVICE void operator()(T *value_update_gate, T *grad_update_gate, T *value_reset_gate, T *grad_reset_gate, T *value_prev_out, T *grad_prev_out, - T *grad_reset_output, ActivationType act_gate) { + T *grad_reset_output, ActivationType act_gate, + bool origin_mode) { *grad_reset_gate = (*grad_reset_output * (*value_prev_out)); *grad_prev_out += (*grad_reset_output * (*value_reset_gate)); *grad_update_gate = @@ -160,7 +171,7 @@ class gru_resetGrad { __m256 *grad_update_gate, __m256 *value_reset_gate, __m256 *grad_reset_gate, __m256 *value_prev_out, __m256 *grad_prev_out, __m256 *grad_reset_output, - ActivationType act_gate) { + ActivationType act_gate, bool origin_mode) { *grad_reset_gate = _mm256_mul_ps(*grad_reset_output, *value_prev_out); *grad_prev_out = _mm256_add_ps( *grad_prev_out, _mm256_mul_ps(*grad_reset_output, *value_reset_gate)); diff --git a/paddle/fluid/operators/math/gru_compute.cc b/paddle/fluid/operators/math/gru_compute.cc index 295b75356c060332cfb0c561b4170815b47f61b6..b875f7d4f4bbe7037309ad01e52629a2da383e27 100644 --- a/paddle/fluid/operators/math/gru_compute.cc +++ b/paddle/fluid/operators/math/gru_compute.cc @@ -60,7 +60,8 @@ struct GRUUnitGradFunctor { bool origin_mode) { #ifndef __NVCC__ detail::backward_state_grad(detail::backward::gru_stateGrad(), value, - grad, frame_size, batch_size, active_node); + grad, frame_size, batch_size, active_node, + origin_mode); auto blas = math::GetBlas(context); if (value.prev_out_value && grad.prev_out_grad) { blas.GEMM(false, true, batch_size, frame_size, frame_size, 1, @@ -77,7 +78,8 @@ struct GRUUnitGradFunctor { } detail::backward_reset_grad(detail::backward::gru_resetGrad(), value, - grad, frame_size, batch_size, active_gate); + grad, frame_size, batch_size, active_gate, + origin_mode); if (grad.prev_out_grad && value.prev_out_value) { blas.GEMM(false, true, batch_size, frame_size, frame_size * 2, 1, grad.gate_grad, frame_size * 3, value.gate_weight,