node.h 3.7 KB
Newer Older
S
superjomn 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <list>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/kernel.h"
22
#include "paddle/fluid/lite/core/op_lite.h"
S
superjomn 已提交
23

S
Superjomn 已提交
24 25 26 27
namespace paddle {
namespace lite {
namespace mir {

S
superjomn 已提交
28
// Node in a MIR graph.
S
Superjomn 已提交
29 30
class Node {
 public:
S
superjomn 已提交
31 32 33 34 35 36
  std::list<Node*> inlinks;
  std::list<Node*> outlinks;

  Node() = default;

  enum class Role {
S
superjomn 已提交
37
    kArgument = 0,
S
superjomn 已提交
38
    kInstruct,
S
superjomn 已提交
39 40
    kNumRoles, /*should be last*/
    kUnk,
S
superjomn 已提交
41
  };
S
Superjomn 已提交
42

S
superjomn 已提交
43 44 45 46
  struct Instruct {
    std::string op_type;
    // The kernel instances this Instruct contains.
    std::vector<std::unique_ptr<KernelBase>> valid_kernels;
47 48
    // TODO(Superjomn) make this a shared_ptr for resource safety.
    std::shared_ptr<OpLite> op;  // we hold op to run InferShape
S
Superjomn 已提交
49

S
superjomn 已提交
50 51 52 53 54 55 56 57 58 59
    const OpInfo* op_info() {
      CHECK(op);
      return op->op_info();
    }

    Place place() const {
      CHECK(!valid_kernels.empty());
      return valid_kernels.front()->place();
    }

S
Superjomn 已提交
60 61 62 63
    KernelBase& picked_kernel() {
      CHECK(!valid_kernels.empty());
      return *valid_kernels.front();
    }
S
superjomn 已提交
64 65 66 67 68

    friend std::ostream& operator<<(std::ostream& os, const Instruct& other) {
      os << "Instruct " << other.op_type << " " << other.place();
      return os;
    }
S
superjomn 已提交
69 70 71 72
  };

  struct Argument {
    std::string name;
S
superjomn 已提交
73
    const Type* type{};
74 75 76
    // Weight is a special kind of argument, it is marked as weight explicitly
    // so that some weight related optimization can take place.
    bool is_weight{false};
S
superjomn 已提交
77 78
  };

S
superjomn 已提交
79 80 81 82 83 84
  Argument& AsArgument(const std::string& name) {
    auto& x = AsArgument();
    x.name = name;
    return x;
  }

85 86
  Instruct& AsInstruct(const std::string& op_type,
                       std::vector<std::unique_ptr<KernelBase>>&& kernels,
S
superjomn 已提交
87
                       const std::shared_ptr<OpLite>& op) {
S
superjomn 已提交
88 89
    auto& x = AsInstruct();
    x.op_type = op_type;
90
    x.op = op;
91
    x.valid_kernels = std::move(kernels);
S
superjomn 已提交
92 93 94
    return x;
  }

S
superjomn 已提交
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
  // Set roles.
  Argument& AsArgument() {
    if (role_ != Role::kUnk) {
      CHECK(role_ == Role::kArgument);
      return *argument_;
    }
    role_ = Role::kArgument;
    argument_.reset(new Argument);
    return *argument_;
  }
  Instruct& AsInstruct() {
    if (role_ != Role::kUnk) {
      CHECK(role_ == Role::kInstruct);
      return *instruct_;
    }
    role_ = Role::kInstruct;
    instruct_.reset(new Instruct);
    return *instruct_;
  }
S
superjomn 已提交
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130

  friend std::ostream& operator<<(std::ostream& os, Node& other) {
    os << static_cast<int>(other.role_) << " ";
    if (!other.IsRoleSet()) {
      os << "Unk role node";
    }
    if (other.IsArgument()) {
      auto& arg = other.AsArgument();
      os << "Argument " << arg.name;
    }
    if (other.IsInstruct()) {
      auto& arg = other.AsInstruct();
      os << "Instruct " << arg.op_type;
    }
    return os;
  }

S
superjomn 已提交
131
  // Check roles.
S
superjomn 已提交
132
  bool IsRoleSet() const { return role_ != Role::kUnk; }
S
superjomn 已提交
133 134 135 136 137 138 139 140 141 142
  bool IsInstruct() const { return role_ == Role::kInstruct; }
  bool IsArgument() const { return role_ == Role::kArgument; }

 private:
  // Either instruct_ or argument_ is used.
  std::unique_ptr<Instruct> instruct_;
  std::unique_ptr<Argument> argument_;

  Role role_{Role::kUnk};
};
S
Superjomn 已提交
143 144
}  // namespace mir
}  // namespace lite
S
superjomn 已提交
145
}  // namespace paddle