Tensor.h 593 字节
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
#ifndef PADDLE_OPTIMIZER_TENSOR_H_
#define PADDLE_OPTIMIZER_TENSOR_H_
/**
 * @brief tensor used by optimizer
 */

#include <string.h>
#include "paddle/math/BaseMatrix.h"

namespace paddle {
namespace optimizer {

template <class T>
using TensorBase = BaseMatrixT<T>;

template <class T>
class Tensor : public TensorBase<T> {
public:
  Tensor(T* data, int size) : TensorBase<T>(size, 1, 0, data, false, false) {}
  T* get_buffer() { return this->data_; }
  // TODO: replace with tensorshape
  size_t width() { return this->width_; }
};

}  // namespace optimizer
}  // namespace paddle

#endif