提交 d7f67c5c 编写于 作者: Z zhupengyang

fix conv_elemetwise_fuse when elemetwise_bias need broadcast

test=develop
上级 ddc1b571
...@@ -92,6 +92,27 @@ void ConvElementwiseFuser::InsertNewNode(SSAGraph* graph, ...@@ -92,6 +92,27 @@ void ConvElementwiseFuser::InsertNewNode(SSAGraph* graph,
auto elementwise_add_bias_d = auto elementwise_add_bias_d =
elementwise_add_bias_t->mutable_data<float>(); elementwise_add_bias_t->mutable_data<float>();
auto conv_bias_size = conv_bias_t->numel();
auto elemetwise_bias_size = elementwise_add_bias_t->numel();
// If elements size of `elemwise_bias` and `conv_bias` are not same,
// `elemwise_bias` should be broadcast to the same size of `conv_bias`
if (conv_bias_size != elemetwise_bias_size && elemetwise_bias_size == 1) {
auto data_tmp = elementwise_add_bias_d[0];
elementwise_add_bias_t->Resize({conv_bias_size});
elementwise_add_bias_d = elementwise_add_bias_t->mutable_data<float>();
for (int64_t i = 0; i < conv_bias_size; i++) {
elementwise_add_bias_d[i] = data_tmp;
}
}
if (conv_bias_t->numel() != elementwise_add_bias_t->numel()) {
LOG(WARNING) << "Elements size of `elemwise_bias` and `conv_bias` "
"should be the same, but get size of `elemwise_bias` "
"is: "
<< elementwise_add_bias_t->numel()
<< ", size of `conv_bias` is: " << conv_bias_t->numel();
return;
}
for (unsigned int i = 0; i < conv_bias_t->data_size(); ++i) { for (unsigned int i = 0; i < conv_bias_t->data_size(); ++i) {
elementwise_add_bias_d[i] += conv_bias_d[i]; elementwise_add_bias_d[i] += conv_bias_d[i];
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册