diff --git a/lite/core/mir/fusion/conv_bn_fuse_pass.cc b/lite/core/mir/fusion/conv_bn_fuse_pass.cc index f5a7837b53650e08f9632b499a4c2ab1faeaeedf..4393832931c95ca20e34ca3b3d2fb4501274b15f 100644 --- a/lite/core/mir/fusion/conv_bn_fuse_pass.cc +++ b/lite/core/mir/fusion/conv_bn_fuse_pass.cc @@ -26,7 +26,8 @@ namespace mir { void ConvBNFusePass::Apply(const std::unique_ptr& graph) { // initialze fuser params std::vector conv_has_bias_cases{true, false}; - std::vector conv_type_cases{"conv2d", "depthwise_conv2d"}; + std::vector conv_type_cases{ + "conv2d", "depthwise_conv2d", "conv2d_transpose"}; // start fuse using params for (auto conv_has_bias : conv_has_bias_cases) { for (auto conv_type : conv_type_cases) { diff --git a/lite/core/mir/fusion/conv_bn_fuser.cc b/lite/core/mir/fusion/conv_bn_fuser.cc index 0f5bb64e10dd61c3edf4ddd32569a2d365651cdf..150a6e68d8a924ebfa96fdffb99e28b230689a48 100644 --- a/lite/core/mir/fusion/conv_bn_fuser.cc +++ b/lite/core/mir/fusion/conv_bn_fuser.cc @@ -103,10 +103,17 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { std::string conv_weight_name = matched.at("conv_weight")->arg()->name; auto conv_weight_t = scope->FindVar(conv_weight_name)->GetMutable(); - CHECK_EQ(static_cast(bn_scale_t->data_size()), - static_cast(conv_weight_t->dims()[0])) - << "The BN bias's size should be equal to the size of the first " - << "dim size of the conv weights"; + if (conv_type_ == "conv2d_transpose") { + CHECK_EQ(static_cast(bn_scale_t->data_size()), + static_cast(conv_weight_t->dims()[1])) + << "The BN bias's size should be equal to the size of the first " + << "dim size of the conv weights"; + } else { + CHECK_EQ(static_cast(bn_scale_t->data_size()), + static_cast(conv_weight_t->dims()[0])) + << "The BN bias's size should be equal to the size of the first " + << "dim size of the conv weights"; + } size_t weight_num = conv_weight_t->data_size(); bool enable_int8 = conv_op_desc->HasAttr("enable_int8") ? true : false; bool is_weight_quantization = @@ -153,12 +160,29 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { // compute new conv_weight for int8 auto weight_scale = conv_op_desc->GetAttr>("weight_scale"); - for (unsigned int i = 0; i < h; ++i) { - weight_scale[i] *= fabsf(alpha_data[i]); - if (alpha_data[i] < 0.f) { - auto ptr_row = conv_weight_d + i * w; - for (unsigned int j = 0; j < w; ++j) { - ptr_row[j] *= -1; + if (conv_type_ == "conv2d_transpose") { + int c_size = conv_weight_t->dims()[1] * conv_weight_t->dims()[2] * + conv_weight_t->dims()[3]; + int hw = conv_weight_t->dims()[2] * conv_weight_t->dims()[3]; + for (unsigned int k = 0; k < conv_weight_t->dims()[0]; ++k) { + for (unsigned int i = 0; i < h; ++i) { + weight_scale[i] *= fabsf(alpha_data[i]); + if (alpha_data[i] < 0.f) { + auto ptr_row = conv_weight_d + k * c_size + i * hw; + for (unsigned int j = 0; j < hw; ++j) { + ptr_row[j] *= -1; + } + } + } + } + } else { + for (unsigned int i = 0; i < h; ++i) { + weight_scale[i] *= fabsf(alpha_data[i]); + if (alpha_data[i] < 0.f) { + auto ptr_row = conv_weight_d + i * w; + for (unsigned int j = 0; j < w; ++j) { + ptr_row[j] *= -1; + } } } } @@ -176,9 +200,23 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { } else { // compute new conv_weight auto conv_weight_d = conv_weight_t->mutable_data(); - for (unsigned int i = 0; i < h; ++i) { // n: conv2d output channels - for (unsigned int j = 0; j < w; ++j) { // w: conv2d input channels - conv_weight_d[i * w + j] *= alpha_data[i]; + if (conv_type_ == "conv2d_transpose") { + int c_size = conv_weight_t->dims()[1] * conv_weight_t->dims()[2] * + conv_weight_t->dims()[3]; + int hw = conv_weight_t->dims()[2] * conv_weight_t->dims()[3]; + for (unsigned int k = 0; k < conv_weight_t->dims()[0]; ++k) { + for (unsigned int i = 0; i < h; ++i) { + auto ptr_row = conv_weight_d + k * c_size + i * hw; + for (unsigned int j = 0; j < hw; ++j) { + ptr_row[j] *= alpha_data[i]; + } + } + } + } else { + for (unsigned int i = 0; i < h; ++i) { // n: conv2d output channels + for (unsigned int j = 0; j < w; ++j) { // w: conv2d input channels + conv_weight_d[i * w + j] *= alpha_data[i]; + } } } }