node.h 3.6 KB
Newer Older
S
superjomn 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <list>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/kernel.h"
22
#include "paddle/fluid/lite/core/op_lite.h"
S
superjomn 已提交
23

S
Superjomn 已提交
24 25 26 27
namespace paddle {
namespace lite {
namespace mir {

S
superjomn 已提交
28
// Node in a MIR graph.
S
Superjomn 已提交
29 30
class Node {
 public:
S
superjomn 已提交
31 32 33 34 35 36
  std::list<Node*> inlinks;
  std::list<Node*> outlinks;

  Node() = default;

  enum class Role {
S
Superjomn 已提交
37 38
    kArg = 0,
    kStmt,
S
superjomn 已提交
39 40
    kNumRoles, /*should be last*/
    kUnk,
S
superjomn 已提交
41
  };
S
Superjomn 已提交
42

S
Superjomn 已提交
43
  struct Stmt {
S
superjomn 已提交
44
    std::string op_type;
S
Superjomn 已提交
45
    // The kernel instances this Statement contains.
S
superjomn 已提交
46
    std::vector<std::unique_ptr<KernelBase>> valid_kernels;
47 48
    // TODO(Superjomn) make this a shared_ptr for resource safety.
    std::shared_ptr<OpLite> op;  // we hold op to run InferShape
S
Superjomn 已提交
49

S
superjomn 已提交
50 51 52 53 54 55 56 57 58 59
    const OpInfo* op_info() {
      CHECK(op);
      return op->op_info();
    }

    Place place() const {
      CHECK(!valid_kernels.empty());
      return valid_kernels.front()->place();
    }

S
Superjomn 已提交
60
    KernelBase& picked_kernel() {
61
      CHECK(!valid_kernels.empty()) << "no kernel for " << op_type;
S
Superjomn 已提交
62 63
      return *valid_kernels.front();
    }
S
superjomn 已提交
64

S
Superjomn 已提交
65 66
    friend std::ostream& operator<<(std::ostream& os, const Stmt& other) {
      os << "Statement " << other.op_type << " " << other.place();
S
superjomn 已提交
67 68
      return os;
    }
S
superjomn 已提交
69 70
  };

S
Superjomn 已提交
71
  struct Arg {
S
superjomn 已提交
72
    std::string name;
S
superjomn 已提交
73
    const Type* type{};
74 75 76
    // Weight is a special kind of argument, it is marked as weight explicitly
    // so that some weight related optimization can take place.
    bool is_weight{false};
S
superjomn 已提交
77 78
  };

S
Superjomn 已提交
79 80
  Arg& AsArg(const std::string& name) {
    auto& x = AsArg();
S
superjomn 已提交
81 82 83 84
    x.name = name;
    return x;
  }

S
Superjomn 已提交
85 86 87 88
  Stmt& AsStmt(const std::string& op_type,
               std::vector<std::unique_ptr<KernelBase>>&& kernels,
               const std::shared_ptr<OpLite>& op) {
    auto& x = AsStmt();
S
superjomn 已提交
89
    x.op_type = op_type;
90
    x.op = op;
91
    x.valid_kernels = std::move(kernels);
S
superjomn 已提交
92 93 94
    return x;
  }

S
superjomn 已提交
95
  // Set roles.
S
Superjomn 已提交
96
  Arg& AsArg() {
S
superjomn 已提交
97
    if (role_ != Role::kUnk) {
S
Superjomn 已提交
98 99
      CHECK(role_ == Role::kArg);
      return *arg_;
S
superjomn 已提交
100
    }
S
Superjomn 已提交
101 102 103
    role_ = Role::kArg;
    arg_.reset(new Arg);
    return *arg_;
S
superjomn 已提交
104
  }
S
Superjomn 已提交
105
  Stmt& AsStmt() {
S
superjomn 已提交
106
    if (role_ != Role::kUnk) {
S
Superjomn 已提交
107 108
      CHECK(role_ == Role::kStmt);
      return *stmt_;
S
superjomn 已提交
109
    }
S
Superjomn 已提交
110 111 112
    role_ = Role::kStmt;
    stmt_.reset(new Stmt);
    return *stmt_;
S
superjomn 已提交
113
  }
S
superjomn 已提交
114 115 116 117 118 119

  friend std::ostream& operator<<(std::ostream& os, Node& other) {
    os << static_cast<int>(other.role_) << " ";
    if (!other.IsRoleSet()) {
      os << "Unk role node";
    }
S
Superjomn 已提交
120 121
    if (other.IsArg()) {
      auto& arg = other.AsArg();
S
superjomn 已提交
122 123
      os << "Argument " << arg.name;
    }
S
Superjomn 已提交
124 125 126
    if (other.IsStmt()) {
      auto& arg = other.AsStmt();
      os << "Statement " << arg.op_type;
S
superjomn 已提交
127 128 129 130
    }
    return os;
  }

S
superjomn 已提交
131
  // Check roles.
S
superjomn 已提交
132
  bool IsRoleSet() const { return role_ != Role::kUnk; }
S
Superjomn 已提交
133 134
  bool IsStmt() const { return role_ == Role::kStmt; }
  bool IsArg() const { return role_ == Role::kArg; }
S
superjomn 已提交
135 136

 private:
S
Superjomn 已提交
137 138 139
  // Either stmt_ or argument_ is used.
  std::unique_ptr<Stmt> stmt_;
  std::unique_ptr<Arg> arg_;
S
superjomn 已提交
140 141 142

  Role role_{Role::kUnk};
};
S
Superjomn 已提交
143 144
}  // namespace mir
}  // namespace lite
S
superjomn 已提交
145
}  // namespace paddle