diff --git a/lite/core/mir/fusion/conv_elementwise_fuser.cc b/lite/core/mir/fusion/conv_elementwise_fuser.cc index f94da2f1b1fc0a0d4ca17718f9407a4a56c544fe..d7333efde9fa3396dbec1a5efffe079ac17dcce1 100644 --- a/lite/core/mir/fusion/conv_elementwise_fuser.cc +++ b/lite/core/mir/fusion/conv_elementwise_fuser.cc @@ -92,6 +92,27 @@ void ConvElementwiseFuser::InsertNewNode(SSAGraph* graph, auto elementwise_add_bias_d = elementwise_add_bias_t->mutable_data(); + auto conv_bias_size = conv_bias_t->numel(); + auto elemetwise_bias_size = elementwise_add_bias_t->numel(); + // If elements size of `elemwise_bias` and `conv_bias` are not same, + // `elemwise_bias` should be broadcast to the same size of `conv_bias` + if (conv_bias_size != elemetwise_bias_size && elemetwise_bias_size == 1) { + auto data_tmp = elementwise_add_bias_d[0]; + elementwise_add_bias_t->Resize({conv_bias_size}); + elementwise_add_bias_d = elementwise_add_bias_t->mutable_data(); + for (int64_t i = 0; i < conv_bias_size; i++) { + elementwise_add_bias_d[i] = data_tmp; + } + } + if (conv_bias_t->numel() != elementwise_add_bias_t->numel()) { + LOG(WARNING) << "Elements size of `elemwise_bias` and `conv_bias` " + "should be the same, but get size of `elemwise_bias` " + "is: " + << elementwise_add_bias_t->numel() + << ", size of `conv_bias` is: " << conv_bias_t->numel(); + return; + } + for (unsigned int i = 0; i < conv_bias_t->data_size(); ++i) { elementwise_add_bias_d[i] += conv_bias_d[i]; }