From 405b2486db8bbbafad58e4362ef2dc0be0f74c7e Mon Sep 17 00:00:00 2001 From: Tao Luo Date: Tue, 4 Dec 2018 18:59:49 +0800 Subject: [PATCH] support loading from memory test=develop --- .../fluid/framework/executor_thread_worker.cc | 2 +- paddle/fluid/inference/analysis/argument.h | 1 + .../analysis/passes/ir_graph_build_pass.cc | 7 ++- .../analysis/passes/ir_graph_build_pass.h | 5 +- paddle/fluid/inference/api/analysis_config.cc | 11 ++++ .../fluid/inference/api/analysis_predictor.cc | 38 ++++++------ .../inference/api/paddle_analysis_config.h | 10 ++- paddle/fluid/inference/io.cc | 34 ++++++++--- paddle/fluid/inference/io.h | 8 ++- .../tests/api/analyzer_ner_tester.cc | 24 ++++++-- .../inference/tests/api/config_printer.h | 9 ++- paddle/fluid/operators/impl/CMakeLists.txt | 1 + paddle/fluid/operators/impl/load_combine.cc | 61 +++++++++++++++++++ paddle/fluid/operators/impl/load_combine.h | 33 ++++++++++ paddle/fluid/operators/load_combine_op.cc | 37 +++++++---- 15 files changed, 229 insertions(+), 52 deletions(-) create mode 100644 paddle/fluid/operators/impl/CMakeLists.txt create mode 100644 paddle/fluid/operators/impl/load_combine.cc create mode 100644 paddle/fluid/operators/impl/load_combine.h diff --git a/paddle/fluid/framework/executor_thread_worker.cc b/paddle/fluid/framework/executor_thread_worker.cc index 4e4001e97..e3849931d 100644 --- a/paddle/fluid/framework/executor_thread_worker.cc +++ b/paddle/fluid/framework/executor_thread_worker.cc @@ -97,7 +97,7 @@ void ExecutorThreadWorker::SetDevice() { static unsigned concurrency_cap = std::thread::hardware_concurrency(); int thread_id = this->thread_id_; - if (thread_id < concurrency_cap) { + if ((unsigned)thread_id < concurrency_cap) { unsigned proc = thread_id; cpu_set_t mask; diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index 21203e2d9..d205465ab 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -103,6 +103,7 @@ struct Argument { // Model specified with program and parameters files. DECL_ARGUMENT_FIELD(model_program_path, ModelProgramPath, std::string); DECL_ARGUMENT_FIELD(model_params_path, ModelParamsPath, std::string); + DECL_ARGUMENT_FIELD(is_memory_load, IsMemoryLoad, bool); // The overall graph to work on. DECL_ARGUMENT_UNIQUE_FIELD(main_graph, MainGraph, framework::ir::Graph); diff --git a/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc b/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc index 740030c3a..31138d734 100644 --- a/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc +++ b/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc @@ -46,7 +46,7 @@ void IrGraphBuildPass::RunImpl(Argument *argument) { argument->model_params_path_valid()) { auto program = LoadModel(argument->model_program_path(), argument->model_params_path(), - argument->scope_ptr(), place); + argument->scope_ptr(), place, argument->is_memory_load()); argument->SetMainProgram(program.release()); } else { PADDLE_THROW( @@ -68,9 +68,10 @@ std::unique_ptr IrGraphBuildPass::LoadModel( std::unique_ptr IrGraphBuildPass::LoadModel( const std::string &program_path, const std::string ¶ms_path, - framework::Scope *scope, const platform::Place &place) { + framework::Scope *scope, const platform::Place &place, + bool is_memory_load) { framework::Executor exe(place); - return Load(&exe, scope, program_path, params_path); + return Load(&exe, scope, program_path, params_path, is_memory_load); } std::string IrGraphBuildPass::repr() const { return "ir-graph-build-pass"; } diff --git a/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.h b/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.h index 271e64fce..ae7de676d 100644 --- a/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.h +++ b/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.h @@ -24,7 +24,7 @@ namespace inference { namespace analysis { /* - * Load program and parameter to memory from the disk. + * Load program and parameter to memory from the disk or directly from memory. */ class IrGraphBuildPass : public AnalysisPass { public: @@ -38,7 +38,8 @@ class IrGraphBuildPass : public AnalysisPass { const platform::Place &place); std::unique_ptr LoadModel( const std::string &program_path, const std::string ¶ms_path, - framework::Scope *scope, const platform::Place &place); + framework::Scope *scope, const platform::Place &place, + bool is_memory_load); std::string model_binary_str_; }; diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index dd75f0d9a..bfa1f1908 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -53,6 +53,7 @@ contrib::AnalysisConfig::AnalysisConfig(const contrib::AnalysisConfig &other) { use_tensorrt_ = other.use_tensorrt_; tensorrt_max_batchsize_ = other.tensorrt_max_batchsize_; tensorrt_workspace_size_ = other.tensorrt_workspace_size_; + is_memory_load_ = other.is_memory_load_; if (use_gpu) { pass_builder_.reset(new GpuPassStrategy( @@ -80,6 +81,8 @@ contrib::AnalysisConfig::AnalysisConfig(contrib::AnalysisConfig &&other) { use_tensorrt_ = other.use_tensorrt_; tensorrt_max_batchsize_ = other.tensorrt_max_batchsize_; tensorrt_workspace_size_ = other.tensorrt_workspace_size_; + is_memory_load_ = other.is_memory_load_; + pass_builder_ = std::move(other.pass_builder_); } @@ -102,4 +105,12 @@ void contrib::AnalysisConfig::EnableTensorRtEngine(int workspace_size, pass_builder()->InsertPass(1, "tensorrt_subgraph_pass"); } +void contrib::AnalysisConfig::SetProgBufferAndParamBuffer( + const char *prog_buffer, size_t prog_buffer_size, const char *param_buffer, + size_t param_buffer_size) { + prog_file = std::string(prog_buffer, prog_buffer + prog_buffer_size); + param_file = std::string(param_buffer, param_buffer + param_buffer_size); + is_memory_load_ = true; +} + } // namespace paddle diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 391330a7c..6091c5a62 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -304,20 +304,20 @@ bool AnalysisPredictor::GetFetch(std::vector *outputs, // NOTE All the members in AnalysisConfig should be copied to Argument. void AnalysisPredictor::OptimizeInferenceProgram() { + LOG(INFO) << "optimization program"; status_program_optimized_ = true; argument_.SetUseGPU(config_.use_gpu); argument_.SetGPUDeviceId(config_.device); + argument_.SetIsMemoryLoad(config_.is_memory_load_); // Analyze inference_program if (!config_.model_dir.empty()) { argument_.SetModelDir(config_.model_dir); - } else { - PADDLE_ENFORCE( - !config_.param_file.empty(), - "Either model_dir or (param_file, prog_file) should be set."); - PADDLE_ENFORCE(!config_.prog_file.empty()); + } else if (!config_.param_file.empty() && !config_.prog_file.empty()) { argument_.SetModelProgramPath(config_.prog_file); argument_.SetModelParamsPath(config_.param_file); + } else { + PADDLE_THROW("Either model_dir or (param_file, prog_file) should be set."); } if (config_.use_gpu && config_.use_tensorrt_) { @@ -448,20 +448,23 @@ bool AnalysisPredictor::LoadProgramDesc() { return false; } - std::string pb_content; - // Read binary - std::ifstream fin(filename, std::ios::in | std::ios::binary); - PADDLE_ENFORCE(static_cast(fin), "Cannot open file %s", filename); - fin.seekg(0, std::ios::end); - - pb_content.resize(fin.tellg()); - fin.seekg(0, std::ios::beg); - fin.read(&(pb_content.at(0)), pb_content.size()); - fin.close(); - // Create ProgramDesc framework::proto::ProgramDesc proto; - proto.ParseFromString(pb_content); + if (!config_.is_memory_load()) { + std::string pb_content; + // Read binary + std::ifstream fin(filename, std::ios::in | std::ios::binary); + PADDLE_ENFORCE(static_cast(fin), "Cannot open file %s", filename); + fin.seekg(0, std::ios::end); + pb_content.resize(fin.tellg()); + fin.seekg(0, std::ios::beg); + fin.read(&(pb_content.at(0)), pb_content.size()); + fin.close(); + + proto.ParseFromString(pb_content); + } else { + proto.ParseFromString(config_.prog_file); + } inference_program_.reset(new framework::ProgramDesc(proto)); return true; } @@ -469,6 +472,7 @@ bool AnalysisPredictor::LoadProgramDesc() { bool AnalysisPredictor::LoadParameters() { PADDLE_ENFORCE_NOT_NULL(inference_program_.get(), "The inference program should be loaded first."); + const auto &global_block = inference_program_->MutableBlock(0); // create a temporary program to load parameters. diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index a09bd1cac..4e9d5bc32 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -52,10 +52,15 @@ struct AnalysisConfig : public NativeConfig { bool use_tensorrt() const { return use_tensorrt_; } void EnableMKLDNN(); - // NOTE this is just for internal development, please not use it. - // NOT stable yet. bool use_mkldnn() const { return use_mkldnn_; } + // Specify the memory buffer of program and parameter + void SetProgBufferAndParamBuffer(const char* prog_buffer, + size_t prog_buffer_size, + const char* program_buffer, + size_t program_buffer_size); + bool is_memory_load() const { return is_memory_load_; } + friend class ::paddle::AnalysisPredictor; protected: @@ -64,6 +69,7 @@ struct AnalysisConfig : public NativeConfig { int tensorrt_workspace_size_; int tensorrt_max_batchsize_; std::unique_ptr pass_builder_; + bool is_memory_load_{false}; }; // Configurations for Anakin engine. diff --git a/paddle/fluid/inference/io.cc b/paddle/fluid/inference/io.cc index 31f43bfdc..053c516f8 100644 --- a/paddle/fluid/inference/io.cc +++ b/paddle/fluid/inference/io.cc @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/version.h" +#include "paddle/fluid/operators/impl/load_combine.h" #include "paddle/fluid/platform/cpu_helper.h" #include "paddle/fluid/pybind/pybind.h" @@ -69,7 +70,8 @@ bool IsPersistable(const framework::VarDesc* var) { void LoadPersistables(framework::Executor* executor, framework::Scope* scope, const framework::ProgramDesc& main_program, const std::string& dirname, - const std::string& param_filename) { + const std::string& param_filename, + bool is_memory_load = false) { const framework::BlockDesc& global_block = main_program.Block(0); framework::ProgramDesc* load_program = new framework::ProgramDesc(); @@ -108,6 +110,7 @@ void LoadPersistables(framework::Executor* executor, framework::Scope* scope, op->SetType("load_combine"); op->SetOutput("Out", paramlist); op->SetAttr("file_path", {param_filename}); + op->SetAttr("is_memory_load", {is_memory_load}); op->CheckAttrs(); } @@ -130,16 +133,23 @@ std::unique_ptr Load(framework::Executor* executor, "model version %ld is not supported.", main_program->Version()); - LoadPersistables(executor, scope, *main_program, dirname, ""); + // is_memory_load is false in seperate parameters. + LoadPersistables(executor, scope, *main_program, dirname, "", + false /* is_memory_load */); return main_program; } -std::unique_ptr Load( - framework::Executor* executor, framework::Scope* scope, - const std::string& prog_filename, const std::string& param_filename) { - std::string model_filename = prog_filename; +std::unique_ptr Load(framework::Executor* executor, + framework::Scope* scope, + const std::string& prog_filename, + const std::string& param_filename, + bool is_memory_load = false) { std::string program_desc_str; - ReadBinaryFile(model_filename, &program_desc_str); + if (!is_memory_load) { + ReadBinaryFile(prog_filename, &program_desc_str); + } else { + program_desc_str = prog_filename; + } std::unique_ptr main_program( new framework::ProgramDesc(program_desc_str)); @@ -147,10 +157,18 @@ std::unique_ptr Load( "model version %ld is not supported.", main_program->Version()); - LoadPersistables(executor, scope, *main_program, "", param_filename); + LoadPersistables(executor, scope, *main_program, "", param_filename, + is_memory_load); return main_program; } +std::unique_ptr Load( + framework::Executor* executor, framework::Scope* scope, + const std::string& prog_filename, const std::string& param_filename) { + return Load(executor, scope, prog_filename, param_filename, + false /* is_memory_load */); +} + void SaveVars(const framework::Scope& scope, const std::vector& vars, const std::string& dirname, bool predicate) { diff --git a/paddle/fluid/inference/io.h b/paddle/fluid/inference/io.h index ab492577c..20d635575 100644 --- a/paddle/fluid/inference/io.h +++ b/paddle/fluid/inference/io.h @@ -30,7 +30,7 @@ void Init(const std::vector argv); void LoadPersistables(framework::Executor* executor, framework::Scope* scope, const framework::ProgramDesc& main_program, const std::string& dirname, - const std::string& param_filename); + const std::string& param_filename, bool is_memory_load); std::unique_ptr Load(framework::Executor* executor, framework::Scope* scope, @@ -41,6 +41,12 @@ std::unique_ptr Load(framework::Executor* executor, const std::string& prog_filename, const std::string& param_filename); +std::unique_ptr Load(framework::Executor* executor, + framework::Scope* scope, + const std::string& prog_filename, + const std::string& param_filename, + bool is_memory_load); + // Save the variables from a scope to disk. void SaveVars(const framework::Scope& scope, const std::vector& vars, const std::string& dirname, diff --git a/paddle/fluid/inference/tests/api/analyzer_ner_tester.cc b/paddle/fluid/inference/tests/api/analyzer_ner_tester.cc index 3a5f844de..aaa94c893 100644 --- a/paddle/fluid/inference/tests/api/analyzer_ner_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_ner_tester.cc @@ -93,9 +93,17 @@ void PrepareInputs(std::vector *input_slots, DataRecord *data, } } -void SetConfig(contrib::AnalysisConfig *cfg) { - cfg->prog_file = FLAGS_infer_model + "/__model__"; - cfg->param_file = FLAGS_infer_model + "/param"; +void SetConfig(contrib::AnalysisConfig *cfg, bool memory_load = false) { + if (memory_load) { + std::string buffer_prog, buffer_param; + ReadBinaryFile(FLAGS_infer_model + "/__model__", &buffer_prog); + ReadBinaryFile(FLAGS_infer_model + "/param", &buffer_param); + cfg->SetProgBufferAndParamBuffer(&buffer_prog[0], buffer_prog.size(), + &buffer_param[0], buffer_param.size()); + } else { + cfg->prog_file = FLAGS_infer_model + "/__model__"; + cfg->param_file = FLAGS_infer_model + "/param"; + } cfg->use_gpu = false; cfg->device = 0; cfg->specify_input_name = true; @@ -114,9 +122,9 @@ void SetInput(std::vector> *inputs) { } // Easy for profiling independently. -TEST(Analyzer_Chinese_ner, profile) { +void profile(bool memory_load = false) { contrib::AnalysisConfig cfg; - SetConfig(&cfg); + SetConfig(&cfg, memory_load); std::vector outputs; std::vector> input_slots_all; @@ -138,6 +146,12 @@ TEST(Analyzer_Chinese_ner, profile) { } } +TEST(Analyzer_Chinese_ner, profile) { profile(); } + +TEST(Analyzer_Chinese_ner, profile_memory_load) { + profile(true /* memory_load */); +} + // Check the fuse status TEST(Analyzer_Chinese_ner, fuse_statis) { contrib::AnalysisConfig cfg; diff --git a/paddle/fluid/inference/tests/api/config_printer.h b/paddle/fluid/inference/tests/api/config_printer.h index 4231eef72..c07a27674 100644 --- a/paddle/fluid/inference/tests/api/config_printer.h +++ b/paddle/fluid/inference/tests/api/config_printer.h @@ -49,8 +49,6 @@ std::ostream &operator<<(std::ostream &os, const NativeConfig &config) { os << GenSpaces(num_spaces) << "device: " << config.device << "\n"; os << GenSpaces(num_spaces) << "fraction_of_gpu_memory: " << config.fraction_of_gpu_memory << "\n"; - os << GenSpaces(num_spaces) << "prog_file: " << config.prog_file << "\n"; - os << GenSpaces(num_spaces) << "param_file: " << config.param_file << "\n"; os << GenSpaces(num_spaces) << "specify_input_name: " << config.specify_input_name << "\n"; os << GenSpaces(num_spaces) @@ -65,6 +63,13 @@ std::ostream &operator<<(std::ostream &os, os << GenSpaces(num_spaces) << "contrib::AnalysisConfig {\n"; num_spaces++; os << *reinterpret_cast(&config); + if (!config.is_memory_load()) { + os << GenSpaces(num_spaces) << "prog_file: " << config.prog_file << "\n"; + os << GenSpaces(num_spaces) << "param_file: " << config.param_file << "\n"; + } else { + os << GenSpaces(num_spaces) + << "prog_file and param_file: load from memory \n"; + } os << GenSpaces(num_spaces) << "enable_ir_optim: " << config.enable_ir_optim << "\n"; os << GenSpaces(num_spaces) diff --git a/paddle/fluid/operators/impl/CMakeLists.txt b/paddle/fluid/operators/impl/CMakeLists.txt new file mode 100644 index 000000000..1c0ff7c3d --- /dev/null +++ b/paddle/fluid/operators/impl/CMakeLists.txt @@ -0,0 +1 @@ +cc_library(load_combine_impl SRCS load_combine.cc DEPS scope lod_tensor device_context op_registry data_type_transform) diff --git a/paddle/fluid/operators/impl/load_combine.cc b/paddle/fluid/operators/impl/load_combine.cc new file mode 100644 index 000000000..454d583a1 --- /dev/null +++ b/paddle/fluid/operators/impl/load_combine.cc @@ -0,0 +1,61 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/impl/load_combine.h" + +namespace paddle { +namespace operators { +namespace impl { + +void LoadParamsFromStream(const std::vector &out_var_names, + const paddle::platform::Place &place, + bool load_as_fp16, std::istream *buffer, + const paddle::framework::Scope *scope) { + auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place); + for (size_t i = 0; i < out_var_names.size(); i++) { + auto *out_var = scope->FindVar(out_var_names[i]); + + PADDLE_ENFORCE(out_var != nullptr, "Output variable %s cannot be found", + out_var_names[i]); + + auto *tensor = out_var->GetMutable(); + + // Get data from fin to tensor + DeserializeFromStream(*buffer, tensor, *dev_ctx); + + auto in_dtype = framework::ToDataType(tensor->type()); + auto out_dtype = load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype; + + if (in_dtype != out_dtype) { + // convert to float16 tensor + auto in_kernel_type = framework::OpKernelType(in_dtype, place); + auto out_kernel_type = framework::OpKernelType(out_dtype, place); + framework::LoDTensor fp16_tensor; + // copy LoD info to the new tensor + fp16_tensor.set_lod(tensor->lod()); + framework::TransDataType(in_kernel_type, out_kernel_type, *tensor, + &fp16_tensor); + + // reset output tensor + out_var->Clear(); + tensor = out_var->GetMutable(); + tensor->set_lod(fp16_tensor.lod()); + tensor->ShareDataWith(fp16_tensor); + } + } +} + +} // namespace impl +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/impl/load_combine.h b/paddle/fluid/operators/impl/load_combine.h new file mode 100644 index 000000000..53ffcaf43 --- /dev/null +++ b/paddle/fluid/operators/impl/load_combine.h @@ -0,0 +1,33 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "paddle/fluid/framework/data_type_transform.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/device_context.h" + +namespace paddle { +namespace operators { +namespace impl { + +// Load parameters from a single stream. +void LoadParamsFromStream(const std::vector &out_var_names, + const platform::Place &place, bool load_as_fp16, + std::istream *buffer, const framework::Scope *scope); + +} // namespace impl +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/load_combine_op.cc b/paddle/fluid/operators/load_combine_op.cc index 0522a9419..8cb7c6c9c 100644 --- a/paddle/fluid/operators/load_combine_op.cc +++ b/paddle/fluid/operators/load_combine_op.cc @@ -32,16 +32,26 @@ class LoadCombineOp : public framework::OperatorBase { const platform::Place &place) const override { auto filename = Attr("file_path"); auto load_as_fp16 = Attr("load_as_fp16"); - - std::ifstream fin(filename); - PADDLE_ENFORCE(static_cast(fin), - "Cannot open file %s for load_combine op", filename); - + auto is_memory_load = Attr("is_memory_load"); auto out_var_names = Outputs("Out"); PADDLE_ENFORCE_GT( static_cast(out_var_names.size()), 0, "The number of output variables should be greater than 0."); - + if (!is_memory_load) { + std::ifstream fin(filename); + PADDLE_ENFORCE(static_cast(fin), + "Cannot open file %s for load_combine op", filename); + LoadParamsFromBuffer(scope, place, &fin, load_as_fp16, out_var_names); + } else { + PADDLE_ENFORCE(!filename.empty(), "Cannot load file from memory"); + std::stringstream fin(filename); + LoadParamsFromBuffer(scope, place, &fin, load_as_fp16, out_var_names); + } + } + void LoadParamsFromBuffer( + const framework::Scope &scope, const platform::Place &place, + std::istream *buffer, bool load_as_fp16, + const std::vector &out_var_names) const { platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(place); @@ -54,11 +64,10 @@ class LoadCombineOp : public framework::OperatorBase { auto *tensor = out_var->GetMutable(); // Error checking - PADDLE_ENFORCE(static_cast(fin), "Cannot read more from file %s", - filename); + PADDLE_ENFORCE(static_cast(buffer), "Cannot read more"); // Get data from fin to tensor - DeserializeFromStream(fin, tensor, dev_ctx); + DeserializeFromStream(*buffer, tensor, dev_ctx); auto in_dtype = framework::ToDataType(tensor->type()); auto out_dtype = @@ -103,11 +112,17 @@ class LoadCombineOpProtoMaker : public framework::OpProtoAndCheckerMaker { "LoDTensors will be loaded from \"file_path\".") .AddCustomChecker( [](const std::string &path) { return !path.empty(); }); + AddAttr("is_memory_load", + "(boolean, default false)" + "If true, file_path is in memory, and LoDTensors will be " + "loaded directly from memory") + .SetDefault(false); AddComment(R"DOC( LoadCombine Operator. -LoadCombine operator loads LoDTensor variables from a file. The file should -contain one or more LoDTensors serialized using the SaveCombine operator. The +LoadCombine operator loads LoDTensor variables from a file, which could be +loaded in memory already. The file should contain one or more LoDTensors +serialized using the SaveCombine operator. The LoadCombine operator applies a deserialization strategy to appropriately load the LodTensors, and this strategy complements the serialization strategy used in the SaveCombine operator. Hence, the LoadCombine operator is tightly coupled -- GitLab