提交 216ebcb9 编写于 作者: H He Wei

Enable Lazy allocation for Tensor

1. Tensor data is lazy allocated if no data to be copied;
2. Add prefix to Tensor id to distinguish with other id.
上级 e40ba6f8
...@@ -35,7 +35,7 @@ using Bool = unsigned char; ...@@ -35,7 +35,7 @@ using Bool = unsigned char;
static std::string MakeId() { static std::string MakeId() {
// Use atomic to make id generator thread safe. // Use atomic to make id generator thread safe.
static std::atomic<uint64_t> last_id{1}; static std::atomic<uint64_t> last_id{1};
return std::to_string(last_id.fetch_add(1, std::memory_order_relaxed)); return "T" + std::to_string(last_id.fetch_add(1, std::memory_order_relaxed));
} }
static TypeId TypeIdOf(const TypePtr &data_type, TypeId defaultTypeId) { static TypeId TypeIdOf(const TypePtr &data_type, TypeId defaultTypeId) {
...@@ -127,41 +127,47 @@ std::vector<T> CopyData(const std::vector<int> &shape, void *data, size_t data_l ...@@ -127,41 +127,47 @@ std::vector<T> CopyData(const std::vector<int> &shape, void *data, size_t data_l
template <typename T> template <typename T>
class TensorDataImpl : public TensorData { class TensorDataImpl : public TensorData {
public: public:
explicit TensorDataImpl(const std::vector<int> &shape) : shape_(shape), data_(SizeOf(shape)) {} explicit TensorDataImpl(const std::vector<int> &shape) : ndim_(shape.size()), data_size_(SizeOf(shape)) {}
TensorDataImpl(const std::vector<int> &shape, void *data, size_t data_len) TensorDataImpl(const std::vector<int> &shape, void *data, size_t data_len)
: shape_(shape), data_(CopyData<T>(shape, data, data_len)) {} : ndim_(shape.size()), data_size_(SizeOf(shape)), data_(CopyData<T>(shape, data, data_len)) {}
TensorDataImpl(const std::vector<int> &shape, void *data, TypeId data_type) TensorDataImpl(const std::vector<int> &shape, void *data, TypeId data_type)
: shape_(shape), data_(CopyData<T>(shape, data, data_type)) {} : ndim_(shape.size()), data_size_(SizeOf(shape)), data_(CopyData<T>(shape, data, data_type)) {}
template <typename InputIt> template <typename InputIt>
TensorDataImpl(const std::vector<int> &shape, InputIt first, InputIt last) : shape_(shape), data_(first, last) {} TensorDataImpl(const std::vector<int> &shape, InputIt first, InputIt last)
: ndim_(shape.size()), data_size_(SizeOf(shape)), data_(first, last) {}
template <typename Scalar> template <typename Scalar>
TensorDataImpl(const std::vector<int> &shape, Scalar scalar) : shape_(shape), data_({static_cast<T>(scalar)}) {} TensorDataImpl(const std::vector<int> &shape, Scalar scalar)
: ndim_(shape.size()), data_size_(SizeOf(shape)), data_({static_cast<T>(scalar)}) {}
ssize_t size() const override { return data_.size(); } ssize_t size() const override { return static_cast<ssize_t>(data_size_); }
ssize_t itemsize() const override { return static_cast<ssize_t>(sizeof(T)); } ssize_t itemsize() const override { return static_cast<ssize_t>(sizeof(T)); }
ssize_t nbytes() const override { return size() * itemsize(); } ssize_t nbytes() const override { return size() * itemsize(); }
ssize_t ndim() const override { return static_cast<ssize_t>(shape_.size()); } ssize_t ndim() const override { return static_cast<ssize_t>(ndim_); }
void *data() override { void *data() override {
static std::vector<T> empty_data(1); static std::vector<T> empty_data(1);
if (data_.empty()) { if (data_size_ == 0) {
// Prevent null pointer for empty data. // Prevent null pointer for empty shape.
return empty_data.data(); return empty_data.data();
} }
if (data_.empty()) {
// Lazy allocation.
data_.resize(data_size_);
}
return data_.data(); return data_.data();
} }
bool equals(const TensorData &other) const override { bool equals(const TensorData &other) const override {
auto ptr = dynamic_cast<const TensorDataImpl<T> *>(&other); auto ptr = dynamic_cast<const TensorDataImpl<T> *>(&other);
if (ptr) { if (ptr) {
return (ptr == this) || ((shape_ == ptr->shape_) && (data_ == ptr->data_)); return (ptr == this) || ((ndim_ == ptr->ndim_) && (data_size_ == ptr->data_size_) && (data_ == ptr->data_));
} }
return false; return false;
} }
...@@ -177,7 +183,8 @@ class TensorDataImpl : public TensorData { ...@@ -177,7 +183,8 @@ class TensorDataImpl : public TensorData {
} }
private: private:
std::vector<int> shape_; size_t ndim_{0};
size_t data_size_{0};
std::vector<T> data_; std::vector<T> data_;
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册