Tensor.h 1.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#ifndef PADDLE_OPTIMIZER_TENSOR_H_
#define PADDLE_OPTIMIZER_TENSOR_H_
/**
 * @brief tensor used by optimizer
 */

#include <string.h>
#include "paddle/math/BaseMatrix.h"

namespace paddle {
namespace optimizer {

template <class T>
using TensorBase = BaseMatrixT<T>;

template <class T>
D
dzhwinter 已提交
17
class TensorT : public TensorBase<T> {
18
public:
D
dzhwinter 已提交
19 20 21 22 23 24 25
  TensorT(T* data, int size) : TensorBase<T>(1, size, 0, data, false, false) {}
  TensorT(const TensorT& t)
      : TensorBase<T>(1, t.size(), 0, t.get_buffer(), false, false) {}
  TensorT& operator=(const TensorT& t) {
    this->size_ = t.size();
    this->data_ = t.get_buffer();
  }
26
  T* get_buffer() { return this->data_; }
27
  T& operator[](const int idx) {
D
dzhwinter 已提交
28
    CHECK(idx >= 0 && idx < this->width_) << "out of index range";
29 30
    return this->data_[idx];
  }
31
  // TODO: replace with tensorshape
32
  size_t size() const { return this->width_; }
33 34
};

D
dzhwinter 已提交
35 36 37
// TODO(zhihong): design problem of dynamic datatype, need to fix
typedef TensorT<real> Tensor;

38 39 40 41
}  // namespace optimizer
}  // namespace paddle

#endif