node.h 3.7 KB
Newer Older
S
superjomn 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <list>
#include <memory>
#include <string>
S
superjomn 已提交
20
#include <utility>
S
superjomn 已提交
21 22
#include <vector>
#include "paddle/fluid/lite/core/kernel.h"
23
#include "paddle/fluid/lite/core/op_lite.h"
S
superjomn 已提交
24

S
Superjomn 已提交
25 26 27 28
namespace paddle {
namespace lite {
namespace mir {

S
superjomn 已提交
29
// Node in a MIR graph.
S
Superjomn 已提交
30 31
class Node {
 public:
S
superjomn 已提交
32 33 34 35 36 37
  std::list<Node*> inlinks;
  std::list<Node*> outlinks;

  Node() = default;

  enum class Role {
S
Superjomn 已提交
38 39
    kArg = 0,
    kStmt,
S
superjomn 已提交
40 41
    kNumRoles, /*should be last*/
    kUnk,
S
superjomn 已提交
42
  };
S
Superjomn 已提交
43

S
Superjomn 已提交
44
  struct Stmt {
S
superjomn 已提交
45
    std::string op_type;
S
Superjomn 已提交
46
    // The kernel instances this Statement contains.
S
superjomn 已提交
47
    std::vector<std::unique_ptr<KernelBase>> valid_kernels;
48 49
    // TODO(Superjomn) make this a shared_ptr for resource safety.
    std::shared_ptr<OpLite> op;  // we hold op to run InferShape
S
Superjomn 已提交
50

S
superjomn 已提交
51 52 53 54 55 56 57 58 59 60
    const OpInfo* op_info() {
      CHECK(op);
      return op->op_info();
    }

    Place place() const {
      CHECK(!valid_kernels.empty());
      return valid_kernels.front()->place();
    }

S
Superjomn 已提交
61
    KernelBase& picked_kernel() {
62
      CHECK(!valid_kernels.empty()) << "no kernel for " << op_type;
S
Superjomn 已提交
63 64
      return *valid_kernels.front();
    }
S
superjomn 已提交
65

S
Superjomn 已提交
66 67
    friend std::ostream& operator<<(std::ostream& os, const Stmt& other) {
      os << "Statement " << other.op_type << " " << other.place();
S
superjomn 已提交
68 69
      return os;
    }
S
superjomn 已提交
70 71
  };

S
Superjomn 已提交
72
  struct Arg {
S
superjomn 已提交
73
    std::string name;
S
superjomn 已提交
74
    const Type* type{};
75 76 77
    // Weight is a special kind of argument, it is marked as weight explicitly
    // so that some weight related optimization can take place.
    bool is_weight{false};
S
superjomn 已提交
78 79
  };

S
Superjomn 已提交
80 81
  Arg& AsArg(const std::string& name) {
    auto& x = AsArg();
S
superjomn 已提交
82 83 84 85
    x.name = name;
    return x;
  }

S
Superjomn 已提交
86 87 88 89
  Stmt& AsStmt(const std::string& op_type,
               std::vector<std::unique_ptr<KernelBase>>&& kernels,
               const std::shared_ptr<OpLite>& op) {
    auto& x = AsStmt();
S
superjomn 已提交
90
    x.op_type = op_type;
91
    x.op = op;
92
    x.valid_kernels = std::move(kernels);
S
superjomn 已提交
93 94 95
    return x;
  }

96 97 98 99 100 101 102 103 104 105
  Stmt* stmt() const {
    CHECK(IsStmt());
    return stmt_.get();
  }

  Arg* arg() const {
    CHECK(IsArg());
    return arg_.get();
  }

S
superjomn 已提交
106
  // Set roles.
S
Superjomn 已提交
107
  Arg& AsArg() {
S
superjomn 已提交
108
    if (role_ != Role::kUnk) {
S
Superjomn 已提交
109 110
      CHECK(role_ == Role::kArg);
      return *arg_;
S
superjomn 已提交
111
    }
S
Superjomn 已提交
112 113 114
    role_ = Role::kArg;
    arg_.reset(new Arg);
    return *arg_;
S
superjomn 已提交
115
  }
S
Superjomn 已提交
116
  Stmt& AsStmt() {
S
superjomn 已提交
117
    if (role_ != Role::kUnk) {
S
Superjomn 已提交
118 119
      CHECK(role_ == Role::kStmt);
      return *stmt_;
S
superjomn 已提交
120
    }
S
Superjomn 已提交
121 122 123
    role_ = Role::kStmt;
    stmt_.reset(new Stmt);
    return *stmt_;
S
superjomn 已提交
124
  }
S
superjomn 已提交
125 126 127 128 129 130

  friend std::ostream& operator<<(std::ostream& os, Node& other) {
    os << static_cast<int>(other.role_) << " ";
    if (!other.IsRoleSet()) {
      os << "Unk role node";
    }
S
Superjomn 已提交
131 132
    if (other.IsArg()) {
      auto& arg = other.AsArg();
S
superjomn 已提交
133 134
      os << "Argument " << arg.name;
    }
S
Superjomn 已提交
135 136 137
    if (other.IsStmt()) {
      auto& arg = other.AsStmt();
      os << "Statement " << arg.op_type;
S
superjomn 已提交
138 139 140 141
    }
    return os;
  }

S
superjomn 已提交
142
  // Check roles.
S
superjomn 已提交
143
  bool IsRoleSet() const { return role_ != Role::kUnk; }
S
Superjomn 已提交
144 145
  bool IsStmt() const { return role_ == Role::kStmt; }
  bool IsArg() const { return role_ == Role::kArg; }
S
superjomn 已提交
146 147

 private:
S
Superjomn 已提交
148 149 150
  // Either stmt_ or argument_ is used.
  std::unique_ptr<Stmt> stmt_;
  std::unique_ptr<Arg> arg_;
S
superjomn 已提交
151 152 153

  Role role_{Role::kUnk};
};
S
Superjomn 已提交
154 155
}  // namespace mir
}  // namespace lite
S
superjomn 已提交
156
}  // namespace paddle