未验证 提交 ec6e0a2c 编写于 作者: H Hui Zhang 提交者: GitHub

jit layer optimzer model param memory usage (#50135)

* jit layer support multi thread
上级 6edc7bba
...@@ -25,9 +25,10 @@ ...@@ -25,9 +25,10 @@
namespace paddle { namespace paddle {
namespace jit { namespace jit {
InterpreterEngine::InterpreterEngine(const std::shared_ptr<FunctionInfo> &info, InterpreterEngine::InterpreterEngine(
const VariableMap &params_dict, const std::shared_ptr<FunctionInfo> &info,
const phi::Place &place) const std::shared_ptr<VariableMap> &params_dict,
const phi::Place &place)
: info_(info), params_dict_(params_dict), place_(place) { : info_(info), params_dict_(params_dict), place_(place) {
info_->RemoveDescFeedFetch(); info_->RemoveDescFeedFetch();
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
......
...@@ -36,7 +36,7 @@ using InterpreterCore = framework::InterpreterCore; ...@@ -36,7 +36,7 @@ using InterpreterCore = framework::InterpreterCore;
class InterpreterEngine : public BaseEngine { class InterpreterEngine : public BaseEngine {
public: public:
InterpreterEngine(const std::shared_ptr<FunctionInfo> &info, InterpreterEngine(const std::shared_ptr<FunctionInfo> &info,
const VariableMap &params_dict, const std::shared_ptr<VariableMap> &params_dict,
const phi::Place &place); const phi::Place &place);
~InterpreterEngine() noexcept {} ~InterpreterEngine() noexcept {}
...@@ -54,7 +54,7 @@ class InterpreterEngine : public BaseEngine { ...@@ -54,7 +54,7 @@ class InterpreterEngine : public BaseEngine {
private: private:
std::shared_ptr<FunctionInfo> info_; std::shared_ptr<FunctionInfo> info_;
VariableMap params_dict_; std::shared_ptr<VariableMap> params_dict_;
framework::Scope scope_; framework::Scope scope_;
phi::Place place_; phi::Place place_;
std::shared_ptr<framework::InterpreterCore> inner_interpreter_; std::shared_ptr<framework::InterpreterCore> inner_interpreter_;
......
...@@ -27,11 +27,15 @@ static bool PaddleTensorToDenseTensor(const PaddleTensor &pt, ...@@ -27,11 +27,15 @@ static bool PaddleTensorToDenseTensor(const PaddleTensor &pt,
DenseTensor *t, DenseTensor *t,
const platform::Place &place); const platform::Place &place);
PredictorEngine::PredictorEngine(const std::shared_ptr<FunctionInfo> &info, PredictorEngine::PredictorEngine(
const VariableMap &params_dict, const std::shared_ptr<FunctionInfo> &info,
const phi::Place &place) const std::shared_ptr<VariableMap> &params_dict,
: info_(info), scope_(new framework::Scope()), place_(place) { const phi::Place &place)
utils::ShareParamsIntoScope(info_->ParamNames(), params_dict, scope_.get()); : info_(info),
params_dict_(params_dict),
scope_(new framework::Scope()),
place_(place) {
utils::ShareParamsIntoScope(info_->ParamNames(), params_dict_, scope_.get());
VLOG(6) << framework::GenScopeTreeDebugInfo(scope_.get()); VLOG(6) << framework::GenScopeTreeDebugInfo(scope_.get());
// TODO(Aurelius84): Expose AnalysisConfig to user. // TODO(Aurelius84): Expose AnalysisConfig to user.
...@@ -66,6 +70,12 @@ PredictorEngine::PredictorEngine( ...@@ -66,6 +70,12 @@ PredictorEngine::PredictorEngine(
predictor_(std::dynamic_pointer_cast<AnalysisPredictor, PaddlePredictor>( predictor_(std::dynamic_pointer_cast<AnalysisPredictor, PaddlePredictor>(
predictor)) {} predictor)) {}
std::unique_ptr<BaseEngine> PredictorEngine::Clone(void *stream) {
auto *x = new PredictorEngine(
info_, scope_, place_, std::move(predictor_->Clone(stream)));
return std::unique_ptr<BaseEngine>(x);
}
std::vector<Tensor> PredictorEngine::operator()( std::vector<Tensor> PredictorEngine::operator()(
const std::vector<Tensor> &inputs) { const std::vector<Tensor> &inputs) {
auto dense_tensors = utils::ToDenseTensors(inputs); auto dense_tensors = utils::ToDenseTensors(inputs);
...@@ -199,11 +209,5 @@ static bool PaddleTensorToDenseTensor(const PaddleTensor &pt, ...@@ -199,11 +209,5 @@ static bool PaddleTensorToDenseTensor(const PaddleTensor &pt,
return true; return true;
} }
std::unique_ptr<BaseEngine> PredictorEngine::Clone(void *stream) {
auto *x = new PredictorEngine(
info_, scope_, place_, std::move(predictor_->Clone(stream)));
return std::unique_ptr<BaseEngine>(x);
}
} // namespace jit } // namespace jit
} // namespace paddle } // namespace paddle
...@@ -31,7 +31,7 @@ namespace jit { ...@@ -31,7 +31,7 @@ namespace jit {
class PredictorEngine : public BaseEngine { class PredictorEngine : public BaseEngine {
public: public:
PredictorEngine(const std::shared_ptr<FunctionInfo> &info, PredictorEngine(const std::shared_ptr<FunctionInfo> &info,
const VariableMap &params_dict, const std::shared_ptr<VariableMap> &params_dict,
const phi::Place &place); const phi::Place &place);
PredictorEngine(const std::shared_ptr<FunctionInfo> &info, PredictorEngine(const std::shared_ptr<FunctionInfo> &info,
...@@ -50,6 +50,7 @@ class PredictorEngine : public BaseEngine { ...@@ -50,6 +50,7 @@ class PredictorEngine : public BaseEngine {
private: private:
std::shared_ptr<FunctionInfo> info_; std::shared_ptr<FunctionInfo> info_;
std::shared_ptr<VariableMap> params_dict_;
std::shared_ptr<framework::Scope> scope_; std::shared_ptr<framework::Scope> scope_;
phi::Place place_; phi::Place place_;
std::shared_ptr<AnalysisPredictor> predictor_; std::shared_ptr<AnalysisPredictor> predictor_;
......
...@@ -71,18 +71,18 @@ void ShareIntoScope(const std::vector<std::string> &ordered_input_names, ...@@ -71,18 +71,18 @@ void ShareIntoScope(const std::vector<std::string> &ordered_input_names,
} }
void ShareParamsIntoScope(const std::vector<std::string> &param_names, void ShareParamsIntoScope(const std::vector<std::string> &param_names,
const VariableMap &params_dict, const std::shared_ptr<VariableMap> &params_dict,
framework::Scope *scope) { framework::Scope *scope) {
for (size_t i = 0; i < param_names.size(); ++i) { for (size_t i = 0; i < param_names.size(); ++i) {
std::string name = param_names[i]; std::string name = param_names[i];
PADDLE_ENFORCE_EQ(params_dict.count(name), PADDLE_ENFORCE_EQ(params_dict->count(name),
1, 1,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"Parameter named %s is not existed in params_dict. " "Parameter named %s is not existed in params_dict. "
"Please check that your model was saved correctly", "Please check that your model was saved correctly",
name)); name));
auto &param = params_dict.find(name)->second; auto &param = params_dict->find(name)->second;
auto &dense_tensor = param->Get<DenseTensor>(); auto &dense_tensor = param->Get<DenseTensor>();
auto *var = scope->Var(name); auto *var = scope->Var(name);
auto *dst_tensor = var->GetMutable<DenseTensor>(); auto *dst_tensor = var->GetMutable<DenseTensor>();
......
...@@ -51,14 +51,14 @@ void ShareIntoScope(const std::vector<std::string> &ordered_input_names, ...@@ -51,14 +51,14 @@ void ShareIntoScope(const std::vector<std::string> &ordered_input_names,
framework::Scope *scope); framework::Scope *scope);
void ShareParamsIntoScope(const std::vector<std::string> &param_names, void ShareParamsIntoScope(const std::vector<std::string> &param_names,
const VariableMap &params_dict, const std::shared_ptr<VariableMap> &params_dict,
framework::Scope *scope); framework::Scope *scope);
void RemoveFeedFetch(framework::ProgramDesc *program_desc); void RemoveFeedFetch(framework::ProgramDesc *program_desc);
template <typename T> template <typename T>
std::shared_ptr<T> MakeEngine(const std::shared_ptr<FunctionInfo> &info, std::shared_ptr<T> MakeEngine(const std::shared_ptr<FunctionInfo> &info,
const VariableMap &params_dict, const std::shared_ptr<VariableMap> &params_dict,
const phi::Place &place) { const phi::Place &place) {
return std::make_shared<T>(info, params_dict, place); return std::make_shared<T>(info, params_dict, place);
} }
......
...@@ -26,8 +26,8 @@ ...@@ -26,8 +26,8 @@
namespace paddle { namespace paddle {
namespace jit { namespace jit {
Layer::Layer(const VariableMap& params_map, Layer::Layer(const std::shared_ptr<VariableMap>& params_map,
const VariableMap& attrs_map, const std::shared_ptr<VariableMap>& attrs_map,
const FunctionInfoMap& info_map, const FunctionInfoMap& info_map,
const phi::Place& place) const phi::Place& place)
: params_map_(params_map), : params_map_(params_map),
...@@ -80,12 +80,12 @@ std::vector<std::string> Layer::FunctionNames() const { ...@@ -80,12 +80,12 @@ std::vector<std::string> Layer::FunctionNames() const {
#define PD_SPECIALZE_ATTRIBUTE_TYPE(T) \ #define PD_SPECIALZE_ATTRIBUTE_TYPE(T) \
template <> \ template <> \
T Layer::Attribute<T>(const std::string& name) const { \ T Layer::Attribute<T>(const std::string& name) const { \
if (attrs_map_.find(name) == attrs_map_.end()) { \ if (attrs_map_->find(name) == attrs_map_->end()) { \
PADDLE_THROW(phi::errors::NotFound( \ PADDLE_THROW(phi::errors::NotFound( \
"Attribute can not found %s, please check if it exists.")); \ "Attribute can not found %s, please check if it exists.")); \
return T(); \ return T(); \
} \ } \
auto var = attrs_map_.at(name); \ auto var = attrs_map_->at(name); \
T ret = var->Get<T>(); \ T ret = var->Get<T>(); \
return ret; \ return ret; \
} }
......
...@@ -43,8 +43,8 @@ using FunctionInfoMap = ...@@ -43,8 +43,8 @@ using FunctionInfoMap =
class Layer { class Layer {
public: public:
Layer(const VariableMap& params_map, Layer(const std::shared_ptr<VariableMap>& params_map,
const VariableMap& attrs_map_, const std::shared_ptr<VariableMap>& attrs_map_,
const FunctionInfoMap& info_map, const FunctionInfoMap& info_map,
const phi::Place& place); const phi::Place& place);
...@@ -70,8 +70,8 @@ class Layer { ...@@ -70,8 +70,8 @@ class Layer {
std::shared_ptr<Layer> Clone(void* stream = nullptr); std::shared_ptr<Layer> Clone(void* stream = nullptr);
private: private:
VariableMap params_map_; std::shared_ptr<VariableMap> params_map_;
VariableMap attrs_map_; std::shared_ptr<VariableMap> attrs_map_;
FunctionInfoMap info_map_; FunctionInfoMap info_map_;
phi::Place place_; phi::Place place_;
std::shared_ptr<CompilationUnit> unit_; std::shared_ptr<CompilationUnit> unit_;
......
...@@ -58,12 +58,12 @@ Layer Deserializer::operator()(const std::string& path, ...@@ -58,12 +58,12 @@ Layer Deserializer::operator()(const std::string& path,
info_map[func_name]->SetProgramFilePath(it.second); info_map[func_name]->SetProgramFilePath(it.second);
} }
VariableMap params_dict; auto params_dict = std::make_shared<VariableMap>();
VariableMap attrs_dict; auto attrs_dict = std::make_shared<VariableMap>();
ReadTensorData(path + PDPARAMS_SUFFIX, param_names_set, place, &params_dict); ReadTensorData(path + PDPARAMS_SUFFIX, param_names_set, place, params_dict);
if (utils::FileExists(path + PROPERTY_SUFFIX)) { if (utils::FileExists(path + PROPERTY_SUFFIX)) {
ReadAttributeData(path + PROPERTY_SUFFIX, &attrs_dict); ReadAttributeData(path + PROPERTY_SUFFIX, attrs_dict);
VLOG(3) << "Read Property Success!"; VLOG(3) << "Read Property Success!";
} }
...@@ -90,10 +90,11 @@ Layer Deserializer::operator()(const std::string& path, ...@@ -90,10 +90,11 @@ Layer Deserializer::operator()(const std::string& path,
return layer; return layer;
} }
void Deserializer::ReadTensorData(const std::string& file_name, void Deserializer::ReadTensorData(
const std::set<std::string>& var_name, const std::string& file_name,
const phi::Place& place, const std::set<std::string>& var_name,
VariableMap* params_dict) const { const phi::Place& place,
std::shared_ptr<VariableMap> params_dict) const {
VLOG(3) << "ReadTensorData from: " << file_name; VLOG(3) << "ReadTensorData from: " << file_name;
std::ifstream fin(file_name, std::ios::binary); std::ifstream fin(file_name, std::ios::binary);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
...@@ -108,12 +109,15 @@ void Deserializer::ReadTensorData(const std::string& file_name, ...@@ -108,12 +109,15 @@ void Deserializer::ReadTensorData(const std::string& file_name,
} }
} }
void Deserializer::ReadAttributeData(const std::string& file_path, void Deserializer::ReadAttributeData(
VariableMap* attrs_dict) const { const std::string& file_path,
std::shared_ptr<VariableMap> attrs_dict) const {
VLOG(3) << "ReadPropertyData from: " << file_path; VLOG(3) << "ReadPropertyData from: " << file_path;
Property p; Property p;
p.Deserialization(file_path); p.Deserialization(file_path);
*attrs_dict = static_cast<VariableMap>(p.Values()); for (auto& it : p.Values()) {
attrs_dict->emplace(it.first, it.second);
}
return; return;
} }
......
...@@ -55,11 +55,11 @@ class Deserializer { ...@@ -55,11 +55,11 @@ class Deserializer {
void ReadTensorData(const std::string& file_name, void ReadTensorData(const std::string& file_name,
const std::set<std::string>& var_name, const std::set<std::string>& var_name,
const phi::Place& place, const phi::Place& place,
VariableMap* params_dict) const; std::shared_ptr<VariableMap> params_dict) const;
// property pb // property pb
void ReadAttributeData(const std::string& file_path, void ReadAttributeData(const std::string& file_path,
VariableMap* attrs_dict) const; std::shared_ptr<VariableMap> attrs_dict) const;
// void ReadExtraInfo(const std::string& file_name) const; // void ReadExtraInfo(const std::string& file_name) const;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册