node.h 3.9 KB
Newer Older
S
superjomn 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <list>
#include <memory>
#include <string>
S
superjomn 已提交
20
#include <utility>
S
superjomn 已提交
21 22
#include <vector>
#include "paddle/fluid/lite/core/kernel.h"
23
#include "paddle/fluid/lite/core/op_lite.h"
S
superjomn 已提交
24

S
Superjomn 已提交
25 26 27 28
namespace paddle {
namespace lite {
namespace mir {

S
superjomn 已提交
29
// Node in a MIR graph.
S
Superjomn 已提交
30 31
class Node {
 public:
S
superjomn 已提交
32 33 34 35 36 37
  std::list<Node*> inlinks;
  std::list<Node*> outlinks;

  Node() = default;

  enum class Role {
S
Superjomn 已提交
38 39
    kArg = 0,
    kStmt,
S
superjomn 已提交
40 41
    kNumRoles, /*should be last*/
    kUnk,
S
superjomn 已提交
42
  };
S
Superjomn 已提交
43

S
Superjomn 已提交
44
  struct Stmt {
S
superjomn 已提交
45
    std::string op_type;
S
Superjomn 已提交
46
    // The kernel instances this Statement contains.
S
superjomn 已提交
47
    std::vector<std::unique_ptr<KernelBase>> valid_kernels;
48 49
    // TODO(Superjomn) make this a shared_ptr for resource safety.
    std::shared_ptr<OpLite> op;  // we hold op to run InferShape
S
Superjomn 已提交
50

S
superjomn 已提交
51 52 53 54 55 56 57 58 59 60
    const OpInfo* op_info() {
      CHECK(op);
      return op->op_info();
    }

    Place place() const {
      CHECK(!valid_kernels.empty());
      return valid_kernels.front()->place();
    }

S
Superjomn 已提交
61
    KernelBase& picked_kernel() {
62
      CHECK(!valid_kernels.empty()) << "no kernel for " << op_type;
S
Superjomn 已提交
63 64
      return *valid_kernels.front();
    }
S
superjomn 已提交
65

S
Superjomn 已提交
66 67
    friend std::ostream& operator<<(std::ostream& os, const Stmt& other) {
      os << "Statement " << other.op_type << " " << other.place();
S
superjomn 已提交
68 69
      return os;
    }
S
superjomn 已提交
70 71
  };

S
Superjomn 已提交
72
  struct Arg {
S
superjomn 已提交
73
    std::string name;
74
    int id{0};
S
superjomn 已提交
75
    const Type* type{};
76 77 78
    // Weight is a special kind of argument, it is marked as weight explicitly
    // so that some weight related optimization can take place.
    bool is_weight{false};
S
superjomn 已提交
79 80
  };

81 82 83 84 85 86 87
  Arg& AsArg(const std::string& name, int id) {
    auto& x = AsArg();
    x.name = name;
    x.id = id;
    return x;
  }

S
Superjomn 已提交
88 89
  Arg& AsArg(const std::string& name) {
    auto& x = AsArg();
S
superjomn 已提交
90 91 92 93
    x.name = name;
    return x;
  }

S
Superjomn 已提交
94 95 96 97
  Stmt& AsStmt(const std::string& op_type,
               std::vector<std::unique_ptr<KernelBase>>&& kernels,
               const std::shared_ptr<OpLite>& op) {
    auto& x = AsStmt();
S
superjomn 已提交
98
    x.op_type = op_type;
99
    x.op = op;
100
    x.valid_kernels = std::move(kernels);
S
superjomn 已提交
101 102 103
    return x;
  }

104 105 106 107 108 109 110 111 112 113
  Stmt* stmt() const {
    CHECK(IsStmt());
    return stmt_.get();
  }

  Arg* arg() const {
    CHECK(IsArg());
    return arg_.get();
  }

S
superjomn 已提交
114
  // Set roles.
S
Superjomn 已提交
115
  Arg& AsArg() {
S
superjomn 已提交
116
    if (role_ != Role::kUnk) {
S
Superjomn 已提交
117 118
      CHECK(role_ == Role::kArg);
      return *arg_;
S
superjomn 已提交
119
    }
S
Superjomn 已提交
120 121 122
    role_ = Role::kArg;
    arg_.reset(new Arg);
    return *arg_;
S
superjomn 已提交
123
  }
S
Superjomn 已提交
124
  Stmt& AsStmt() {
S
superjomn 已提交
125
    if (role_ != Role::kUnk) {
S
Superjomn 已提交
126 127
      CHECK(role_ == Role::kStmt);
      return *stmt_;
S
superjomn 已提交
128
    }
S
Superjomn 已提交
129 130 131
    role_ = Role::kStmt;
    stmt_.reset(new Stmt);
    return *stmt_;
S
superjomn 已提交
132
  }
S
superjomn 已提交
133 134 135 136 137 138

  friend std::ostream& operator<<(std::ostream& os, Node& other) {
    os << static_cast<int>(other.role_) << " ";
    if (!other.IsRoleSet()) {
      os << "Unk role node";
    }
S
Superjomn 已提交
139 140
    if (other.IsArg()) {
      auto& arg = other.AsArg();
S
superjomn 已提交
141 142
      os << "Argument " << arg.name;
    }
S
Superjomn 已提交
143 144 145
    if (other.IsStmt()) {
      auto& arg = other.AsStmt();
      os << "Statement " << arg.op_type;
S
superjomn 已提交
146 147 148 149
    }
    return os;
  }

S
superjomn 已提交
150
  // Check roles.
S
superjomn 已提交
151
  bool IsRoleSet() const { return role_ != Role::kUnk; }
S
Superjomn 已提交
152 153
  bool IsStmt() const { return role_ == Role::kStmt; }
  bool IsArg() const { return role_ == Role::kArg; }
S
superjomn 已提交
154 155

 private:
S
Superjomn 已提交
156 157 158
  // Either stmt_ or argument_ is used.
  std::unique_ptr<Stmt> stmt_;
  std::unique_ptr<Arg> arg_;
S
superjomn 已提交
159 160 161

  Role role_{Role::kUnk};
};
S
Superjomn 已提交
162 163
}  // namespace mir
}  // namespace lite
S
superjomn 已提交
164
}  // namespace paddle