diff --git a/paddle/fluid/operators/gru_op.cc b/paddle/fluid/operators/gru_op.cc index 087f903a8bba9a4bfcd7eaabd7098555442a904e..752d706cbfab8eb3027fe9610c25b7400ecfed1d 100644 --- a/paddle/fluid/operators/gru_op.cc +++ b/paddle/fluid/operators/gru_op.cc @@ -137,6 +137,10 @@ class GRUOpMaker : public framework::OpProtoAndCheckerMaker { "(bool, defalut: False) " "whether to compute reversed GRU.") .SetDefault(false); + AddAttr("origin_mode", + "bool" + "use origin mode in article https://arxiv.org/abs/1412.3555") + .SetDefault(false); AddComment(R"DOC( GRU Operator implements part calculations of the complete GRU as following: @@ -221,6 +225,7 @@ class GRUCPUKernel : public framework::OpKernel { public: void BatchCompute(const framework::ExecutionContext& context) const { using DeviceContext = paddle::platform::CPUDeviceContext; + bool origin_mode = context.Attr("origin_mode"); auto* input = context.Input("Input"); auto* h0 = context.Input("H0"); auto* weight = context.Input("Weight"); @@ -327,7 +332,7 @@ class GRUCPUKernel : public framework::OpKernel { math::detail::forward_final_output( math::detail::forward::gru_finalOutput(), gru_value, frame_size, - cur_batch_size, active_node); + cur_batch_size, active_node, origin_mode); gru_value.prev_out_value = gru_value.output_value; } @@ -351,7 +356,7 @@ class GRUCPUKernel : public framework::OpKernel { math::GRUUnitFunctor::compute( dev_ctx, gru_value, frame_size, cur_batch_size, active_node, - active_gate); + active_gate, origin_mode); gru_value.prev_out_value = gru_value.output_value; } diff --git a/paddle/fluid/operators/gru_op.cu.cc b/paddle/fluid/operators/gru_op.cu.cc index 55721c283dd18c2f9642563a9ce1eabfce16fd7b..ba918b3def22e3c60c4155f77ecbaad85d520928 100644 --- a/paddle/fluid/operators/gru_op.cu.cc +++ b/paddle/fluid/operators/gru_op.cu.cc @@ -21,6 +21,7 @@ template class GRUKernel : public framework::OpKernel { public: void BatchCompute(const framework::ExecutionContext& context) const { + bool origin_mode = context.Attr("origin_mode"); auto* input = context.Input("Input"); auto* h0 = context.Input("H0"); auto* weight = context.Input("Weight"); @@ -87,7 +88,7 @@ class GRUKernel : public framework::OpKernel { gru_value.reset_output_value = reset_hidden_prev_t.data(); math::GRUUnitFunctor::compute( dev_ctx, gru_value, frame_size, cur_batch_size, active_node, - active_gate); + active_gate, origin_mode); gru_value.prev_out_value = gru_value.output_value; } diff --git a/paddle/fluid/operators/gru_op.h b/paddle/fluid/operators/gru_op.h index 0b551e8046be16c95f7d6b10b68b32a9af594f73..45c769ee37260bf912ebc848d58019557f4adc07 100644 --- a/paddle/fluid/operators/gru_op.h +++ b/paddle/fluid/operators/gru_op.h @@ -41,6 +41,7 @@ template class GRUGradKernel : public framework::OpKernel { public: void BatchCompute(const framework::ExecutionContext& context) const { + bool origin_mode = context.Attr("origin_mode"); auto* h0 = context.Input("H0"); auto* weight = context.Input("Weight"); const T* weight_data = weight->data(); @@ -146,7 +147,7 @@ class GRUGradKernel : public framework::OpKernel { math::GRUUnitGradFunctor::compute( dev_ctx, gru_value, gru_grad, frame_size, cur_batch_size, active_node, - active_gate); + active_gate, origin_mode); } if (input_grad) { input_grad->mutable_data(context.GetPlace()); diff --git a/paddle/fluid/operators/math/detail/gru_cpu_kernel.h b/paddle/fluid/operators/math/detail/gru_cpu_kernel.h index 47c771f7c5c01b651423c7886207abf4a4297019..6e74e124fc2b2bef4e5128e02bcc2beb27b7db23 100644 --- a/paddle/fluid/operators/math/detail/gru_cpu_kernel.h +++ b/paddle/fluid/operators/math/detail/gru_cpu_kernel.h @@ -56,7 +56,8 @@ template void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output, T *gate_value, T *prev_output_value, T *output_value, int frame_size, - ActivationType active_node) { + ActivationType active_node, + bool origin_mode) { T r_value_update_gate; T r_value_frame_state; T r_prev_out = 0; @@ -72,7 +73,7 @@ void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output, } op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out, - &r_output, active_node); + &r_output, active_node, origin_mode); frame_state[i] = r_value_frame_state; output_value[i] = r_output; @@ -146,7 +147,8 @@ template void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output, T *gate_value, T *prev_output_value, T *output_value, int frame_size, - ActivationType active_node) { + ActivationType active_node, + bool origin_mode) { #ifdef __AVX__ __m256 r_value_update_gate, r_value_update_gate_last = _mm256_set1_ps(0.0f); __m256 r_value_frame_state, r_value_frame_state_last = _mm256_set1_ps(0.0f); @@ -180,7 +182,7 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output, } op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out, - &r_output, active_node); + &r_output, active_node, origin_mode); _mm256_storeu_ps(reinterpret_cast(frame_state + i), r_value_frame_state); @@ -190,7 +192,7 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output, if (rest > 0) { i = n - block; op_final_output(&r_value_update_gate_last, &r_value_frame_state_last, - &r_prev_out_last, &r_output, active_node); + &r_prev_out_last, &r_output, active_node, origin_mode); _mm256_storeu_ps(reinterpret_cast(frame_state + i), r_value_frame_state_last); @@ -227,17 +229,18 @@ inline void forward_reset_output(OpResetOutput op_reset_output, template inline void forward_final_output(OpFinalOutput op_final_output, GRUMetaValue value, int frame_size, - int batch_size, ActivationType active_node) { + int batch_size, ActivationType active_node, + bool origin_mode) { for (int b = 0; b < batch_size; b++) { if (OpFinalOutput::avx && (frame_size > static_cast(8 - 1)) && (sizeof(T) == 4)) { hl_avx_gru_forward_final_output(op_final_output, value.gate_value, value.prev_out_value, value.output_value, - frame_size, active_node); + frame_size, active_node, origin_mode); } else { hl_naive_gru_forward_final_output( op_final_output, value.gate_value, value.prev_out_value, - value.output_value, frame_size, active_node); + value.output_value, frame_size, active_node, origin_mode); } value.gate_value += frame_size * 3; diff --git a/paddle/fluid/operators/math/detail/gru_gpu_kernel.h b/paddle/fluid/operators/math/detail/gru_gpu_kernel.h index 813d69f6aba722609a0523a5be71d32f91f76d59..8d133f5327d28abd57356ab5d874cf57368ca1e2 100644 --- a/paddle/fluid/operators/math/detail/gru_gpu_kernel.h +++ b/paddle/fluid/operators/math/detail/gru_gpu_kernel.h @@ -72,7 +72,8 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output, T *gate_value, T *prev_output_value, T *output_value, int frame_size, int batch_size, - ActivationType active_node) { + ActivationType active_node, + bool origin_mode) { const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; if (frame_idx >= frame_size) return; int batch_idx = 0; @@ -94,7 +95,7 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output, } op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out, - &r_output, active_node); + &r_output, active_node, origin_mode); gate_value[frame_idx + frame_size * 2] = r_value_frame_state; output_value[frame_idx] = r_output; diff --git a/paddle/fluid/operators/math/detail/gru_kernel.h b/paddle/fluid/operators/math/detail/gru_kernel.h index f6d192358bd84eb56a2e01eb36f28d8832ef271f..d978bd95c87446524748c3d45b2726a27821d04b 100644 --- a/paddle/fluid/operators/math/detail/gru_kernel.h +++ b/paddle/fluid/operators/math/detail/gru_kernel.h @@ -57,10 +57,16 @@ class gru_finalOutput { public: HOSTDEVICE void operator()(T *value_update_gate, T *value_frame_state, T *prev_out, T *value_output, - ActivationType act_input) { + ActivationType act_input, bool origin_mode) { *value_frame_state = activation(*value_frame_state, act_input); - *value_output = *prev_out - ((*value_update_gate) * (*prev_out)) + - ((*value_update_gate) * (*value_frame_state)); + if (origin_mode) { + *value_output = ((*value_update_gate) * (*prev_out)) + + *value_frame_state - + ((*value_update_gate) * (*value_frame_state)); + } else { + *value_output = *prev_out - ((*value_update_gate) * (*prev_out)) + + ((*value_update_gate) * (*value_frame_state)); + } } #ifndef __NVCC__ #ifndef __AVX__ @@ -69,11 +75,20 @@ class gru_finalOutput { static const bool avx = true; HOSTDEVICE void operator()(__m256 *value_update_gate, __m256 *value_frame_state, __m256 *prev_out, - __m256 *value_output, ActivationType act_input) { + __m256 *value_output, ActivationType act_input, + bool origin_mode) { *value_frame_state = activation(*value_frame_state, act_input); - *value_output = _mm256_add_ps( - _mm256_sub_ps(*prev_out, _mm256_mul_ps(*value_update_gate, *prev_out)), - _mm256_mul_ps(*value_update_gate, *value_frame_state)); + if (origin_mode) { + *value_output = _mm256_sub_ps( + _mm256_add_ps(_mm256_mul_ps(*value_update_gate, *prev_out), + *value_frame_state), + _mm256_mul_ps(*value_update_gate, *value_frame_state)); + } else { + *value_output = _mm256_add_ps( + _mm256_sub_ps(*prev_out, + _mm256_mul_ps(*value_update_gate, *prev_out)), + _mm256_mul_ps(*value_update_gate, *value_frame_state)); + } } #endif #endif diff --git a/paddle/fluid/operators/math/gru_compute.cc b/paddle/fluid/operators/math/gru_compute.cc index 0e15b81deef43a932d4b2d3f545393b0ad9e080c..295b75356c060332cfb0c561b4170815b47f61b6 100644 --- a/paddle/fluid/operators/math/gru_compute.cc +++ b/paddle/fluid/operators/math/gru_compute.cc @@ -23,7 +23,8 @@ struct GRUUnitFunctor { static void compute(const platform::CPUDeviceContext &context, GRUMetaValue value, int frame_size, int batch_size, const detail::ActivationType active_node, - const detail::ActivationType active_gate) { + const detail::ActivationType active_gate, + bool origin_mode) { #ifndef __NVCC__ auto blas = math::GetBlas(context); if (value.prev_out_value) { @@ -43,7 +44,8 @@ struct GRUUnitFunctor { } detail::forward_final_output(detail::forward::gru_finalOutput(), value, - frame_size, batch_size, active_node); + frame_size, batch_size, active_node, + origin_mode); #endif } }; @@ -54,7 +56,8 @@ struct GRUUnitGradFunctor { GRUMetaValue value, GRUMetaGrad grad, int frame_size, int batch_size, const detail::ActivationType active_node, - const detail::ActivationType active_gate) { + const detail::ActivationType active_gate, + bool origin_mode) { #ifndef __NVCC__ detail::backward_state_grad(detail::backward::gru_stateGrad(), value, grad, frame_size, batch_size, active_node); diff --git a/paddle/fluid/operators/math/gru_compute.cu b/paddle/fluid/operators/math/gru_compute.cu index 1327d914952d57aab6e5d17090d0ea976a6d4755..e2c40b739542941bf9016b6414f359eede845bde 100644 --- a/paddle/fluid/operators/math/gru_compute.cu +++ b/paddle/fluid/operators/math/gru_compute.cu @@ -24,7 +24,8 @@ struct GRUUnitFunctor { static void compute(const platform::CUDADeviceContext &context, GRUMetaValue value, int frame_size, int batch_size, const detail::ActivationType active_node, - const detail::ActivationType active_gate) { + const detail::ActivationType active_gate, + bool origin_mode) { auto stream = context.stream(); dim3 threads; dim3 grid; @@ -73,14 +74,14 @@ struct GRUUnitFunctor { T><<>>( detail::forward::gru_finalOutput(), value.gate_value, value.prev_out_value, value.output_value, frame_size, batch_size, - active_node); + active_node, origin_mode); } else { detail::KeGruForwardFinalOutput, /* is_batch= */ true, T><<>>( detail::forward::gru_finalOutput(), value.gate_value, value.prev_out_value, value.output_value, frame_size, batch_size, - active_node); + active_node, origin_mode); } } }; diff --git a/paddle/fluid/operators/math/gru_compute.h b/paddle/fluid/operators/math/gru_compute.h index c5816b16cd90410fcc48929931c25d0d561ad653..f5ddec0aaa275a32a5a9937699066a170edc0825 100644 --- a/paddle/fluid/operators/math/gru_compute.h +++ b/paddle/fluid/operators/math/gru_compute.h @@ -44,7 +44,8 @@ struct GRUUnitFunctor { static void compute(const DeviceContext &context, GRUMetaValue value, int frame_size, int batch_size, const detail::ActivationType active_node, - const detail::ActivationType active_gate); + const detail::ActivationType active_gate, + bool origin_mode); }; template @@ -52,7 +53,8 @@ struct GRUUnitGradFunctor { static void compute(const DeviceContext &context, GRUMetaValue value, GRUMetaGrad grad, int frame_size, int batch_size, const detail::ActivationType active_node, - const detail::ActivationType active_gate); + const detail::ActivationType active_gate, + bool origin_mode); }; } // namespace math