提交 e0b0b945 编写于 作者: 李寅

Fix fc bias nullptr

上级 3c5da142
......@@ -64,17 +64,20 @@ struct FullyConnectedFunctor<DeviceType::CPU, float>: FullyConnectedBase {
Tensor::MappingGuard guard_input(input);
Tensor::MappingGuard guard_weight(weight);
Tensor::MappingGuard guard_bias(bias);
Tensor::MappingGuard guard_output(output);
const float *input_ptr = input->data<float>();
const float *weight_ptr = weight->data<float>();
const float *bias_ptr = bias == nullptr ? nullptr : bias->data<float>();
float *output_ptr = output->mutable_data<float>();
Gemv(weight_ptr, input_ptr, N, input_size, output_size, output_ptr);
for (int i = 0; i < N; ++i) {
for (int j = 0; j < output_size; ++j) {
output_ptr[j + i * output_size] += bias_ptr[j];
if (bias) {
Tensor::MappingGuard guard_bias(bias);
const float *bias_ptr = bias == nullptr ? nullptr : bias->data<float>();
for (int i = 0; i < N; ++i) {
for (int j = 0; j < output_size; ++j) {
output_ptr[j + i * output_size] += bias_ptr[j];
}
}
}
......
......@@ -42,17 +42,23 @@ class FullyConnectedOp : public Operator<D, T> {
if (D == DeviceType::CPU) {
MACE_CHECK(
input->dim(1) == weight->dim(1) && input->dim(2) == weight->dim(2) &&
input->dim(3) == weight->dim(3) && weight->dim(0) == bias->dim(0),
input->dim(3) == weight->dim(3),
"The shape of Input: ", MakeString(input->shape()),
"The shape of Weight: ", MakeString(weight->shape()), " and Bias ",
bias->dim(0), " don't match.");
"The shape of Weight: ", MakeString(weight->shape()),
" don't match.");
} else {
MACE_CHECK(
input->dim(1) == weight->dim(2) && input->dim(2) == weight->dim(3) &&
input->dim(3) == weight->dim(1) && weight->dim(0) == bias->dim(0),
input->dim(3) == weight->dim(1),
"The shape of Input: ", MakeString(input->shape()),
"The shape of Weight: ", MakeString(weight->shape()), " and Bias ",
bias->dim(0), " don't match.");
"The shape of Weight: ", MakeString(weight->shape()),
" don't match.");
}
if (bias) {
MACE_CHECK(weight->dim(0) == bias->dim(0),
"The shape of Weight: ", MakeString(weight->shape()),
" and shape of Bias: ", bias->dim(0),
" don't match.");
}
return functor_(input, weight, bias, output, future);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册