未验证 提交 43777438 编写于 作者: H HappyAngel 提交者: GitHub

[BUG FIX][Pass]fix depthwise deconv+bn fusion (#3480)

* fix format, test=develop

* add some op infershape implement, test=develop

* add reshape infershape, test=develop

* fix depthwise_deconv error. test=develop

* fix format. test=develop
上级 6a04c221
...@@ -103,9 +103,12 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { ...@@ -103,9 +103,12 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
std::string conv_weight_name = matched.at("conv_weight")->arg()->name; std::string conv_weight_name = matched.at("conv_weight")->arg()->name;
auto conv_weight_t = auto conv_weight_t =
scope->FindVar(conv_weight_name)->GetMutable<lite::Tensor>(); scope->FindVar(conv_weight_name)->GetMutable<lite::Tensor>();
auto groups = conv_op_desc->GetAttr<int>("groups");
bool depthwise = false;
if (conv_type_ == "conv2d_transpose") { if (conv_type_ == "conv2d_transpose") {
depthwise = (conv_weight_t->dims()[0] == conv_weight_t->dims()[1] * groups);
CHECK_EQ(static_cast<size_t>(bn_scale_t->data_size()), CHECK_EQ(static_cast<size_t>(bn_scale_t->data_size()),
static_cast<size_t>(conv_weight_t->dims()[1])) static_cast<size_t>(conv_weight_t->dims()[1] * groups))
<< "The BN bias's size should be equal to the size of the first " << "The BN bias's size should be equal to the size of the first "
<< "dim size of the conv weights"; << "dim size of the conv weights";
} else { } else {
...@@ -159,7 +162,7 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { ...@@ -159,7 +162,7 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
// compute new conv_weight for int8 // compute new conv_weight for int8
auto weight_scale = auto weight_scale =
conv_op_desc->GetAttr<std::vector<float>>("weight_scale"); conv_op_desc->GetAttr<std::vector<float>>("weight_scale");
if (conv_type_ == "conv2d_transpose") { if (conv_type_ == "conv2d_transpose" && !depthwise) {
int c_size = conv_weight_t->dims()[1] * conv_weight_t->dims()[2] * int c_size = conv_weight_t->dims()[1] * conv_weight_t->dims()[2] *
conv_weight_t->dims()[3]; conv_weight_t->dims()[3];
int hw = conv_weight_t->dims()[2] * conv_weight_t->dims()[3]; int hw = conv_weight_t->dims()[2] * conv_weight_t->dims()[3];
...@@ -199,7 +202,7 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { ...@@ -199,7 +202,7 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
} else { } else {
// compute new conv_weight // compute new conv_weight
auto conv_weight_d = conv_weight_t->mutable_data<float>(); auto conv_weight_d = conv_weight_t->mutable_data<float>();
if (conv_type_ == "conv2d_transpose") { if (conv_type_ == "conv2d_transpose" && !depthwise) {
int c_size = conv_weight_t->dims()[1] * conv_weight_t->dims()[2] * int c_size = conv_weight_t->dims()[1] * conv_weight_t->dims()[2] *
conv_weight_t->dims()[3]; conv_weight_t->dims()[3];
int hw = conv_weight_t->dims()[2] * conv_weight_t->dims()[3]; int hw = conv_weight_t->dims()[2] * conv_weight_t->dims()[3];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册