infer_data.h 3.7 KB
Newer Older
W
wangguibao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <string>
#include <vector>
W
wangguibao 已提交
18 19 20 21 22 23
#include "common/inner_common.h"

namespace baidu {
namespace paddle_serving {
namespace predictor {

W
wangguibao 已提交
24
enum DataType { FLOAT32, INT64 };
W
wangguibao 已提交
25 26

class DataBuf {
W
wangguibao 已提交
27 28
 public:
  DataBuf() : _data(NULL), _size(0), _owned(true) {}
W
wangguibao 已提交
29

W
wangguibao 已提交
30 31
  explicit DataBuf(size_t size)
      : _data(new char[size]), _size(size), _owned(true) {}
W
wangguibao 已提交
32

W
wangguibao 已提交
33
  DataBuf(void* data, size_t size) : _data(data), _size(size), _owned(false) {}
W
wangguibao 已提交
34

W
wangguibao 已提交
35 36
  DataBuf(void* data, size_t size, bool owned)
      : _data(data), _size(size), _owned(owned) {}
W
wangguibao 已提交
37

W
wangguibao 已提交
38
  void* data() const { return _data; }
W
wangguibao 已提交
39

W
wangguibao 已提交
40
  size_t size() const { return _size; }
W
wangguibao 已提交
41

W
wangguibao 已提交
42 43 44 45
  void free() {
    _size = 0;
    if (_owned) {
      delete[](reinterpret_cast<char*>(_data));
W
wangguibao 已提交
46
    }
W
wangguibao 已提交
47
  }
W
wangguibao 已提交
48

W
wangguibao 已提交
49
  ~DataBuf() { free(); }
W
wangguibao 已提交
50

W
wangguibao 已提交
51 52 53 54
 private:
  void* _data;
  size_t _size;
  bool _owned;
W
wangguibao 已提交
55 56 57
};

struct Tensor {
W
wangguibao 已提交
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
  Tensor() {
    shape.clear();
    for (int li = 0; li < lod.size(); ++li) {
      lod[li].clear();
    }
    lod.clear();
  }

  Tensor(const Tensor& tensor) {
    name = tensor.name;
    data = tensor.data;
    type = tensor.type;
    shape.assign(tensor.shape.begin(), tensor.shape.end());
    for (int li = 0; li < tensor.lod.size(); ++li) {
      std::vector<size_t> l;
      l.assign(tensor.lod[li].begin(), tensor.lod[li].end());
      lod.push_back(l);
    }
  }

  ~Tensor() { shape.clear(); }

  size_t ele_byte() const {
    if (type == INT64) {
      return sizeof(int64_t);
    } else {
      return sizeof(float);
W
wangguibao 已提交
85
    }
W
wangguibao 已提交
86
  }
W
wangguibao 已提交
87

W
wangguibao 已提交
88 89 90 91 92 93 94
  bool valid() const {
    if (shape.empty()) {
      if (data.data() || data.size()) {
        LOG(ERROR) << "data should be empty";
        return false;
      }
      return true;
W
wangguibao 已提交
95 96
    }

W
wangguibao 已提交
97 98 99
    if (!data.data() || !data.size()) {
      LOG(ERROR) << "data cannot empty";
      return false;
W
wangguibao 已提交
100 101
    }

W
wangguibao 已提交
102 103 104
    size_t byte_size = 1;
    for (size_t si = 0; si < shape.size(); ++si) {
      byte_size *= shape[si];
W
wangguibao 已提交
105 106
    }

W
wangguibao 已提交
107 108 109 110
    if (byte_size * ele_byte() != data.size()) {
      LOG(ERROR) << "wrong data size: " << byte_size * ele_byte() << " vs. "
                 << data.size();
      return false;
W
wangguibao 已提交
111 112
    }

W
wangguibao 已提交
113 114 115 116 117 118
    return true;
  }

  size_t shape0() {
    if (shape.empty()) {
      return 0;
W
wangguibao 已提交
119
    }
W
wangguibao 已提交
120 121
    return shape[0];
  }
W
wangguibao 已提交
122

W
wangguibao 已提交
123 124 125 126 127
  std::string name;
  std::vector<int> shape;
  DataBuf data;
  DataType type;
  std::vector<std::vector<size_t>> lod;
W
wangguibao 已提交
128 129 130
};

class BatchTensor {
W
wangguibao 已提交
131 132 133
 public:
  BatchTensor() {}
  ~BatchTensor() { _features.clear(); }
W
wangguibao 已提交
134

W
wangguibao 已提交
135 136 137
  BatchTensor(const BatchTensor& tv) {
    _features.assign(tv.features().begin(), tv.features().end());
  }
W
wangguibao 已提交
138

W
wangguibao 已提交
139
  Tensor& operator[](int index) { return _features[index]; }
W
wangguibao 已提交
140

W
wangguibao 已提交
141
  const Tensor& operator[](int index) const { return _features[index]; }
W
wangguibao 已提交
142

W
wangguibao 已提交
143
  void push_back(const Tensor& tensor) { _features.push_back(tensor); }
W
wangguibao 已提交
144

W
wangguibao 已提交
145
  size_t count() const { return _features.size(); }
W
wangguibao 已提交
146

W
wangguibao 已提交
147 148 149 150
  size_t size() const {
    // shape0 indicates batch_size
    if (count() <= 0 || _features[0].shape.size() <= 0) {
      return 0;
W
wangguibao 已提交
151
    }
W
wangguibao 已提交
152 153
    return _features[0].shape[0];
  }
W
wangguibao 已提交
154

W
wangguibao 已提交
155
  const std::vector<Tensor>& features() const { return _features; }
W
wangguibao 已提交
156

W
wangguibao 已提交
157
  void clear() { _features.clear(); }
W
wangguibao 已提交
158

W
wangguibao 已提交
159 160
 private:
  std::vector<Tensor> _features;
W
wangguibao 已提交
161 162
};

W
wangguibao 已提交
163 164 165
}  // namespace predictor
}  // namespace paddle_serving
}  // namespace baidu