BufferArg.h 10.9 KB
Newer Older
H
hedaoyuan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <glog/logging.h>

#include "TensorShape.h"
#include "TensorType.h"
#include "paddle/math/Matrix.h"

namespace paddle {

enum BufferType {
X
xutianbing 已提交
26 27 28 29 30
  TENSOR_UNKNOWN = 0,
  TENSOR_NORMAL = 1,
  TENSOR_SEQUENCE_ID = 2,
  TENSOR_SEQUENCE_DATA = 3,
  TENSOR_SPARSE = 4
H
hedaoyuan 已提交
31 32 33 34 35 36
};

class BufferArg;
class SequenceArg;
class SparseMatrixArg;

37 38 39 40 41 42 43 44 45
/**
 * \brief BufferArg used as the argument type of Function.
 *
 * The arguments of the Paddle Function have four Buffer types.
 * 1. BufferArg for a dense Buffer of any dimension.
 * 2. SequenceIdArg for a Buffer of sequence start positions.
 * 3. SequenceArg for a Buffer of sequence data.
 * 4. SparseMatrixArg for a Buffer of sparse matrix.
 *
46 47 48 49 50
 * Buffer shape
 * For most buffers, the first dimension `shape()[0]` represents
 * the size of the mini-batch.
 *
 * Buffer argType
51 52 53 54 55
 * There is an ArgType property for the BufferArg used as Function Output.
 * Whether the result of the Function calculation is assigned to the
 * output Buffer or added to the output Buffer is determined by the
 * argType_ property of the output BufferArg.
 */
56 57 58 59 60 61 62 63 64

// ArgType is only used by output BufferArg.
// For input argument, argType_ is ignored.
// For output argument, need to set the argType_ of the BufferArg.
enum ArgType {
  UNSPECIFIED = 0,
  ASSIGN_TO = 1,
  ADD_TO = 2,
};
H
hedaoyuan 已提交
65
class BufferArg {
66 67 68 69 70
public:
  void setArgType(ArgType argType) { argType_ = argType; }

  ArgType getArgType() const { return argType_; }

H
hedaoyuan 已提交
71
public:
72 73
  BufferArg(ValueType valueType,
            const TensorShape& shape,
X
xutianbing 已提交
74
            ArgType argType = UNSPECIFIED)
75 76 77
      : buf_(nullptr),
        valueType_(valueType),
        shape_(shape),
X
xutianbing 已提交
78
        argType_(argType) {}
79

80 81 82
  BufferArg(void* buf,
            ValueType valueType,
            const TensorShape& shape,
X
xutianbing 已提交
83 84
            ArgType argType = UNSPECIFIED)
      : buf_(buf), valueType_(valueType), shape_(shape), argType_(argType) {}
H
hedaoyuan 已提交
85 86 87 88

  BufferArg(void* buf, ValueType valueType)
      : buf_(buf), valueType_(valueType) {}

89
  BufferArg(const Matrix& matrix, ArgType argType = UNSPECIFIED)
90 91
      : buf_(
            const_cast<void*>(reinterpret_cast<const void*>(matrix.getData()))),
H
hedaoyuan 已提交
92
        valueType_(DataType<real>::value),
93
        shape_(2),
X
xutianbing 已提交
94
        argType_(argType) {
X
xutianbing 已提交
95
    bufferType_ = TENSOR_NORMAL;
H
hedaoyuan 已提交
96 97 98 99
    shape_.setDim(0, matrix.getHeight());
    shape_.setDim(1, matrix.getWidth());
  }

100 101 102
  BufferArg(const Matrix& matrix,
            const TensorShape& shape,
            ArgType argType = UNSPECIFIED)
103 104
      : buf_(
            const_cast<void*>(reinterpret_cast<const void*>(matrix.getData()))),
H
hedaoyuan 已提交
105
        valueType_(DataType<real>::value),
106
        shape_(shape),
X
xutianbing 已提交
107
        argType_(argType) {
X
xutianbing 已提交
108
    bufferType_ = TENSOR_NORMAL;
H
hedaoyuan 已提交
109 110 111
    CHECK_EQ(matrix.getElementCnt(), shape.getElements());
  }

112
  BufferArg(const Vector& vector, ArgType argType = UNSPECIFIED)
113 114
      : buf_(
            const_cast<void*>(reinterpret_cast<const void*>(vector.getData()))),
H
hedaoyuan 已提交
115
        valueType_(DataType<real>::value),
116 117
        shape_(1),
        argType_(argType) {
X
xutianbing 已提交
118
    bufferType_ = TENSOR_NORMAL;
H
hedaoyuan 已提交
119 120 121
    shape_.setDim(0, vector.getSize());
  }

122
  BufferArg(const IVector& vector, ArgType argType = UNSPECIFIED)
123 124
      : buf_(
            const_cast<void*>(reinterpret_cast<const void*>(vector.getData()))),
H
hedaoyuan 已提交
125
        valueType_(VALUE_TYPE_INT32),
126 127
        shape_(1),
        argType_(argType) {
X
xutianbing 已提交
128
    bufferType_ = TENSOR_NORMAL;
H
hedaoyuan 已提交
129 130 131 132 133 134 135 136
    shape_.setDim(0, vector.getSize());
  }

  template <DeviceType DType>
  typename Tensor<real, DType>::Matrix matrix() const {
    CHECK(buf_);
    CHECK(valueType_ == DataType<real>::value);
    // CHECK(deviceType_ == DType);
H
hedaoyuan 已提交
137
    CHECK_EQ((size_t)2, shape_.ndims());
H
hedaoyuan 已提交
138
    return typename Tensor<real, DType>::Matrix(
X
xutianbing 已提交
139
        reinterpret_cast<real*>(buf_), shape_[0], shape_[1]);
H
hedaoyuan 已提交
140 141 142 143 144 145 146
  }

  template <typename VType, DeviceType DType>
  typename Tensor<VType, DType>::Vector vector() const {
    CHECK(buf_);
    CHECK(valueType_ == DataType<VType>::value);
    // CHECK(deviceType_ == DType);
H
hedaoyuan 已提交
147
    CHECK_EQ((size_t)1, shape_.ndims());
H
hedaoyuan 已提交
148 149 150 151 152 153 154 155 156 157 158 159 160 161 162
    return typename Tensor<VType, DType>::Vector(
        shape_[0], reinterpret_cast<VType*>(buf_));
  }

  virtual ~BufferArg() {}

  template <typename T>
  T* data() const {
    return reinterpret_cast<T*>(buf_);
  }

  void* data() const { return buf_; }
  ValueType valueType() const { return valueType_; }
  BufferType bufferType() const { return bufferType_; }
  const TensorShape& shape() const { return shape_; }
163
  bool isSparseArg() const { return TENSOR_SPARSE == bufferType_; }
X
xutianbing 已提交
164
  bool isSequenceArg() const { return TENSOR_SEQUENCE_DATA == bufferType_; }
165
  virtual size_t numElements() const { return shape_.getElements(); }
H
hedaoyuan 已提交
166 167 168 169 170 171 172 173

  const SequenceArg& sequence() const;
  const SparseMatrixArg& sparse() const;

protected:
  void* buf_;
  ValueType valueType_;
  TensorShape shape_;
X
xutianbing 已提交
174 175
  BufferType bufferType_{TENSOR_UNKNOWN};
  ArgType argType_{UNSPECIFIED};
176
  // todo(tianbing), add deviceType_
H
hedaoyuan 已提交
177 178 179 180 181 182 183
  // leading dimensions. The size is dims_.size()
  // Dims lds_;
};

// sequence start positions in a mini-batch of sequences
// shape_.ndims() == 1
// valueType_ = int32
H
hedaoyuan 已提交
184
// if a < b then value_.buf_[a] < value_.buf_[b]
H
hedaoyuan 已提交
185 186
class SequenceIdArg : public BufferArg {
public:
187 188 189
  SequenceIdArg(const TensorShape& shape, ArgType argType = UNSPECIFIED)
      : BufferArg(VALUE_TYPE_INT32, shape, argType) {
    CHECK_EQ(shape_.ndims(), (size_t)1);
H
hedaoyuan 已提交
190
    CHECK_GT(shape_[0], 1);
191 192 193
    numSeqs_ = shape_[0] - 1;
  }

194 195 196 197
  SequenceIdArg(void* buf,
                const TensorShape& shape,
                ArgType argType = UNSPECIFIED)
      : BufferArg(buf, VALUE_TYPE_INT32, shape, argType) {
X
xutianbing 已提交
198
    bufferType_ = TENSOR_SEQUENCE_ID;
H
hedaoyuan 已提交
199
    CHECK_EQ(shape_.ndims(), (size_t)1);
H
hedaoyuan 已提交
200 201 202 203
    numSeqs_ = shape_[0] - 1;
  }

  SequenceIdArg(const IVector& vector) : BufferArg(vector) {
X
xutianbing 已提交
204
    bufferType_ = TENSOR_SEQUENCE_ID;
H
hedaoyuan 已提交
205 206 207 208 209 210 211 212 213 214 215
    numSeqs_ = shape_[0] - 1;
  }

  ~SequenceIdArg() {}

  size_t numSeqs() const { return numSeqs_; }

private:
  size_t numSeqs_;
};

216 217 218 219 220
// sequences data
// For mini-batch calculate,
// one batch can contain more than one sequence of data.
// SequenceArg can be used to represent sequences that contain multiple
// unequal lengths.
H
hedaoyuan 已提交
221 222
class SequenceArg : public BufferArg {
public:
223 224 225 226 227
  SequenceArg(ValueType valueType,
              const TensorShape& shape,
              ArgType argType = UNSPECIFIED)
      : BufferArg(valueType, shape, argType), startPositions_(TensorShape()) {}

H
hedaoyuan 已提交
228 229 230
  SequenceArg(void* buf,
              ValueType valueType,
              const TensorShape& shape,
231 232 233
              const SequenceIdArg& startPositions,
              ArgType argType = UNSPECIFIED)
      : BufferArg(buf, valueType, shape, argType),
X
xutianbing 已提交
234 235 236
        startPositions_(startPositions) {
    bufferType_ = TENSOR_SEQUENCE_DATA;
  }
H
hedaoyuan 已提交
237

238 239 240
  SequenceArg(const Matrix& matrix,
              const IVector& vector,
              ArgType argType = UNSPECIFIED)
X
xutianbing 已提交
241 242 243
      : BufferArg(matrix, argType), startPositions_(vector) {
    bufferType_ = TENSOR_SEQUENCE_DATA;
  }
H
hedaoyuan 已提交
244 245 246 247 248

  ~SequenceArg() {}

  void* getIdBuf() const { return startPositions_.data(); }
  size_t numSeqs() const { return startPositions_.numSeqs(); }
249 250
  SequenceIdArg& getSequenceId() { return startPositions_; }
  const SequenceIdArg& getSequenceId() const { return startPositions_; }
H
hedaoyuan 已提交
251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266

private:
  SequenceIdArg startPositions_;
};

// sparse matrix
// valueType_ == float or double
// shape_.ndims() == 2
class SparseMatrixArg : public BufferArg {
public:
  SparseMatrixArg(void* buf,
                  ValueType valueType,
                  const TensorShape& shape,
                  const BufferArg& row,
                  const BufferArg& col,
                  size_t nnz,
267 268
                  SparseFormat format,
                  SparseValueType type,
X
xutianbing 已提交
269 270
                  ArgType argType = UNSPECIFIED)
      : BufferArg(buf, valueType, shape, argType),
H
hedaoyuan 已提交
271 272 273 274
        row_(row),
        col_(col),
        nnz_(nnz),
        format_(format),
275
        type_(type) {
X
xutianbing 已提交
276
    bufferType_ = TENSOR_SPARSE;
H
hedaoyuan 已提交
277
    CHECK((valueType == VALUE_TYPE_FLOAT) || (valueType == VALUE_TYPE_DOUBLE));
H
hedaoyuan 已提交
278 279 280
    CHECK_EQ(shape_.ndims(), (size_t)2);
    CHECK_EQ(row_.shape().ndims(), (size_t)1);
    CHECK_EQ(col_.shape().ndims(), (size_t)1);
281
    if (format == SPARSE_CSR) {
H
hedaoyuan 已提交
282
      CHECK_EQ(nnz, col.shape()[0]);
283
    } else if (format == SPARSE_CSC) {
H
hedaoyuan 已提交
284 285 286 287
      CHECK_EQ(nnz, row.shape()[0]);
    }
  }

288 289 290 291 292
  SparseMatrixArg(ValueType valueType,
                  const TensorShape& shape,
                  size_t nnz,
                  SparseFormat format,
                  SparseValueType type,
X
xutianbing 已提交
293 294
                  ArgType argType = UNSPECIFIED)
      : BufferArg(valueType, shape, argType),
295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313
        /// len of row_ : height + 1 (CSR), buf_ == nullptr
        row_(format == SPARSE_CSR
                 ? BufferArg(VALUE_TYPE_INT32, TensorShape{shape[0] + 1})
                 : BufferArg(VALUE_TYPE_INT32, TensorShape{nnz})),
        /// len of col_ :  width + 1 (CSC), buf_ == nullptr
        col_(format == SPARSE_CSR
                 ? BufferArg(VALUE_TYPE_INT32, TensorShape{nnz})
                 : BufferArg(VALUE_TYPE_INT32, TensorShape{shape[1] + 1})),
        nnz_(nnz),
        format_(format),
        type_(type) {
    bufferType_ = TENSOR_SPARSE;
    /// todo(tianbing)
    /// valueType and shape_.ndims() == 2 need to check before
    /// this constructor to make sure row_ and col_ are right
    CHECK((valueType == VALUE_TYPE_FLOAT) || (valueType == VALUE_TYPE_DOUBLE));
    CHECK_EQ(shape_.ndims(), (size_t)2);
  }

314
  SparseMatrixArg(const CpuSparseMatrix& sparse, ArgType argType = UNSPECIFIED);
H
hedaoyuan 已提交
315

316
  SparseMatrixArg(const GpuSparseMatrix& sparse, ArgType argType = UNSPECIFIED);
H
hedaoyuan 已提交
317

318 319 320 321 322 323 324 325 326 327 328 329 330
  template <DeviceType DType>
  typename Tensor<real, DType>::SparseMatrix SparseMatrix() const {
    CHECK(buf_);
    CHECK(valueType_ == DataType<real>::value);
    // CHECK(deviceType_ == DType);
    CHECK_EQ(2, shape_.ndims());
    return typename Tensor<real, DType>::SparseMatrix(
        reinterpret_cast<real*>(buf_),
        reinterpret_cast<int*>(row_.data()),
        reinterpret_cast<int*>(col_.data()),
        shape_[0],
        shape_[1],
        nnz_,
331 332
        type_,
        format_,
X
xutianbing 已提交
333
        false);
334 335
  }

H
hedaoyuan 已提交
336 337 338 339 340 341 342 343
  ~SparseMatrixArg() {}

  void* getRowBuf() const { return row_.data(); }

  void* getColBuf() const { return col_.data(); }

  size_t nnz() const { return nnz_; }

344 345
  size_t numElements() const override { return nnz_; }

346
  SparseFormat dataFormat() const { return format_; }
H
hedaoyuan 已提交
347

348
  SparseValueType dataType() const { return type_; }
H
hedaoyuan 已提交
349 350 351 352 353

private:
  BufferArg row_;
  BufferArg col_;
  size_t nnz_;
354 355
  SparseFormat format_;
  SparseValueType type_;
H
hedaoyuan 已提交
356 357 358
};

}  // namespace paddle