提交 b0f94214 编写于 作者: H huzhiqiang 提交者: GitHub

Fix ‘Large memory usage of Naive model loading’ (#2175)


Fix ‘Large memory usage of Naive model loading’  (#2175)
上级 2f57f5b4
...@@ -82,6 +82,10 @@ Type StdTypeToRepr<double>() { ...@@ -82,6 +82,10 @@ Type StdTypeToRepr<double>() {
return Type::_float64; return Type::_float64;
} }
template <> template <>
Type StdTypeToRepr<std::vector<char>>() {
return Type::_char_list;
}
template <>
Type StdTypeToRepr<std::string>() { Type StdTypeToRepr<std::string>() {
return Type::_string; return Type::_string;
} }
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <stack> #include <stack>
#include <string> #include <string>
#include <vector>
#include "lite/api/paddle_place.h" #include "lite/api/paddle_place.h"
#include "lite/utils/all.h" #include "lite/utils/all.h"
...@@ -36,7 +37,9 @@ enum class Type { ...@@ -36,7 +37,9 @@ enum class Type {
_float64, _float64,
_bool, _bool,
_string, _string,
// primary list types // primary list type
_char_list,
// list types
_list, _list,
// enum type // enum type
_enum, _enum,
...@@ -89,6 +92,8 @@ Type StdTypeToRepr<float>(); ...@@ -89,6 +92,8 @@ Type StdTypeToRepr<float>();
template <> template <>
Type StdTypeToRepr<bool>(); Type StdTypeToRepr<bool>();
template <> template <>
Type StdTypeToRepr<std::vector<char>>();
template <>
Type StdTypeToRepr<std::string>(); Type StdTypeToRepr<std::string>();
// Factors that impact the kernel picking strategy. Multiple factors can be // Factors that impact the kernel picking strategy. Multiple factors can be
......
...@@ -727,10 +727,8 @@ void LoadModelNaiveFromMemory(const std::string &model_buffer, ...@@ -727,10 +727,8 @@ void LoadModelNaiveFromMemory(const std::string &model_buffer,
// Load model // Load model
std::string prog_path = model_buffer;
naive_buffer::BinaryTable table; naive_buffer::BinaryTable table;
table.LoadFromMemory(prog_path.c_str(), prog_path.length()); table.LoadFromMemory(model_buffer.c_str(), model_buffer.length());
naive_buffer::proto::ProgramDesc nb_proto_prog(&table); naive_buffer::proto::ProgramDesc nb_proto_prog(&table);
nb_proto_prog.Load(); nb_proto_prog.Load();
...@@ -742,8 +740,7 @@ void LoadModelNaiveFromMemory(const std::string &model_buffer, ...@@ -742,8 +740,7 @@ void LoadModelNaiveFromMemory(const std::string &model_buffer,
// Load Params // Load Params
// NOTE: Only main block be used now. // NOTE: Only main block be used now.
// only combined Params are supported in Loading Model from memory // only combined Params are supported in Loading Model from memory
std::string combined_params_path = param_buffer; LoadCombinedParamsNaive(param_buffer, scope, *cpp_prog, true);
LoadCombinedParamsNaive(combined_params_path, scope, *cpp_prog, true);
VLOG(4) << "Load model from naive buffer memory successfully"; VLOG(4) << "Load model from naive buffer memory successfully";
} }
......
...@@ -126,6 +126,41 @@ using UInt64Builder = PrimaryBuilder<uint64_t>; ...@@ -126,6 +126,41 @@ using UInt64Builder = PrimaryBuilder<uint64_t>;
using Float32Builder = PrimaryBuilder<float>; using Float32Builder = PrimaryBuilder<float>;
using Float64Builder = PrimaryBuilder<double>; using Float64Builder = PrimaryBuilder<double>;
template <typename Primary>
class PrimaryListBuilder : public FieldBuilder {
std::vector<Primary> data_;
public:
using value_type = Primary;
explicit PrimaryListBuilder(BinaryTable* table) : FieldBuilder(table) {}
PrimaryListBuilder(BinaryTable* table, const std::vector<Primary>& val)
: FieldBuilder(table), data_(val) {}
/// Set data.
void set(const std::vector<Primary>& x) { data_ = x; }
const std::vector<Primary>& data() const { return data_; }
/// Save information to the corresponding BinaryTable.
void Save() override;
/// Load information from the corresponding BinaryTable.
void Load() override;
/// Number of elements.
size_t size() const { return data_.size(); }
Type type() const override {
return core::StdTypeToRepr<std::vector<Primary>>();
}
/// clear builder
void Clear() { data_.clear(); }
~PrimaryListBuilder() = default;
};
/* /*
* Builder for all the primary types. int32, float, bool and so on. * Builder for all the primary types. int32, float, bool and so on.
*/ */
...@@ -344,6 +379,36 @@ void PrimaryBuilder<Primary>::Load() { ...@@ -344,6 +379,36 @@ void PrimaryBuilder<Primary>::Load() {
table()->Consume(sizeof(value_type)); table()->Consume(sizeof(value_type));
} }
template <typename Primary>
void PrimaryListBuilder<Primary>::Load() {
CHECK(data_.empty()) << "Duplicate load";
// Load number of elements first.
uint64_t num_elems{};
memcpy(&num_elems, table()->cursor(), sizeof(uint64_t));
table()->Consume(sizeof(uint64_t));
data_.resize(num_elems);
for (uint64_t i = 0; i < num_elems; i++) {
memcpy(&data_[i], table()->cursor(), sizeof(value_type));
table()->Consume(sizeof(value_type));
}
}
template <typename Primary>
void PrimaryListBuilder<Primary>::Save() {
// store number of elements in the head.
uint64_t num_elems = size();
table()->Require(sizeof(uint64_t));
memcpy(table()->cursor(), &num_elems, sizeof(uint64_t));
table()->Consume(sizeof(uint64_t));
table()->Require(num_elems * sizeof(value_type));
memcpy(table()->cursor(),
reinterpret_cast<byte_t*>(&data_[0]),
num_elems * sizeof(value_type));
table()->Consume(num_elems * sizeof(value_type));
}
template <typename EnumType> template <typename EnumType>
void EnumBuilder<EnumType>::Save() { void EnumBuilder<EnumType>::Save() {
value_type holder = static_cast<value_type>(data_); value_type holder = static_cast<value_type>(data_);
......
...@@ -149,15 +149,16 @@ void ParamDesc::SetDim(const std::vector<int64_t>& dim) { ...@@ -149,15 +149,16 @@ void ParamDesc::SetDim(const std::vector<int64_t>& dim) {
CHECK(GetDataType() == VarDescAPI::VarDataType::type__) \ CHECK(GetDataType() == VarDescAPI::VarDataType::type__) \
<< "Data Type mismatch"; \ << "Data Type mismatch"; \
std::vector<T> res; \ std::vector<T> res; \
auto& data_builder = desc_->GetField<ListBuilder<CharBuilder>>("data"); \ auto& data_builder = desc_->GetField<PrimaryListBuilder<char>>("data"); \
auto data = RepeatedToVector<char, CharBuilder>(data_builder); \ auto& data = data_builder.data(); \
size_t size = data.size() / sizeof(T); \ size_t size = data.size() / sizeof(T); \
auto* data_ptr = reinterpret_cast<T*>(&data[0]); \ auto* data_ptr = reinterpret_cast<const T*>(&data[0]); \
for (size_t i = 0; i < size; ++i) { \ for (size_t i = 0; i < size; ++i) { \
res.push_back(data_ptr[i]); \ res.push_back(data_ptr[i]); \
} \ } \
return res; \ return res; \
} }
GET_DATA_IMPL(uint8_t, UINT8); GET_DATA_IMPL(uint8_t, UINT8);
GET_DATA_IMPL(int8_t, INT8); GET_DATA_IMPL(int8_t, INT8);
GET_DATA_IMPL(int16_t, INT16); GET_DATA_IMPL(int16_t, INT16);
...@@ -172,14 +173,13 @@ GET_DATA_IMPL(double, FP64); ...@@ -172,14 +173,13 @@ GET_DATA_IMPL(double, FP64);
CHECK(GetDataType() == VarDescAPI::VarDataType::type__) \ CHECK(GetDataType() == VarDescAPI::VarDataType::type__) \
<< "Data Type mismatch, call SetDataType first."; \ << "Data Type mismatch, call SetDataType first."; \
auto* data_builder = \ auto* data_builder = \
desc_->GetMutableField<ListBuilder<CharBuilder>>("data"); \ desc_->GetMutableField<PrimaryListBuilder<char>>("data"); \
CHECK(data_builder); \ CHECK(data_builder); \
data_builder->Clear(); \ data_builder->Clear(); \
size_t size = size__ * sizeof(T); \ size_t size = size__ * sizeof(T); \
auto* data_ptr = reinterpret_cast<const char*>(data_ptr__); \ auto* data_ptr = reinterpret_cast<const char*>(data_ptr__); \
for (size_t i = 0; i < size; ++i) { \ std::vector<char> data_vec(data_ptr, data_ptr + size); \
data_builder->New()->set(data_ptr[i]); \ data_builder->set(data_vec);
}
#define SET_DATA_IMPL(T, type__) \ #define SET_DATA_IMPL(T, type__) \
template <> \ template <> \
......
...@@ -191,7 +191,7 @@ class ParamDesc : public StructBuilder { ...@@ -191,7 +191,7 @@ class ParamDesc : public StructBuilder {
New<lod_type>("lod"); New<lod_type>("lod");
NewUInt32("tensor_version"); NewUInt32("tensor_version");
New<TensorDesc>("tensor_desc"); New<TensorDesc>("tensor_desc");
New<ListBuilder<CharBuilder>>("data"); New<PrimaryListBuilder<char>>("data");
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册