graph.h 6.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "ai_ddk_lib/include/graph/op/all_ops.h"
#include "lite/core/op_lite.h"
#include "lite/core/tensor.h"

namespace paddle {
namespace lite {
namespace subgraph {
namespace npu {

31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
// Type of graph nodes
class Type {
 public:
  Type(PrecisionType precision = PRECISION(kFloat),
       DataLayoutType layout = DATALAYOUT(kNCHW),
       bool persistable = false)
      : precision_(precision), layout_(layout), persistable_(persistable) {}

  void set_precision(PrecisionType precision) { precision_ = precision; }
  void set_layout(DataLayoutType layout) { layout_ = layout; }
  bool set_persistable(bool persistable) { persistable_ = persistable; }

  PrecisionType precision() const { return precision_; }
  DataLayoutType layout() const { return layout_; }
  bool persistable() const { return persistable_; }

 private:
  PrecisionType precision_{PRECISION(kFloat)};
  DataLayoutType layout_{DATALAYOUT(kNCHW)};
  bool persistable_{false};
};

// Graph to collect all of converted HiAI IR nodes
54 55 56
class Graph {
 public:
  template <typename T>
57 58 59
  std::shared_ptr<T> AddNode(const std::string& name,
                             PrecisionType precision = PRECISION(kFloat),
                             DataLayoutType layout = DATALAYOUT(kNCHW)) {
60 61 62 63 64 65 66 67 68 69
    auto unique_name = [&](const std::string& key) {
      int idx = 1;
      auto it = counts_.find(key);
      if (it == counts_.end()) {
        counts_.insert(std::make_pair(key, idx));
      } else {
        idx = ++(it->second);
      }
      return key + "_" + std::to_string(idx);
    };
70
    bool persistable = typeid(T) == typeid(ge::op::Const);
71 72
    auto it = nodes_.find(name);
    if (it != nodes_.end()) {
73 74 75
      // Only variable can rebind the name
      CHECK(!it->second.second.persistable() && !persistable)
          << "[NPU] Node " << name << " redefined.";
76 77 78 79 80 81 82
      // Generate a new unique name as the key to bind the origin node:
      // new_name->node
      nodes_.insert(std::make_pair(unique_name(name + "_var"), it->second));
      nodes_.erase(it);
    }
    // Create a new node and bind with the name: name->new_node
    auto node = std::make_shared<T>(unique_name(name + "_op"));
83 84
    nodes_.insert(std::make_pair(
        name, std::make_pair(node, Type(precision, layout, persistable))));
85 86 87 88 89 90 91
    return node;
  }

  // Const node
  std::shared_ptr<ge::op::Const> AddNode(
      const std::string& name,
      const Tensor& tensor,
92 93 94 95
      PrecisionType precision = PRECISION(kFloat),
      DataLayoutType layout = DATALAYOUT(kNCHW)) {
    return AddNode(name, tensor, tensor.dims().Vectorize(), precision, layout);
  }
96 97 98 99 100

  std::shared_ptr<ge::op::Const> AddNode(
      const std::string& name,
      const Tensor& tensor,
      std::vector<int64_t> shape,
101 102 103 104 105 106 107 108 109 110 111
      PrecisionType precision = PRECISION(kFloat),
      DataLayoutType layout = DATALAYOUT(kNCHW));

  std::shared_ptr<ge::op::Const> AddNode(
      const std::string& name,
      const Tensor& tensor,
      DDim dims,
      PrecisionType precision = PRECISION(kFloat),
      DataLayoutType layout = DATALAYOUT(kNCHW)) {
    return AddNode(name, tensor, dims.Vectorize(), precision, layout);
  }
112 113 114 115 116 117

  template <typename T>
  std::shared_ptr<ge::op::Const> AddNode(
      const std::string& name,
      const std::vector<T>& data,
      std::vector<int64_t> shape = {},
118
      DataLayoutType layout = DATALAYOUT(kNCHW)) {
119
    const std::type_info& info = typeid(T);
120
    PrecisionType precision = PRECISION(kFloat);
121
    if (info == typeid(float)) {
122
      precision = PRECISION(kFloat);
123
    } else if (info == typeid(int8_t)) {
124
      precision = PRECISION(kFloat);
125
    } else if (info == typeid(int32_t)) {
126
      precision = PRECISION(kInt32);
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
    } else {
      LOG(FATAL) << "[NPU] Unknow data type " << info.name();
    }
    if (shape.empty()) {
      shape = {static_cast<int64_t>(data.size())};
    } else {
      int size = 1;
      for (auto i : shape) {
        size *= i;
      }
      CHECK_EQ(data.size(), size);
    }
    Tensor tensor;
    tensor.Resize(shape);
    std::memcpy(reinterpret_cast<uint8_t*>(tensor.mutable_data<T>()),
                reinterpret_cast<const uint8_t*>(data.data()),
                data.size() * sizeof(T));
144 145 146 147 148 149 150 151 152 153
    return AddNode(name, tensor, precision, layout);
  }

  template <typename T>
  std::shared_ptr<ge::op::Const> AddNode(
      const std::string& name,
      const std::vector<T>& data,
      DDim dims,
      DataLayoutType layout = DATALAYOUT(kNCHW)) {
    return AddNode(name, data, dims.Vectorize(), layout);
154 155 156 157 158 159 160
  }

  template <typename T>
  std::shared_ptr<ge::op::Const> AddNode(
      const std::string& name,
      T value,
      std::vector<int64_t> shape = {1},
161
      DataLayoutType layout = DATALAYOUT(kNCHW)) {
162 163 164 165 166
    int64_t size = 1;
    for (auto i : shape) {
      size *= i;
    }
    std::vector<T> data(size, value);
167 168 169 170 171 172 173 174 175 176
    return AddNode(name, data, shape, layout);
  }

  template <typename T>
  std::shared_ptr<ge::op::Const> AddNode(
      const std::string& name,
      T value,
      DDim dims,
      DataLayoutType layout = DATALAYOUT(kNCHW)) {
    return AddNode(name, value, dims.Vectorize(), layout);
177 178 179 180 181 182
  }

  // Data node
  std::shared_ptr<ge::op::Data> AddNode(
      const std::string& name,
      std::vector<int64_t> shape,
183 184 185 186 187 188 189 190 191 192
      PrecisionType precision = PRECISION(kFloat),
      DataLayoutType layout = DATALAYOUT(kNCHW));

  std::shared_ptr<ge::op::Data> AddNode(
      const std::string& name,
      DDim dims,
      PrecisionType precision = PRECISION(kFloat),
      DataLayoutType layout = DATALAYOUT(kNCHW)) {
    return AddNode(name, dims.Vectorize(), precision, layout);
  }
193 194 195

  std::shared_ptr<ge::Operator> GetNode(std::string name) {
    CHECK(HasNode(name)) << "[NPU] Node " << name << " not found.";
196 197 198 199 200 201
    return nodes_.at(name).first;
  }

  const Type& GetType(const std::string& name) {
    CHECK(HasNode(name)) << "[NPU] Node " << name << " not found.";
    return nodes_.at(name).second;
202 203 204 205 206 207 208
  }

  bool HasNode(const std::string& name) {
    return nodes_.find(name) != nodes_.end();
  }

 private:
209 210 211
  std::unordered_map<std::string,
                     std::pair<std::shared_ptr<ge::Operator>, Type>>
      nodes_;
212 213 214 215 216 217 218
  std::unordered_map<std::string, int> counts_;
};

}  // namespace npu
}  // namespace subgraph
}  // namespace lite
}  // namespace paddle