// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/inference/analysis/passes/ir_graph_build_pass.h" #include #include #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/inference/io.h" #include "paddle/fluid/platform/enforce.h" namespace paddle { namespace inference { extern void ReadBinaryFile(const std::string &filename, std::string *contents); namespace analysis { void IrGraphBuildPass::RunImpl(Argument *argument) { if (!argument->scope_valid()) { argument->SetScope(new framework::Scope); } PADDLE_ENFORCE(argument->use_gpu_valid()); // The load program should run on the same device with the inference program, // so that the parameters will on the same device, or they will keep copying // between difference devices. platform::Place place; if (argument->use_gpu()) { PADDLE_ENFORCE(argument->gpu_device_id_valid()); place = platform::CUDAPlace(argument->gpu_device_id()); } else { place = platform::CPUPlace(); } if (argument->model_dir_valid()) { auto program = LoadModel(argument->model_dir(), argument->scope_ptr(), place); argument->SetMainProgram(program.release()); } else if (argument->model_program_path_valid() && argument->model_params_path_valid()) { auto program = LoadModel(argument->model_program_path(), argument->model_params_path(), argument->scope_ptr(), place); argument->SetMainProgram(program.release()); } else { PADDLE_THROW( "either model_dir or (program path and parameter path) should be set."); } auto graph = std::unique_ptr(new Graph(argument->main_program())); argument->SetMainGraph(graph.release()); argument->main_graph().Set(framework::ir::kParamScopeAttr, new framework::Scope *(argument->scope_ptr())); } std::unique_ptr IrGraphBuildPass::LoadModel( const std::string &path, framework::Scope *scope, const platform::Place &place) { framework::Executor exe(place); return Load(&exe, scope, path); } std::unique_ptr IrGraphBuildPass::LoadModel( const std::string &program_path, const std::string ¶ms_path, framework::Scope *scope, const platform::Place &place) { framework::Executor exe(place); return Load(&exe, scope, program_path, params_path); } std::string IrGraphBuildPass::repr() const { return "ir-graph-build-pass"; } } // namespace analysis } // namespace inference } // namespace paddle