未验证 提交 f358cdb8 编写于 作者: H HappyAngel 提交者: GitHub

[arm] add conv+conv fusion (#3967)

* add conv+conv(1x1s1p0) fusion

* fix build and run error

* fix formmat. test=develop
上级 c572523f
......@@ -28,6 +28,7 @@ USE_MIR_PASS(graph_visualize_pass);
USE_MIR_PASS(remove_tf_redundant_ops_pass);
USE_MIR_PASS(lite_conv_bn_fuse_pass);
USE_MIR_PASS(lite_conv_conv_fuse_pass);
USE_MIR_PASS(lite_fc_fuse_pass);
USE_MIR_PASS(lite_shuffle_channel_fuse_pass);
USE_MIR_PASS(lite_transpose_softmax_transpose_fuse_pass);
......
......@@ -152,7 +152,7 @@ static bool conv_trans_weights_numc(const dtype* din,
*/
template <typename Dtype>
void local_transpose(const Dtype* din, Dtype* dout, int m, int n) {
// n % 4 == 0 m % 4 == 0
// n % 4 == 0 && m % 4 == 0
// n * m ==> n * m data trans
int offset_m = m << 2;
const Dtype* din_ptr = din;
......
......@@ -18,6 +18,7 @@ lite_cc_library(mir_passes
fusion/conv_activation_fuse_pass.cc
fusion/var_conv_2d_activation_fuse_pass.cc
fusion/conv_bn_fuse_pass.cc
fusion/conv_conv_fuse_pass.cc
fusion/elementwise_add_activation_fuse_pass.cc
fusion/quant_dequant_fuse_pass.cc
fusion/sequence_pool_concat_fuse_pass.cc
......
......@@ -16,6 +16,9 @@ lite_cc_library(fuse_var_conv_activation
lite_cc_library(fuse_conv_bn
SRCS conv_bn_fuser.cc
DEPS pattern_matcher_high_api)
lite_cc_library(fuse_conv_conv
SRCS conv_conv_fuser.cc
DEPS pattern_matcher_high_api)
lite_cc_library(fuse_elementwise_add_activation
SRCS elementwise_add_activation_fuser.cc
DEPS pattern_matcher_high_api)
......@@ -42,6 +45,7 @@ set(mir_fusers
fuse_conv_activation
fuse_var_conv_activation
fuse_conv_bn
fuse_conv_conv
fuse_quant_dequant
fuse_elementwise_add_activation
fuse_transpose_softmax_transpose
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/fusion/conv_conv_fuse_pass.h"
#include <memory>
#include <vector>
#include "lite/core/mir/fusion/conv_conv_fuser.h"
#include "lite/core/mir/graph_visualize_pass.h"
#include "lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void ConvConvFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
// initialze fuser params
std::vector<bool> conv_has_bias_cases{true, false};
std::vector<std::string> conv_type_cases{"conv2d", "depthwise_conv2d"};
bool has_arm = false;
for (auto& place : graph->valid_places()) {
if (place.target == TARGET(kARM) && place.precision == PRECISION(kFloat)) {
has_arm = true;
break;
}
}
if (!has_arm) {
return;
}
// only support fp32 fusion
for (auto conv_has_bias0 : conv_has_bias_cases) {
for (auto conv_has_bias1 : conv_has_bias_cases) {
for (auto conv_type0 : conv_type_cases) {
for (auto conv_type1 : conv_type_cases) {
VLOG(4) << "conv_has_bias0:" << conv_has_bias0
<< " conv_type0:" << conv_type0;
VLOG(4) << "conv_has_bias1:" << conv_has_bias1
<< " conv_type1:" << conv_type1;
fusion::ConvConvFuser fuser(
conv_type0, conv_type1, conv_has_bias0, conv_has_bias1);
fuser(graph.get());
}
}
}
}
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(lite_conv_conv_fuse_pass, paddle::lite::mir::ConvConvFusePass)
.BindTargets({TARGET(kARM)});
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "lite/core/mir/pass.h"
namespace paddle {
namespace lite {
namespace mir {
class ConvConvFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
};
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/fusion/conv_conv_fuser.h"
#include <memory>
#include <set>
#include <vector>
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
void ConvConvFuser::BuildPattern() {
auto* conv_input0 = VarNode("conv_input0")
->assert_is_op_input(conv_type0_, "Input")
->AsInput();
auto* conv_weight0 = VarNode("conv_weight0")
->assert_is_op_input(conv_type0_, "Filter")
->AsInput();
auto* conv0 = OpNode("conv2d0", conv_type0_)->assert_is_op(conv_type0_);
auto* conv_out0 = VarNode("conv_out0")
->assert_is_op_output(conv_type0_, "Output")
->assert_is_op_input(conv_type1_, "Input")
->AsIntermediate();
auto* conv_weight1 = VarNode("conv_weight1")
->assert_is_op_input(conv_type1_, "Filter")
->AsIntermediate();
auto* conv1 = OpNode("conv2d1", conv_type1_)
->assert_is_op(conv_type1_)
->assert_op_attr<int>("groups", 1)
->AsIntermediate();
auto* conv_out1 = VarNode("conv_out1")
->assert_is_op_output(conv_type1_, "Output")
->AsOutput();
if (conv_has_bias0_) {
if (conv_has_bias1_) {
auto* conv_bias0 = VarNode("conv_bias0")
->assert_is_op_input(conv_type0_, "Bias")
->AsIntermediate();
auto* conv_bias1 = VarNode("conv_bias1")
->assert_is_op_input(conv_type1_, "Bias")
->AsInput();
conv0->LinksFrom({conv_input0, conv_weight0, conv_bias0})
.LinksTo({conv_out0});
conv1->LinksFrom({conv_out0, conv_weight1, conv_bias1})
.LinksTo({conv_out1});
} else {
auto* conv_bias0 = VarNode("conv_bias0")
->assert_is_op_input(conv_type0_, "Bias")
->AsIntermediate();
conv0->LinksFrom({conv_input0, conv_weight0, conv_bias0})
.LinksTo({conv_out0});
conv1->LinksFrom({conv_out0, conv_weight1}).LinksTo({conv_out1});
}
} else {
conv0->LinksFrom({conv_input0, conv_weight0}).LinksTo({conv_out0});
if (conv_has_bias1_) {
auto* conv_bias1 = VarNode("conv_bias1")
->assert_is_op_input(conv_type1_, "Bias")
->AsInput();
conv1->LinksFrom({conv_out0, conv_weight1, conv_bias1})
.LinksTo({conv_out1});
} else {
conv1->LinksFrom({conv_out0, conv_weight1}).LinksTo({conv_out1});
}
}
}
void ConvConvFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
auto conv_instruct = matched.at("conv2d0")->stmt();
auto conv_op_desc = conv_instruct->mutable_op_info();
auto conv = conv_instruct->op();
auto* scope = conv->scope();
auto conv_op_desc1 = matched.at("conv2d1")->stmt()->mutable_op_info();
// conv0
auto weight0_t = scope->FindVar(matched.at("conv_weight0")->arg()->name)
->GetMutable<lite::Tensor>();
// conv1
auto weight1_t = scope->FindVar(matched.at("conv_weight1")->arg()->name)
->GetMutable<lite::Tensor>();
// auto groups0 = conv_op_desc->GetAttr<int>("groups");
auto groups1 = conv_op_desc1->GetAttr<int>("groups");
auto strides1 = conv_op_desc1->GetAttr<std::vector<int>>("strides");
auto paddings1 = conv_op_desc1->GetAttr<std::vector<int>>("paddings");
auto dilations1 = conv_op_desc1->GetAttr<std::vector<int>>("dilations");
bool enable0_int8 = conv_op_desc->HasAttr("enable_int8") ? true : false;
bool enable1_int8 = conv_op_desc1->HasAttr("enable_int8") ? true : false;
int kw = weight1_t->dims()[2];
int kh = weight1_t->dims()[3];
if (!(kw == 1 && kh == 1)) {
return;
}
CHECK_EQ(enable0_int8, enable1_int8) << "The Conv compute type must be same";
CHECK_EQ(groups1, 1) << "The groups of weight1_dim must be 1";
CHECK_EQ(weight0_t->dims()[0], weight1_t->dims()[1])
<< "weight0_dims[0] == weight1_dim[1]";
for (int i = 0; i < strides1.size(); i++) {
CHECK_EQ(strides1[i], 1) << "strides[" << i << "]: " << strides1[i]
<< " must be 1";
}
for (int i = 0; i < paddings1.size(); i++) {
CHECK_EQ(paddings1[i], 0) << "paddings1[" << i << "]: " << paddings1[i]
<< " must be 0";
}
for (int i = 0; i < dilations1.size(); i++) {
CHECK_EQ(dilations1[i], 1) << "dilations1[" << i << "]: " << dilations1[i]
<< " must be 1";
}
// comupte new_wight and new bias
///////////////////////////////////////////////////////////////////////////////
// Compute ConvConvFuser
// Before fusion
//
// conv(x) = conv(x) = kx + z = y
// conv(y) = ay + b
//
// After fusion:
//
// conv(conv(x)) = a(kx + z) + b = akx + az + b
//
// new_weights = ak
// new_bias = az + b
///////////////////////////////////////////////////////////////////////////////
if (enable0_int8) {
LOG(FATAL) << "it doesn't support";
return;
} else {
// compute new conv_weight
Tensor weight_tensor;
auto in_dims = weight0_t->dims();
auto weight_dims = weight1_t->dims();
const float* din = weight0_t->data<float>();
const float* weights = weight1_t->data<float>();
int oc0 = in_dims[0];
int ic = in_dims[1];
int ih = in_dims[2];
int iw = in_dims[3];
int oc = weight_dims[0];
weight_tensor.Resize({oc, ic, ih, iw});
float* dout = weight_tensor.mutable_data<float>();
ComputeNewWeight(dout, din, weights, oc0, ic, ih, iw, oc);
weight0_t->CopyDataFrom(weight_tensor);
}
// compute new conv_bias
if (conv_has_bias0_ && conv_op_desc->HasInput("Bias") &&
conv_op_desc->Input("Bias").size() > 0) {
auto bias_t0 = scope->FindVar(matched.at("conv_bias0")->arg()->name)
->GetMutable<lite::Tensor>();
if (conv_has_bias1_ && conv_op_desc1->HasInput("Bias") &&
conv_op_desc1->Input("Bias").size() > 0) {
auto bias_t1 = scope->FindVar(matched.at("conv_bias1")->arg()->name)
->GetMutable<lite::Tensor>();
Tensor bias;
bias.CopyDataFrom(*bias_t1);
auto bias_data = bias.mutable_data<float>();
ComputeNewBias(bias_data, bias_t0, weight1_t, bias_t1);
bias_t1->CopyDataFrom(bias);
conv_op_desc->SetInput(
"Bias", {matched.at("conv_bias1")->arg()->name}); // conv_bias
IR_NODE_LINK_TO(matched.at("conv_bias1"), matched.at("conv2d0"));
} else {
Tensor bias;
auto weight_dims = weight1_t->dims();
bias.Resize({weight_dims[0]});
auto bias_d = bias.mutable_data<float>();
ComputeNewBias(bias_d, bias_t0, weight1_t, nullptr);
bias_t0->CopyDataFrom(bias);
conv_op_desc->SetInput(
"Bias", {matched.at("conv_bias0")->arg()->name}); // conv_bias
}
} else {
if (conv_has_bias1_ && conv_op_desc1->HasInput("Bias") &&
conv_op_desc1->Input("Bias").size() > 0) {
conv_op_desc->SetInput(
"Bias", {matched.at("conv_bias1")->arg()->name}); // conv_bias
IR_NODE_LINK_TO(matched.at("conv_bias1"), matched.at("conv2d0"));
}
}
conv_op_desc->SetType(conv_type0_);
conv_op_desc->SetInput("Input", {matched.at("conv_input0")->arg()->name});
conv_op_desc->SetInput("Filter", {matched.at("conv_weight0")->arg()->name});
conv_op_desc->SetOutput("Output", {matched.at("conv_out1")->arg()->name});
auto update_conv_desc = *conv_instruct->mutable_op_info();
conv_instruct->ResetOp(update_conv_desc, graph->valid_places());
IR_OP_VAR_LINK(matched.at("conv2d0"), matched.at("conv_out1"));
}
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cmath>
#include <memory>
#include <string>
#include "lite/core/mir/pattern_matcher_high_api.h"
#include "lite/utils/cp_logging.h"
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
class ConvConvFuser : public FuseBase {
public:
explicit ConvConvFuser(const std::string& conv_type0,
const std::string& conv_type1,
const bool conv_has_bias0,
const bool conv_has_bias1)
: conv_type0_(conv_type0),
conv_type1_(conv_type1),
conv_has_bias0_(conv_has_bias0),
conv_has_bias1_(conv_has_bias1) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
void ComputeNewWeight(float* dout,
const float* din,
const float* weights,
int oc0,
int ic,
int ih,
int iw,
int oc1) {
// input conv_weight0_t weights conv_weight1_t
// output weight_tensor
// ksize = 1
int in_size = ih * iw;
int in_channel_size = ic * in_size;
// out = w1[j, i, ih, iw] * w2[k, j, kw, kh]
// out_dim = [oc1, ic, kh, kw], din_dim = [oc0, ic, kh, kw]
// weight_dim = [oc1, oc0, kh, kw]
for (int k = 0; k < oc1; k++) {
const float* weights_ptr = weights + k * oc0;
float* out_ptr = dout + k * in_channel_size;
for (int c = 0; c < ic; c++) {
float* out_ptr_channel = out_ptr + c * in_size;
const float* din_ptr = din + c * in_size;
for (int i = 0; i < in_size; i++) {
float sum = 0.f;
for (int j = 0; j < oc0; j++) {
sum += din_ptr[j * in_channel_size] * weights_ptr[j];
}
*out_ptr_channel++ = sum;
}
}
}
}
void ComputeNewBias(float* dout,
Tensor* bias0_tensor,
Tensor* weight_tensor,
Tensor* bias1_tensor) {
// input bias0_tensor weight_tensor bias1_tensor
// output bias_tensor
auto in_dims = bias0_tensor->dims();
auto weight_dims = weight_tensor->dims();
const float* din = bias0_tensor->data<float>();
const float* weights = weight_tensor->data<float>();
int ic = in_dims[0];
int oc = weight_dims[0];
// out_k = b0[num, j, 1, 1] * w2[k, j, 1, 1]
if (bias1_tensor) {
const float* din2 = bias1_tensor->data<float>();
for (int k = 0; k < oc; k++) {
const float* weights_ptr = weights + k * ic;
float sum = 0.f;
for (int j = 0; j < ic; j++) {
sum += din[j] * weights_ptr[j];
}
dout[k] = sum + din2[k];
}
} else {
for (int k = 0; k < oc; k++) {
const float* weights_ptr = weights + k * ic;
float sum = 0.f;
for (int j = 0; j < ic; j++) {
sum += din[j] * weights_ptr[j];
}
dout[k] = sum;
}
}
}
private:
std::string conv_type0_{"conv2d"};
std::string conv_type1_{"conv2d"};
bool conv_has_bias0_{false};
bool conv_has_bias1_{false};
};
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
......@@ -76,6 +76,7 @@ class Optimizer {
"lite_conv_elementwise_fuse_pass", // conv-elemwise-bn
"lite_conv_bn_fuse_pass", //
"lite_conv_elementwise_fuse_pass", // conv-bn-elemwise
"lite_conv_conv_fuse_pass", //
// TODO(Superjomn) Refine the fusion related design to select fusion
// kernels for devices automatically.
"lite_conv_activation_fuse_pass", //
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册