infer_tensor.h 6.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef MINDSPORE_INCLUDE_INFER_TENSOR_H_
#define MINDSPORE_INCLUDE_INFER_TENSOR_H_

#include <utility>
#include <vector>
#include <memory>
#include <numeric>
#include <map>
#include <functional>

#include "securec/include/securec.h"
#include "include/infer_log.h"

namespace mindspore {
#define MS_API __attribute__((visibility("default")))
namespace inference {

enum DataType {
  kMSI_Unknown = 0,
  kMSI_Bool = 1,
  kMSI_Int8 = 2,
  kMSI_Int16 = 3,
  kMSI_Int32 = 4,
  kMSI_Int64 = 5,
  kMSI_Uint8 = 6,
  kMSI_Uint16 = 7,
  kMSI_Uint32 = 8,
  kMSI_Uint64 = 9,
  kMSI_Float16 = 10,
  kMSI_Float32 = 11,
  kMSI_Float64 = 12,
};

class InferTensorBase {
 public:
  InferTensorBase() = default;
  virtual ~InferTensorBase() = default;

  virtual DataType data_type() const = 0;
  virtual void set_data_type(DataType type) = 0;
  virtual std::vector<int64_t> shape() const = 0;
  virtual void set_shape(const std::vector<int64_t> &shape) = 0;
  virtual const void *data() const = 0;
  virtual size_t data_size() const = 0;
  virtual bool resize_data(size_t data_len) = 0;
  virtual void *mutable_data() = 0;

  bool set_data(const void *data, size_t data_len) {
    resize_data(data_len);
    if (mutable_data() == nullptr) {
      MSI_LOG_ERROR << "set data failed, data len " << data_len;
      return false;
    }
    if (data_size() != data_len) {
      MSI_LOG_ERROR << "set data failed, tensor current data size " << data_size() << " not match data len "
                    << data_len;
      return false;
    }
    if (data_len == 0) {
      return true;
    }
    memcpy_s(mutable_data(), data_size(), data, data_len);
    return true;
  }

  int64_t ElementNum() const {
    std::vector<int64_t> shapex = shape();
    return std::accumulate(shapex.begin(), shapex.end(), 1LL, std::multiplies<int64_t>());
  }

  int GetTypeSize(DataType type) const {
    const std::map<DataType, size_t> type_size_map{
      {kMSI_Bool, sizeof(bool)},       {kMSI_Float64, sizeof(double)},   {kMSI_Int8, sizeof(int8_t)},
      {kMSI_Uint8, sizeof(uint8_t)},   {kMSI_Int16, sizeof(int16_t)},    {kMSI_Uint16, sizeof(uint16_t)},
      {kMSI_Int32, sizeof(int32_t)},   {kMSI_Uint32, sizeof(uint32_t)},  {kMSI_Int64, sizeof(int64_t)},
      {kMSI_Uint64, sizeof(uint64_t)}, {kMSI_Float16, sizeof(uint16_t)}, {kMSI_Float32, sizeof(float)},
    };
    auto it = type_size_map.find(type);
    if (it != type_size_map.end()) {
      return it->second;
    }
    return 0;
  }
};

class InferTensor : public InferTensorBase {
 public:
  DataType type_;
  std::vector<int64_t> shape_;
  std::vector<uint8_t> data_;

 public:
  InferTensor() = default;
X
xuyongfei 已提交
110
  ~InferTensor() = default;
111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
  InferTensor(DataType type, std::vector<int64_t> shape, const void *data, size_t data_len) {
    set_data_type(type);
    set_shape(shape);
    set_data(data, data_len);
  }

  void set_data_type(DataType type) override { type_ = type; }
  DataType data_type() const override { return type_; }

  void set_shape(const std::vector<int64_t> &shape) override { shape_ = shape; }
  std::vector<int64_t> shape() const override { return shape_; }

  const void *data() const override { return data_.data(); }
  size_t data_size() const override { return data_.size(); }

  bool resize_data(size_t data_len) override {
    data_.resize(data_len);
    return true;
  }
  void *mutable_data() override { return data_.data(); }
};

133 134
class InferImagesBase {
 public:
X
xuyongfei 已提交
135 136
  InferImagesBase() = default;
  virtual ~InferImagesBase() = default;
137 138 139 140 141
  virtual size_t batch_size() const = 0;
  virtual bool get(size_t index, const void *&pic_buffer, uint32_t &pic_size) const = 0;
  virtual size_t input_index() const = 0;  // the index of images as input in model
};

142 143
class RequestBase {
 public:
X
xuyongfei 已提交
144 145
  RequestBase() = default;
  virtual ~RequestBase() = default;
146 147 148 149
  virtual size_t size() const = 0;
  virtual const InferTensorBase *operator[](size_t index) const = 0;
};

150 151
class ImagesRequestBase {
 public:
X
xuyongfei 已提交
152 153
  ImagesRequestBase() = default;
  virtual ~ImagesRequestBase() = default;
154 155 156 157
  virtual size_t size() const = 0;
  virtual const InferImagesBase *operator[](size_t index) const = 0;
};

158 159
class ReplyBase {
 public:
X
xuyongfei 已提交
160 161
  ReplyBase() = default;
  virtual ~ReplyBase() = default;
162 163 164 165 166 167 168 169 170 171
  virtual size_t size() const = 0;
  virtual InferTensorBase *operator[](size_t index) = 0;
  virtual const InferTensorBase *operator[](size_t index) const = 0;
  virtual InferTensorBase *add() = 0;
  virtual void clear() = 0;
};

class VectorInferTensorWrapReply : public ReplyBase {
 public:
  explicit VectorInferTensorWrapReply(std::vector<InferTensor> &tensor_list) : tensor_list_(tensor_list) {}
X
xuyongfei 已提交
172
  ~VectorInferTensorWrapReply() = default;
173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199

  size_t size() const { return tensor_list_.size(); }
  InferTensorBase *operator[](size_t index) {
    if (index >= tensor_list_.size()) {
      MSI_LOG_ERROR << "visit invalid index " << index << " total size " << tensor_list_.size();
      return nullptr;
    }
    return &(tensor_list_[index]);
  }
  const InferTensorBase *operator[](size_t index) const {
    if (index >= tensor_list_.size()) {
      MSI_LOG_ERROR << "visit invalid index " << index << " total size " << tensor_list_.size();
      return nullptr;
    }
    return &(tensor_list_[index]);
  }
  InferTensorBase *add() {
    tensor_list_.push_back(InferTensor());
    return &(tensor_list_.back());
  }
  void clear() { tensor_list_.clear(); }
  std::vector<InferTensor> &tensor_list_;
};

class VectorInferTensorWrapRequest : public RequestBase {
 public:
  explicit VectorInferTensorWrapRequest(const std::vector<InferTensor> &tensor_list) : tensor_list_(tensor_list) {}
X
xuyongfei 已提交
200
  ~VectorInferTensorWrapRequest() = default;
201 202 203 204 205 206 207 208 209 210 211 212 213 214 215

  size_t size() const { return tensor_list_.size(); }
  const InferTensorBase *operator[](size_t index) const {
    if (index >= tensor_list_.size()) {
      MSI_LOG_ERROR << "visit invalid index " << index << " total size " << tensor_list_.size();
      return nullptr;
    }
    return &(tensor_list_[index]);
  }
  const std::vector<InferTensor> &tensor_list_;
};

}  // namespace inference
}  // namespace mindspore
#endif  // MINDSPORE_INCLUDE_INFER_TENSOR_H_