fc_fuse_pass.cc 3.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/fc_fuse_pass.h"
#include <string>
#include <vector>
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace framework {
namespace ir {

std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
    std::unique_ptr<ir::Graph> graph) const {
  PADDLE_ENFORCE(graph.get());
Y
Yan Chunwei 已提交
27
  FusePassBase::Init("fc_fuse", graph.get());
28 29 30

  std::unordered_set<Node*> nodes2delete;

31
  GraphPatternDetector gpd;
32 33 34 35 36 37 38 39 40 41 42 43
  // BuildFCPattern(gpd.mutable_pattern());
  auto* x = gpd.mutable_pattern()
                ->NewNode("fc_fuse/x")
                ->AsInput()
                ->assert_is_op_input("mul", "X");
  patterns::FC(gpd.mutable_pattern(), "fc_fuse", x, true /*with bias*/);

#define GET_NODE(id)                                                         \
  PADDLE_ENFORCE(subgraph.count(gpd.pattern().RetrieveNode("fc_fuse/" #id)), \
                 "pattern has no Node called %s", #id);                      \
  auto* id = subgraph.at(gpd.pattern().RetrieveNode("fc_fuse/" #id));        \
  PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", "fc_fuse/" #id);
44

Y
Yan Chunwei 已提交
45
  int found_fc_count = 0;
46
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
47 48 49 50 51 52
                     Graph* g) {
    VLOG(4) << "handle FC fuse";
    // Currently, there is no FC op available, so I will just simulate the
    // scenerio.
    // FC's fusion is simple, just op fuse, no need to process the
    // parameters.
53 54 55 56 57 58 59
    GET_NODE(x);                // x
    GET_NODE(w);                // Y
    GET_NODE(fc_bias);          // bias
    GET_NODE(fc_out);           // Out
    GET_NODE(mul);              // MUL op
    GET_NODE(elementwise_add);  // ELEMENT_ADD op
    GET_NODE(mul_out);          // tmp
60 61 62 63
#undef GET_NODE

    // Create an FC Node.
    OpDesc desc;
64 65 66 67
    std::string fc_x_in = x->Name();
    std::string fc_Y_in = w->Name();
    std::string fc_bias_in = fc_bias->Name();
    std::string fc_out_out = fc_out->Name();
68 69 70
    desc.SetInput("Input", std::vector<std::string>({fc_x_in}));
    desc.SetInput("W", std::vector<std::string>({fc_Y_in}));
    desc.SetInput("Bias", std::vector<std::string>({fc_bias_in}));
71
    desc.SetOutput("Out", std::vector<std::string>({fc_out_out}));
72 73
    desc.SetType("fc");
    auto fc_node = g->CreateOpNode(&desc);  // OpDesc will be copied.
74
    GraphSafeRemoveNodes(graph.get(), {mul, elementwise_add, mul_out});
75

76 77 78 79
    IR_NODE_LINK_TO(x, fc_node);
    IR_NODE_LINK_TO(w, fc_node);
    IR_NODE_LINK_TO(fc_bias, fc_node);
    IR_NODE_LINK_TO(fc_node, fc_out);
Y
Yan Chunwei 已提交
80 81

    found_fc_count++;
82 83 84 85
  };

  gpd(graph.get(), handler);

Y
Yan Chunwei 已提交
86
  AddStatis(found_fc_count);
87 88 89 90 91 92 93 94
  return graph;
}

}  // namespace ir
}  // namespace framework
}  // namespace paddle

REGISTER_PASS(fc_fuse_pass, paddle::framework::ir::FCFusePass);