fc_fuse_pass.cc 3.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/fc_fuse_pass.h"
16
#include <memory>
17
#include <string>
18
#include <unordered_set>
19
#include <vector>
Y
Yan Chunwei 已提交
20
#include "paddle/fluid/framework/ir/graph_helper.h"
21 22 23 24 25 26
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace framework {
namespace ir {

27 28 29
void FCFusePass::ApplyImpl(ir::Graph* graph) const {
  PADDLE_ENFORCE(graph);
  FusePassBase::Init("fc_fuse", graph);
30 31 32

  std::unordered_set<Node*> nodes2delete;

33
  GraphPatternDetector gpd;
34 35 36 37
  auto* x = gpd.mutable_pattern()
                ->NewNode("fc_fuse/x")
                ->AsInput()
                ->assert_is_op_input("mul", "X");
Y
Yan Chunwei 已提交
38 39
  patterns::FC fc_pattern(gpd.mutable_pattern(), "fc_fuse");
  fc_pattern(x, true /*with bias*/);
40

Y
Yan Chunwei 已提交
41
  int found_fc_count = 0;
42
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
43
                     Graph* g) {
M
minqiyang 已提交
44
    VLOG(4) << "handle FC fuse";
Y
Yan Chunwei 已提交
45 46 47 48 49 50
    GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(fc_bias, bias, fc_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(fc_out, Out, fc_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(mul, mul, fc_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(mul_out, mul_out, fc_pattern);
51

52
    auto base_op_desc = mul->Op();
53
    // Create an FC Node.
54
    // OpDesc desc(base_op_desc, nullptr);
55
    OpDesc desc;
Y
Yan Chunwei 已提交
56
    std::string fc_x_in = subgraph.at(x)->Name();
57 58 59
    std::string fc_Y_in = w->Name();
    std::string fc_bias_in = fc_bias->Name();
    std::string fc_out_out = fc_out->Name();
60

61 62 63
    desc.SetInput("Input", std::vector<std::string>({fc_x_in}));
    desc.SetInput("W", std::vector<std::string>({fc_Y_in}));
    desc.SetInput("Bias", std::vector<std::string>({fc_bias_in}));
64
    desc.SetOutput("Out", std::vector<std::string>({fc_out_out}));
T
Tao Luo 已提交
65
    desc.SetAttr("in_num_col_dims", mul->Op()->GetAttr("x_num_col_dims"));
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82

    // For anakin subgraph int8
    // When in anakin subgraph int8 mode, the pattern like "fake_quant + mul +
    // fake_dequant"
    // can be detected by the quant_dequant_fuse_pass. This pass will add
    // "input_scale",
    // "weight_scale" which are extracted from fake_quant op and fake_dequant op
    // to mul op,
    // and then delete the fake_quant op and fake_dequant op in the graph. If
    // the mul op
    // has the scale info, we should add those to the fused fc.
    if (base_op_desc->HasAttr("enable_int8")) {
      desc.SetAttr("enable_int8", base_op_desc->GetAttr("enable_int8"));
      desc.SetAttr("input_scale", base_op_desc->GetAttr("input_scale"));
      desc.SetAttr("weight_scale", base_op_desc->GetAttr("weight_scale"));
    }

83
    desc.SetType("fc");
84

85
    auto fc_node = g->CreateOpNode(&desc);  // OpDesc will be copied.
86
    GraphSafeRemoveNodes(graph, {mul, elementwise_add, mul_out});
87

Y
Yan Chunwei 已提交
88 89
    PADDLE_ENFORCE(subgraph.count(x));
    IR_NODE_LINK_TO(subgraph.at(x), fc_node);
90 91 92
    IR_NODE_LINK_TO(w, fc_node);
    IR_NODE_LINK_TO(fc_bias, fc_node);
    IR_NODE_LINK_TO(fc_node, fc_out);
Y
Yan Chunwei 已提交
93 94

    found_fc_count++;
95 96
  };

97
  gpd(graph, handler);
98

Y
Yan Chunwei 已提交
99
  AddStatis(found_fc_count);
100 101 102 103 104 105 106
}

}  // namespace ir
}  // namespace framework
}  // namespace paddle

REGISTER_PASS(fc_fuse_pass, paddle::framework::ir::FCFusePass);