提交 fd799513 编写于 作者: H haozech

reconstrcut infershape

上级 07ae2599
...@@ -38,28 +38,123 @@ class TensorLite; ...@@ -38,28 +38,123 @@ class TensorLite;
using DDim = lite::DDimLite; using DDim = lite::DDimLite;
using Tensor = lite::TensorLite; using Tensor = lite::TensorLite;
template <typename ValueType, size_t initLength>
class SmallVector {
public:
SmallVector() {
// VLOG(3)<<"call constructor";
data_ = new ValueType[initLength]();
// data_ = static_cast<ValueType *>(malloc(DimLength *
// sizeof(ValueType)));
// data_.resize(DimLength);
// memset(data_, 0, DimLength * sizeof(ValueType));
size_ = 0U;
memory_size = initLength;
}
~SmallVector() {
// VLOG(3)<<"call deconstructor";
if (data_ != nullptr) {
delete[] data_;
// free(data_);
}
data_ = nullptr;
size_ = 0U;
memory_size = 0U;
}
size_t size() const {
// VLOG(3)<<"call size()";
return size_;
}
void resize(size_t new_size) {
// VLOG(3)<<"call resize()";
if (new_size > memory_size) {
if (data_ != nullptr) {
delete[] data_;
}
data_ = new ValueType[new_size]();
memory_size = new_size;
}
size_ = new_size;
}
ValueType *mutable_data() { return data_; }
const ValueType *data() const {
// VLOG(3)<<"call data()";
return data_;
}
ValueType operator[](int offset) const {
// VLOG(3)<<"call operator[]";
return data_[offset];
}
ValueType &operator[](int offset) {
// VLOG(3)<<"call &operator[]";
return data_[offset];
}
private:
// ValueType data_[DimLength];
// ValueType* data_{nullptr};
ValueType *data_{nullptr};
size_t size_{0U};
size_t memory_size{0U};
};
class DDimLite { class DDimLite {
public: public:
constexpr static size_t init_length = 4;
using value_type = int64_t; using value_type = int64_t;
using DDimVector = SmallVector<value_type, init_length>;
DDimLite() = default; DDimLite() = default;
DDimLite(const DDimLite &a) {
data_.resize(a.size());
if (a.size() > 0U) {
memcpy(
data_.mutable_data(), a.data().data(), a.size() * sizeof(value_type));
} // deep copy
}
explicit DDimLite(const std::vector<value_type> &x) { ConstructFrom(x); } explicit DDimLite(const std::vector<value_type> &x) { ConstructFrom(x); }
// DDimLite(std::initializer_list<value_type> init_list) : // DDimLite(std::initializer_list<value_type> init_list) :
// DDimLite(std::vector<value_type>(init_list)) {} // DDimLite(std::vector<value_type>(init_list)) {}
void ConstructFrom(const std::vector<value_type> &x) { data_ = x; } void ConstructFrom(const std::vector<value_type> &x) {
data_.resize(x.size());
if (x.size() > 0U) {
memcpy(data_.mutable_data(), x.data(), x.size() * sizeof(value_type));
// std::copy(x.data(), x.data() + x.size(), data_.mutable_data());
}
}
value_type operator[](int offset) const { return data_[offset]; } value_type operator[](int offset) const { return data_[offset]; }
value_type &operator[](int offset) { return data_[offset]; } value_type &operator[](int offset) { return data_[offset]; }
std::vector<int64_t> Vectorize() const { return data_; } std::vector<value_type> Vectorize() const {
std::vector<value_type> vec;
vec.resize(data_.size());
if (data_.size() > 0U) {
memcpy(vec.data(), data_.data(), data_.size() * sizeof(value_type));
// std::copy(data_.data(), data_.data() + data_.size(), vec.data());
}
return vec;
}
size_t size() const { return data_.size(); } size_t size() const { return data_.size(); }
bool empty() const { return data_.empty(); } void resize(size_t size) { data_.resize(size); }
bool empty() const { return data_.size() == 0U; }
value_type production() const; value_type production() const;
const std::vector<value_type> &data() const { return data_; } const std::vector<value_type> data() const {
std::vector<value_type> vec;
vec.resize(data_.size());
if (data_.size() > 0U) {
memcpy(vec.data(), data_.data(), data_.size() * sizeof(value_type));
// std::copy(data_.data(), data_.data() + data_.size(), vec.data());
}
return vec;
}
value_type count(int start, int end) const; value_type count(int start, int end) const;
DDimLite Slice(int start, int end) const; DDimLite Slice(int start, int end) const;
...@@ -76,6 +171,18 @@ class DDimLite { ...@@ -76,6 +171,18 @@ class DDimLite {
return os; return os;
} }
DDimLite &operator=(const DDimLite &a) {
this->data_.resize(a.size());
if (a.size() > 0U) {
// std::copy(a.data().data(), a.data().data() + a.data().size(),
// this->data_.mutable_data());
memcpy(this->data_.mutable_data(),
a.data().data(),
a.size() * sizeof(value_type));
}
return *this;
}
friend bool operator==(const DDimLite &a, const DDimLite &b) { friend bool operator==(const DDimLite &a, const DDimLite &b) {
if (a.size() != b.size()) return false; if (a.size() != b.size()) return false;
for (size_t i = 0; i < a.size(); i++) { for (size_t i = 0; i < a.size(); i++) {
...@@ -93,7 +200,7 @@ class DDimLite { ...@@ -93,7 +200,7 @@ class DDimLite {
} }
private: private:
std::vector<value_type> data_; DDimVector data_;
}; };
using LoD = std::vector<std::vector<uint64_t>>; using LoD = std::vector<std::vector<uint64_t>>;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册