提交 3aa021e7 编写于 作者: H HappyAngel 提交者: chenjiaoAngel

fix conv_conv fusion error in conv_dw+conv_1x1 (#4446)

* fix conv_conv fusion error in conv_dw+conv_1x1. test=develop

* test=develop

* fix format. test=develop
上级 d13c94d9
...@@ -27,7 +27,7 @@ namespace mir { ...@@ -27,7 +27,7 @@ namespace mir {
void ConvConvFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { void ConvConvFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
// initialze fuser params // initialze fuser params
std::vector<bool> conv_has_bias_cases{true, false}; std::vector<bool> conv_has_bias_cases{true, false};
std::vector<std::string> conv_type_cases{"conv2d", "depthwise_conv2d"}; std::vector<std::string> conv_type_cases{"conv2d"};
bool has_int8 = false; bool has_int8 = false;
bool has_weight_quant = false; bool has_weight_quant = false;
for (auto& place : graph->valid_places()) { for (auto& place : graph->valid_places()) {
......
...@@ -132,8 +132,8 @@ void ConvConvFuser::BuildPattern() { ...@@ -132,8 +132,8 @@ void ConvConvFuser::BuildPattern() {
VLOG(5) << "The kernel size of the second conv must be 1x1"; VLOG(5) << "The kernel size of the second conv must be 1x1";
continue; continue;
} }
if (groups1 != 1) { if (groups0 != 1 || groups1 != 1) {
VLOG(5) << "The groups of weight1_dim must be 1"; VLOG(5) << "The all groups of weight_dim must be 1";
continue; continue;
} }
if (ch_out_0 != ch_in_1) { if (ch_out_0 != ch_in_1) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册