node.h 4.1 KB
Newer Older
S
superjomn 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <list>
#include <memory>
#include <string>
S
superjomn 已提交
20
#include <utility>
S
superjomn 已提交
21 22
#include <vector>
#include "paddle/fluid/lite/core/kernel.h"
23
#include "paddle/fluid/lite/core/op_lite.h"
S
superjomn 已提交
24

S
Superjomn 已提交
25 26 27 28
namespace paddle {
namespace lite {
namespace mir {

S
superjomn 已提交
29
// Node in a MIR graph.
S
Superjomn 已提交
30 31
class Node {
 public:
S
superjomn 已提交
32 33 34 35 36 37
  std::list<Node*> inlinks;
  std::list<Node*> outlinks;

  Node() = default;

  enum class Role {
S
Superjomn 已提交
38 39
    kArg = 0,
    kStmt,
S
superjomn 已提交
40 41
    kNumRoles, /*should be last*/
    kUnk,
S
superjomn 已提交
42
  };
S
Superjomn 已提交
43

C
Chunwei 已提交
44
  class Stmt {
S
Superjomn 已提交
45
    // The kernel instances this Statement contains.
C
Chunwei 已提交
46
    std::vector<std::unique_ptr<KernelBase>> valid_kernels_;
47
    // TODO(Superjomn) make this a shared_ptr for resource safety.
C
Chunwei 已提交
48
    std::shared_ptr<OpLite> op_;  // we hold op to run InferShape
S
Superjomn 已提交
49

C
Chunwei 已提交
50 51 52 53 54
   public:
    // Refresh the operator and kernels with the latest OpInfo.
    void ResetOp(const cpp::OpDesc& op_desc,
                 const std::vector<Place>& valid_places,
                 lite::Scope* scope = nullptr);
S
superjomn 已提交
55

C
Chunwei 已提交
56 57 58
    std::string op_type() const { return op_info()->Type(); }
    const OpInfo* op_info() const;
    OpInfo* mutable_op_info();
S
superjomn 已提交
59

C
Chunwei 已提交
60 61
    void SetKernels(std::vector<std::unique_ptr<KernelBase>>&& kernels) {
      valid_kernels_ = std::move(kernels);
S
Superjomn 已提交
62
    }
C
Chunwei 已提交
63 64
    std::vector<std::unique_ptr<KernelBase>>& kernels() {
      return valid_kernels_;
S
superjomn 已提交
65
    }
C
Chunwei 已提交
66 67 68 69 70 71 72 73 74 75 76 77

    void SetOp(const std::shared_ptr<OpLite>& op) { op_ = op; }
    const std::shared_ptr<OpLite> op() const { return op_; }

    Place place() const;

    KernelBase& picked_kernel();

    friend std::ostream& operator<<(std::ostream& os, const Stmt& other);

    // Description.
    std::string desc;
S
superjomn 已提交
78 79
  };

S
Superjomn 已提交
80
  struct Arg {
S
superjomn 已提交
81
    std::string name;
82
    int id{0};
S
superjomn 已提交
83
    const Type* type{};
84 85 86
    // Weight is a special kind of argument, it is marked as weight explicitly
    // so that some weight related optimization can take place.
    bool is_weight{false};
S
superjomn 已提交
87 88
  };

C
Chunwei 已提交
89
  Arg& AsArg(const std::string& name, int id);
90

C
Chunwei 已提交
91
  Arg& AsArg(const std::string& name);
S
superjomn 已提交
92

S
Superjomn 已提交
93 94 95 96
  Stmt& AsStmt(const std::string& op_type,
               std::vector<std::unique_ptr<KernelBase>>&& kernels,
               const std::shared_ptr<OpLite>& op) {
    auto& x = AsStmt();
C
Chunwei 已提交
97 98
    x.SetOp(op);
    x.SetKernels(std::move(kernels));
S
superjomn 已提交
99 100 101
    return x;
  }

102 103 104 105 106 107 108 109 110 111
  Stmt* stmt() const {
    CHECK(IsStmt());
    return stmt_.get();
  }

  Arg* arg() const {
    CHECK(IsArg());
    return arg_.get();
  }

S
superjomn 已提交
112
  // Set roles.
S
Superjomn 已提交
113
  Arg& AsArg() {
S
superjomn 已提交
114
    if (role_ != Role::kUnk) {
S
Superjomn 已提交
115 116
      CHECK(role_ == Role::kArg);
      return *arg_;
S
superjomn 已提交
117
    }
S
Superjomn 已提交
118 119 120
    role_ = Role::kArg;
    arg_.reset(new Arg);
    return *arg_;
S
superjomn 已提交
121
  }
S
Superjomn 已提交
122
  Stmt& AsStmt() {
S
superjomn 已提交
123
    if (role_ != Role::kUnk) {
S
Superjomn 已提交
124 125
      CHECK(role_ == Role::kStmt);
      return *stmt_;
S
superjomn 已提交
126
    }
S
Superjomn 已提交
127 128 129
    role_ = Role::kStmt;
    stmt_.reset(new Stmt);
    return *stmt_;
S
superjomn 已提交
130
  }
S
superjomn 已提交
131 132 133 134 135 136

  friend std::ostream& operator<<(std::ostream& os, Node& other) {
    os << static_cast<int>(other.role_) << " ";
    if (!other.IsRoleSet()) {
      os << "Unk role node";
    }
S
Superjomn 已提交
137 138
    if (other.IsArg()) {
      auto& arg = other.AsArg();
S
superjomn 已提交
139 140
      os << "Argument " << arg.name;
    }
S
Superjomn 已提交
141 142
    if (other.IsStmt()) {
      auto& arg = other.AsStmt();
C
Chunwei 已提交
143
      os << "Statement " << arg.op_type();
S
superjomn 已提交
144 145 146 147
    }
    return os;
  }

S
superjomn 已提交
148
  // Check roles.
S
superjomn 已提交
149
  bool IsRoleSet() const { return role_ != Role::kUnk; }
S
Superjomn 已提交
150 151
  bool IsStmt() const { return role_ == Role::kStmt; }
  bool IsArg() const { return role_ == Role::kArg; }
S
superjomn 已提交
152 153

 private:
S
Superjomn 已提交
154 155 156
  // Either stmt_ or argument_ is used.
  std::unique_ptr<Stmt> stmt_;
  std::unique_ptr<Arg> arg_;
S
superjomn 已提交
157 158 159

  Role role_{Role::kUnk};
};
S
Superjomn 已提交
160 161
}  // namespace mir
}  // namespace lite
S
superjomn 已提交
162
}  // namespace paddle