未验证 提交 8e7906d0 编写于 作者: H huzhiqiang 提交者: GitHub

fix the issue that: loading model consumes too much time test=decelop (#2726)

* fix the issue that: loading model consumes too much time test=decelop
上级 4430df40
...@@ -128,19 +128,23 @@ using Float64Builder = PrimaryBuilder<double>; ...@@ -128,19 +128,23 @@ using Float64Builder = PrimaryBuilder<double>;
template <typename Primary> template <typename Primary>
class PrimaryListBuilder : public FieldBuilder { class PrimaryListBuilder : public FieldBuilder {
std::vector<Primary> data_; const Primary* data_{nullptr};
int size_{0};
public: public:
using value_type = Primary; using value_type = Primary;
explicit PrimaryListBuilder(BinaryTable* table) : FieldBuilder(table) {} explicit PrimaryListBuilder(BinaryTable* table) : FieldBuilder(table) {}
PrimaryListBuilder(BinaryTable* table, const std::vector<Primary>& val) PrimaryListBuilder(BinaryTable* table, const Primary* val, int size)
: FieldBuilder(table), data_(val) {} : FieldBuilder(table), data_(val), size_(size) {}
/// Set data. /// Set data.
void set(const std::vector<Primary>& x) { data_ = x; } void set(const Primary* x, int size) {
data_ = x;
size_ = size;
}
const std::vector<Primary>& data() const { return data_; } const Primary* data() const { return data_; }
/// Save information to the corresponding BinaryTable. /// Save information to the corresponding BinaryTable.
void Save() override; void Save() override;
...@@ -149,14 +153,12 @@ class PrimaryListBuilder : public FieldBuilder { ...@@ -149,14 +153,12 @@ class PrimaryListBuilder : public FieldBuilder {
void Load() override; void Load() override;
/// Number of elements. /// Number of elements.
size_t size() const { return data_.size(); } size_t size() const { return size_; }
Type type() const override { Type type() const override { return core::StdTypeToRepr<const Primary*>(); }
return core::StdTypeToRepr<std::vector<Primary>>();
}
/// clear builder /// clear builder
void Clear() { data_.clear(); } void Clear() { size_ = 0; }
~PrimaryListBuilder() = default; ~PrimaryListBuilder() = default;
}; };
...@@ -381,17 +383,14 @@ void PrimaryBuilder<Primary>::Load() { ...@@ -381,17 +383,14 @@ void PrimaryBuilder<Primary>::Load() {
template <typename Primary> template <typename Primary>
void PrimaryListBuilder<Primary>::Load() { void PrimaryListBuilder<Primary>::Load() {
CHECK(data_.empty()) << "Duplicate load"; CHECK(data_ == nullptr) << "Duplicate load";
// Load number of elements first. // Load number of elements first.
uint64_t num_elems{}; uint64_t num_elems{};
memcpy(&num_elems, table()->cursor(), sizeof(uint64_t)); memcpy(&num_elems, table()->cursor(), sizeof(uint64_t));
table()->Consume(sizeof(uint64_t)); table()->Consume(sizeof(uint64_t));
data_.resize(num_elems); set(reinterpret_cast<Primary*>(table()->cursor()), num_elems);
for (uint64_t i = 0; i < num_elems; i++) { table()->Consume(num_elems * sizeof(value_type));
memcpy(&data_[i], table()->cursor(), sizeof(value_type));
table()->Consume(sizeof(value_type));
}
} }
template <typename Primary> template <typename Primary>
...@@ -404,7 +403,7 @@ void PrimaryListBuilder<Primary>::Save() { ...@@ -404,7 +403,7 @@ void PrimaryListBuilder<Primary>::Save() {
table()->Require(num_elems * sizeof(value_type)); table()->Require(num_elems * sizeof(value_type));
memcpy(table()->cursor(), memcpy(table()->cursor(),
reinterpret_cast<byte_t*>(&data_[0]), reinterpret_cast<const byte_t*>(data_),
num_elems * sizeof(value_type)); num_elems * sizeof(value_type));
table()->Consume(num_elems * sizeof(value_type)); table()->Consume(num_elems * sizeof(value_type));
} }
......
...@@ -150,9 +150,9 @@ void ParamDesc::SetDim(const std::vector<int64_t>& dim) { ...@@ -150,9 +150,9 @@ void ParamDesc::SetDim(const std::vector<int64_t>& dim) {
<< "Data Type mismatch"; \ << "Data Type mismatch"; \
std::vector<T> res; \ std::vector<T> res; \
auto& data_builder = desc_->GetField<PrimaryListBuilder<char>>("data"); \ auto& data_builder = desc_->GetField<PrimaryListBuilder<char>>("data"); \
auto& data = data_builder.data(); \ auto data = data_builder.data(); \
size_t size = data.size() / sizeof(T); \ size_t size = data_builder.size() / sizeof(T); \
auto* data_ptr = reinterpret_cast<const T*>(&data[0]); \ auto* data_ptr = reinterpret_cast<const T*>(data); \
for (size_t i = 0; i < size; ++i) { \ for (size_t i = 0; i < size; ++i) { \
res.push_back(data_ptr[i]); \ res.push_back(data_ptr[i]); \
} \ } \
...@@ -178,8 +178,7 @@ GET_DATA_IMPL(double, FP64); ...@@ -178,8 +178,7 @@ GET_DATA_IMPL(double, FP64);
data_builder->Clear(); \ data_builder->Clear(); \
size_t size = size__ * sizeof(T); \ size_t size = size__ * sizeof(T); \
auto* data_ptr = reinterpret_cast<const char*>(data_ptr__); \ auto* data_ptr = reinterpret_cast<const char*>(data_ptr__); \
std::vector<char> data_vec(data_ptr, data_ptr + size); \ data_builder->set(data_ptr, size);
data_builder->set(data_vec);
#define SET_DATA_IMPL(T, type__) \ #define SET_DATA_IMPL(T, type__) \
template <> \ template <> \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册