cnnpredict_interface.h 2.8 KB
Newer Older
K
Kaibing Chen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
#pragma once

#include <cstddef>
#include <memory>
#include <string>
#include <vector>

enum class DataType : int {
  INT8 = 0,
  INT32 = 2,
  INT64 = 3,
  FLOAT32 = 4,
};

inline size_t get_type_size(DataType type) {
  switch (type) {
    case DataType::INT8:
      return sizeof(int8_t);
    case DataType::INT32:
      return sizeof(int32_t);
    case DataType::INT64:
      return sizeof(int64_t);
    case DataType::FLOAT32:
      return sizeof(float);
    default:
      return 0;
  }
}

struct DataBuf {
  std::size_t size;
  DataType type;
  std::shared_ptr<char> data;

  DataBuf() = default;

  DataBuf(DataType dtype, size_t dsize) { alloc(dtype, dsize); }

  DataBuf(const void *ddata, DataType dtype, size_t dsize) {
    alloc(dtype, dsize);
    copy(ddata, dsize);
  }

  DataBuf(const DataBuf &dbuf)
      : size(dbuf.size), type(dbuf.type), data(dbuf.data) {}

  DataBuf &operator=(const DataBuf &dbuf) {
    size = dbuf.size;
    type = dbuf.type;
    data = dbuf.data;
    return *this;
  }

  void reset(const void *ddata, size_t dsize) {
    clear();
    alloc(type, dsize);
    copy(ddata, dsize);
  }

  void clear() {
    size = 0;
    data.reset();
  }

  ~DataBuf() { clear(); }

 private:
  void alloc(DataType dtype, size_t dsize) {
    type = dtype;
    size = dsize;
    data.reset(new char[dsize * get_type_size(dtype)],
               std::default_delete<char[]>());
  }

  void copy(const void *ddata, size_t dsize) {
    const char *temp = reinterpret_cast<const char *>(ddata);
    std::copy(temp, temp + dsize * get_type_size(type), data.get());
  }
};

struct Tensor {
  std::string name;
  std::vector<int> shape;
  std::vector<std::vector<size_t>> lod;
  DataBuf data;
};

class ICNNPredict {
 public:
  ICNNPredict() {}
  virtual ~ICNNPredict() {}

  virtual ICNNPredict *clone() = 0;

  virtual bool predict(const std::vector<Tensor> &inputs,
                       const std::vector<std::string> &layers,
                       std::vector<Tensor> &outputs) = 0;

  virtual bool predict(const std::vector<std::vector<float>> &input_datas,
                       const std::vector<std::vector<int>> &input_shapes,
                       const std::vector<std::string> &layers,
                       std::vector<std::vector<float>> &output_datas,
                       std::vector<std::vector<int>> &output_shapes) = 0;

  virtual void destroy(std::vector<Tensor> &tensors) {
    std::vector<Tensor>().swap(tensors);
  }

  virtual void destroy(std::vector<std::vector<float>> &datas) {
    std::vector<std::vector<float>>().swap(datas);
  }

  virtual void destroy(std::vector<std::vector<int>> &shapes) {
    std::vector<std::vector<int>>().swap(shapes);
  }
};

ICNNPredict *create_cnnpredict(const std::string &conf_file,
                               const std::string &prefix);