delete_repeated_ops_pass.cc 9.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"

namespace phi {
class DenseTensor;
}  // namespace phi

namespace paddle {
namespace framework {
class Scope;
}  // namespace framework
}  // namespace paddle

namespace paddle {
namespace framework {
namespace ir {
35 36 37 38 39 40 41 42 43 44 45 46 47

bool HasOutVarName(Node* op_node, std::string name) {
  auto* op_desc = op_node->Op();
  auto outputs = op_desc->Outputs();
  for (auto iter : outputs) {
    auto out_names = iter.second;
    if (std::count(out_names.begin(), out_names.end(), name) > 0) {
      return true;
    }
  }
  return false;
}

48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
namespace patterns {

struct VarWithRepeatedOpsPattern : public PatternBase {
  VarWithRepeatedOpsPattern(PDPattern* pattern,
                            const std::string& name_scope,
                            const std::string& op_type);

  // declare variable node's name
  PATTERN_DECL_NODE(in_var);

  std::string op_type_;
};

VarWithRepeatedOpsPattern::VarWithRepeatedOpsPattern(
    PDPattern* pattern,
    const std::string& name_scope,
    const std::string& op_type)
    : PatternBase(pattern, name_scope, name_scope), op_type_(op_type) {
  pattern->NewNode(in_var_repr())
      ->assert_is_var()
      ->assert_more([&](Node* node) {
        auto out_nodes = node->outputs;
        if (out_nodes.size() <= 1) return false;
        int op_counts = 0;
        for (auto* next_op : out_nodes) {
          if (next_op->Name() == op_type_) {
            op_counts++;
          }
        }
        return op_counts > 1;
      });
}

}  // namespace patterns

/*
Delete repeated ops, for example:
Origin subgraph:
     (input_variable)
      /     |    \     ...
    shape shape shape  ...
      |     |     |    ...
     op0   op1   op2   ...

Optimized subgraph:
      (input_variable)
            |
          shape
         /  |  \     ...
       op0 op1 op2   ...
*/
class DeleteRepeatedOpsPass : public FusePassBase {
 protected:
  void ApplyImpl(ir::Graph* graph) const override;

 private:
104 105 106
  void DeleteRepeatedOps(ir::Graph* graph,
                         const std::string& op_type,
                         std::function<std::string(Node*)> gen_op_key_fn) const;
107 108

  const std::string name_scope_{"delete_repeated_ops_pass"};
109
  mutable int delete_op_count{0};
110 111
};

112 113 114
void DeleteRepeatedOpsPass::DeleteRepeatedOps(
    ir::Graph* graph,
    const std::string& op_type,
115
    std::function<std::string(Node*)> gen_op_key_fn) const {
116 117
  GraphPatternDetector gpd;
  patterns::VarWithRepeatedOpsPattern pattern(
118
      gpd.mutable_pattern(), name_scope_, op_type);
119 120 121 122

  int delete_counts = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* graph) {
123
    VLOG(4) << "handle DeleteRepeatedOps";
124 125
    GET_IR_NODE_FROM_SUBGRAPH(in_var, in_var, pattern);

126 127 128
    std::vector<std::string> invalid_out_ops{
        "while", "conditional_block", "fetch"};
    std::map<std::string, std::vector<Node*>> ops_map;
129
    for (auto* next_op : in_var->outputs) {
130 131 132 133 134 135 136 137 138
      if (next_op->Name() != op_type) continue;
      auto* op = next_op;
      bool out_op_is_invalid = false;
      for (auto* out_op : op->outputs[0]->outputs) {
        if (std::count(invalid_out_ops.begin(),
                       invalid_out_ops.end(),
                       out_op->Name()) > 0 ||
            HasOutVarName(out_op, op->outputs[0]->Name())) {
          out_op_is_invalid = true;
139 140 141
          break;
        }
      }
142
      if (out_op_is_invalid) continue;
143
      auto attr_key = gen_op_key_fn(op);
144 145 146 147 148 149 150
      ops_map[attr_key].push_back(op);
    }
    for (auto iter = ops_map.begin(); iter != ops_map.end();) {
      if (iter->second.size() <= 1) {
        iter = ops_map.erase(iter);
      } else {
        iter++;
151 152 153
      }
    }

154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
    for (auto iter : ops_map) {
      auto ops = iter.second;
      auto* first_op_out = ops[0]->outputs[0];
      auto first_op_out_name = first_op_out->Name();
      std::unordered_set<const Node*> delete_nodes;
      for (size_t i = 1; i < ops.size(); i++) {
        auto* cur_op = ops[i];
        auto* cur_op_out = cur_op->outputs[0];
        auto cur_op_out_name = cur_op_out->Name();
        for (auto* out_op : cur_op_out->outputs) {
          out_op->Op()->RenameInput(cur_op_out_name, first_op_out_name);
          IR_NODE_LINK_TO(first_op_out, out_op);
        }
        delete_nodes.insert(cur_op);
        delete_nodes.insert(cur_op_out);
        delete_counts++;
170
      }
171
      GraphSafeRemoveNodes(graph, delete_nodes);
172 173 174 175
    }
  };

  gpd(graph, handler);
176
  delete_op_count += delete_counts;
177 178 179 180
  if (delete_counts > 0) {
    LOG(INFO) << "--- delete " << delete_counts << " repeated " << op_type
              << " ops";
  }
181 182
}

183
std::string GenShapeAttrKey(Node* shape_op_node) { return ""; }
184

185
std::string GenSliceAttrKey(Node* slice_op_node) {
186
  std::string attr_key;
187
  auto slice_op_desc = slice_op_node->Op();
188 189 190
  auto starts = slice_op_desc->GetAttrIfExists<std::vector<int>>("starts");
  auto ends = slice_op_desc->GetAttrIfExists<std::vector<int>>("ends");
  auto axes = slice_op_desc->GetAttrIfExists<std::vector<int>>("axes");
191 192
  auto decrease_axis =
      slice_op_desc->GetAttrIfExists<std::vector<int>>("decrease_axis");
193 194 195 196 197 198 199 200 201 202 203 204
  attr_key += "starts_";
  for (auto start : starts) {
    attr_key += std::to_string(start) + "_";
  }
  attr_key += "ends_";
  for (auto end : ends) {
    attr_key += std::to_string(end) + "_";
  }
  attr_key += "axes_";
  for (auto axis : axes) {
    attr_key += std::to_string(axis) + "_";
  }
205 206 207 208
  attr_key += "decrease_axis_";
  for (auto axis : decrease_axis) {
    attr_key += std::to_string(axis) + "_";
  }
209 210 211
  return attr_key;
}

212 213
std::string GenCastAttrKey(Node* cast_op_node) {
  auto cast_op_desc = cast_op_node->Op();
214 215 216 217 218
  auto in_dtype = cast_op_desc->GetAttrIfExists<int>("in_dtype");
  auto out_dtype = cast_op_desc->GetAttrIfExists<int>("out_dtype");
  return "in_dtype_" + std::to_string(in_dtype) + "_out_dtype_" +
         std::to_string(out_dtype);
}
219

220 221
std::string GenAddAttrKey(Node* add_op_node) {
  auto add_op_desc = add_op_node->Op();
222 223 224 225 226
  std::string x_name = add_op_desc->Input("X")[0];
  std::string y_name = add_op_desc->Input("Y")[0];
  auto axis = add_op_desc->GetAttrIfExists<int>("axis");
  return x_name + "_" + y_name + "_axis_" + std::to_string(axis);
}
227

228 229
std::string GenScaleAttrKey(Node* scale_op_node) {
  auto scale_op_desc = scale_op_node->Op();
230 231 232 233 234 235
  auto scale = scale_op_desc->GetAttrIfExists<float>("scale");
  auto bias = scale_op_desc->GetAttrIfExists<float>("bias");
  auto bias_after_scale =
      scale_op_desc->GetAttrIfExists<bool>("bias_after_scale");
  return "scale_" + std::to_string(scale) + "_bias_" + std::to_string(bias) +
         "_bias_after_scale_" + std::to_string(bias_after_scale);
236 237
}

238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258
std::string GenGatherAttrKey(Node* gather_op_node) {
  std::string input_names{""};
  for (auto input_var : gather_op_node->inputs) {
    input_names += input_var->Var()->Name();
  }
  auto gather_op_desc = gather_op_node->Op();
  auto axis = gather_op_desc->GetAttrIfExists<int>("axis");
  return "axis_" + std::to_string(axis) + "_input_names_" + input_names;
}

std::string GenSqueeze2AttrKey(Node* squeeze2_op_node) {
  auto squeeze2_op_desc = squeeze2_op_node->Op();
  auto axes = squeeze2_op_desc->GetAttrIfExists<std::vector<int>>("axes");
  std::string attr_key{""};
  attr_key += "axes_";
  for (auto axis : axes) {
    attr_key += std::to_string(axis) + "_";
  }
  return attr_key;
}

259 260 261 262
void DeleteRepeatedOpsPass::ApplyImpl(ir::Graph* graph) const {
  PADDLE_ENFORCE_NOT_NULL(
      graph, platform::errors::PreconditionNotMet("graph should not be null."));
  Init(name_scope_, graph);
263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284
  int repeat_time = 0;
  int total_delete_op_count = 0;
  // This pass needs to loop run until there are no nodes in the graph that need
  // to be deleted.
  while (true) {
    delete_op_count = 0;
    DeleteRepeatedOps(graph, "shape", GenShapeAttrKey);
    DeleteRepeatedOps(graph, "slice", GenSliceAttrKey);
    DeleteRepeatedOps(graph, "cast", GenCastAttrKey);
    DeleteRepeatedOps(graph, "elementwise_add", GenAddAttrKey);
    DeleteRepeatedOps(graph, "scale", GenScaleAttrKey);
    DeleteRepeatedOps(graph, "gather", GenGatherAttrKey);
    DeleteRepeatedOps(graph, "squeeze2", GenSqueeze2AttrKey);
    DeleteRepeatedOps(graph, "unsqueeze2", GenSqueeze2AttrKey);
    LOG(INFO) << "Round " << repeat_time++
              << ": delete op counts: " << delete_op_count;
    total_delete_op_count += delete_op_count;
    if (delete_op_count == 0) {
      break;  // No node need to delete.
    }
  }
  LOG(INFO) << "Total delete op counts: " << total_delete_op_count;
285 286 287 288 289 290 291 292 293 294 295 296 297
}

}  // namespace ir
}  // namespace framework
}  // namespace paddle

REGISTER_PASS(delete_repeated_ops_pass,
              paddle::framework::ir::DeleteRepeatedOpsPass);

REGISTER_PASS_CAPABILITY(delete_repeated_ops_pass)
    .AddCombination(
        paddle::framework::compatible::OpVersionComparatorCombination().EQ(
            "shape", 0));