graph.h 14.0 KB
Newer Older
S
sneaxiy 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include <gflags/gflags.h>
S
sneaxiy 已提交
18 19 20
#include <map>
#include <memory>
#include <string>
21
#include <unordered_set>
S
sneaxiy 已提交
22 23 24 25 26 27
#include <vector>

#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/variant.h"
28
#include "paddle/utils/any.h"
S
sneaxiy 已提交
29

30 31
DECLARE_bool(convert_all_blocks);

W
wanghuancoder 已提交
32 33 34 35 36 37 38
namespace paddle {
namespace framework {
class OpDesc;
class VarDesc;
}  // namespace framework
}  // namespace paddle

S
sneaxiy 已提交
39 40
namespace paddle {
namespace framework {
Y
Yancey1989 已提交
41 42

namespace details {
Y
Yancey1989 已提交
43 44 45

// This attr is not recommended, because the graph should not dependence
// the program once it is built.
X
Xin Pan 已提交
46
constexpr char kStaleProgramOpDescs[] = "stale_program_op_descs";
Y
Yancey1989 已提交
47 48
}  //  namespace details

S
sneaxiy 已提交
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
namespace ir {

/*
 * The graph is a Directed Acyclic Single Static Assignment Graph.
 *
 * In more detail, the following properties must hold:
 *
 *   The graph shouldn't contain cycle. Each node is a black-box to the graph
 *   so the node itself could be a loop operator.
 *
 *   Each Variable-type node has only one input (thus single static assignment).
 *
 *   The output/input of operator is variable and the output/input of variable
 *   is operator.
 *
 * The following data harzards in Program are addressed in the Graph:
 *
 *   Write-After-Read
 *     a = op1(x)
 *     x = op2(b)
 *     A control-dependency connection is created bettwen op1 and op2 such that
 *     op1->op2, so as to ensure correct order.
 *
 *   Write-After-Write
 *     x = op1(a)
 *     x = op2(b)
 *     A control-dependency connection is created between op1 and op2 such that
 *     op1->op2, so as to ensure correct order.
 *
 * Other properties currently hold, but is not enforced yet:
 *
 *   Variable-type node (not control dep) with the same variable name share
 *   the same underlying VarDesc.
 */
class Graph {
 public:
85
  // Construct a main_graph with some sub_graphs
S
sneaxiy 已提交
86
  explicit Graph(const ProgramDesc &program);
87 88 89 90 91 92 93 94 95 96 97 98

  // Construct a main_graph with some sub_graphs, and the 1st sub_graph is
  // constructed with ops[start_op_index, end_op_index)
  Graph(const ProgramDesc &program, const int64_t start_op_index,
        const int64_t end_op_index);

  // Construct a sub_graph
  Graph(const BlockDesc &block, const Graph *main_graph);

  // Construct a sub_graph with ops[start_op_index, end_op_index)
  Graph(const BlockDesc &block, const Graph *main_graph,
        const int64_t start_op_index, const int64_t end_op_index);
S
sneaxiy 已提交
99 100 101 102 103 104 105 106 107

  virtual ~Graph() {
    for (auto &attr : attrs_) {
      attr_dels_[attr.first]();
    }
    attrs_.clear();
    attr_dels_.clear();
  }

108 109 110 111 112 113 114 115
  bool IsConstructedByPartialProgram() const {
    if (FLAGS_convert_all_blocks) {
      if (IsMainGraph()) {
        return GetSubGraph(0)->IsConstructedByPartialProgram();
      }
    }
    return is_partial_;
  }
116

S
sneaxiy 已提交
117
  bool Has(const std::string &attr_name) const {
118 119 120 121 122
    if (FLAGS_convert_all_blocks) {
      if (IsMainGraph()) {
        return GetSubGraph(0)->Has(attr_name);
      }
    }
S
sneaxiy 已提交
123
    return attrs_.count(attr_name) > 0;
S
sneaxiy 已提交
124 125
  }

126 127
  template <typename AttrType>
  AttrType &GetOrInit(const std::string &attr_name) {
128 129 130 131 132
    if (FLAGS_convert_all_blocks) {
      if (IsMainGraph()) {
        return GetSubGraph(0)->GetOrInit<AttrType>(attr_name);
      }
    }
133 134 135 136 137 138
    if (!Has(attr_name)) {
      Set(attr_name, new AttrType);
    }
    return Get<AttrType>(attr_name);
  }

S
sneaxiy 已提交
139 140
  template <typename AttrType>
  AttrType &Get(const std::string &attr_name) const {
141 142 143 144 145
    if (FLAGS_convert_all_blocks) {
      if (IsMainGraph()) {
        return GetSubGraph(0)->Get<AttrType>(attr_name);
      }
    }
146 147 148 149
    PADDLE_ENFORCE_EQ(
        Has(attr_name), true,
        platform::errors::PreconditionNotMet(
            "%s attribute not registered for current graph.", attr_name));
S
sneaxiy 已提交
150
    try {
151 152
      return *paddle::any_cast<AttrType *>(attrs_.at(attr_name));
    } catch (paddle::bad_any_cast &) {
153 154 155 156
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Invalid attribute type of %s, expected: %s, received: %s.",
          attr_name, platform::demangle(typeid(AttrType *).name()),  // NOLINT
          platform::demangle(attrs_.at(attr_name).type().name())));
S
sneaxiy 已提交
157
    }
S
sneaxiy 已提交
158 159 160 161
  }

  template <typename AttrType>
  void Set(const std::string &attr_name, AttrType *attr) {
162 163 164 165 166
    if (FLAGS_convert_all_blocks) {
      if (IsMainGraph()) {
        return GetSubGraph(0)->Set<AttrType>(attr_name, attr);
      }
    }
167 168 169
    PADDLE_ENFORCE_EQ(
        attrs_.count(attr_name), 0,
        platform::errors::AlreadyExists(
170 171
            "The attribute %s to be set already exists in the graph.",
            attr_name));
S
sneaxiy 已提交
172 173
    attrs_[attr_name] = attr;
    attr_dels_[attr_name] = [attr, attr_name]() {
M
minqiyang 已提交
174
      VLOG(3) << "deleting " << attr_name;
S
sneaxiy 已提交
175 176 177 178 179 180
      delete attr;
    };
  }

  template <typename AttrType>
  void SetNotOwned(const std::string &attr_name, AttrType *attr) {
181 182 183 184 185
    if (FLAGS_convert_all_blocks) {
      if (IsMainGraph()) {
        return GetSubGraph(0)->SetNotOwned<AttrType>(attr_name, attr);
      }
    }
186 187
    PADDLE_ENFORCE_EQ(
        attrs_.count(attr_name), 0,
188 189 190
        platform::errors::AlreadyExists("The attribute %s to be set(not owned) "
                                        "already exists in the graph.",
                                        attr_name));
S
sneaxiy 已提交
191 192 193 194
    attrs_[attr_name] = attr;
    attr_dels_[attr_name] = []() {};
  }

X
Xin Pan 已提交
195
  void Erase(const std::string &attr_name) {
196 197 198 199 200
    if (FLAGS_convert_all_blocks) {
      if (IsMainGraph()) {
        return GetSubGraph(0)->Erase(attr_name);
      }
    }
201 202 203
    PADDLE_ENFORCE_NE(
        attrs_.count(attr_name), 0,
        platform::errors::NotFound(
204 205
            "The attribute %s to be erased does not exist in the graph.",
            attr_name));
X
Xin Pan 已提交
206 207 208 209 210
    attr_dels_[attr_name]();
    attrs_.erase(attr_name);
    attr_dels_.erase(attr_name);
  }

211 212 213 214 215 216 217 218
  const std::unordered_set<ir::Node *> &Nodes() const {
    if (FLAGS_convert_all_blocks) {
      if (IsMainGraph()) {
        return GetSubGraph(0)->Nodes();
      }
    }
    return node_set_;
  }
S
sneaxiy 已提交
219 220

  // Create a normal variable with non-null VarDesc.
221
  ir::Node *CreateVarNode(VarDesc *var_desc, int block_id = -1) {
222 223 224 225 226
    if (FLAGS_convert_all_blocks) {
      if (IsMainGraph()) {
        return GetSubGraph(0)->CreateVarNode(var_desc);
      }
    }
227 228 229
    PADDLE_ENFORCE_NOT_NULL(
        var_desc, platform::errors::InvalidArgument(
                      "The VarDesc used to create variable node is null."));
230 231
    auto *x =
        AddNode(new ir::Node(var_desc, block_id == -1 ? block_id_ : block_id));
232
    x->SetId(num_node_created_++);
233
    x->SetGraphId(block_id_);
234
    return x;
S
sneaxiy 已提交
235 236 237 238
  }

  // Create a normal runnable operator with OpDesc.
  ir::Node *CreateOpNode(OpDesc *op_desc) {
239 240 241 242 243
    if (FLAGS_convert_all_blocks) {
      if (IsMainGraph()) {
        return GetSubGraph(0)->CreateOpNode(op_desc);
      }
    }
244 245 246
    PADDLE_ENFORCE_NOT_NULL(
        op_desc, platform::errors::InvalidArgument(
                     "The OpDesc used to create operator node is null."));
247 248
    auto *x = AddNode(new ir::Node(op_desc));
    x->SetId(num_node_created_++);
249
    x->SetGraphId(block_id_);
250
    return x;
S
sneaxiy 已提交
251 252 253 254 255 256
  }

  // Create a control dependency var that connects 2 operations. The
  // var doesn't hold any data. Other than that, it's no different from
  // other var, considering dependency analysis.
  ir::Node *CreateControlDepVar() {
257 258 259 260 261
    if (FLAGS_convert_all_blocks) {
      if (IsMainGraph()) {
        return GetSubGraph(0)->CreateControlDepVar();
      }
    }
S
sneaxiy 已提交
262 263
    // TODO(panyx0718): control var name should be really unique.
    const std::string name = string::Sprintf(
264
        "%s@%llu", static_cast<const char *>(ir::Node::kControlDepVarName),
D
Dun Liang 已提交
265
        num_node_created_);
266
    auto *x = AddNode(new ir::Node(name, ir::Node::Type::kVariable, block_id_));
267
    x->SetId(num_node_created_++);
268
    x->SetGraphId(block_id_);
269
    return x;
S
sneaxiy 已提交
270 271 272 273 274
  }

  // A more free style way of creating a graph node. Mostly use for test
  // or "copy" from another node. Avoid using it if possible.
  ir::Node *CreateEmptyNode(const std::string &name, ir::Node::Type type) {
275 276 277 278 279
    if (FLAGS_convert_all_blocks) {
      if (IsMainGraph()) {
        return GetSubGraph(0)->CreateEmptyNode(name, type);
      }
    }
280
    auto *x = AddNode(new ir::Node(name, type, block_id_));
281
    x->SetId(num_node_created_++);
282
    x->SetGraphId(block_id_);
283
    return x;
S
sneaxiy 已提交
284 285 286 287 288
  }

  // Clear all node information of the graph and return the ownership of the
  // nodes.
  std::vector<std::unique_ptr<ir::Node>> ReleaseNodes() {
289 290 291 292 293
    if (FLAGS_convert_all_blocks) {
      if (IsMainGraph()) {
        return GetSubGraph(0)->ReleaseNodes();
      }
    }
S
sneaxiy 已提交
294 295 296 297 298 299 300 301 302
    std::vector<std::unique_ptr<ir::Node>> ret;
    for (auto &n : nodes_) {
      ret.emplace_back(n.second.release());
    }
    nodes_.clear();
    node_set_.clear();
    return ret;
  }

Y
Yancey1989 已提交
303
  std::unique_ptr<ir::Node> RemoveNode(ir::Node *node) {
304 305 306 307 308
    if (FLAGS_convert_all_blocks) {
      if (IsMainGraph()) {
        return GetSubGraph(0)->RemoveNode(node);
      }
    }
309 310 311
    PADDLE_ENFORCE_EQ(node_set_.find(node) != node_set_.end(), true,
                      platform::errors::PreconditionNotMet(
                          "The node to be removed does not exist."));
Y
Yancey1989 已提交
312 313 314 315 316 317 318
    std::unique_ptr<ir::Node> ret;
    ret.reset(nodes_.at(node).release());
    nodes_.erase(node);
    node_set_.erase(node);
    return ret;
  }

S
sneaxiy 已提交
319
  // NOTE low performance, but simple and secure.
320
  Node *RetrieveNode(int id) {
321 322 323 324 325
    if (FLAGS_convert_all_blocks) {
      if (IsMainGraph()) {
        return GetSubGraph(0)->RetrieveNode(id);
      }
    }
S
sneaxiy 已提交
326 327 328 329 330 331 332 333
    for (auto &node : nodes_) {
      if (node.second->id() == id) {
        return node.second.get();
      }
    }
    return nullptr;
  }

X
fix  
Xin Pan 已提交
334 335 336 337
  // Returns reference to the original program.
  // WARN: After a series of passes, the current graph can be quite
  // different from OriginProgram. Caller shouldn't assume much from
  // the returned OriginProgram.
338 339 340 341 342 343 344 345
  const ProgramDesc &OriginProgram() const {
    if (FLAGS_convert_all_blocks) {
      if (!IsMainGraph()) {
        return main_graph_->OriginProgram();
      }
    }
    return program_;
  }
X
fix  
Xin Pan 已提交
346

S
sneaxiy 已提交
347 348
  // This method takes ownership of `node`.
  ir::Node *AddNode(ir::Node *node) {
349 350 351 352 353
    if (FLAGS_convert_all_blocks) {
      if (IsMainGraph()) {
        return GetSubGraph(0)->AddNode(node);
      }
    }
354 355 356
    PADDLE_ENFORCE_EQ(node_set_.find(node) == node_set_.end(), true,
                      platform::errors::PreconditionNotMet(
                          "The node to be added already exists."));
S
sneaxiy 已提交
357 358 359 360 361
    nodes_[node].reset(node);
    node_set_.insert(node);
    return node;
  }

Y
Yancey1989 已提交
362 363 364
  void ResolveHazard(
      const std::map<std::string, std::vector<ir::Node *>> &var_nodes);

365 366 367 368
  // Create a new and duplicated graph.
  // WARN: The method only clones the graph structure, not its attributes.
  std::shared_ptr<Graph> Clone();

369 370 371 372 373 374 375 376 377 378 379 380
  bool IsMainGraph() const { return main_graph_ == nullptr; }

  Graph *GetSubGraph(const size_t idx) const {
    PADDLE_ENFORCE_EQ(
        this->IsMainGraph(), true,
        platform::errors::InvalidArgument("This graph is not main_graph"));
    PADDLE_ENFORCE_LT(
        idx, sub_graphs_.size(),
        platform::errors::InvalidArgument("Invalid sub_graph index"));
    return sub_graphs_.at(idx).get();
  }

381 382 383 384 385 386 387 388 389
  int GetBlockId() const {
    if (FLAGS_convert_all_blocks) {
      if (IsMainGraph()) {
        return GetSubGraph(0)->block_id_;
      }
    }
    return block_id_;
  }

390 391 392 393 394 395 396
  size_t SubGraphsSize() const {
    PADDLE_ENFORCE_EQ(
        this->IsMainGraph(), true,
        platform::errors::InvalidArgument("This graph is not main_graph"));
    return sub_graphs_.size();
  }

Y
Yancey1989 已提交
397
 private:
398 399
  // TODO(levi): delete this interface after when we can convert all
  // blocks into sub_graphs.
Y
Yancey1989 已提交
400
  std::map<std::string, std::vector<ir::Node *>> InitFromProgram(
401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418
      const ProgramDesc &program, const int64_t start_op_index,
      const int64_t end_op_index);

  std::map<std::string, std::vector<ir::Node *>> InitFromBlock(
      const BlockDesc &block, const int64_t start_op_index,
      const int64_t end_op_index);

  void ReleaseSubGraphs() {
    PADDLE_ENFORCE_EQ(
        this->IsMainGraph(), true,
        platform::errors::InvalidArgument("This graph is not main_graph"));
    sub_graphs_.clear();
  }

  void AddSubGraph(std::unique_ptr<Graph> sub_graph) {
    PADDLE_ENFORCE_EQ(
        this->IsMainGraph(), true,
        platform::errors::InvalidArgument("This graph is not main_graph"));
419 420 421
    PADDLE_ENFORCE_EQ(sub_graphs_.size(), sub_graph->block_id_,
                      platform::errors::InvalidArgument(
                          "sub_graph idx is not equal to block_id_"));
422 423 424 425
    sub_graphs_.push_back(std::move(sub_graph));
  }

  std::unique_ptr<Graph> CloneSubGraph(const size_t idx);
Y
Yancey1989 已提交
426

S
sneaxiy 已提交
427 428
  // NOTE: program_ shouldn't be exposed to user.
  const ProgramDesc program_;
429 430 431 432 433
  // NOTE: main_graph_ doesn't hold any node. It's used as a container of
  // sub_graphs, and the sub_graph holds the nodes.
  const Graph *main_graph_;  // not owned.
  std::vector<std::unique_ptr<Graph>> sub_graphs_;

434
  std::map<std::string, paddle::any> attrs_;
S
sneaxiy 已提交
435 436 437
  std::map<std::string, std::function<void(void)>> attr_dels_;
  std::map<ir::Node *, std::unique_ptr<ir::Node>> nodes_;
  std::unordered_set<ir::Node *> node_set_;
438
  size_t num_node_created_{0};  // help to generate a unique node id.
439 440 441 442 443
  // NOTE(Aurelius84): Whether is constructed with partial ProgramDesc.
  // In case of @to_static, whole trainning program is splited into two
  // parts: forward graph and backward graph, which can be executed
  // independently.
  bool is_partial_{false};
444 445
  // The block this SubGraph belongs to.
  int block_id_{0};
S
sneaxiy 已提交
446 447 448 449 450 451
};

bool IsControlDepVar(const ir::Node &var);
}  // namespace ir
}  // namespace framework
}  // namespace paddle