delete_repeated_ops_pass.cc 8.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"

namespace phi {
class DenseTensor;
}  // namespace phi

namespace paddle {
namespace framework {
class Scope;
}  // namespace framework
}  // namespace paddle

namespace paddle {
namespace framework {
namespace ir {
35 36 37 38 39 40 41 42 43 44 45 46 47

bool HasOutVarName(Node* op_node, std::string name) {
  auto* op_desc = op_node->Op();
  auto outputs = op_desc->Outputs();
  for (auto iter : outputs) {
    auto out_names = iter.second;
    if (std::count(out_names.begin(), out_names.end(), name) > 0) {
      return true;
    }
  }
  return false;
}

48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
namespace patterns {

struct VarWithRepeatedOpsPattern : public PatternBase {
  VarWithRepeatedOpsPattern(PDPattern* pattern,
                            const std::string& name_scope,
                            const std::string& op_type);

  // declare variable node's name
  PATTERN_DECL_NODE(in_var);

  std::string op_type_;
};

VarWithRepeatedOpsPattern::VarWithRepeatedOpsPattern(
    PDPattern* pattern,
    const std::string& name_scope,
    const std::string& op_type)
    : PatternBase(pattern, name_scope, name_scope), op_type_(op_type) {
  pattern->NewNode(in_var_repr())
      ->assert_is_var()
      ->assert_more([&](Node* node) {
        auto out_nodes = node->outputs;
        if (out_nodes.size() <= 1) return false;
        int op_counts = 0;
        for (auto* next_op : out_nodes) {
          if (next_op->Name() == op_type_) {
            op_counts++;
          }
        }
        return op_counts > 1;
      });
}

}  // namespace patterns

/*
Delete repeated ops, for example:
Origin subgraph:
     (input_variable)
      /     |    \     ...
    shape shape shape  ...
      |     |     |    ...
     op0   op1   op2   ...

Optimized subgraph:
      (input_variable)
            |
          shape
         /  |  \     ...
       op0 op1 op2   ...
*/
class DeleteRepeatedOpsPass : public FusePassBase {
 protected:
  void ApplyImpl(ir::Graph* graph) const override;

 private:
104 105 106 107
  void DeleteRepeatedOps(
      ir::Graph* graph,
      const std::string& op_type,
      std::function<std::string(OpDesc*)> gen_op_key_fn) const;
108 109 110 111

  const std::string name_scope_{"delete_repeated_ops_pass"};
};

112 113 114 115
void DeleteRepeatedOpsPass::DeleteRepeatedOps(
    ir::Graph* graph,
    const std::string& op_type,
    std::function<std::string(OpDesc*)> gen_op_key_fn) const {
116 117
  GraphPatternDetector gpd;
  patterns::VarWithRepeatedOpsPattern pattern(
118
      gpd.mutable_pattern(), name_scope_, op_type);
119 120 121 122

  int delete_counts = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* graph) {
123
    VLOG(4) << "handle DeleteRepeatedOps";
124 125
    GET_IR_NODE_FROM_SUBGRAPH(in_var, in_var, pattern);

126 127 128
    std::vector<std::string> invalid_out_ops{
        "while", "conditional_block", "fetch"};
    std::map<std::string, std::vector<Node*>> ops_map;
129
    for (auto* next_op : in_var->outputs) {
130 131 132 133 134 135 136 137 138
      if (next_op->Name() != op_type) continue;
      auto* op = next_op;
      bool out_op_is_invalid = false;
      for (auto* out_op : op->outputs[0]->outputs) {
        if (std::count(invalid_out_ops.begin(),
                       invalid_out_ops.end(),
                       out_op->Name()) > 0 ||
            HasOutVarName(out_op, op->outputs[0]->Name())) {
          out_op_is_invalid = true;
139 140 141
          break;
        }
      }
142 143 144 145 146 147 148 149 150
      if (out_op_is_invalid) continue;
      auto attr_key = gen_op_key_fn(op->Op());
      ops_map[attr_key].push_back(op);
    }
    for (auto iter = ops_map.begin(); iter != ops_map.end();) {
      if (iter->second.size() <= 1) {
        iter = ops_map.erase(iter);
      } else {
        iter++;
151 152 153
      }
    }

154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
    for (auto iter : ops_map) {
      auto ops = iter.second;
      auto* first_op_out = ops[0]->outputs[0];
      auto first_op_out_name = first_op_out->Name();
      std::unordered_set<const Node*> delete_nodes;
      for (size_t i = 1; i < ops.size(); i++) {
        auto* cur_op = ops[i];
        auto* cur_op_out = cur_op->outputs[0];
        auto cur_op_out_name = cur_op_out->Name();
        for (auto* out_op : cur_op_out->outputs) {
          out_op->Op()->RenameInput(cur_op_out_name, first_op_out_name);
          IR_NODE_LINK_TO(first_op_out, out_op);
        }
        delete_nodes.insert(cur_op);
        delete_nodes.insert(cur_op_out);
        delete_counts++;
170
      }
171
      GraphSafeRemoveNodes(graph, delete_nodes);
172 173 174 175
    }
  };

  gpd(graph, handler);
176 177 178 179
  if (delete_counts > 0) {
    LOG(INFO) << "--- delete " << delete_counts << " repeated " << op_type
              << " ops";
  }
180 181
}

182 183
std::string GenShapeAttrKey(OpDesc* slice_op_desc) { return ""; }

184 185 186 187 188
std::string GenSliceAttrKey(OpDesc* slice_op_desc) {
  std::string attr_key;
  auto starts = slice_op_desc->GetAttrIfExists<std::vector<int>>("starts");
  auto ends = slice_op_desc->GetAttrIfExists<std::vector<int>>("ends");
  auto axes = slice_op_desc->GetAttrIfExists<std::vector<int>>("axes");
189 190
  auto decrease_axis =
      slice_op_desc->GetAttrIfExists<std::vector<int>>("decrease_axis");
191 192 193 194 195 196 197 198 199 200 201 202
  attr_key += "starts_";
  for (auto start : starts) {
    attr_key += std::to_string(start) + "_";
  }
  attr_key += "ends_";
  for (auto end : ends) {
    attr_key += std::to_string(end) + "_";
  }
  attr_key += "axes_";
  for (auto axis : axes) {
    attr_key += std::to_string(axis) + "_";
  }
203 204 205 206
  attr_key += "decrease_axis_";
  for (auto axis : decrease_axis) {
    attr_key += std::to_string(axis) + "_";
  }
207 208 209
  return attr_key;
}

210 211 212 213 214 215
std::string GenCastAttrKey(OpDesc* cast_op_desc) {
  auto in_dtype = cast_op_desc->GetAttrIfExists<int>("in_dtype");
  auto out_dtype = cast_op_desc->GetAttrIfExists<int>("out_dtype");
  return "in_dtype_" + std::to_string(in_dtype) + "_out_dtype_" +
         std::to_string(out_dtype);
}
216

217 218 219 220 221 222
std::string GenAddAttrKey(OpDesc* add_op_desc) {
  std::string x_name = add_op_desc->Input("X")[0];
  std::string y_name = add_op_desc->Input("Y")[0];
  auto axis = add_op_desc->GetAttrIfExists<int>("axis");
  return x_name + "_" + y_name + "_axis_" + std::to_string(axis);
}
223

224 225 226 227 228 229 230
std::string GenScaleAttrKey(OpDesc* scale_op_desc) {
  auto scale = scale_op_desc->GetAttrIfExists<float>("scale");
  auto bias = scale_op_desc->GetAttrIfExists<float>("bias");
  auto bias_after_scale =
      scale_op_desc->GetAttrIfExists<bool>("bias_after_scale");
  return "scale_" + std::to_string(scale) + "_bias_" + std::to_string(bias) +
         "_bias_after_scale_" + std::to_string(bias_after_scale);
231 232 233 234 235 236 237
}

void DeleteRepeatedOpsPass::ApplyImpl(ir::Graph* graph) const {
  PADDLE_ENFORCE_NOT_NULL(
      graph, platform::errors::PreconditionNotMet("graph should not be null."));
  Init(name_scope_, graph);

238 239 240 241 242 243
  DeleteRepeatedOps(graph, "shape", GenShapeAttrKey);
  DeleteRepeatedOps(graph, "slice", GenSliceAttrKey);
  DeleteRepeatedOps(graph, "cast", GenCastAttrKey);
  DeleteRepeatedOps(graph, "elementwise_add", GenAddAttrKey);
  DeleteRepeatedOps(graph, "scale", GenScaleAttrKey);
  DeleteRepeatedOps(graph, "cast", GenCastAttrKey);
244 245 246 247 248 249 250 251 252 253 254 255 256
}

}  // namespace ir
}  // namespace framework
}  // namespace paddle

REGISTER_PASS(delete_repeated_ops_pass,
              paddle::framework::ir::DeleteRepeatedOpsPass);

REGISTER_PASS_CAPABILITY(delete_repeated_ops_pass)
    .AddCombination(
        paddle::framework::compatible::OpVersionComparatorCombination().EQ(
            "shape", 0));