From fd79951348d6b51ce172799bb158da4193301a94 Mon Sep 17 00:00:00 2001 From: haozech Date: Wed, 17 Jun 2020 11:14:15 +0000 Subject: [PATCH] reconstrcut infershape --- lite/core/tensor.h | 119 ++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 113 insertions(+), 6 deletions(-) diff --git a/lite/core/tensor.h b/lite/core/tensor.h index 2209e524f4..0fdd72458d 100644 --- a/lite/core/tensor.h +++ b/lite/core/tensor.h @@ -38,28 +38,123 @@ class TensorLite; using DDim = lite::DDimLite; using Tensor = lite::TensorLite; +template +class SmallVector { + public: + SmallVector() { + // VLOG(3)<<"call constructor"; + data_ = new ValueType[initLength](); + // data_ = static_cast(malloc(DimLength * + // sizeof(ValueType))); + // data_.resize(DimLength); + // memset(data_, 0, DimLength * sizeof(ValueType)); + size_ = 0U; + memory_size = initLength; + } + + ~SmallVector() { + // VLOG(3)<<"call deconstructor"; + if (data_ != nullptr) { + delete[] data_; + // free(data_); + } + data_ = nullptr; + size_ = 0U; + memory_size = 0U; + } + + size_t size() const { + // VLOG(3)<<"call size()"; + return size_; + } + void resize(size_t new_size) { + // VLOG(3)<<"call resize()"; + if (new_size > memory_size) { + if (data_ != nullptr) { + delete[] data_; + } + data_ = new ValueType[new_size](); + memory_size = new_size; + } + size_ = new_size; + } + + ValueType *mutable_data() { return data_; } + const ValueType *data() const { + // VLOG(3)<<"call data()"; + return data_; + } + + ValueType operator[](int offset) const { + // VLOG(3)<<"call operator[]"; + return data_[offset]; + } + ValueType &operator[](int offset) { + // VLOG(3)<<"call &operator[]"; + return data_[offset]; + } + + private: + // ValueType data_[DimLength]; + // ValueType* data_{nullptr}; + ValueType *data_{nullptr}; + size_t size_{0U}; + size_t memory_size{0U}; +}; + class DDimLite { public: + constexpr static size_t init_length = 4; using value_type = int64_t; + using DDimVector = SmallVector; DDimLite() = default; + DDimLite(const DDimLite &a) { + data_.resize(a.size()); + if (a.size() > 0U) { + memcpy( + data_.mutable_data(), a.data().data(), a.size() * sizeof(value_type)); + } // deep copy + } explicit DDimLite(const std::vector &x) { ConstructFrom(x); } // DDimLite(std::initializer_list init_list) : // DDimLite(std::vector(init_list)) {} - void ConstructFrom(const std::vector &x) { data_ = x; } - + void ConstructFrom(const std::vector &x) { + data_.resize(x.size()); + if (x.size() > 0U) { + memcpy(data_.mutable_data(), x.data(), x.size() * sizeof(value_type)); + // std::copy(x.data(), x.data() + x.size(), data_.mutable_data()); + } + } value_type operator[](int offset) const { return data_[offset]; } value_type &operator[](int offset) { return data_[offset]; } - std::vector Vectorize() const { return data_; } + std::vector Vectorize() const { + std::vector vec; + vec.resize(data_.size()); + if (data_.size() > 0U) { + memcpy(vec.data(), data_.data(), data_.size() * sizeof(value_type)); + // std::copy(data_.data(), data_.data() + data_.size(), vec.data()); + } + return vec; + } size_t size() const { return data_.size(); } - bool empty() const { return data_.empty(); } + void resize(size_t size) { data_.resize(size); } + bool empty() const { return data_.size() == 0U; } value_type production() const; - const std::vector &data() const { return data_; } + const std::vector data() const { + std::vector vec; + vec.resize(data_.size()); + if (data_.size() > 0U) { + memcpy(vec.data(), data_.data(), data_.size() * sizeof(value_type)); + // std::copy(data_.data(), data_.data() + data_.size(), vec.data()); + } + return vec; + } value_type count(int start, int end) const; DDimLite Slice(int start, int end) const; @@ -76,6 +171,18 @@ class DDimLite { return os; } + DDimLite &operator=(const DDimLite &a) { + this->data_.resize(a.size()); + if (a.size() > 0U) { + // std::copy(a.data().data(), a.data().data() + a.data().size(), + // this->data_.mutable_data()); + memcpy(this->data_.mutable_data(), + a.data().data(), + a.size() * sizeof(value_type)); + } + return *this; + } + friend bool operator==(const DDimLite &a, const DDimLite &b) { if (a.size() != b.size()) return false; for (size_t i = 0; i < a.size(); i++) { @@ -93,7 +200,7 @@ class DDimLite { } private: - std::vector data_; + DDimVector data_; }; using LoD = std::vector>; -- GitLab