提交 d397aed5 编写于 作者: C chenjiaoAngel

add conv_transpose+bn fusion. test=develop

上级 baa4ff00
......@@ -26,10 +26,12 @@ namespace mir {
void ConvBNFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
// initialze fuser params
std::vector<bool> conv_has_bias_cases{true, false};
std::vector<std::string> conv_type_cases{"conv2d", "depthwise_conv2d"};
std::vector<std::string> conv_type_cases{"conv2d", "depthwise_conv2d", "conv2d_transpose"};
// start fuse using params
for (auto conv_has_bias : conv_has_bias_cases) {
for (auto conv_type : conv_type_cases) {
std::cout << "conv_has_bias:" << conv_has_bias
<< " conv_type:" << conv_type << std::endl;
VLOG(4) << "conv_has_bias:" << conv_has_bias
<< " conv_type:" << conv_type;
fusion::ConvBNFuser fuser(conv_type, conv_has_bias);
......
......@@ -153,7 +153,7 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
// compute new conv_weight for int8
auto weight_scale =
conv_op_desc->GetAttr<std::vector<float>>("weight_scale");
for (unsigned int i = 0; i < h; ++i) {
/* for (unsigned int i = 0; i < h; ++i) {
weight_scale[i] *= fabsf(alpha_data[i]);
if (alpha_data[i] < 0.f) {
auto ptr_row = conv_weight_d + i * w;
......@@ -162,6 +162,33 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
}
}
}
*/
if (conv_type_ == "conv2d_transpose") {
int c_size = conv_weight_t->dims()[1] * conv_weight_t->dims()[2] *
conv_weight_t->dims()[3];
int hw = conv_weight_t->dims()[2] * conv_weight_t->dims()[3];
for (unsigned int k = 0; k < conv_weight_t->dims()[0]; ++k) {
for (unsigned int i = 0; i < h; ++i) {
weight_scale[i] *= fabsf(alpha_data[i]);
if (alpha_data[i] < 0.f) {
auto ptr_row = conv_weight_d + k * c_size + i * hw;
for (unsigned int j = 0; j < hw; ++j) {
ptr_row[j] *= -1;
}
}
}
}
} else {
for (unsigned int i = 0; i < h; ++i) {
weight_scale[i] *= fabsf(alpha_data[i]);
if (alpha_data[i] < 0.f) {
auto ptr_row = conv_weight_d + i * w;
for (unsigned int j = 0; j < w; ++j) {
ptr_row[j] *= -1;
}
}
}
}
conv_op_desc->SetAttr("weight_scale", weight_scale);
} else if (is_weight_quantization) {
std::string scale_name = conv_weight_name + "_quant_scale";
......@@ -176,10 +203,29 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
} else {
// compute new conv_weight
auto conv_weight_d = conv_weight_t->mutable_data<float>();
for (unsigned int i = 0; i < h; ++i) { // n: conv2d output channels
/*for (unsigned int i = 0; i < h; ++i) { // n: conv2d output channels
for (unsigned int j = 0; j < w; ++j) { // w: conv2d input channels
conv_weight_d[i * w + j] *= alpha_data[i];
}
}*/
if (conv_type_ == "conv2d_transpose") {
int c_size = conv_weight_t->dims()[1] * conv_weight_t->dims()[2] *
conv_weight_t->dims()[3];
int hw = conv_weight_t->dims()[2] * conv_weight_t->dims()[3];
for (unsigned int k = 0; k < conv_weight_t->dims()[0]; ++k) {
for (unsigned int i = 0; i < h; ++i) {
auto ptr_row = conv_weight_d + k * c_size + i * hw;
for (unsigned int j = 0; j < hw; ++j) {
ptr_row[j] *= alpha_data[i];
}
}
}
} else {
for (unsigned int i = 0; i < h; ++i) { // n: conv2d output channels
for (unsigned int j = 0; j < w; ++j) { // w: conv2d input channels
conv_weight_d[i * w + j] *= alpha_data[i];
}
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册