提交 b0f94214 编写于 作者: H huzhiqiang 提交者: GitHub

Fix ‘Large memory usage of Naive model loading’ (#2175)


Fix ‘Large memory usage of Naive model loading’  (#2175)
上级 2f57f5b4
......@@ -82,6 +82,10 @@ Type StdTypeToRepr<double>() {
return Type::_float64;
}
template <>
Type StdTypeToRepr<std::vector<char>>() {
return Type::_char_list;
}
template <>
Type StdTypeToRepr<std::string>() {
return Type::_string;
}
......
......@@ -16,6 +16,7 @@
#include <stack>
#include <string>
#include <vector>
#include "lite/api/paddle_place.h"
#include "lite/utils/all.h"
......@@ -36,7 +37,9 @@ enum class Type {
_float64,
_bool,
_string,
// primary list types
// primary list type
_char_list,
// list types
_list,
// enum type
_enum,
......@@ -89,6 +92,8 @@ Type StdTypeToRepr<float>();
template <>
Type StdTypeToRepr<bool>();
template <>
Type StdTypeToRepr<std::vector<char>>();
template <>
Type StdTypeToRepr<std::string>();
// Factors that impact the kernel picking strategy. Multiple factors can be
......
......@@ -727,10 +727,8 @@ void LoadModelNaiveFromMemory(const std::string &model_buffer,
// Load model
std::string prog_path = model_buffer;
naive_buffer::BinaryTable table;
table.LoadFromMemory(prog_path.c_str(), prog_path.length());
table.LoadFromMemory(model_buffer.c_str(), model_buffer.length());
naive_buffer::proto::ProgramDesc nb_proto_prog(&table);
nb_proto_prog.Load();
......@@ -742,8 +740,7 @@ void LoadModelNaiveFromMemory(const std::string &model_buffer,
// Load Params
// NOTE: Only main block be used now.
// only combined Params are supported in Loading Model from memory
std::string combined_params_path = param_buffer;
LoadCombinedParamsNaive(combined_params_path, scope, *cpp_prog, true);
LoadCombinedParamsNaive(param_buffer, scope, *cpp_prog, true);
VLOG(4) << "Load model from naive buffer memory successfully";
}
......
......@@ -126,6 +126,41 @@ using UInt64Builder = PrimaryBuilder<uint64_t>;
using Float32Builder = PrimaryBuilder<float>;
using Float64Builder = PrimaryBuilder<double>;
template <typename Primary>
class PrimaryListBuilder : public FieldBuilder {
std::vector<Primary> data_;
public:
using value_type = Primary;
explicit PrimaryListBuilder(BinaryTable* table) : FieldBuilder(table) {}
PrimaryListBuilder(BinaryTable* table, const std::vector<Primary>& val)
: FieldBuilder(table), data_(val) {}
/// Set data.
void set(const std::vector<Primary>& x) { data_ = x; }
const std::vector<Primary>& data() const { return data_; }
/// Save information to the corresponding BinaryTable.
void Save() override;
/// Load information from the corresponding BinaryTable.
void Load() override;
/// Number of elements.
size_t size() const { return data_.size(); }
Type type() const override {
return core::StdTypeToRepr<std::vector<Primary>>();
}
/// clear builder
void Clear() { data_.clear(); }
~PrimaryListBuilder() = default;
};
/*
* Builder for all the primary types. int32, float, bool and so on.
*/
......@@ -344,6 +379,36 @@ void PrimaryBuilder<Primary>::Load() {
table()->Consume(sizeof(value_type));
}
template <typename Primary>
void PrimaryListBuilder<Primary>::Load() {
CHECK(data_.empty()) << "Duplicate load";
// Load number of elements first.
uint64_t num_elems{};
memcpy(&num_elems, table()->cursor(), sizeof(uint64_t));
table()->Consume(sizeof(uint64_t));
data_.resize(num_elems);
for (uint64_t i = 0; i < num_elems; i++) {
memcpy(&data_[i], table()->cursor(), sizeof(value_type));
table()->Consume(sizeof(value_type));
}
}
template <typename Primary>
void PrimaryListBuilder<Primary>::Save() {
// store number of elements in the head.
uint64_t num_elems = size();
table()->Require(sizeof(uint64_t));
memcpy(table()->cursor(), &num_elems, sizeof(uint64_t));
table()->Consume(sizeof(uint64_t));
table()->Require(num_elems * sizeof(value_type));
memcpy(table()->cursor(),
reinterpret_cast<byte_t*>(&data_[0]),
num_elems * sizeof(value_type));
table()->Consume(num_elems * sizeof(value_type));
}
template <typename EnumType>
void EnumBuilder<EnumType>::Save() {
value_type holder = static_cast<value_type>(data_);
......
......@@ -149,15 +149,16 @@ void ParamDesc::SetDim(const std::vector<int64_t>& dim) {
CHECK(GetDataType() == VarDescAPI::VarDataType::type__) \
<< "Data Type mismatch"; \
std::vector<T> res; \
auto& data_builder = desc_->GetField<ListBuilder<CharBuilder>>("data"); \
auto data = RepeatedToVector<char, CharBuilder>(data_builder); \
auto& data_builder = desc_->GetField<PrimaryListBuilder<char>>("data"); \
auto& data = data_builder.data(); \
size_t size = data.size() / sizeof(T); \
auto* data_ptr = reinterpret_cast<T*>(&data[0]); \
auto* data_ptr = reinterpret_cast<const T*>(&data[0]); \
for (size_t i = 0; i < size; ++i) { \
res.push_back(data_ptr[i]); \
} \
return res; \
}
GET_DATA_IMPL(uint8_t, UINT8);
GET_DATA_IMPL(int8_t, INT8);
GET_DATA_IMPL(int16_t, INT16);
......@@ -172,14 +173,13 @@ GET_DATA_IMPL(double, FP64);
CHECK(GetDataType() == VarDescAPI::VarDataType::type__) \
<< "Data Type mismatch, call SetDataType first."; \
auto* data_builder = \
desc_->GetMutableField<ListBuilder<CharBuilder>>("data"); \
desc_->GetMutableField<PrimaryListBuilder<char>>("data"); \
CHECK(data_builder); \
data_builder->Clear(); \
size_t size = size__ * sizeof(T); \
auto* data_ptr = reinterpret_cast<const char*>(data_ptr__); \
for (size_t i = 0; i < size; ++i) { \
data_builder->New()->set(data_ptr[i]); \
}
std::vector<char> data_vec(data_ptr, data_ptr + size); \
data_builder->set(data_vec);
#define SET_DATA_IMPL(T, type__) \
template <> \
......
......@@ -191,7 +191,7 @@ class ParamDesc : public StructBuilder {
New<lod_type>("lod");
NewUInt32("tensor_version");
New<TensorDesc>("tensor_desc");
New<ListBuilder<CharBuilder>>("data");
New<PrimaryListBuilder<char>>("data");
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册