未验证 提交 18709770 编写于 作者: H HappyAngel 提交者: GitHub

[BUG FIX][Pass]fix depthwise deconv+bn fusion (#3480) (#3519)

*fix depthwise deconv+bn fusion
上级 717c176d
......@@ -103,9 +103,12 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
std::string conv_weight_name = matched.at("conv_weight")->arg()->name;
auto conv_weight_t =
scope->FindVar(conv_weight_name)->GetMutable<lite::Tensor>();
auto groups = conv_op_desc->GetAttr<int>("groups");
bool depthwise = false;
if (conv_type_ == "conv2d_transpose") {
depthwise = (conv_weight_t->dims()[0] == conv_weight_t->dims()[1] * groups);
CHECK_EQ(static_cast<size_t>(bn_scale_t->data_size()),
static_cast<size_t>(conv_weight_t->dims()[1]))
static_cast<size_t>(conv_weight_t->dims()[1] * groups))
<< "The BN bias's size should be equal to the size of the first "
<< "dim size of the conv weights";
} else {
......@@ -159,7 +162,7 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
// compute new conv_weight for int8
auto weight_scale =
conv_op_desc->GetAttr<std::vector<float>>("weight_scale");
if (conv_type_ == "conv2d_transpose") {
if (conv_type_ == "conv2d_transpose" && !depthwise) {
int c_size = conv_weight_t->dims()[1] * conv_weight_t->dims()[2] *
conv_weight_t->dims()[3];
int hw = conv_weight_t->dims()[2] * conv_weight_t->dims()[3];
......@@ -199,7 +202,7 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
} else {
// compute new conv_weight
auto conv_weight_d = conv_weight_t->mutable_data<float>();
if (conv_type_ == "conv2d_transpose") {
if (conv_type_ == "conv2d_transpose" && !depthwise) {
int c_size = conv_weight_t->dims()[1] * conv_weight_t->dims()[2] *
conv_weight_t->dims()[3];
int hw = conv_weight_t->dims()[2] * conv_weight_t->dims()[3];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册