diff --git a/mace/kernels/fully_connected.h b/mace/kernels/fully_connected.h index a6fbebd61d95d29ec9a03eda08cfda06291ecd73..7a337a9d847ff658e4e48304122abe0ecd10e269 100644 --- a/mace/kernels/fully_connected.h +++ b/mace/kernels/fully_connected.h @@ -64,17 +64,20 @@ struct FullyConnectedFunctor: FullyConnectedBase { Tensor::MappingGuard guard_input(input); Tensor::MappingGuard guard_weight(weight); - Tensor::MappingGuard guard_bias(bias); Tensor::MappingGuard guard_output(output); const float *input_ptr = input->data(); const float *weight_ptr = weight->data(); - const float *bias_ptr = bias == nullptr ? nullptr : bias->data(); float *output_ptr = output->mutable_data(); Gemv(weight_ptr, input_ptr, N, input_size, output_size, output_ptr); - for (int i = 0; i < N; ++i) { - for (int j = 0; j < output_size; ++j) { - output_ptr[j + i * output_size] += bias_ptr[j]; + + if (bias) { + Tensor::MappingGuard guard_bias(bias); + const float *bias_ptr = bias == nullptr ? nullptr : bias->data(); + for (int i = 0; i < N; ++i) { + for (int j = 0; j < output_size; ++j) { + output_ptr[j + i * output_size] += bias_ptr[j]; + } } } diff --git a/mace/ops/fully_connected.h b/mace/ops/fully_connected.h index 286b22580ee6c8c42df61410748e607781522bb1..8ec0039185366a5419cf2b56dfd9317b3a5342a3 100644 --- a/mace/ops/fully_connected.h +++ b/mace/ops/fully_connected.h @@ -42,17 +42,23 @@ class FullyConnectedOp : public Operator { if (D == DeviceType::CPU) { MACE_CHECK( input->dim(1) == weight->dim(1) && input->dim(2) == weight->dim(2) && - input->dim(3) == weight->dim(3) && weight->dim(0) == bias->dim(0), + input->dim(3) == weight->dim(3), "The shape of Input: ", MakeString(input->shape()), - "The shape of Weight: ", MakeString(weight->shape()), " and Bias ", - bias->dim(0), " don't match."); + "The shape of Weight: ", MakeString(weight->shape()), + " don't match."); } else { MACE_CHECK( input->dim(1) == weight->dim(2) && input->dim(2) == weight->dim(3) && - input->dim(3) == weight->dim(1) && weight->dim(0) == bias->dim(0), + input->dim(3) == weight->dim(1), "The shape of Input: ", MakeString(input->shape()), - "The shape of Weight: ", MakeString(weight->shape()), " and Bias ", - bias->dim(0), " don't match."); + "The shape of Weight: ", MakeString(weight->shape()), + " don't match."); + } + if (bias) { + MACE_CHECK(weight->dim(0) == bias->dim(0), + "The shape of Weight: ", MakeString(weight->shape()), + " and shape of Bias: ", bias->dim(0), + " don't match."); } return functor_(input, weight, bias, output, future);