未验证 提交 4a6499a6 编写于 作者: H HappyAngel 提交者: GitHub

[arm] add Conv+conv fusion requirement (#4216)

* add conv+conv fusion requirment

* fix compute error
* fix format. test=develop

* fix format. test=develop

* fix according comments. test=develop
上级 ac6c98f4
......@@ -30,7 +30,7 @@ void ConvConvFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
bool has_fp32 = false;
bool has_int8 = false;
for (auto& place : graph->valid_places()) {
if (place.target == TARGET(kARM)) {
if (place.target == TARGET(kARM) || place.target == TARGET(kHost)) {
if (place.precision == PRECISION(kFloat)) {
has_fp32 = true;
}
......@@ -38,6 +38,7 @@ void ConvConvFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
has_int8 = true;
}
} else {
VLOG(5) << "place.target: " << static_cast<int>(place.target);
return;
}
}
......@@ -50,12 +51,12 @@ void ConvConvFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
for (auto conv_has_bias1 : conv_has_bias_cases) {
for (auto conv_type0 : conv_type_cases) {
for (auto conv_type1 : {"conv2d"}) { // it mustbe 1x1s1p0_conv
VLOG(4) << "conv_has_bias0:" << conv_has_bias0
VLOG(5) << "conv_has_bias0:" << conv_has_bias0
<< " conv_type0:" << conv_type0;
VLOG(4) << "conv_has_bias1:" << conv_has_bias1
VLOG(5) << "conv_has_bias1:" << conv_has_bias1
<< " conv_type1:" << conv_type1;
fusion::ConvConvFuser fuser(
conv_type0, conv_type1, conv_has_bias0, conv_has_bias1);
conv_type0, conv_type1, conv_has_bias0, conv_has_bias1, graph);
fuser(graph.get());
}
}
......
......@@ -22,7 +22,7 @@ namespace lite {
namespace mir {
namespace fusion {
void ConvConvFuser::BuildPattern() {
inline void ConvConvFuser::createPattern() {
auto* conv_input0 = VarNode("conv_input0")
->assert_is_op_input(conv_type0_, "Input")
->AsInput();
......@@ -81,6 +81,91 @@ void ConvConvFuser::BuildPattern() {
}
}
void ConvConvFuser::BuildPattern() {
for (auto& node : graph_->StmtTopologicalOrder()) {
if (node->IsStmt() &&
node->AsStmt().picked_kernel().op_type() == conv_type0_) {
auto* scope = node->stmt()->op()->scope();
auto conv_op_desc0 = node->stmt()->mutable_op_info();
// find outlinks of conv2d: in_arg_node
auto conv2d_outlinks = node->outlinks;
VLOG(5) << "conv2d_outlinks.size():" << conv2d_outlinks.size();
if (conv2d_outlinks.size() == 1) {
auto next_node_tmp = conv2d_outlinks.front();
if (next_node_tmp->IsArg() && next_node_tmp->outlinks.size() == 1) {
auto next_node = next_node_tmp->outlinks.front();
auto conv0_in = node->inlinks;
auto conv0_wei_name = conv0_in.front();
VLOG(5) << "next_node->IsStmt(): " << next_node->IsStmt();
VLOG(5) << ", next op_type:"
<< next_node->AsStmt().picked_kernel().op_type();
if (next_node->IsStmt() &&
next_node->AsStmt().picked_kernel().op_type() == conv_type1_) {
// find conv->conv pattern
auto conv1_in = next_node->inlinks;
auto conv1_wei_name = conv1_in.front();
auto a = conv0_wei_name->AsArg().name;
auto b = conv1_wei_name->AsArg().name;
VLOG(5) << "conv0_wei_name: " << a;
VLOG(5) << "conv1_wei_name: " << b;
auto conv_op_desc1 = next_node->stmt()->mutable_op_info();
auto weight0_dims = scope->FindVar(a)->Get<lite::Tensor>().dims();
auto weight1_dims = scope->FindVar(b)->Get<lite::Tensor>().dims();
auto groups0 = conv_op_desc0->GetAttr<int>("groups");
auto groups1 = conv_op_desc1->GetAttr<int>("groups");
auto strides1 = conv_op_desc1->GetAttr<std::vector<int>>("strides");
auto paddings1 =
conv_op_desc1->GetAttr<std::vector<int>>("paddings");
auto dilations1 =
conv_op_desc1->GetAttr<std::vector<int>>("dilations");
auto ch_out_0 = weight0_dims[0];
auto ch_in_0 = weight0_dims[1] * groups0;
auto ch_out_1 = weight1_dims[0];
auto ch_in_1 = weight1_dims[1] * groups1;
auto kh = weight1_dims[2];
auto kw = weight1_dims[3];
bool enable0_int8 =
conv_op_desc0->HasAttr("enable_int8") ? true : false;
bool enable1_int8 =
conv_op_desc1->HasAttr("enable_int8") ? true : false;
if (!(kw == 1 && kh == 1)) {
VLOG(5) << "The kernel size of the second conv must be 1x1";
continue;
}
if (groups1 != 1) {
VLOG(5) << "The groups of weight1_dim must be 1";
continue;
}
if (ch_out_0 != ch_in_1) {
VLOG(5) << "channel0_out must be equal channel1_in";
continue;
}
if (enable0_int8 || enable0_int8 != enable1_int8) {
VLOG(5) << "The Conv-compute type must be same and be false";
continue;
}
// computation: ic0 x (oc1-oc0) < oc0 x oc1
VLOG(5) << "a: " << (ch_in_0 * (ch_out_1 - ch_out_0)) << " <= "
<< "b: " << (ch_out_0 * ch_out_1);
if (ch_in_0 * (ch_out_1 - ch_out_0) > ch_out_0 * ch_out_1) {
VLOG(5) << "it dose not meet the requirment of conv+conv fusion "
<< "computation "
<< "a: " << (ch_in_0 * (ch_out_1 - ch_out_0)) << " <= "
<< "b: " << (ch_out_0 * ch_out_1);
continue;
}
// create pattern
VLOG(5) << "matched: " << conv_type0_ << " and " << conv_type1_;
createPattern();
return;
}
}
}
}
}
}
void ConvConvFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
auto conv_instruct = matched.at("conv2d0")->stmt();
auto conv_op_desc = conv_instruct->mutable_op_info();
......@@ -95,25 +180,11 @@ void ConvConvFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
// conv1
auto weight1_t = scope->FindVar(matched.at("conv_weight1")->arg()->name)
->GetMutable<lite::Tensor>();
// auto groups0 = conv_op_desc->GetAttr<int>("groups");
auto groups1 = conv_op_desc1->GetAttr<int>("groups");
bool enable0_int8 = conv_op_desc->HasAttr("enable_int8") ? true : false;
auto strides1 = conv_op_desc1->GetAttr<std::vector<int>>("strides");
auto paddings1 = conv_op_desc1->GetAttr<std::vector<int>>("paddings");
auto dilations1 = conv_op_desc1->GetAttr<std::vector<int>>("dilations");
bool enable0_int8 = conv_op_desc->HasAttr("enable_int8") ? true : false;
bool enable1_int8 = conv_op_desc1->HasAttr("enable_int8") ? true : false;
int kw = weight1_t->dims()[2];
int kh = weight1_t->dims()[3];
if (!(kw == 1 && kh == 1)) {
LOG(FATAL) << "The kernel size of the second conv must be 1x1";
}
auto channel0_out = weight0_t->dims()[0];
auto channel1_in = weight1_t->dims()[1] * groups1;
CHECK_EQ(enable0_int8, enable1_int8) << "The Conv compute type must be same";
CHECK_EQ(groups1, 1) << "The groups of weight1_dim must be 1";
CHECK_EQ(channel0_out, channel1_in) << "channel0_out == channel1_in";
for (int i = 0; i < strides1.size(); i++) {
CHECK_EQ(strides1[i], 1) << "strides[" << i << "]: " << strides1[i]
<< " must be 1";
......
......@@ -30,13 +30,16 @@ class ConvConvFuser : public FuseBase {
explicit ConvConvFuser(const std::string& conv_type0,
const std::string& conv_type1,
const bool conv_has_bias0,
const bool conv_has_bias1)
const bool conv_has_bias1,
const std::unique_ptr<SSAGraph>& graph)
: conv_type0_(conv_type0),
conv_type1_(conv_type1),
conv_has_bias0_(conv_has_bias0),
conv_has_bias1_(conv_has_bias1) {}
conv_has_bias1_(conv_has_bias1),
graph_(graph) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
inline void createPattern();
private:
void ComputeNewWeight(float* dout,
......@@ -112,6 +115,7 @@ class ConvConvFuser : public FuseBase {
std::string conv_type1_{"conv2d"};
bool conv_has_bias0_{false};
bool conv_has_bias1_{false};
const std::unique_ptr<SSAGraph>& graph_;
};
} // namespace fusion
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册