From 7a3bb061d8289f38be074da05395bc729eb58035 Mon Sep 17 00:00:00 2001 From: Zhaolong Xing Date: Thu, 9 May 2019 15:39:01 +0800 Subject: [PATCH] fix: (#17279) 1. infernce multi card occupy 2. facebox model inference occupy too much test=develop --- .../inference/analysis/passes/CMakeLists.txt | 2 + .../adjust_cudnn_workspace_size_pass.cc | 43 +++++++++++++++++++ .../passes/adjust_cudnn_workspace_size_pass.h | 41 ++++++++++++++++++ .../fluid/inference/analysis/passes/passes.cc | 3 ++ .../fluid/inference/analysis/passes/passes.h | 2 + .../fluid/inference/api/analysis_predictor.cc | 8 +++- .../fluid/inference/api/paddle_pass_builder.h | 3 +- 7 files changed, 100 insertions(+), 2 deletions(-) create mode 100644 paddle/fluid/inference/analysis/passes/adjust_cudnn_workspace_size_pass.cc create mode 100644 paddle/fluid/inference/analysis/passes/adjust_cudnn_workspace_size_pass.h diff --git a/paddle/fluid/inference/analysis/passes/CMakeLists.txt b/paddle/fluid/inference/analysis/passes/CMakeLists.txt index 9d74dc6c211..a8d0c69a54a 100644 --- a/paddle/fluid/inference/analysis/passes/CMakeLists.txt +++ b/paddle/fluid/inference/analysis/passes/CMakeLists.txt @@ -3,11 +3,13 @@ cc_library(ir_analysis_pass SRCS ir_analysis_pass.cc DEPS analysis_pass argument cc_library(memory_optim_pass SRCS memory_optimize_pass.cc DEPS analysis_pass zero_copy_tensor) cc_library(ir_params_sync_among_devices_pass SRCS ir_params_sync_among_devices_pass.cc DEPS analysis_pass argument ir_pass_manager) cc_library(ir_graph_to_program_pass SRCS ir_graph_to_program_pass.cc DEPS analysis_pass graph_to_program_pass) +cc_library(adjust_cudnn_workspace_size_pass SRCS adjust_cudnn_workspace_size_pass.cc DEPS analysis_pass graph_to_program_pass) cc_library(analysis_passes SRCS passes.cc DEPS ir_graph_build_pass ir_analysis_pass ir_params_sync_among_devices_pass + adjust_cudnn_workspace_size_pass memory_optim_pass ir_graph_to_program_pass ) diff --git a/paddle/fluid/inference/analysis/passes/adjust_cudnn_workspace_size_pass.cc b/paddle/fluid/inference/analysis/passes/adjust_cudnn_workspace_size_pass.cc new file mode 100644 index 00000000000..0470e0d5a24 --- /dev/null +++ b/paddle/fluid/inference/analysis/passes/adjust_cudnn_workspace_size_pass.cc @@ -0,0 +1,43 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/analysis/passes/adjust_cudnn_workspace_size_pass.h" + +namespace paddle { +namespace inference { +namespace analysis { + +void AdjustCudnnWorkSpacePass::RunImpl(Argument* argument) { + if (!argument->use_gpu()) return; + auto& graph = argument->main_graph(); + auto nodes = graph.Nodes(); + const int cudnn_workspace_size_MB = 64; + const std::string attr_name = "workspace_size_MB"; + + for (auto& node : nodes) { + if (!node->IsOp()) continue; + auto* op_desc = node->Op(); + if (!op_desc->HasAttr(attr_name)) continue; + op_desc->SetAttr(attr_name, cudnn_workspace_size_MB); + op_desc->Flush(); + } +} + +std::string AdjustCudnnWorkSpacePass::repr() const { + return "adjust-cudnn-work-space-pass"; +} + +} // namespace analysis +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/analysis/passes/adjust_cudnn_workspace_size_pass.h b/paddle/fluid/inference/analysis/passes/adjust_cudnn_workspace_size_pass.h new file mode 100644 index 00000000000..65d1c545313 --- /dev/null +++ b/paddle/fluid/inference/analysis/passes/adjust_cudnn_workspace_size_pass.h @@ -0,0 +1,41 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/inference/analysis/analysis_pass.h" +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace inference { +namespace analysis { + +/* + * The default cudnn workspace is 4G, we set it to 64M in this pass, which + * is applicable for most inference tasks. + */ +class AdjustCudnnWorkSpacePass : public AnalysisPass { + public: + void RunImpl(Argument *argument) override; + std::string repr() const override; +}; + +} // namespace analysis +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/analysis/passes/passes.cc b/paddle/fluid/inference/analysis/passes/passes.cc index 161b127d6d5..a55904ed536 100644 --- a/paddle/fluid/inference/analysis/passes/passes.cc +++ b/paddle/fluid/inference/analysis/passes/passes.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/inference/analysis/passes/passes.h" +#include "paddle/fluid/inference/analysis/passes/adjust_cudnn_workspace_size_pass.h" #include "paddle/fluid/inference/analysis/passes/ir_analysis_pass.h" #include "paddle/fluid/inference/analysis/passes/ir_graph_build_pass.h" #include "paddle/fluid/inference/analysis/passes/ir_graph_to_program_pass.h" @@ -35,6 +36,8 @@ PassRegistry::PassRegistry() { passes_.emplace( "ir_params_sync_among_devices_pass", std::unique_ptr(new IrParamsSyncAmongDevicesPass)); + passes_.emplace("adjust_cudnn_workspace_size_pass", + std::unique_ptr(new AdjustCudnnWorkSpacePass)); passes_.emplace( "ir_graph_to_program_pass", std::unique_ptr(new IrGraphToProgramPass)); diff --git a/paddle/fluid/inference/analysis/passes/passes.h b/paddle/fluid/inference/analysis/passes/passes.h index ea07e0dcbd9..8a13091d083 100644 --- a/paddle/fluid/inference/analysis/passes/passes.h +++ b/paddle/fluid/inference/analysis/passes/passes.h @@ -14,7 +14,9 @@ #pragma once +#include #include +#include #include "paddle/fluid/inference/analysis/analysis_pass.h" namespace paddle { diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 321107377c2..e57d3a80456 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -120,7 +120,11 @@ bool AnalysisPredictor::PrepareScope( scope_ = parent_scope; status_is_cloned_ = true; } else { - paddle::framework::InitDevices(false); + if (config_.use_gpu_) { + paddle::framework::InitDevices(false, {config_.device_id_}); + } else { + paddle::framework::InitDevices(false, {}); + } scope_.reset(new paddle::framework::Scope()); status_is_cloned_ = false; } @@ -459,6 +463,8 @@ std::unique_ptr CreatePaddlePredictor< std::string flag = "--fraction_of_gpu_memory_to_use=" + std::to_string(fraction_of_gpu_memory); flags.push_back(flag); + flags.push_back("--selected_gpus=" + + std::to_string(config.gpu_device_id())); VLOG(3) << "set flag: " << flag; framework::InitGflags(flags); } diff --git a/paddle/fluid/inference/api/paddle_pass_builder.h b/paddle/fluid/inference/api/paddle_pass_builder.h index 09ef195d5e6..057e7dc65d5 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.h +++ b/paddle/fluid/inference/api/paddle_pass_builder.h @@ -73,7 +73,8 @@ class PaddlePassBuilder { protected: std::vector analysis_passes_{ {"ir_graph_build_pass", "ir_analysis_pass", - "ir_params_sync_among_devices_pass"}}; + "ir_params_sync_among_devices_pass", + "adjust_cudnn_workspace_size_pass"}}; std::vector passes_; }; -- GitLab