提交 743cb840 编写于 作者: T Tao Luo

update with comments

test=develop
上级 42359e88
...@@ -97,7 +97,7 @@ void ExecutorThreadWorker::SetDevice() { ...@@ -97,7 +97,7 @@ void ExecutorThreadWorker::SetDevice() {
static unsigned concurrency_cap = std::thread::hardware_concurrency(); static unsigned concurrency_cap = std::thread::hardware_concurrency();
int thread_id = this->thread_id_; int thread_id = this->thread_id_;
if ((unsigned)thread_id < concurrency_cap) { if (static_cast<unsigned>(thread_id) < concurrency_cap) {
unsigned proc = thread_id; unsigned proc = thread_id;
cpu_set_t mask; cpu_set_t mask;
......
...@@ -103,7 +103,7 @@ struct Argument { ...@@ -103,7 +103,7 @@ struct Argument {
// Model specified with program and parameters files. // Model specified with program and parameters files.
DECL_ARGUMENT_FIELD(model_program_path, ModelProgramPath, std::string); DECL_ARGUMENT_FIELD(model_program_path, ModelProgramPath, std::string);
DECL_ARGUMENT_FIELD(model_params_path, ModelParamsPath, std::string); DECL_ARGUMENT_FIELD(model_params_path, ModelParamsPath, std::string);
DECL_ARGUMENT_FIELD(is_memory_load, IsMemoryLoad, bool); DECL_ARGUMENT_FIELD(model_from_memory, ModelFromMemory, bool);
// The overall graph to work on. // The overall graph to work on.
DECL_ARGUMENT_UNIQUE_FIELD(main_graph, MainGraph, framework::ir::Graph); DECL_ARGUMENT_UNIQUE_FIELD(main_graph, MainGraph, framework::ir::Graph);
......
...@@ -46,7 +46,7 @@ void IrGraphBuildPass::RunImpl(Argument *argument) { ...@@ -46,7 +46,7 @@ void IrGraphBuildPass::RunImpl(Argument *argument) {
argument->model_params_path_valid()) { argument->model_params_path_valid()) {
auto program = auto program =
LoadModel(argument->model_program_path(), argument->model_params_path(), LoadModel(argument->model_program_path(), argument->model_params_path(),
argument->scope_ptr(), place, argument->is_memory_load()); argument->scope_ptr(), place, argument->model_from_memory());
argument->SetMainProgram(program.release()); argument->SetMainProgram(program.release());
} else { } else {
PADDLE_THROW( PADDLE_THROW(
...@@ -69,9 +69,13 @@ std::unique_ptr<framework::ProgramDesc> IrGraphBuildPass::LoadModel( ...@@ -69,9 +69,13 @@ std::unique_ptr<framework::ProgramDesc> IrGraphBuildPass::LoadModel(
std::unique_ptr<framework::ProgramDesc> IrGraphBuildPass::LoadModel( std::unique_ptr<framework::ProgramDesc> IrGraphBuildPass::LoadModel(
const std::string &program_path, const std::string &params_path, const std::string &program_path, const std::string &params_path,
framework::Scope *scope, const platform::Place &place, framework::Scope *scope, const platform::Place &place,
bool is_memory_load) { bool model_from_memory) {
framework::Executor exe(place); framework::Executor exe(place);
return Load(&exe, scope, program_path, params_path, is_memory_load); if (!model_from_memory) {
return Load(&exe, scope, program_path, params_path);
} else {
return LoadFromMemory(&exe, scope, program_path, params_path);
}
} }
std::string IrGraphBuildPass::repr() const { return "ir-graph-build-pass"; } std::string IrGraphBuildPass::repr() const { return "ir-graph-build-pass"; }
......
...@@ -39,7 +39,7 @@ class IrGraphBuildPass : public AnalysisPass { ...@@ -39,7 +39,7 @@ class IrGraphBuildPass : public AnalysisPass {
std::unique_ptr<framework::ProgramDesc> LoadModel( std::unique_ptr<framework::ProgramDesc> LoadModel(
const std::string &program_path, const std::string &params_path, const std::string &program_path, const std::string &params_path,
framework::Scope *scope, const platform::Place &place, framework::Scope *scope, const platform::Place &place,
bool is_memory_load); bool model_from_memory);
std::string model_binary_str_; std::string model_binary_str_;
}; };
......
...@@ -53,7 +53,7 @@ contrib::AnalysisConfig::AnalysisConfig(const contrib::AnalysisConfig &other) { ...@@ -53,7 +53,7 @@ contrib::AnalysisConfig::AnalysisConfig(const contrib::AnalysisConfig &other) {
use_tensorrt_ = other.use_tensorrt_; use_tensorrt_ = other.use_tensorrt_;
tensorrt_max_batchsize_ = other.tensorrt_max_batchsize_; tensorrt_max_batchsize_ = other.tensorrt_max_batchsize_;
tensorrt_workspace_size_ = other.tensorrt_workspace_size_; tensorrt_workspace_size_ = other.tensorrt_workspace_size_;
is_memory_load_ = other.is_memory_load_; model_from_memory_ = other.model_from_memory_;
if (use_gpu) { if (use_gpu) {
pass_builder_.reset(new GpuPassStrategy( pass_builder_.reset(new GpuPassStrategy(
...@@ -81,7 +81,7 @@ contrib::AnalysisConfig::AnalysisConfig(contrib::AnalysisConfig &&other) { ...@@ -81,7 +81,7 @@ contrib::AnalysisConfig::AnalysisConfig(contrib::AnalysisConfig &&other) {
use_tensorrt_ = other.use_tensorrt_; use_tensorrt_ = other.use_tensorrt_;
tensorrt_max_batchsize_ = other.tensorrt_max_batchsize_; tensorrt_max_batchsize_ = other.tensorrt_max_batchsize_;
tensorrt_workspace_size_ = other.tensorrt_workspace_size_; tensorrt_workspace_size_ = other.tensorrt_workspace_size_;
is_memory_load_ = other.is_memory_load_; model_from_memory_ = other.model_from_memory_;
pass_builder_ = std::move(other.pass_builder_); pass_builder_ = std::move(other.pass_builder_);
} }
...@@ -105,12 +105,13 @@ void contrib::AnalysisConfig::EnableTensorRtEngine(int workspace_size, ...@@ -105,12 +105,13 @@ void contrib::AnalysisConfig::EnableTensorRtEngine(int workspace_size,
pass_builder()->InsertPass(1, "tensorrt_subgraph_pass"); pass_builder()->InsertPass(1, "tensorrt_subgraph_pass");
} }
void contrib::AnalysisConfig::SetProgBufferAndParamBuffer( void contrib::AnalysisConfig::SetModelBuffer(const char *prog_buffer,
const char *prog_buffer, size_t prog_buffer_size, const char *param_buffer, size_t prog_buffer_size,
size_t param_buffer_size) { const char *param_buffer,
size_t param_buffer_size) {
prog_file = std::string(prog_buffer, prog_buffer + prog_buffer_size); prog_file = std::string(prog_buffer, prog_buffer + prog_buffer_size);
param_file = std::string(param_buffer, param_buffer + param_buffer_size); param_file = std::string(param_buffer, param_buffer + param_buffer_size);
is_memory_load_ = true; model_from_memory_ = true;
} }
} // namespace paddle } // namespace paddle
...@@ -308,7 +308,7 @@ void AnalysisPredictor::OptimizeInferenceProgram() { ...@@ -308,7 +308,7 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
argument_.SetUseGPU(config_.use_gpu); argument_.SetUseGPU(config_.use_gpu);
argument_.SetGPUDeviceId(config_.device); argument_.SetGPUDeviceId(config_.device);
argument_.SetIsMemoryLoad(config_.is_memory_load_); argument_.SetModelFromMemory(config_.model_from_memory_);
// Analyze inference_program // Analyze inference_program
if (!config_.model_dir.empty()) { if (!config_.model_dir.empty()) {
argument_.SetModelDir(config_.model_dir); argument_.SetModelDir(config_.model_dir);
...@@ -451,11 +451,12 @@ bool AnalysisPredictor::LoadProgramDesc() { ...@@ -451,11 +451,12 @@ bool AnalysisPredictor::LoadProgramDesc() {
// Create ProgramDesc // Create ProgramDesc
framework::proto::ProgramDesc proto; framework::proto::ProgramDesc proto;
if (!config_.is_memory_load()) { if (!config_.model_from_memory()) {
std::string pb_content; std::string pb_content;
// Read binary // Read binary
std::ifstream fin(filename, std::ios::in | std::ios::binary); std::ifstream fin(filename, std::ios::in | std::ios::binary);
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s", filename); PADDLE_ENFORCE(static_cast<bool>(fin.is_open()), "Cannot open file %s",
filename);
fin.seekg(0, std::ios::end); fin.seekg(0, std::ios::end);
pb_content.resize(fin.tellg()); pb_content.resize(fin.tellg());
fin.seekg(0, std::ios::beg); fin.seekg(0, std::ios::beg);
......
...@@ -55,11 +55,9 @@ struct AnalysisConfig : public NativeConfig { ...@@ -55,11 +55,9 @@ struct AnalysisConfig : public NativeConfig {
bool use_mkldnn() const { return use_mkldnn_; } bool use_mkldnn() const { return use_mkldnn_; }
// Specify the memory buffer of program and parameter // Specify the memory buffer of program and parameter
void SetProgBufferAndParamBuffer(const char* prog_buffer, void SetModelBuffer(const char* prog_buffer, size_t prog_buffer_size,
size_t prog_buffer_size, const char* program_buffer, size_t program_buffer_size);
const char* program_buffer, bool model_from_memory() const { return model_from_memory_; }
size_t program_buffer_size);
bool is_memory_load() const { return is_memory_load_; }
friend class ::paddle::AnalysisPredictor; friend class ::paddle::AnalysisPredictor;
...@@ -69,7 +67,7 @@ struct AnalysisConfig : public NativeConfig { ...@@ -69,7 +67,7 @@ struct AnalysisConfig : public NativeConfig {
int tensorrt_workspace_size_; int tensorrt_workspace_size_;
int tensorrt_max_batchsize_; int tensorrt_max_batchsize_;
std::unique_ptr<PassStrategy> pass_builder_; std::unique_ptr<PassStrategy> pass_builder_;
bool is_memory_load_{false}; bool model_from_memory_{false};
}; };
// Configurations for Anakin engine. // Configurations for Anakin engine.
......
...@@ -70,7 +70,7 @@ void LoadPersistables(framework::Executor* executor, framework::Scope* scope, ...@@ -70,7 +70,7 @@ void LoadPersistables(framework::Executor* executor, framework::Scope* scope,
const framework::ProgramDesc& main_program, const framework::ProgramDesc& main_program,
const std::string& dirname, const std::string& dirname,
const std::string& param_filename, const std::string& param_filename,
bool is_memory_load = false) { bool model_from_memory = false) {
const framework::BlockDesc& global_block = main_program.Block(0); const framework::BlockDesc& global_block = main_program.Block(0);
framework::ProgramDesc* load_program = new framework::ProgramDesc(); framework::ProgramDesc* load_program = new framework::ProgramDesc();
...@@ -109,7 +109,7 @@ void LoadPersistables(framework::Executor* executor, framework::Scope* scope, ...@@ -109,7 +109,7 @@ void LoadPersistables(framework::Executor* executor, framework::Scope* scope,
op->SetType("load_combine"); op->SetType("load_combine");
op->SetOutput("Out", paramlist); op->SetOutput("Out", paramlist);
op->SetAttr("file_path", {param_filename}); op->SetAttr("file_path", {param_filename});
op->SetAttr("is_memory_load", {is_memory_load}); op->SetAttr("model_from_memory", {model_from_memory});
op->CheckAttrs(); op->CheckAttrs();
} }
...@@ -132,23 +132,17 @@ std::unique_ptr<framework::ProgramDesc> Load(framework::Executor* executor, ...@@ -132,23 +132,17 @@ std::unique_ptr<framework::ProgramDesc> Load(framework::Executor* executor,
"model version %ld is not supported.", "model version %ld is not supported.",
main_program->Version()); main_program->Version());
// is_memory_load is false in seperate parameters. // model_from_memory is false in seperate parameters.
LoadPersistables(executor, scope, *main_program, dirname, "", LoadPersistables(executor, scope, *main_program, dirname, "",
false /* is_memory_load */); false /* model_from_memory */);
return main_program; return main_program;
} }
std::unique_ptr<framework::ProgramDesc> Load(framework::Executor* executor, std::unique_ptr<framework::ProgramDesc> Load(
framework::Scope* scope, framework::Executor* executor, framework::Scope* scope,
const std::string& prog_filename, const std::string& prog_filename, const std::string& param_filename) {
const std::string& param_filename,
bool is_memory_load = false) {
std::string program_desc_str; std::string program_desc_str;
if (!is_memory_load) { ReadBinaryFile(prog_filename, &program_desc_str);
ReadBinaryFile(prog_filename, &program_desc_str);
} else {
program_desc_str = prog_filename;
}
std::unique_ptr<framework::ProgramDesc> main_program( std::unique_ptr<framework::ProgramDesc> main_program(
new framework::ProgramDesc(program_desc_str)); new framework::ProgramDesc(program_desc_str));
...@@ -157,15 +151,22 @@ std::unique_ptr<framework::ProgramDesc> Load(framework::Executor* executor, ...@@ -157,15 +151,22 @@ std::unique_ptr<framework::ProgramDesc> Load(framework::Executor* executor,
main_program->Version()); main_program->Version());
LoadPersistables(executor, scope, *main_program, "", param_filename, LoadPersistables(executor, scope, *main_program, "", param_filename,
is_memory_load); false /* model_from_memory */);
return main_program; return main_program;
} }
std::unique_ptr<framework::ProgramDesc> Load( std::unique_ptr<framework::ProgramDesc> LoadFromMemory(
framework::Executor* executor, framework::Scope* scope, framework::Executor* executor, framework::Scope* scope,
const std::string& prog_filename, const std::string& param_filename) { const std::string& prog_buffer, const std::string& param_buffer) {
return Load(executor, scope, prog_filename, param_filename, std::unique_ptr<framework::ProgramDesc> main_program(
false /* is_memory_load */); new framework::ProgramDesc(prog_buffer));
PADDLE_ENFORCE(framework::IsProgramVersionSupported(main_program->Version()),
"model version %ld is not supported.",
main_program->Version());
LoadPersistables(executor, scope, *main_program, "", param_buffer,
true /* model_filename */);
return main_program;
} }
void SaveVars(const framework::Scope& scope, void SaveVars(const framework::Scope& scope,
......
...@@ -30,7 +30,8 @@ void Init(const std::vector<std::string> argv); ...@@ -30,7 +30,8 @@ void Init(const std::vector<std::string> argv);
void LoadPersistables(framework::Executor* executor, framework::Scope* scope, void LoadPersistables(framework::Executor* executor, framework::Scope* scope,
const framework::ProgramDesc& main_program, const framework::ProgramDesc& main_program,
const std::string& dirname, const std::string& dirname,
const std::string& param_filename, bool is_memory_load); const std::string& param_filename,
bool model_from_memory);
std::unique_ptr<framework::ProgramDesc> Load(framework::Executor* executor, std::unique_ptr<framework::ProgramDesc> Load(framework::Executor* executor,
framework::Scope* scope, framework::Scope* scope,
...@@ -41,11 +42,9 @@ std::unique_ptr<framework::ProgramDesc> Load(framework::Executor* executor, ...@@ -41,11 +42,9 @@ std::unique_ptr<framework::ProgramDesc> Load(framework::Executor* executor,
const std::string& prog_filename, const std::string& prog_filename,
const std::string& param_filename); const std::string& param_filename);
std::unique_ptr<framework::ProgramDesc> Load(framework::Executor* executor, std::unique_ptr<framework::ProgramDesc> LoadFromMemory(
framework::Scope* scope, framework::Executor* executor, framework::Scope* scope,
const std::string& prog_filename, const std::string& prog_buffer, const std::string& param_buffer);
const std::string& param_filename,
bool is_memory_load);
// Save the variables from a scope to disk. // Save the variables from a scope to disk.
void SaveVars(const framework::Scope& scope, void SaveVars(const framework::Scope& scope,
......
...@@ -98,8 +98,8 @@ void SetConfig(contrib::AnalysisConfig *cfg, bool memory_load = false) { ...@@ -98,8 +98,8 @@ void SetConfig(contrib::AnalysisConfig *cfg, bool memory_load = false) {
std::string buffer_prog, buffer_param; std::string buffer_prog, buffer_param;
ReadBinaryFile(FLAGS_infer_model + "/__model__", &buffer_prog); ReadBinaryFile(FLAGS_infer_model + "/__model__", &buffer_prog);
ReadBinaryFile(FLAGS_infer_model + "/param", &buffer_param); ReadBinaryFile(FLAGS_infer_model + "/param", &buffer_param);
cfg->SetProgBufferAndParamBuffer(&buffer_prog[0], buffer_prog.size(), cfg->SetModelBuffer(&buffer_prog[0], buffer_prog.size(), &buffer_param[0],
&buffer_param[0], buffer_param.size()); buffer_param.size());
} else { } else {
cfg->prog_file = FLAGS_infer_model + "/__model__"; cfg->prog_file = FLAGS_infer_model + "/__model__";
cfg->param_file = FLAGS_infer_model + "/param"; cfg->param_file = FLAGS_infer_model + "/param";
......
...@@ -63,7 +63,7 @@ std::ostream &operator<<(std::ostream &os, ...@@ -63,7 +63,7 @@ std::ostream &operator<<(std::ostream &os,
os << GenSpaces(num_spaces) << "contrib::AnalysisConfig {\n"; os << GenSpaces(num_spaces) << "contrib::AnalysisConfig {\n";
num_spaces++; num_spaces++;
os << *reinterpret_cast<const NativeConfig *>(&config); os << *reinterpret_cast<const NativeConfig *>(&config);
if (!config.is_memory_load()) { if (!config.model_from_memory()) {
os << GenSpaces(num_spaces) << "prog_file: " << config.prog_file << "\n"; os << GenSpaces(num_spaces) << "prog_file: " << config.prog_file << "\n";
os << GenSpaces(num_spaces) << "param_file: " << config.param_file << "\n"; os << GenSpaces(num_spaces) << "param_file: " << config.param_file << "\n";
} else { } else {
......
...@@ -32,12 +32,12 @@ class LoadCombineOp : public framework::OperatorBase { ...@@ -32,12 +32,12 @@ class LoadCombineOp : public framework::OperatorBase {
const platform::Place &place) const override { const platform::Place &place) const override {
auto filename = Attr<std::string>("file_path"); auto filename = Attr<std::string>("file_path");
auto load_as_fp16 = Attr<bool>("load_as_fp16"); auto load_as_fp16 = Attr<bool>("load_as_fp16");
auto is_memory_load = Attr<bool>("is_memory_load"); auto model_from_memory = Attr<bool>("model_from_memory");
auto out_var_names = Outputs("Out"); auto out_var_names = Outputs("Out");
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
static_cast<int>(out_var_names.size()), 0, static_cast<int>(out_var_names.size()), 0,
"The number of output variables should be greater than 0."); "The number of output variables should be greater than 0.");
if (!is_memory_load) { if (!model_from_memory) {
std::ifstream fin(filename); std::ifstream fin(filename);
PADDLE_ENFORCE(static_cast<bool>(fin), PADDLE_ENFORCE(static_cast<bool>(fin),
"Cannot open file %s for load_combine op", filename); "Cannot open file %s for load_combine op", filename);
...@@ -112,7 +112,7 @@ class LoadCombineOpProtoMaker : public framework::OpProtoAndCheckerMaker { ...@@ -112,7 +112,7 @@ class LoadCombineOpProtoMaker : public framework::OpProtoAndCheckerMaker {
"LoDTensors will be loaded from \"file_path\".") "LoDTensors will be loaded from \"file_path\".")
.AddCustomChecker( .AddCustomChecker(
[](const std::string &path) { return !path.empty(); }); [](const std::string &path) { return !path.empty(); });
AddAttr<bool>("is_memory_load", AddAttr<bool>("model_from_memory",
"(boolean, default false)" "(boolean, default false)"
"If true, file_path is in memory, and LoDTensors will be " "If true, file_path is in memory, and LoDTensors will be "
"loaded directly from memory") "loaded directly from memory")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册