From 06fee998e3d727d3e40fb3e70ed7083fe2053c66 Mon Sep 17 00:00:00 2001 From: baoachun <962571062@qq.com> Date: Thu, 17 Mar 2022 11:03:58 +0800 Subject: [PATCH] support gpu mixed precision inference (#40531) --- paddle/fluid/framework/ir/CMakeLists.txt | 1 + .../ir/mixed_precision_configure_pass.cc | 149 ++++++++++++++++++ .../ir/mixed_precision_configure_pass.h | 39 +++++ paddle/fluid/inference/analysis/argument.h | 3 + .../inference/analysis/ir_pass_manager.cc | 4 + .../ir_params_sync_among_devices_pass.cc | 66 ++++++-- .../ir_params_sync_among_devices_pass.h | 7 +- paddle/fluid/inference/api/analysis_config.cc | 33 ++++ .../fluid/inference/api/analysis_predictor.cc | 5 + .../api/analysis_predictor_tester.cc | 26 +++ .../inference/api/paddle_analysis_config.h | 16 ++ .../inference/api/paddle_pass_builder.cc | 34 ++++ .../fluid/inference/api/paddle_pass_builder.h | 12 ++ paddle/fluid/pybind/inference_api.cc | 3 + 14 files changed, 385 insertions(+), 13 deletions(-) create mode 100644 paddle/fluid/framework/ir/mixed_precision_configure_pass.cc create mode 100644 paddle/fluid/framework/ir/mixed_precision_configure_pass.h diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 623c8a048c2..7aaaef712a6 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -97,6 +97,7 @@ pass_library(layer_norm_fuse_pass inference) pass_library(add_support_int8_pass inference) pass_library(matmul_scale_fuse_pass inference) pass_library(gpu_cpu_map_matmul_to_mul_pass inference) +pass_library(mixed_precision_configure_pass inference) pass_library(generate_pass DEPS pass_desc_proto) target_link_libraries(generate_pass pass_desc_proto) diff --git a/paddle/fluid/framework/ir/mixed_precision_configure_pass.cc b/paddle/fluid/framework/ir/mixed_precision_configure_pass.cc new file mode 100644 index 00000000000..4aa59d9196b --- /dev/null +++ b/paddle/fluid/framework/ir/mixed_precision_configure_pass.cc @@ -0,0 +1,149 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/mixed_precision_configure_pass.h" +#include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace framework { +namespace ir { + +void MixedPrecisionConfigurePass::InsertCastOps( + Graph* graph, const StringSet& blacklist) const { + VLOG(3) << "Insert the cast op before and after the kernel that does not " + "supports fp16 precision"; + + auto update_cast_desc = [&]( + framework::OpDesc& desc, const std::string& x_name, + const std::string& out_name, const int in_dtype, const int out_dtype) { + desc.SetType("cast"); + desc.SetInput("X", {x_name}); + desc.SetOutput("Out", {out_name}); + desc.SetAttr("in_dtype", in_dtype); + desc.SetAttr("out_dtype", out_dtype); + desc.SetAttr("use_mkldnn", false); + desc.SetAttr("with_quant_attr", false); + desc.Flush(); + }; + + auto cast_input = [&](Graph* graph, Node* op_node, + const StringSet& cast_list) { + auto inlinks = op_node->inputs; + for (auto* pre_node : inlinks) { + if (pre_node->IsVar()) { + const auto is_persistable = pre_node->Var()->Persistable(); + const auto is_float = + pre_node->Var()->GetDataType() == proto::VarType::FP16 || + pre_node->Var()->GetDataType() == proto::VarType::FP32 || + pre_node->Var()->GetDataType() == proto::VarType::FP64; + if (!is_persistable && is_float) { + int suffix = 0; + for (auto* pre_node_input : pre_node->inputs) { + if (!pre_node_input->IsOp()) continue; + const auto& type = pre_node_input->Op()->Type(); + if (!cast_list.count(type) && type != "cast") { + std::string old_name = pre_node->Name(); + std::string new_name = + old_name + "_cast.tmp_" + std::to_string(suffix); + suffix++; + + framework::OpDesc new_op_desc(op_node->Op()->Block()); + // 4 for fp16, 5 for fp32 + update_cast_desc(new_op_desc, old_name, new_name, 4, 5); + auto* new_op = graph->CreateOpNode(&new_op_desc); + + VarDesc out_var(new_name); + out_var.SetPersistable(false); + auto* node_var = graph->CreateVarNode(&out_var); + + op_node->Op()->RenameInput(old_name, new_name); + IR_NODE_LINK_TO(pre_node, new_op); + IR_NODE_LINK_TO(new_op, node_var); + IR_NODE_LINK_TO(node_var, op_node); + } + } + } + } + } + }; + + auto cast_output = [&](Graph* graph, Node* op_node, + const StringSet& cast_list) { + auto outlinks = op_node->outputs; + for (auto* next_node : outlinks) { + if (next_node->IsVar()) { + const auto is_persistable = next_node->Var()->Persistable(); + const auto is_float = + next_node->Var()->GetDataType() == proto::VarType::FP16 || + next_node->Var()->GetDataType() == proto::VarType::FP32 || + next_node->Var()->GetDataType() == proto::VarType::FP64; + if (!is_persistable && is_float) { + int suffix = 0; + for (auto* next_node_output : next_node->outputs) { + if (!next_node_output->IsOp()) continue; + + const auto& type = next_node_output->Op()->Type(); + if (!cast_list.count(type) && type != "cast") { + std::string old_name = next_node->Name(); + std::string new_name = + old_name + "_cast.tmp_" + std::to_string(suffix); + suffix++; + + framework::OpDesc new_op_desc(op_node->Op()->Block()); + // 4 for fp16, 5 for fp32 + update_cast_desc(new_op_desc, old_name, new_name, 5, 4); + auto* new_op = graph->CreateOpNode(&new_op_desc); + + VarDesc out_var(new_name); + out_var.SetPersistable(false); + auto* node_var = graph->CreateVarNode(&out_var); + + next_node_output->Op()->RenameInput(old_name, new_name); + IR_NODE_LINK_TO(next_node, new_op); + IR_NODE_LINK_TO(new_op, node_var); + IR_NODE_LINK_TO(node_var, next_node_output); + } + } + } + } + } + }; + + for (auto* op_node : + ir::TopologyVarientSort(*graph, static_cast(0))) { + if (!op_node->IsOp() || op_node->Op()->Type() == "feed" || + op_node->Op()->Type() == "fetch") + continue; + + const auto& type = op_node->Op()->Type(); + if (blacklist.count(type)) { + cast_input(graph, op_node, blacklist); + cast_output(graph, op_node, blacklist); + } + } +} + +void MixedPrecisionConfigurePass::ApplyImpl(Graph* graph) const { + const auto blacklist = + Get>("gpu_fp16_disabled_op_types"); + InsertCastOps(graph, blacklist); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(mixed_precision_configure_pass, + paddle::framework::ir::MixedPrecisionConfigurePass); diff --git a/paddle/fluid/framework/ir/mixed_precision_configure_pass.h b/paddle/fluid/framework/ir/mixed_precision_configure_pass.h new file mode 100644 index 00000000000..fc5a612ecb8 --- /dev/null +++ b/paddle/fluid/framework/ir/mixed_precision_configure_pass.h @@ -0,0 +1,39 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { + +using StringSet = std::unordered_set; + +class MixedPrecisionConfigurePass : public FusePassBase { + public: + MixedPrecisionConfigurePass() = default; + virtual ~MixedPrecisionConfigurePass() {} + + protected: + void ApplyImpl(Graph* graph) const override; + + private: + void InsertCastOps(Graph* graph, const StringSet& blacklist) const; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index a5c32164bf1..74e8ca3f229 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -188,6 +188,9 @@ struct Argument { DECL_ARGUMENT_FIELD(use_gpu, UseGPU, bool); DECL_ARGUMENT_FIELD(use_fc_padding, UseFcPadding, bool); DECL_ARGUMENT_FIELD(gpu_device_id, GPUDeviceId, int); + DECL_ARGUMENT_FIELD(use_gpu_fp16, UseGPUFp16, bool); + DECL_ARGUMENT_FIELD(gpu_fp16_disabled_op_types, GpuFp16DisabledOpTypes, + std::unordered_set); // Usually use for trt dynamic shape. // TRT will select the best kernel according to opt shape diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index 796c86a3ad1..287c896e49b 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -189,6 +189,10 @@ void IRPassManager::CreatePasses(Argument *argument, new int(argument->dlnne_min_subgraph_size())); pass->Set("program", new framework::ProgramDesc *(&argument->main_program())); + } else if (pass_name == "mixed_precision_configure_pass") { + pass->Set("gpu_fp16_disabled_op_types", + new std::unordered_set( + argument->gpu_fp16_disabled_op_types())); } if (pass_name == "lite_subgraph_pass") { bool lite_enable_int8 = diff --git a/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc b/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc index daa18d8c78b..614eea24a0e 100644 --- a/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc +++ b/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.h" #include "paddle/fluid/framework/data_layout.h" +#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/platform/enforce.h" @@ -65,6 +66,26 @@ void IrParamsSyncAmongDevicesPass::CopyParamsToNpu(Argument *argument) { #else +void IrParamsSyncAmongDevicesPass::GetVarNameToOpTypeMap( + const framework::ir::Graph &graph, + std::unordered_map *var_name_op_type_map) { + std::vector node_list = + framework::ir::TopologyVarientSort( + graph, static_cast(0)); + for (auto *op_node : node_list) { + if (!op_node->IsOp() || op_node->Op()->Type() == "feed" || + op_node->Op()->Type() == "fetch") + continue; + + for (auto *pre_node : op_node->inputs) { + if (pre_node->IsVar() && pre_node->Var()->Persistable()) { + var_name_op_type_map->insert(std::pair( + pre_node->Var()->Name(), op_node->Op()->Type())); + } + } + } +} + void IrParamsSyncAmongDevicesPass::CopyParamsToGpu(Argument *argument) { // The parameters are on the cpu, therefore, synchronization is not necessary. if (!argument->use_gpu()) return; @@ -102,6 +123,16 @@ void IrParamsSyncAmongDevicesPass::CopyParamsToGpu(Argument *argument) { if (with_dynamic_shape) { reserve_cpu_weights = true; } + + bool mixed_precision_mode = + argument->Has("use_gpu_fp16") && argument->use_gpu_fp16(); + std::unordered_map var_name_op_type_map{}; + std::unordered_set blacklist{}; + if (mixed_precision_mode) { + GetVarNameToOpTypeMap(graph, &var_name_op_type_map); + blacklist = argument->gpu_fp16_disabled_op_types(); + } + for (auto &var_name : all_vars) { if (std::count(repetitive_params.begin(), repetitive_params.end(), var_name)) { @@ -117,18 +148,29 @@ void IrParamsSyncAmongDevicesPass::CopyParamsToGpu(Argument *argument) { var->IsType()) { auto *t = var->GetMutable(); - platform::CPUPlace cpu_place; - framework::LoDTensor temp_tensor; - temp_tensor.Resize(t->dims()); - temp_tensor.mutable_data(cpu_place); - - // Copy the parameter data to a tmp tensor. - paddle::framework::TensorCopySync(*t, cpu_place, &temp_tensor); - // Reallocation the space on GPU - t->clear(); - - // Copy parameter data to newly allocated GPU space. - paddle::framework::TensorCopySync(temp_tensor, place, t); + bool is_float = t->dtype() == paddle::experimental::DataType::FLOAT32 || + t->dtype() == paddle::experimental::DataType::FLOAT64; + if (mixed_precision_mode && + !blacklist.count(var_name_op_type_map[var_name]) && is_float) { + framework::Tensor half_tensor; + half_tensor.set_type(paddle::experimental::DataType::FLOAT16); + half_tensor.Resize(t->dims()); + auto *half_data = + half_tensor.mutable_data(platform::CPUPlace()); + for (int i = 0; i < t->numel(); i++) { + auto *data = t->mutable_data(platform::CPUPlace()); + half_data[i] = static_cast(data[i]); + } + t->clear(); + paddle::framework::TensorCopySync(half_tensor, place, t); + } else { + platform::CPUPlace cpu_place; + framework::LoDTensor temp_tensor; + temp_tensor.Resize(t->dims()); + paddle::framework::TensorCopySync(*t, cpu_place, &temp_tensor); + t->clear(); + paddle::framework::TensorCopySync(temp_tensor, place, t); + } } } } diff --git a/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.h b/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.h index d5e98ec886e..f8209f051d5 100644 --- a/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.h +++ b/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.h @@ -38,7 +38,12 @@ class IrParamsSyncAmongDevicesPass : public AnalysisPass { #ifdef PADDLE_WITH_ASCEND_CL void CopyParamsToNpu(Argument *argument); #else - void CopyParamsToGpu(Argument *argument); + + void GetVarNameToOpTypeMap( + const framework::ir::Graph& graph, + std::unordered_map* var_name_op_type_map); + + void CopyParamsToGpu(Argument* argument); #endif }; diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index 41c01d3b7e2..d08d28a3f62 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -83,6 +83,7 @@ void AnalysisConfig::SetModel(const std::string &prog_file_path, Update(); } + void AnalysisConfig::EnableUseGpu(uint64_t memory_pool_init_size_mb, int device_id) { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) @@ -97,12 +98,26 @@ void AnalysisConfig::EnableUseGpu(uint64_t memory_pool_init_size_mb, Update(); } + void AnalysisConfig::DisableGpu() { use_gpu_ = false; Update(); } +void AnalysisConfig::Exp_EnableUseGpuFp16( + std::unordered_set op_list) { +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + use_gpu_fp16_ = true; + gpu_fp16_disabled_op_types_.insert(op_list.begin(), op_list.end()); +#else + LOG(ERROR) << "Please compile with gpu to Exp_EnableUseGpuFp16()"; + use_gpu_fp16_ = false; +#endif + + Update(); +} + void AnalysisConfig::DisableFCPadding() { use_fc_padding_ = false; @@ -213,6 +228,8 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { CP_MEMBER(use_cudnn_); CP_MEMBER(gpu_device_id_); CP_MEMBER(memory_pool_init_size_mb_); + CP_MEMBER(use_gpu_fp16_); + CP_MEMBER(gpu_fp16_disabled_op_types_); CP_MEMBER(enable_memory_optim_); // TensorRT related. @@ -573,6 +590,20 @@ void AnalysisConfig::Update() { #endif } + if (use_gpu_fp16_) { +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + if (!enable_ir_optim_) { + LOG(ERROR) << "Exp_EnableUseGpuFp16() only works when IR optimization is " + "enabled."; + } else if (!use_gpu()) { + LOG(ERROR) + << "Exp_EnableUseGpuFp16() only works when use_gpu is enabled."; + } else { + pass_builder()->Exp_EnableUseGpuFp16(); + } +#endif + } + if (use_mkldnn_) { #ifdef PADDLE_WITH_MKLDNN if (!enable_ir_optim_) { @@ -669,6 +700,8 @@ std::string AnalysisConfig::SerializeInfoCache() { ss << params_file_; ss << use_gpu_; + ss << use_gpu_fp16_; + for (auto &item : gpu_fp16_disabled_op_types_) ss << item; ss << use_fc_padding_; ss << gpu_device_id_; ss << xpu_device_id_; diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 871ed596a3e..6f765ef415e 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -872,6 +872,11 @@ void AnalysisPredictor::PrepareArgument() { argument_.SetDlnneMinSubgraphSize(config_.dlnne_min_subgraph_size_); } + if (config_.gpu_fp16_enabled()) { + argument_.SetUseGPUFp16(true); + argument_.SetGpuFp16DisabledOpTypes(config_.gpu_fp16_disabled_op_types_); + } + if (config_.lite_engine_enabled()) { argument_.SetCpuMathLibraryNumThreads( config_.cpu_math_library_num_threads()); diff --git a/paddle/fluid/inference/api/analysis_predictor_tester.cc b/paddle/fluid/inference/api/analysis_predictor_tester.cc index 2c6e8f4f1a4..ecb5eaf9825 100644 --- a/paddle/fluid/inference/api/analysis_predictor_tester.cc +++ b/paddle/fluid/inference/api/analysis_predictor_tester.cc @@ -375,6 +375,19 @@ TEST(AnalysisPredictor, enable_onnxruntime) { ASSERT_TRUE(!config.use_onnxruntime()); } +TEST(AnalysisPredictor, exp_enable_use_gpu_fp16) { + AnalysisConfig config; + config.SwitchIrOptim(); +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + config.EnableUseGpu(100, 0); + config.Exp_EnableUseGpuFp16(); + ASSERT_TRUE(config.gpu_fp16_enabled()); +#else + config.DisableGpu(); +#endif + LOG(INFO) << config.Summary(); +} + } // namespace paddle namespace paddle_infer { @@ -434,6 +447,19 @@ TEST(Predictor, EnableONNXRuntime) { auto predictor = CreatePredictor(config); } +TEST(Predictor, Exp_EnableUseGpuFp16) { + Config config; + config.SetModel(FLAGS_dirname); + config.SwitchIrOptim(); +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + config.EnableUseGpu(100, 0); + config.Exp_EnableUseGpuFp16(); +#else + config.DisableGpu(); +#endif + auto predictor = CreatePredictor(config); +} + TEST(Tensor, CpuShareExternalData) { Config config; config.SetModel(FLAGS_dirname); diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index 7b765e3fa8a..bdfe0e46e9c 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -253,6 +253,19 @@ struct PD_INFER_DECL AnalysisConfig { /// /// void DisableGpu(); + /// + /// \brief Enable GPU fp16 precision computation, in experimental state. + /// + /// \param op_list The operator type list. + /// + void Exp_EnableUseGpuFp16(std::unordered_set op_list = {}); + /// + /// \brief A boolean state telling whether the GPU fp16 precision is turned + /// on. + /// + /// \return bool Whether the GPU fp16 precision is turned on. + /// + bool gpu_fp16_enabled() const { return use_gpu_fp16_; } /// /// \brief Turn on XPU. @@ -859,6 +872,9 @@ struct PD_INFER_DECL AnalysisConfig { int gpu_device_id_{0}; uint64_t memory_pool_init_size_mb_{100}; // initial size is 100MB. bool thread_local_stream_{false}; + bool use_gpu_fp16_{false}; + std::unordered_set gpu_fp16_disabled_op_types_{ + "conv2d_fusion", "conv2d", "roll", "strided_slice"}; bool use_cudnn_{false}; diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 22d9dedb32e..95975d8f2a8 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -172,6 +172,40 @@ void GpuPassStrategy::EnableCUDNN() { use_cudnn_ = true; } +void GpuPassStrategy::Exp_EnableUseGpuFp16() { + passes_.assign({ + "is_test_pass", // + "simplify_with_basic_ops_pass", // + "conv_bn_fuse_pass", // + "conv_eltwiseadd_bn_fuse_pass", // + "embedding_eltwise_layernorm_fuse_pass", // + "multihead_matmul_fuse_pass_v2", // + "gpu_cpu_squeeze2_matmul_fuse_pass", // + "gpu_cpu_reshape2_matmul_fuse_pass", // + "gpu_cpu_flatten2_matmul_fuse_pass", // + "gpu_cpu_map_matmul_v2_to_mul_pass", // + "gpu_cpu_map_matmul_v2_to_matmul_pass", // + "gpu_cpu_map_matmul_to_mul_pass", // + // "fc_fuse_pass", // + "fc_elementwise_layernorm_fuse_pass", // +#if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be + // guaranteed at least v7 +// cudnn8.0 has memory leak problem in conv + eltwise + act, so we +// disable the pass. +#if !(CUDNN_VERSION >= 8000 && CUDNN_VERSION < 8100) + "conv_elementwise_add_act_fuse_pass", // + "conv_elementwise_add2_act_fuse_pass", // +#endif + "conv_elementwise_add_fuse_pass", // +#endif // + "transpose_flatten_concat_fuse_pass", // + "mixed_precision_configure_pass", // + "runtime_context_cache_pass" // + }); + + use_gpu_fp16_ = true; +} + void GpuPassStrategy::EnableMKLDNN() { LOG(ERROR) << "GPU not support MKLDNN yet"; } diff --git a/paddle/fluid/inference/api/paddle_pass_builder.h b/paddle/fluid/inference/api/paddle_pass_builder.h index 351cf71e5ca..02290ed33ff 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.h +++ b/paddle/fluid/inference/api/paddle_pass_builder.h @@ -125,6 +125,9 @@ class PD_INFER_DECL PassStrategy : public PaddlePassBuilder { /// \brief Enable the use of cuDNN kernel. virtual void EnableCUDNN() {} + /// \brief Enable use gpu fp16 kernel. + virtual void Exp_EnableUseGpuFp16() {} + /// \brief Enable the use of MKLDNN. /// The MKLDNN control exists in both CPU and GPU mode, because there can /// still be some CPU kernels running in GPU mode. @@ -140,6 +143,10 @@ class PD_INFER_DECL PassStrategy : public PaddlePassBuilder { /// \return A bool variable implying whether we are in gpu mode. bool use_gpu() const { return use_gpu_; } + /// \brief Check if we are using gpu fp16 kernel. + /// \return A bool variable implying whether we are in gpu fp16 mode. + bool use_gpu_fp16() const { return use_gpu_fp16_; } + /// \brief Check if we are using xpu. /// \return A bool variable implying whether we are in xpu mode. bool use_xpu() const { return use_xpu_; } @@ -162,6 +169,7 @@ class PD_INFER_DECL PassStrategy : public PaddlePassBuilder { bool use_npu_{false}; bool use_ipu_{false}; bool use_mkldnn_{false}; + bool use_gpu_fp16_{false}; /// \endcond }; @@ -223,6 +231,9 @@ class PD_INFER_DECL GpuPassStrategy : public PassStrategy { /// \brief Enable the use of cuDNN kernel. void EnableCUDNN() override; + /// \brief Enable the use of gpu fp16 kernel. + void Exp_EnableUseGpuFp16() override; + /// \brief Not supported in GPU mode yet. void EnableMKLDNN() override; @@ -238,6 +249,7 @@ class PD_INFER_DECL GpuPassStrategy : public PassStrategy { protected: /// \cond Protected bool use_cudnn_{false}; + bool use_gpu_fp16_{false}; /// \endcond }; diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index b008308e27d..c8f0acd0b8a 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -551,6 +551,9 @@ void BindAnalysisConfig(py::module *m) { .def("params_file", &AnalysisConfig::params_file) .def("enable_use_gpu", &AnalysisConfig::EnableUseGpu, py::arg("memory_pool_init_size_mb"), py::arg("device_id") = 0) + .def("exp_enable_use_gpu_fp16", &AnalysisConfig::Exp_EnableUseGpuFp16, + py::arg("gpu_fp16_disabled_op_types") = + std::unordered_set({})) .def("enable_xpu", &AnalysisConfig::EnableXpu, py::arg("l3_workspace_size") = 16 * 1024 * 1024, py::arg("locked") = false, py::arg("autotune") = true, -- GitLab