提交 10c3b66e 编写于 作者: H HappyAngel 提交者: GitHub

Add deconv+batchnorm fusion (#3318)

* add conv_transpose+bn fusion. test=develop

* delete note, test=develop

* fix format space, test=develop

* fix opt run error, test=develop
上级 51e9898d
......@@ -26,7 +26,8 @@ namespace mir {
void ConvBNFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
// initialze fuser params
std::vector<bool> conv_has_bias_cases{true, false};
std::vector<std::string> conv_type_cases{"conv2d", "depthwise_conv2d"};
std::vector<std::string> conv_type_cases{
"conv2d", "depthwise_conv2d", "conv2d_transpose"};
// start fuse using params
for (auto conv_has_bias : conv_has_bias_cases) {
for (auto conv_type : conv_type_cases) {
......
......@@ -103,10 +103,17 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
std::string conv_weight_name = matched.at("conv_weight")->arg()->name;
auto conv_weight_t =
scope->FindVar(conv_weight_name)->GetMutable<lite::Tensor>();
CHECK_EQ(static_cast<size_t>(bn_scale_t->data_size()),
static_cast<size_t>(conv_weight_t->dims()[0]))
<< "The BN bias's size should be equal to the size of the first "
<< "dim size of the conv weights";
if (conv_type_ == "conv2d_transpose") {
CHECK_EQ(static_cast<size_t>(bn_scale_t->data_size()),
static_cast<size_t>(conv_weight_t->dims()[1]))
<< "The BN bias's size should be equal to the size of the first "
<< "dim size of the conv weights";
} else {
CHECK_EQ(static_cast<size_t>(bn_scale_t->data_size()),
static_cast<size_t>(conv_weight_t->dims()[0]))
<< "The BN bias's size should be equal to the size of the first "
<< "dim size of the conv weights";
}
size_t weight_num = conv_weight_t->data_size();
bool enable_int8 = conv_op_desc->HasAttr("enable_int8") ? true : false;
bool is_weight_quantization =
......@@ -153,12 +160,29 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
// compute new conv_weight for int8
auto weight_scale =
conv_op_desc->GetAttr<std::vector<float>>("weight_scale");
for (unsigned int i = 0; i < h; ++i) {
weight_scale[i] *= fabsf(alpha_data[i]);
if (alpha_data[i] < 0.f) {
auto ptr_row = conv_weight_d + i * w;
for (unsigned int j = 0; j < w; ++j) {
ptr_row[j] *= -1;
if (conv_type_ == "conv2d_transpose") {
int c_size = conv_weight_t->dims()[1] * conv_weight_t->dims()[2] *
conv_weight_t->dims()[3];
int hw = conv_weight_t->dims()[2] * conv_weight_t->dims()[3];
for (unsigned int k = 0; k < conv_weight_t->dims()[0]; ++k) {
for (unsigned int i = 0; i < h; ++i) {
weight_scale[i] *= fabsf(alpha_data[i]);
if (alpha_data[i] < 0.f) {
auto ptr_row = conv_weight_d + k * c_size + i * hw;
for (unsigned int j = 0; j < hw; ++j) {
ptr_row[j] *= -1;
}
}
}
}
} else {
for (unsigned int i = 0; i < h; ++i) {
weight_scale[i] *= fabsf(alpha_data[i]);
if (alpha_data[i] < 0.f) {
auto ptr_row = conv_weight_d + i * w;
for (unsigned int j = 0; j < w; ++j) {
ptr_row[j] *= -1;
}
}
}
}
......@@ -176,9 +200,23 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
} else {
// compute new conv_weight
auto conv_weight_d = conv_weight_t->mutable_data<float>();
for (unsigned int i = 0; i < h; ++i) { // n: conv2d output channels
for (unsigned int j = 0; j < w; ++j) { // w: conv2d input channels
conv_weight_d[i * w + j] *= alpha_data[i];
if (conv_type_ == "conv2d_transpose") {
int c_size = conv_weight_t->dims()[1] * conv_weight_t->dims()[2] *
conv_weight_t->dims()[3];
int hw = conv_weight_t->dims()[2] * conv_weight_t->dims()[3];
for (unsigned int k = 0; k < conv_weight_t->dims()[0]; ++k) {
for (unsigned int i = 0; i < h; ++i) {
auto ptr_row = conv_weight_d + k * c_size + i * hw;
for (unsigned int j = 0; j < hw; ++j) {
ptr_row[j] *= alpha_data[i];
}
}
}
} else {
for (unsigned int i = 0; i < h; ++i) { // n: conv2d output channels
for (unsigned int j = 0; j < w; ++j) { // w: conv2d input channels
conv_weight_d[i * w + j] *= alpha_data[i];
}
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册