inplace_op_var_pass.cc 4.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/inplace_op_var_pass.h"

#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/op_version_registry.h"

namespace paddle {
namespace framework {
namespace ir {

class Graph;

void InplaceOpVarPass::ApplyImpl(ir::Graph* graph) const {
  FusePassBase::Init("inplace_op_var", graph);
  int found_subgraph_count = 0;
  MapToReshape(graph);

  auto nodes = graph->Nodes();
  auto is_valid_reshape = [](Node* node) {
    // Some cases need to consider, please refer to
    // https://github.com/PaddlePaddle/Paddle/pull/49146
    if (node->IsOp() && node->Op()->Type() == "reshape2") {
      auto x_name = node->Op()->Input("X").front();
      for (auto* var_node : node->inputs) {
        if (var_node->Name() == x_name) {
          if (!var_node->Var()->Persistable() && var_node->outputs.size() == 1)
            return true;
        }
      }
    }
    return false;
  };

  // Record all reshape2 op's input name and output name in block 0.
  // If the name used in other block, we can not inplace reshape op.
  std::unordered_set<std::string> var_names, deny_var_names;
  for (auto* node : nodes) {
    if (is_valid_reshape(node)) {
      for (auto n : node->inputs) var_names.insert(n->Name());
      for (auto n : node->outputs) var_names.insert(n->Name());
    }
  }
  for (size_t i = 1; i < graph->SubGraphsSize(); ++i) {
    auto sub_graph = graph->GetSubGraph(i);
    for (auto* node : sub_graph->Nodes()) {
      if (node->IsOp()) {
        for (auto var_node : node->inputs) {
          if (var_names.count(var_node->Name()))
            deny_var_names.insert(var_node->Name());
        }
        for (auto var_node : node->outputs) {
          if (var_names.count(var_node->Name()))
            deny_var_names.insert(var_node->Name());
        }
      }
    }
  }

  // inplace all reshape op.
  auto topo_nodes = TopologySortOperations(*graph);
  for (auto* node : topo_nodes) {
    if (!is_valid_reshape(node)) continue;
    auto* op_node = node->Op();
    auto input_name = op_node->Input("X")[0];
    auto output_name = op_node->Output("Out")[0];
    if (deny_var_names.count(input_name) || deny_var_names.count(output_name)) {
      continue;
    }
    ++found_subgraph_count;
    for (auto* out_var : node->outputs) {
      if (out_var->Name() == output_name) {
        out_var->RenameVar(input_name);
        for (auto* next_op : out_var->outputs) {
          next_op->Op()->RenameInput(output_name, input_name);
          next_op->Op()->Flush();
        }
      }
    }

    op_node->RenameOutput(output_name, input_name);
    op_node->Flush();
  }
  AddStatis(found_subgraph_count);
}

void InplaceOpVarPass::MapToReshape(ir::Graph* graph) const {
  // flatten_contiguous_range op map to reshape.
  for (auto* node : graph->Nodes()) {
    if (node->IsOp() && node->Op()->Type() == "flatten_contiguous_range") {
      auto* op_node = node->Op();
      auto start_axis = PADDLE_GET_CONST(int, op_node->GetAttr("start_axis"));
      auto stop_axis = PADDLE_GET_CONST(int, op_node->GetAttr("stop_axis"));
      auto input_name = op_node->Input("X")[0];
      auto* block = op_node->Block();
      auto input_shape = block->FindVar(input_name)->GetShape();
      if (start_axis == 1 && stop_axis == 3 && input_shape.size() == 4 &&
          input_shape[2] == 1 && input_shape[3] == 1) {
        op_node->SetType("reshape2");
        op_node->SetAttr("shape", std::vector<int>{0, -1});
        op_node->Flush();
      }
    }
  }
}

}  // namespace ir
}  // namespace framework
}  // namespace paddle

REGISTER_PASS(inplace_op_var_pass, paddle::framework::ir::InplaceOpVarPass);
REGISTER_PASS_CAPABILITY(inplace_op_var_pass)
    .AddCombination(
        paddle::framework::compatible::OpVersionComparatorCombination().EQ(
            "reshape2", 0));