Tensor.h 755 字节
Newer Older
1 2 3 4 5 6 7
#ifndef PADDLE_OPTIMIZER_TENSOR_H_
#define PADDLE_OPTIMIZER_TENSOR_H_
/**
 * @brief tensor used by optimizer
 */

#include <string.h>
8
#include "optimizer.h"
9 10 11 12 13 14 15 16 17 18 19
#include "paddle/math/BaseMatrix.h"

namespace paddle {
namespace optimizer {

template <class T>
using TensorBase = BaseMatrixT<T>;

template <class T>
class Tensor : public TensorBase<T> {
public:
20
  Tensor(T* data, int size) : TensorBase<T>(1, size, 0, data, false, false) {}
21
  T* get_buffer() { return this->data_; }
22 23 24 25
  T& operator[](const int idx) {
    CHECK(idx >= 0 && idx < this->width_) << " out of index range";
    return this->data_[idx];
  }
26
  // TODO: replace with tensorshape
27
  size_t size() const { return this->width_; }
28 29 30 31 32 33
};

}  // namespace optimizer
}  // namespace paddle

#endif