From e0b0b9459f07b6c2110e3c168e243d136369e7aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=AF=85?= Date: Mon, 30 Jul 2018 10:51:54 +0800 Subject: [PATCH] Fix fc bias nullptr --- mace/kernels/fully_connected.h | 13 ++++++++----- mace/ops/fully_connected.h | 18 ++++++++++++------ 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/mace/kernels/fully_connected.h b/mace/kernels/fully_connected.h index a6fbebd6..7a337a9d 100644 --- a/mace/kernels/fully_connected.h +++ b/mace/kernels/fully_connected.h @@ -64,17 +64,20 @@ struct FullyConnectedFunctor: FullyConnectedBase { Tensor::MappingGuard guard_input(input); Tensor::MappingGuard guard_weight(weight); - Tensor::MappingGuard guard_bias(bias); Tensor::MappingGuard guard_output(output); const float *input_ptr = input->data(); const float *weight_ptr = weight->data(); - const float *bias_ptr = bias == nullptr ? nullptr : bias->data(); float *output_ptr = output->mutable_data(); Gemv(weight_ptr, input_ptr, N, input_size, output_size, output_ptr); - for (int i = 0; i < N; ++i) { - for (int j = 0; j < output_size; ++j) { - output_ptr[j + i * output_size] += bias_ptr[j]; + + if (bias) { + Tensor::MappingGuard guard_bias(bias); + const float *bias_ptr = bias == nullptr ? nullptr : bias->data(); + for (int i = 0; i < N; ++i) { + for (int j = 0; j < output_size; ++j) { + output_ptr[j + i * output_size] += bias_ptr[j]; + } } } diff --git a/mace/ops/fully_connected.h b/mace/ops/fully_connected.h index 286b2258..8ec00391 100644 --- a/mace/ops/fully_connected.h +++ b/mace/ops/fully_connected.h @@ -42,17 +42,23 @@ class FullyConnectedOp : public Operator { if (D == DeviceType::CPU) { MACE_CHECK( input->dim(1) == weight->dim(1) && input->dim(2) == weight->dim(2) && - input->dim(3) == weight->dim(3) && weight->dim(0) == bias->dim(0), + input->dim(3) == weight->dim(3), "The shape of Input: ", MakeString(input->shape()), - "The shape of Weight: ", MakeString(weight->shape()), " and Bias ", - bias->dim(0), " don't match."); + "The shape of Weight: ", MakeString(weight->shape()), + " don't match."); } else { MACE_CHECK( input->dim(1) == weight->dim(2) && input->dim(2) == weight->dim(3) && - input->dim(3) == weight->dim(1) && weight->dim(0) == bias->dim(0), + input->dim(3) == weight->dim(1), "The shape of Input: ", MakeString(input->shape()), - "The shape of Weight: ", MakeString(weight->shape()), " and Bias ", - bias->dim(0), " don't match."); + "The shape of Weight: ", MakeString(weight->shape()), + " don't match."); + } + if (bias) { + MACE_CHECK(weight->dim(0) == bias->dim(0), + "The shape of Weight: ", MakeString(weight->shape()), + " and shape of Bias: ", bias->dim(0), + " don't match."); } return functor_(input, weight, bias, output, future); -- GitLab