From 4bf6817cbc7b98dd695bd60ac3e7ae6a460ed72f Mon Sep 17 00:00:00 2001 From: superjomn Date: Fri, 16 Nov 2018 20:49:38 +0800 Subject: [PATCH] fix gpu load model the parameters will load from CPUPlace, that will keep copying data between CPU and GPU places. test=develop --- paddle/fluid/inference/analysis/argument.h | 1 + .../analysis/passes/ir_graph_build_pass.cc | 24 ++++++++++++++----- .../analysis/passes/ir_graph_build_pass.h | 8 ++++--- .../fluid/inference/api/analysis_predictor.cc | 4 ++-- 4 files changed, 26 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index d7a2f3d1e..21203e2d9 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -116,6 +116,7 @@ struct Argument { std::vector); DECL_ARGUMENT_FIELD(use_gpu, UseGPU, bool); + DECL_ARGUMENT_FIELD(gpu_device_id, GPUDeviceId, int); DECL_ARGUMENT_FIELD(use_tensorrt, UseTensorRT, bool); DECL_ARGUMENT_FIELD(tensorrt_node_teller, TensorRtNodeTeller, std::function); diff --git a/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc b/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc index a30fef08b..d5e0d90de 100644 --- a/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc +++ b/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc @@ -30,15 +30,28 @@ void IrGraphBuildPass::RunImpl(Argument *argument) { if (!argument->scope_valid()) { argument->SetScope(new framework::Scope); } + PADDLE_ENFORCE(argument->use_gpu_valid()); + + // The load program should run on the same device with the inference program, + // so that the parameters will on the same device, or they will keep copying + // between difference devices. + platform::Place place; + if (argument->use_gpu()) { + PADDLE_ENFORCE(argument->gpu_device_id_valid()); + place = platform::CUDAPlace(argument->gpu_device_id()); + } else { + place = platform::CPUPlace(); + } if (argument->model_dir_valid()) { - auto program = LoadModel(argument->model_dir(), argument->scope_ptr()); + auto program = + LoadModel(argument->model_dir(), argument->scope_ptr(), place); argument->SetMainProgram(program.release()); } else if (argument->model_program_path_valid() && argument->model_params_path_valid()) { auto program = LoadModel(argument->model_program_path(), argument->model_params_path(), - argument->scope_ptr()); + argument->scope_ptr(), place); argument->SetMainProgram(program.release()); } else { PADDLE_THROW( @@ -52,16 +65,15 @@ void IrGraphBuildPass::RunImpl(Argument *argument) { } std::unique_ptr IrGraphBuildPass::LoadModel( - const std::string &path, framework::Scope *scope) { - platform::CPUPlace place; + const std::string &path, framework::Scope *scope, + const platform::Place &place) { framework::Executor exe(place); return Load(&exe, scope, path); } std::unique_ptr IrGraphBuildPass::LoadModel( const std::string &program_path, const std::string ¶ms_path, - framework::Scope *scope) { - platform::CPUPlace place; + framework::Scope *scope, const platform::Place &place) { framework::Executor exe(place); return Load(&exe, scope, program_path, params_path); } diff --git a/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.h b/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.h index 3291e4f6a..b0a0b8b75 100644 --- a/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.h +++ b/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.h @@ -32,11 +32,13 @@ class IrGraphBuildPass : public AnalysisPass { std::string repr() const override; private: - std::unique_ptr LoadModel(const std::string &path, - framework::Scope *scope); + std::unique_ptr LoadModel( + const std::string &path, framework::Scope *scope, + const boost::variant &place); std::unique_ptr LoadModel( const std::string &program_path, const std::string ¶ms_path, - framework::Scope *scope); + framework::Scope *scope, + const boost::variant &place); std::string model_binary_str_; }; diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index d19505877..3a707907d 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -285,6 +285,7 @@ void AnalysisPredictor::OptimizeInferenceProgram() { status_program_optimized_ = true; argument_.SetUseGPU(config_.use_gpu); + argument_.SetGPUDeviceId(config_.device); // Analyze inference_program if (!config_.model_dir.empty()) { argument_.SetModelDir(config_.model_dir); @@ -491,8 +492,7 @@ bool AnalysisPredictor::LoadParameters() { } // Use NaiveExecutor to Load parameters. - platform::CPUPlace place; - framework::NaiveExecutor e(place); + framework::NaiveExecutor e(place_); e.Prepare(scope_.get(), *load_program, 0, false); e.Run(); VLOG(3) << "get " << scope_->LocalVarNames().size() << " vars after load"; -- GitLab