From 43777438365d1c755863282559a60228c979d5ac Mon Sep 17 00:00:00 2001 From: HappyAngel Date: Sat, 25 Apr 2020 03:36:29 -0500 Subject: [PATCH] [BUG FIX][Pass]fix depthwise deconv+bn fusion (#3480) * fix format, test=develop * add some op infershape implement, test=develop * add reshape infershape, test=develop * fix depthwise_deconv error. test=develop * fix format. test=develop --- lite/core/mir/fusion/conv_bn_fuser.cc | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/lite/core/mir/fusion/conv_bn_fuser.cc b/lite/core/mir/fusion/conv_bn_fuser.cc index 143a7cecce..6718356788 100644 --- a/lite/core/mir/fusion/conv_bn_fuser.cc +++ b/lite/core/mir/fusion/conv_bn_fuser.cc @@ -103,9 +103,12 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { std::string conv_weight_name = matched.at("conv_weight")->arg()->name; auto conv_weight_t = scope->FindVar(conv_weight_name)->GetMutable(); + auto groups = conv_op_desc->GetAttr("groups"); + bool depthwise = false; if (conv_type_ == "conv2d_transpose") { + depthwise = (conv_weight_t->dims()[0] == conv_weight_t->dims()[1] * groups); CHECK_EQ(static_cast(bn_scale_t->data_size()), - static_cast(conv_weight_t->dims()[1])) + static_cast(conv_weight_t->dims()[1] * groups)) << "The BN bias's size should be equal to the size of the first " << "dim size of the conv weights"; } else { @@ -159,7 +162,7 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { // compute new conv_weight for int8 auto weight_scale = conv_op_desc->GetAttr>("weight_scale"); - if (conv_type_ == "conv2d_transpose") { + if (conv_type_ == "conv2d_transpose" && !depthwise) { int c_size = conv_weight_t->dims()[1] * conv_weight_t->dims()[2] * conv_weight_t->dims()[3]; int hw = conv_weight_t->dims()[2] * conv_weight_t->dims()[3]; @@ -199,7 +202,7 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { } else { // compute new conv_weight auto conv_weight_d = conv_weight_t->mutable_data(); - if (conv_type_ == "conv2d_transpose") { + if (conv_type_ == "conv2d_transpose" && !depthwise) { int c_size = conv_weight_t->dims()[1] * conv_weight_t->dims()[2] * conv_weight_t->dims()[3]; int hw = conv_weight_t->dims()[2] * conv_weight_t->dims()[3]; -- GitLab