diff --git a/micro/ops/matmul.cc b/micro/ops/matmul.cc index 09caeff38e200c0184ccc9ba5e2d8ebe22bb089d..1888d1e94e634c97d3549a0b10ecf177bb8b6f74 100644 --- a/micro/ops/matmul.cc +++ b/micro/ops/matmul.cc @@ -27,9 +27,15 @@ MaceStatus MatMulOp::OnInit() { transpose_b_ = GetArgByName("transpose_b", false); input_a_ = GetInputData(INPUT_A); input_b_ = GetInputData(INPUT_B); - bias_ = GetInputSize() > 3 ? GetInputData(BIAS) : NULL; output_ = GetOutputData(OUTPUT); + bias_ = NULL; + if (GetInputSize() >= 3) { + bias_ = GetInputData(BIAS); + bias_dim_size_ = GetInputShapeDimSize(BIAS); + bias_dims_ = GetInputShapeDims(BIAS); + } + input_a_dim_size_ = GetInputShapeDimSize(INPUT_A); input_b_dim_size_ = GetInputShapeDimSize(INPUT_B); diff --git a/micro/ops/utils/gemv.cc b/micro/ops/utils/gemv.cc index 1fc81c47934497313ed49fb65d91f2741a08de2e..d09bf845dbbd05ed1a14aa623f6fd11162214a47 100644 --- a/micro/ops/utils/gemv.cc +++ b/micro/ops/utils/gemv.cc @@ -98,10 +98,10 @@ MaceStatus Gemv::Compute(const mifloat *lhs_data, float sum2 = 0; float sum3 = 0; if (bias_data != NULL) { - sum0 = bias_data[0]; - sum1 = bias_data[1]; - sum2 = bias_data[2]; - sum3 = bias_data[3]; + sum0 = bias_data[h + 0]; + sum1 = bias_data[h + 1]; + sum2 = bias_data[h + 2]; + sum3 = bias_data[h + 3]; } const int32_t lhs_h_base0 = (lhs_b_base + h) * lhs_width; const int32_t lhs_h_base1 = lhs_h_base0 + lhs_width;