diff --git a/src/operators/kernel/arm/gru_unit_kernel.cpp b/src/operators/kernel/arm/gru_unit_kernel.cpp index c9dbcbc7d50bf313000c2e788d3032c2f0cec0e7..bf20f25d7241449abb03a40a7dfd352bc23643af 100644 --- a/src/operators/kernel/arm/gru_unit_kernel.cpp +++ b/src/operators/kernel/arm/gru_unit_kernel.cpp @@ -27,7 +27,7 @@ bool GruUnitKernel::Init(GruUnitParam *param) { template <> void GruUnitKernel::Compute(const GruUnitParam ¶m) { - GruUnitCompute(param); + GruUnitCompute(param); } template class GruUnitKernel; diff --git a/src/operators/kernel/central-arm-func/gru_unit_arm_func.h b/src/operators/kernel/central-arm-func/gru_unit_arm_func.h index 477674c7e9b5bf0d7356e19090c6cb722abf8c68..3897d6a308484f942f4a0cbb03957a37f3b7ab4c 100644 --- a/src/operators/kernel/central-arm-func/gru_unit_arm_func.h +++ b/src/operators/kernel/central-arm-func/gru_unit_arm_func.h @@ -47,7 +47,7 @@ void GruUnitCompute(const GruUnitParam& param) { gru_value.state_weight = const_cast(weight_data + 2 * frame_size * frame_size); gru_value.output_value = hidden->data

(); - gru_value.prev_out_value = hidden_prev->data

(); + gru_value.prev_out_value = const_cast(hidden_prev->data

()); gru_value.gate_value = gate->data

(); gru_value.reset_output_value = reset_hidden_prev->data

(); auto active_node = math::GetActivationType(param.Activation());