graph.h 13.9 KB
Newer Older
S
sneaxiy 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include <gflags/gflags.h>
S
sneaxiy 已提交
18 19 20
#include <map>
#include <memory>
#include <string>
21
#include <unordered_set>
S
sneaxiy 已提交
22 23 24 25 26 27
#include <vector>

#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/variant.h"
28
#include "paddle/utils/any.h"
S
sneaxiy 已提交
29

30 31
DECLARE_bool(convert_all_blocks);

W
wanghuancoder 已提交
32 33 34 35 36 37 38
namespace paddle {
namespace framework {
class OpDesc;
class VarDesc;
}  // namespace framework
}  // namespace paddle

S
sneaxiy 已提交
39 40
namespace paddle {
namespace framework {
Y
Yancey1989 已提交
41 42

namespace details {
Y
Yancey1989 已提交
43 44 45

// This attr is not recommended, because the graph should not dependence
// the program once it is built.
X
Xin Pan 已提交
46
constexpr char kStaleProgramOpDescs[] = "stale_program_op_descs";
Y
Yancey1989 已提交
47 48
}  //  namespace details

S
sneaxiy 已提交
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
namespace ir {

/*
 * The graph is a Directed Acyclic Single Static Assignment Graph.
 *
 * In more detail, the following properties must hold:
 *
 *   The graph shouldn't contain cycle. Each node is a black-box to the graph
 *   so the node itself could be a loop operator.
 *
 *   Each Variable-type node has only one input (thus single static assignment).
 *
 *   The output/input of operator is variable and the output/input of variable
 *   is operator.
 *
 * The following data harzards in Program are addressed in the Graph:
 *
 *   Write-After-Read
 *     a = op1(x)
 *     x = op2(b)
 *     A control-dependency connection is created bettwen op1 and op2 such that
 *     op1->op2, so as to ensure correct order.
 *
 *   Write-After-Write
 *     x = op1(a)
 *     x = op2(b)
 *     A control-dependency connection is created between op1 and op2 such that
 *     op1->op2, so as to ensure correct order.
 *
 * Other properties currently hold, but is not enforced yet:
 *
 *   Variable-type node (not control dep) with the same variable name share
 *   the same underlying VarDesc.
 */
class Graph {
 public:
85
  // Construct a main_graph with some sub_graphs
S
sneaxiy 已提交
86
  explicit Graph(const ProgramDesc &program);
87 88 89 90 91 92 93 94 95 96 97 98

  // Construct a main_graph with some sub_graphs, and the 1st sub_graph is
  // constructed with ops[start_op_index, end_op_index)
  Graph(const ProgramDesc &program, const int64_t start_op_index,
        const int64_t end_op_index);

  // Construct a sub_graph
  Graph(const BlockDesc &block, const Graph *main_graph);

  // Construct a sub_graph with ops[start_op_index, end_op_index)
  Graph(const BlockDesc &block, const Graph *main_graph,
        const int64_t start_op_index, const int64_t end_op_index);
S
sneaxiy 已提交
99 100 101 102 103 104 105 106 107

  virtual ~Graph() {
    for (auto &attr : attrs_) {
      attr_dels_[attr.first]();
    }
    attrs_.clear();
    attr_dels_.clear();
  }

108 109 110 111 112 113 114 115
  bool IsConstructedByPartialProgram() const {
    if (FLAGS_convert_all_blocks) {
      if (IsMainGraph()) {
        return GetSubGraph(0)->IsConstructedByPartialProgram();
      }
    }
    return is_partial_;
  }
116

S
sneaxiy 已提交
117
  bool Has(const std::string &attr_name) const {
118 119 120 121 122
    if (FLAGS_convert_all_blocks) {
      if (IsMainGraph()) {
        return GetSubGraph(0)->Has(attr_name);
      }
    }
S
sneaxiy 已提交
123
    return attrs_.count(attr_name) > 0;
S
sneaxiy 已提交
124 125
  }

126 127
  template <typename AttrType>
  AttrType &GetOrInit(const std::string &attr_name) {
128 129 130 131 132
    if (FLAGS_convert_all_blocks) {
      if (IsMainGraph()) {
        return GetSubGraph(0)->GetOrInit<AttrType>(attr_name);
      }
    }
133 134 135 136 137 138
    if (!Has(attr_name)) {
      Set(attr_name, new AttrType);
    }
    return Get<AttrType>(attr_name);
  }

S
sneaxiy 已提交
139 140
  template <typename AttrType>
  AttrType &Get(const std::string &attr_name) const {
141 142 143 144 145
    if (FLAGS_convert_all_blocks) {
      if (IsMainGraph()) {
        return GetSubGraph(0)->Get<AttrType>(attr_name);
      }
    }
146 147 148 149
    PADDLE_ENFORCE_EQ(
        Has(attr_name), true,
        platform::errors::PreconditionNotMet(
            "%s attribute not registered for current graph.", attr_name));
S
sneaxiy 已提交
150
    try {
151 152
      return *paddle::any_cast<AttrType *>(attrs_.at(attr_name));
    } catch (paddle::bad_any_cast &) {
153 154 155 156
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Invalid attribute type of %s, expected: %s, received: %s.",
          attr_name, platform::demangle(typeid(AttrType *).name()),  // NOLINT
          platform::demangle(attrs_.at(attr_name).type().name())));
S
sneaxiy 已提交
157
    }
S
sneaxiy 已提交
158 159 160 161
  }

  template <typename AttrType>
  void Set(const std::string &attr_name, AttrType *attr) {
162 163 164 165 166
    if (FLAGS_convert_all_blocks) {
      if (IsMainGraph()) {
        return GetSubGraph(0)->Set<AttrType>(attr_name, attr);
      }
    }
167 168 169
    PADDLE_ENFORCE_EQ(
        attrs_.count(attr_name), 0,
        platform::errors::AlreadyExists(
170 171
            "The attribute %s to be set already exists in the graph.",
            attr_name));
S
sneaxiy 已提交
172 173
    attrs_[attr_name] = attr;
    attr_dels_[attr_name] = [attr, attr_name]() {
M
minqiyang 已提交
174
      VLOG(3) << "deleting " << attr_name;
S
sneaxiy 已提交
175 176 177 178 179 180
      delete attr;
    };
  }

  template <typename AttrType>
  void SetNotOwned(const std::string &attr_name, AttrType *attr) {
181 182 183 184 185
    if (FLAGS_convert_all_blocks) {
      if (IsMainGraph()) {
        return GetSubGraph(0)->SetNotOwned<AttrType>(attr_name, attr);
      }
    }
186 187
    PADDLE_ENFORCE_EQ(
        attrs_.count(attr_name), 0,
188 189 190
        platform::errors::AlreadyExists("The attribute %s to be set(not owned) "
                                        "already exists in the graph.",
                                        attr_name));
S
sneaxiy 已提交
191 192 193 194
    attrs_[attr_name] = attr;
    attr_dels_[attr_name] = []() {};
  }

X
Xin Pan 已提交
195
  void Erase(const std::string &attr_name) {
196 197 198 199 200
    if (FLAGS_convert_all_blocks) {
      if (IsMainGraph()) {
        return GetSubGraph(0)->Erase(attr_name);
      }
    }
201 202 203
    PADDLE_ENFORCE_NE(
        attrs_.count(attr_name), 0,
        platform::errors::NotFound(
204 205
            "The attribute %s to be erased does not exist in the graph.",
            attr_name));
X
Xin Pan 已提交
206 207 208 209 210
    attr_dels_[attr_name]();
    attrs_.erase(attr_name);
    attr_dels_.erase(attr_name);
  }

211 212 213 214 215 216 217 218
  const std::unordered_set<ir::Node *> &Nodes() const {
    if (FLAGS_convert_all_blocks) {
      if (IsMainGraph()) {
        return GetSubGraph(0)->Nodes();
      }
    }
    return node_set_;
  }
S
sneaxiy 已提交
219 220

  // Create a normal variable with non-null VarDesc.
221
  ir::Node *CreateVarNode(VarDesc *var_desc, int block_id = -1) {
222 223 224 225 226
    if (FLAGS_convert_all_blocks) {
      if (IsMainGraph()) {
        return GetSubGraph(0)->CreateVarNode(var_desc);
      }
    }
227 228 229
    PADDLE_ENFORCE_NOT_NULL(
        var_desc, platform::errors::InvalidArgument(
                      "The VarDesc used to create variable node is null."));
230 231
    auto *x =
        AddNode(new ir::Node(var_desc, block_id == -1 ? block_id_ : block_id));
232 233
    x->SetId(num_node_created_++);
    return x;
S
sneaxiy 已提交
234 235 236 237
  }

  // Create a normal runnable operator with OpDesc.
  ir::Node *CreateOpNode(OpDesc *op_desc) {
238 239 240 241 242
    if (FLAGS_convert_all_blocks) {
      if (IsMainGraph()) {
        return GetSubGraph(0)->CreateOpNode(op_desc);
      }
    }
243 244 245
    PADDLE_ENFORCE_NOT_NULL(
        op_desc, platform::errors::InvalidArgument(
                     "The OpDesc used to create operator node is null."));
246 247 248
    auto *x = AddNode(new ir::Node(op_desc));
    x->SetId(num_node_created_++);
    return x;
S
sneaxiy 已提交
249 250 251 252 253 254
  }

  // Create a control dependency var that connects 2 operations. The
  // var doesn't hold any data. Other than that, it's no different from
  // other var, considering dependency analysis.
  ir::Node *CreateControlDepVar() {
255 256 257 258 259
    if (FLAGS_convert_all_blocks) {
      if (IsMainGraph()) {
        return GetSubGraph(0)->CreateControlDepVar();
      }
    }
S
sneaxiy 已提交
260 261
    // TODO(panyx0718): control var name should be really unique.
    const std::string name = string::Sprintf(
262
        "%s@%llu", static_cast<const char *>(ir::Node::kControlDepVarName),
D
Dun Liang 已提交
263
        num_node_created_);
264
    auto *x = AddNode(new ir::Node(name, ir::Node::Type::kVariable, block_id_));
265 266
    x->SetId(num_node_created_++);
    return x;
S
sneaxiy 已提交
267 268 269 270 271
  }

  // A more free style way of creating a graph node. Mostly use for test
  // or "copy" from another node. Avoid using it if possible.
  ir::Node *CreateEmptyNode(const std::string &name, ir::Node::Type type) {
272 273 274 275 276
    if (FLAGS_convert_all_blocks) {
      if (IsMainGraph()) {
        return GetSubGraph(0)->CreateEmptyNode(name, type);
      }
    }
277
    auto *x = AddNode(new ir::Node(name, type, block_id_));
278 279
    x->SetId(num_node_created_++);
    return x;
S
sneaxiy 已提交
280 281 282 283 284
  }

  // Clear all node information of the graph and return the ownership of the
  // nodes.
  std::vector<std::unique_ptr<ir::Node>> ReleaseNodes() {
285 286 287 288 289
    if (FLAGS_convert_all_blocks) {
      if (IsMainGraph()) {
        return GetSubGraph(0)->ReleaseNodes();
      }
    }
S
sneaxiy 已提交
290 291 292 293 294 295 296 297 298
    std::vector<std::unique_ptr<ir::Node>> ret;
    for (auto &n : nodes_) {
      ret.emplace_back(n.second.release());
    }
    nodes_.clear();
    node_set_.clear();
    return ret;
  }

Y
Yancey1989 已提交
299
  std::unique_ptr<ir::Node> RemoveNode(ir::Node *node) {
300 301 302 303 304
    if (FLAGS_convert_all_blocks) {
      if (IsMainGraph()) {
        return GetSubGraph(0)->RemoveNode(node);
      }
    }
305 306 307
    PADDLE_ENFORCE_EQ(node_set_.find(node) != node_set_.end(), true,
                      platform::errors::PreconditionNotMet(
                          "The node to be removed does not exist."));
Y
Yancey1989 已提交
308 309 310 311 312 313 314
    std::unique_ptr<ir::Node> ret;
    ret.reset(nodes_.at(node).release());
    nodes_.erase(node);
    node_set_.erase(node);
    return ret;
  }

S
sneaxiy 已提交
315
  // NOTE low performance, but simple and secure.
316
  Node *RetrieveNode(int id) {
317 318 319 320 321
    if (FLAGS_convert_all_blocks) {
      if (IsMainGraph()) {
        return GetSubGraph(0)->RetrieveNode(id);
      }
    }
S
sneaxiy 已提交
322 323 324 325 326 327 328 329
    for (auto &node : nodes_) {
      if (node.second->id() == id) {
        return node.second.get();
      }
    }
    return nullptr;
  }

X
fix  
Xin Pan 已提交
330 331 332 333
  // Returns reference to the original program.
  // WARN: After a series of passes, the current graph can be quite
  // different from OriginProgram. Caller shouldn't assume much from
  // the returned OriginProgram.
334 335 336 337 338 339 340 341
  const ProgramDesc &OriginProgram() const {
    if (FLAGS_convert_all_blocks) {
      if (!IsMainGraph()) {
        return main_graph_->OriginProgram();
      }
    }
    return program_;
  }
X
fix  
Xin Pan 已提交
342

S
sneaxiy 已提交
343 344
  // This method takes ownership of `node`.
  ir::Node *AddNode(ir::Node *node) {
345 346 347 348 349
    if (FLAGS_convert_all_blocks) {
      if (IsMainGraph()) {
        return GetSubGraph(0)->AddNode(node);
      }
    }
350 351 352
    PADDLE_ENFORCE_EQ(node_set_.find(node) == node_set_.end(), true,
                      platform::errors::PreconditionNotMet(
                          "The node to be added already exists."));
S
sneaxiy 已提交
353 354 355 356 357
    nodes_[node].reset(node);
    node_set_.insert(node);
    return node;
  }

Y
Yancey1989 已提交
358 359 360
  void ResolveHazard(
      const std::map<std::string, std::vector<ir::Node *>> &var_nodes);

361 362 363 364
  // Create a new and duplicated graph.
  // WARN: The method only clones the graph structure, not its attributes.
  std::shared_ptr<Graph> Clone();

365 366 367 368 369 370 371 372 373 374 375 376
  bool IsMainGraph() const { return main_graph_ == nullptr; }

  Graph *GetSubGraph(const size_t idx) const {
    PADDLE_ENFORCE_EQ(
        this->IsMainGraph(), true,
        platform::errors::InvalidArgument("This graph is not main_graph"));
    PADDLE_ENFORCE_LT(
        idx, sub_graphs_.size(),
        platform::errors::InvalidArgument("Invalid sub_graph index"));
    return sub_graphs_.at(idx).get();
  }

377 378 379 380 381 382 383 384 385
  int GetBlockId() const {
    if (FLAGS_convert_all_blocks) {
      if (IsMainGraph()) {
        return GetSubGraph(0)->block_id_;
      }
    }
    return block_id_;
  }

386 387 388 389 390 391 392
  size_t SubGraphsSize() const {
    PADDLE_ENFORCE_EQ(
        this->IsMainGraph(), true,
        platform::errors::InvalidArgument("This graph is not main_graph"));
    return sub_graphs_.size();
  }

Y
Yancey1989 已提交
393
 private:
394 395
  // TODO(levi): delete this interface after when we can convert all
  // blocks into sub_graphs.
Y
Yancey1989 已提交
396
  std::map<std::string, std::vector<ir::Node *>> InitFromProgram(
397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414
      const ProgramDesc &program, const int64_t start_op_index,
      const int64_t end_op_index);

  std::map<std::string, std::vector<ir::Node *>> InitFromBlock(
      const BlockDesc &block, const int64_t start_op_index,
      const int64_t end_op_index);

  void ReleaseSubGraphs() {
    PADDLE_ENFORCE_EQ(
        this->IsMainGraph(), true,
        platform::errors::InvalidArgument("This graph is not main_graph"));
    sub_graphs_.clear();
  }

  void AddSubGraph(std::unique_ptr<Graph> sub_graph) {
    PADDLE_ENFORCE_EQ(
        this->IsMainGraph(), true,
        platform::errors::InvalidArgument("This graph is not main_graph"));
415 416 417
    PADDLE_ENFORCE_EQ(sub_graphs_.size(), sub_graph->block_id_,
                      platform::errors::InvalidArgument(
                          "sub_graph idx is not equal to block_id_"));
418 419 420 421
    sub_graphs_.push_back(std::move(sub_graph));
  }

  std::unique_ptr<Graph> CloneSubGraph(const size_t idx);
Y
Yancey1989 已提交
422

S
sneaxiy 已提交
423 424
  // NOTE: program_ shouldn't be exposed to user.
  const ProgramDesc program_;
425 426 427 428 429
  // NOTE: main_graph_ doesn't hold any node. It's used as a container of
  // sub_graphs, and the sub_graph holds the nodes.
  const Graph *main_graph_;  // not owned.
  std::vector<std::unique_ptr<Graph>> sub_graphs_;

430
  std::map<std::string, paddle::any> attrs_;
S
sneaxiy 已提交
431 432 433
  std::map<std::string, std::function<void(void)>> attr_dels_;
  std::map<ir::Node *, std::unique_ptr<ir::Node>> nodes_;
  std::unordered_set<ir::Node *> node_set_;
434
  size_t num_node_created_{0};  // help to generate a unique node id.
435 436 437 438 439
  // NOTE(Aurelius84): Whether is constructed with partial ProgramDesc.
  // In case of @to_static, whole trainning program is splited into two
  // parts: forward graph and backward graph, which can be executed
  // independently.
  bool is_partial_{false};
440 441
  // The block this SubGraph belongs to.
  int block_id_{0};
S
sneaxiy 已提交
442 443 444 445 446 447
};

bool IsControlDepVar(const ir::Node &var);
}  // namespace ir
}  // namespace framework
}  // namespace paddle