From ecff21e72bb7f66f2390cd72fa2e42423a5b6f18 Mon Sep 17 00:00:00 2001 From: Yuanle Liu Date: Fri, 25 Aug 2023 14:02:47 +0800 Subject: [PATCH] [Inference] auto mixed precision inference support white list (#56535) * auto mixed precision inference support white list * update * update * update * move down identity_op_clean_pass * fix code style --- .../framework/ir/auto_mixed_precision_pass.cc | 27 ++++-- .../framework/ir/auto_mixed_precision_pass.h | 4 +- .../framework/ir/identity_op_clean_pass.cc | 89 ++++++++++++++++++- .../framework/ir/identity_op_clean_pass.h | 4 + paddle/fluid/inference/analysis/argument.h | 3 + .../inference/analysis/ir_pass_manager.cc | 3 + .../ir_passes/tensorrt_subgraph_pass.cc | 16 +++- .../inference/analysis/passes/CMakeLists.txt | 2 +- .../passes/convert_to_mixed_precision.cc | 44 +++++---- .../passes/convert_to_mixed_precision.h | 10 ++- paddle/fluid/inference/api/analysis_config.cc | 7 ++ .../fluid/inference/api/analysis_predictor.cc | 7 +- .../inference/api/paddle_analysis_config.h | 9 ++ .../inference/api/paddle_inference_api.h | 3 +- .../inference/api/paddle_pass_builder.cc | 4 +- paddle/fluid/pybind/inference_api.cc | 5 +- python/paddle/inference/wrapper.py | 7 +- test/ir/inference/test_identity_clean_pass.py | 55 ++++++++++++ 18 files changed, 257 insertions(+), 42 deletions(-) diff --git a/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc b/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc index 1161a53e16e..1d6092331d0 100644 --- a/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc +++ b/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc @@ -165,7 +165,9 @@ void DoInsertCastOp(Graph* graph, bool OpSupportPrecision(const std::string& op_type, phi::Backend backend, phi::DataType precision, - const std::unordered_set& black_list) { + const std::unordered_set& black_list, + const std::unordered_set& white_list) { + if (white_list.count(op_type)) return true; return black_list.count(op_type) == 0 && KernelSupportPrecision(op_type, backend, precision); } @@ -230,11 +232,16 @@ void AutoMixedPrecisionPass::Init(Graph* graph) const { if (skip_pass_) return; black_list_ = Get>("mixed_black_list"); + white_list_ = Get>("mixed_white_list"); SetDefaultBlacklist(); VLOG(4) << "black_list has "; for (const auto& name : black_list_) { VLOG(4) << " - " << name; } + VLOG(4) << "white_list has "; + for (const auto& name : white_list_) { + VLOG(4) << " - " << name; + } if (Has("enable_low_precision_io")) { enable_low_precision_io_ = Get("enable_low_precision_io"); @@ -403,8 +410,11 @@ void AutoMixedPrecisionPass::GetOpPrecision() const { op_node->Op()->GetAttrIfExists("enable_low_precision_io"); support_low_precision = enable_fp16 && !enable_int8 && low_precision_io; } else { - support_low_precision = OpSupportPrecision( - GetOpOriginalType(op_type), backend_, low_precision_, black_list_); + support_low_precision = OpSupportPrecision(GetOpOriginalType(op_type), + backend_, + low_precision_, + black_list_, + white_list_); std::unordered_set check_dtype_op_blacklist( {"arg_max", "arg_min"}); @@ -422,8 +432,8 @@ void AutoMixedPrecisionPass::GetOpPrecision() const { out_dtype == -1); } - // If scale op's "scale" and "bias" attr value exceed the range of fp16 - // and bf16, it cannot run at low precision. + // If scale op's "scale" and "bias" attr value exceed the range of + // fp16 and bf16, it cannot run at low precision. if (GetOpOriginalType(op_node->Op()->Type()) == "scale") { auto scale = op_node->Op()->GetAttrIfExists("scale"); auto bias = op_node->Op()->GetAttrIfExists("bias"); @@ -500,9 +510,9 @@ void AutoMixedPrecisionPass::UpdateOpPrecision() const { << " is output of " << op_type; } - // the select_input op's input var should not convert to low precision. - // when op's output var is select_input op's input var, the op should - // not run at low precision. + // the select_input op's input var should not convert to low + // precision. when op's output var is select_input op's input var, the + // op should not run at low precision. if (GetOpOriginalType(op_node->Op()->Type()) == "select_input") { for (auto* in_var_node : op_node->inputs) { CHECK_EQ(in_var_node->IsVar(), true); @@ -517,6 +527,7 @@ void AutoMixedPrecisionPass::UpdateOpPrecision() const { // output var, then op_2 should not run at low precision. if (GetOpOriginalType(op_type) != "feed" && GetOpOriginalType(op_type) != "tensorrt_engine" && + white_list_.count(GetOpOriginalType(op_type)) == 0 && !KernelSupportPrecision( GetOpOriginalType(op_type), backend_, phi::DataType::FLOAT32)) { for (auto* out_var_node : op_node->outputs) { diff --git a/paddle/fluid/framework/ir/auto_mixed_precision_pass.h b/paddle/fluid/framework/ir/auto_mixed_precision_pass.h index 99ad62874c9..3a5c5f0f54f 100644 --- a/paddle/fluid/framework/ir/auto_mixed_precision_pass.h +++ b/paddle/fluid/framework/ir/auto_mixed_precision_pass.h @@ -75,6 +75,7 @@ class AutoMixedPrecisionPass : public FusePassBase { mutable phi::Backend backend_{phi::Backend::UNDEFINED}; mutable std::unordered_set black_list_; + mutable std::unordered_set white_list_; // subgraph id -> pointer to subgraph mutable std::vector subgraphes_; @@ -93,7 +94,8 @@ class AutoMixedPrecisionPass : public FusePassBase { bool OpSupportPrecision(const std::string& op_type, phi::Backend backend, phi::DataType precision, - const std::unordered_set& black_list); + const std::unordered_set& black_list, + const std::unordered_set& white_list); void DoInsertCastOp(Graph* graph, Node* var_node, diff --git a/paddle/fluid/framework/ir/identity_op_clean_pass.cc b/paddle/fluid/framework/ir/identity_op_clean_pass.cc index 0d6c444878d..8819d4a2766 100644 --- a/paddle/fluid/framework/ir/identity_op_clean_pass.cc +++ b/paddle/fluid/framework/ir/identity_op_clean_pass.cc @@ -89,11 +89,52 @@ FindUselessOpPattern::FindUselessOpPattern(PDPattern* pattern, useless_op->LinksFrom({useless_op_in}).LinksTo({useless_op_out}); } -} // namespace patterns +// pre_op -> pre_op_out -> cast_op_1 -> cast_op_1_out -> cast_op_2 -> +// cast_op_2_out +// -> +// pre_op -> cast_op_2_out +struct FindTwoCastOpPattern : public PatternBase { + FindTwoCastOpPattern(PDPattern* pattern, const std::string& name_scope); -void IdentityOpCleanPass::ApplyImpl(ir::Graph* graph) const { - Init(name_scope_, graph); + // declare operator node's name + PATTERN_DECL_NODE(pre_op_out); + PATTERN_DECL_NODE(cast_op_1); + PATTERN_DECL_NODE(cast_op_1_out); + PATTERN_DECL_NODE(cast_op_2); + PATTERN_DECL_NODE(cast_op_2_out); +}; + +FindTwoCastOpPattern::FindTwoCastOpPattern(PDPattern* pattern, + const std::string& name_scope) + : PatternBase(pattern, name_scope, name_scope) { + auto* pre_op_out = pattern->NewNode(pre_op_out_repr()) + ->assert_is_var() + ->assert_var_not_persistable() + ->assert_has_n_outputs(1) + ->assert_more([](Node* x) { + for (auto* op : x->inputs) { + CHECK_EQ(op->IsOp(), true); + const auto& op_type = op->Op()->Type(); + if (op_type == "conditional_block" || + op_type == "while" || op_type == "feed") { + return false; + } + } + return true; + }); + + auto* cast_op_1 = pattern->NewNode(cast_op_1_repr())->assert_is_op("cast"); + auto* cast_op_1_out = pattern->NewNode(cast_op_1_out_repr())->assert_is_var(); + auto* cast_op_2 = pattern->NewNode(cast_op_2_repr())->assert_is_op("cast"); + auto* cast_op_2_out = pattern->NewNode(cast_op_2_out_repr())->assert_is_var(); + + cast_op_1->LinksFrom({pre_op_out}).LinksTo({cast_op_1_out}); + cast_op_2->LinksFrom({cast_op_1_out}).LinksTo({cast_op_2_out}); +} +} // namespace patterns + +int IdentityOpCleanPass::CleanUselessOp(ir::Graph* graph) const { GraphPatternDetector gpd; patterns::FindUselessOpPattern pattern(gpd.mutable_pattern(), name_scope_); @@ -119,6 +160,48 @@ void IdentityOpCleanPass::ApplyImpl(ir::Graph* graph) const { }; gpd(graph, handler); + return found_count; +} + +int IdentityOpCleanPass::CleanTwoCastOp(ir::Graph* graph) const { + GraphPatternDetector gpd; + patterns::FindTwoCastOpPattern pattern(gpd.mutable_pattern(), name_scope_); + + int found_count = 0; + GraphPatternDetector::handle_t handler = + [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { + GET_IR_NODE_FROM_SUBGRAPH(pre_op_out, pre_op_out, pattern); + GET_IR_NODE_FROM_SUBGRAPH(cast_op_1, cast_op_1, pattern); + GET_IR_NODE_FROM_SUBGRAPH(cast_op_1_out, cast_op_1_out, pattern); + GET_IR_NODE_FROM_SUBGRAPH(cast_op_2, cast_op_2, pattern); + GET_IR_NODE_FROM_SUBGRAPH(cast_op_2_out, cast_op_2_out, pattern); + CHECK_EQ(pre_op_out->IsVar(), true); + CHECK_EQ(cast_op_1_out->IsVar(), true); + CHECK_EQ(cast_op_2_out->IsVar(), true); + CHECK_EQ(cast_op_1->IsOp(), true); + CHECK_EQ(cast_op_2->IsOp(), true); + if (pre_op_out->Var()->GetDataType() == + cast_op_2_out->Var()->GetDataType()) { + for (auto* prev_op : pre_op_out->inputs) { + CHECK_EQ(prev_op->IsOp(), true); + prev_op->Op()->RenameOutput(pre_op_out->Var()->Name(), + cast_op_2_out->Var()->Name()); + IR_NODE_LINK_TO(prev_op, cast_op_2_out); + } + + GraphSafeRemoveNodes( + graph, {pre_op_out, cast_op_1, cast_op_1_out, cast_op_2}); + found_count++; + } + }; + + gpd(graph, handler); + return found_count; +} + +void IdentityOpCleanPass::ApplyImpl(ir::Graph* graph) const { + Init(name_scope_, graph); + int found_count = CleanUselessOp(graph) + CleanTwoCastOp(graph); AddStatis(found_count); } diff --git a/paddle/fluid/framework/ir/identity_op_clean_pass.h b/paddle/fluid/framework/ir/identity_op_clean_pass.h index 50a95dfac9e..ea6be220850 100644 --- a/paddle/fluid/framework/ir/identity_op_clean_pass.h +++ b/paddle/fluid/framework/ir/identity_op_clean_pass.h @@ -27,6 +27,10 @@ class IdentityOpCleanPass : public FusePassBase { void ApplyImpl(ir::Graph* graph) const override; private: + int CleanUselessOp(ir::Graph* graph) const; + + int CleanTwoCastOp(ir::Graph* graph) const; + const std::string name_scope_{"identity_op_clean_pass"}; }; diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index afe81f24842..b3757886e2f 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -419,6 +419,9 @@ struct Argument { DECL_ARGUMENT_FIELD(mixed_black_list, MixedBlackList, std::unordered_set); + DECL_ARGUMENT_FIELD(mixed_white_list, + MixedWhiteList, + std::unordered_set); DECL_ARGUMENT_FIELD(enable_gpu_mixed, EnableGPUMixed, bool); DECL_ARGUMENT_FIELD(mixed_precision_mode, MixedPrecisionMode, int); DECL_ARGUMENT_FIELD(enable_low_precision_io, EnableLowPrecisionIO, bool); diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index fca0e1eeabc..703ae0df544 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -100,6 +100,9 @@ void IRPassManager::CreatePasses(Argument *argument, pass->Set( "mixed_black_list", new std::unordered_set(argument->mixed_black_list())); + pass->Set( + "mixed_white_list", + new std::unordered_set(argument->mixed_white_list())); pass->Set("enable_gpu_mixed", new bool(argument->enable_gpu_mixed())); pass->Set("enable_custom_device_mixed", new bool(argument->enable_custom_device_mixed())); diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index 6b65ccc8b71..e65aff11180 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -50,7 +50,8 @@ void OutputProcess(framework::ir::Graph *graph, const std::unordered_set &trt_outputs, phi::Backend backend, phi::DataType precision, - const std::unordered_set &blacklist) { + const std::unordered_set &blacklist, + const std::unordered_set &whitelist) { framework::BlockDesc *block_desc{nullptr}; int suffix = 0; std::unordered_map @@ -86,7 +87,8 @@ void OutputProcess(framework::ir::Graph *graph, phi::TransToPhiKernelName(next_op->Op()->Type()), backend, precision, - blacklist)) { + blacklist, + whitelist)) { InsertCastOp(graph, var_node, next_op, @@ -363,6 +365,8 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp( static_cast(Get("model_precision")); auto mixed_black_list = Get>("mixed_black_list"); + auto mixed_white_list = + Get>("mixed_white_list"); std::set output_names; std::set output_names_with_id; @@ -414,8 +418,12 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp( static_cast(x->Var()->GetDataType()); } - OutputProcess( - graph, trt_outputs, phi::Backend::GPU, model_precision, mixed_black_list); + OutputProcess(graph, + trt_outputs, + phi::Backend::GPU, + model_precision, + mixed_black_list, + mixed_white_list); std::unordered_map output_name_map; std::unordered_map graph_var_map; diff --git a/paddle/fluid/inference/analysis/passes/CMakeLists.txt b/paddle/fluid/inference/analysis/passes/CMakeLists.txt index bc41a34db5e..0af6876faca 100644 --- a/paddle/fluid/inference/analysis/passes/CMakeLists.txt +++ b/paddle/fluid/inference/analysis/passes/CMakeLists.txt @@ -14,7 +14,7 @@ cc_library( convert_to_mixed_precision SRCS convert_to_mixed_precision.cc DEPS analysis_pass ir_graph_build_pass auto_mixed_precision_pass - constant_folding_pass) + constant_folding_pass identity_op_clean_pass) cc_library( ir_params_sync_among_devices_pass SRCS ir_params_sync_among_devices_pass.cc diff --git a/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc b/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc index 2c942f8aa46..24d18e0c3f9 100644 --- a/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc +++ b/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc @@ -18,6 +18,7 @@ #include "paddle/fluid/framework/ir/auto_mixed_precision_pass.h" #include "paddle/fluid/framework/ir/constant_folding_pass.h" #include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/ir/identity_op_clean_pass.h" #include "paddle/fluid/inference/io.h" #include "paddle/phi/common/backend.h" @@ -33,7 +34,8 @@ ConvertToMixedPrecisionPass::ConvertToMixedPrecisionPass( phi::DataType mixed_precision, phi::Backend backend, bool keep_io_types, - const std::unordered_set& black_list) + const std::unordered_set& black_list, + const std::unordered_set& white_list) : model_file_(model_file), params_file_(params_file), mixed_model_file_(mixed_model_file), @@ -41,7 +43,8 @@ ConvertToMixedPrecisionPass::ConvertToMixedPrecisionPass( mixed_precision_(mixed_precision), backend_(backend), keep_io_types_(keep_io_types), - black_list_(black_list) { + black_list_(black_list), + white_list_(white_list) { switch (backend_) { case phi::Backend::GPU: PADDLE_ENFORCE(mixed_precision_ == phi::DataType::FLOAT16 || @@ -88,19 +91,27 @@ void ConvertToMixedPrecisionPass::Run() { framework::ir::ConstantFoldingPass constant_folding_pass; constant_folding_pass.Apply(main_graph_.get()); - framework::ir::AutoMixedPrecisionPass pass; - pass.Set("mixed_precision_mode", new int{static_cast(mixed_precision_)}); + + framework::ir::AutoMixedPrecisionPass auto_mixed_precision_pass; + auto_mixed_precision_pass.Set("mixed_precision_mode", + new int{static_cast(mixed_precision_)}); if (backend_ == phi::Backend::GPU) { - pass.Set("enable_gpu_mixed", new bool{true}); + auto_mixed_precision_pass.Set("enable_gpu_mixed", new bool{true}); } else if (backend_ == phi::Backend::XPU) { - pass.Set("enable_xpu_mixed", new bool{true}); + auto_mixed_precision_pass.Set("enable_xpu_mixed", new bool{true}); } else if (backend_ == phi::Backend::CUSTOM) { - pass.Set("enable_custom_device_mixed", new bool{true}); + auto_mixed_precision_pass.Set("enable_custom_device_mixed", new bool{true}); } - pass.Set("mixed_black_list", - new std::unordered_set{black_list_}); - pass.Set("enable_low_precision_io", new bool{!keep_io_types_}); - pass.Apply(main_graph_.get()); + auto_mixed_precision_pass.Set( + "mixed_black_list", new std::unordered_set{black_list_}); + auto_mixed_precision_pass.Set( + "mixed_white_list", new std::unordered_set{white_list_}); + auto_mixed_precision_pass.Set("enable_low_precision_io", + new bool{!keep_io_types_}); + auto_mixed_precision_pass.Apply(main_graph_.get()); + + framework::ir::IdentityOpCleanPass identity_op_clean_pass; + identity_op_clean_pass.Apply(main_graph_.get()); SaveMixedModel(); } @@ -184,9 +195,10 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() { bool OpSupportPrecision(const std::string& op_type, phi::Backend backend, phi::DataType precision, - const std::unordered_set& black_list) { + const std::unordered_set& black_list, + const std::unordered_set& white_list) { return framework::ir::OpSupportPrecision( - op_type, backend, precision, black_list); + op_type, backend, precision, black_list, white_list); } void InsertCastOp( @@ -216,7 +228,8 @@ void ConvertToMixedPrecision( phi::DataType mixed_precision, phi::Backend backend, bool keep_io_types, - const std::unordered_set& black_list) { + const std::unordered_set& black_list, + const std::unordered_set& white_list) { ConvertToMixedPrecisionPass pass(model_file, params_file, mixed_model_file, @@ -224,7 +237,8 @@ void ConvertToMixedPrecision( mixed_precision, backend, keep_io_types, - black_list); + black_list, + white_list); pass.Run(); } diff --git a/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h b/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h index 3a1e5fbb30a..c1809531e7d 100644 --- a/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h +++ b/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h @@ -38,7 +38,8 @@ class ConvertToMixedPrecisionPass { phi::DataType mixed_precision, phi::Backend backend, bool keep_io_types, - const std::unordered_set& black_list); + const std::unordered_set& black_list, + const std::unordered_set& white_list); void Run(); @@ -55,6 +56,7 @@ class ConvertToMixedPrecisionPass { phi::Backend backend_; bool keep_io_types_; std::unordered_set black_list_; + std::unordered_set white_list_; framework::Scope scope_; std::unique_ptr main_graph_{nullptr}; @@ -63,7 +65,8 @@ class ConvertToMixedPrecisionPass { bool OpSupportPrecision(const std::string& op_type, phi::Backend backend, phi::DataType precision, - const std::unordered_set& black_list); + const std::unordered_set& black_list, + const std::unordered_set& white_list); void InsertCastOp( framework::ir::Graph* graph, @@ -82,7 +85,8 @@ void ConvertToMixedPrecision(const std::string& model_file, phi::DataType mixed_precision, phi::Backend backend, bool keep_io_types, - const std::unordered_set& black_list); + const std::unordered_set& black_list, + const std::unordered_set& white_list); } // namespace analysis } // namespace inference diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index 702dd610fea..3f9ca0a58ed 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -448,6 +448,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { // Mixed precision related. CP_MEMBER(mixed_black_list_); + CP_MEMBER(mixed_white_list_); CP_MEMBER(enable_gpu_mixed_); CP_MEMBER(mixed_precision_mode_); CP_MEMBER(enable_low_precision_io_); @@ -1154,6 +1155,7 @@ std::string AnalysisConfig::SerializeInfoCache() { for (auto attr : pattern) ss << attr; ss << ";"; for (auto &op : mixed_black_list_) ss << op.c_str(); + for (auto &op : mixed_white_list_) ss << op.c_str(); return ss.str(); } @@ -1535,6 +1537,11 @@ void AnalysisConfig::Exp_DisableMixedPrecisionOps( mixed_black_list_ = black_list; } +void AnalysisConfig::Exp_EnableMixedPrecisionOps( + const std::unordered_set &white_list) { + mixed_white_list_ = white_list; +} + void AnalysisConfig::Exp_EnableCINNCompiler() { #ifdef PADDLE_WITH_CINN use_cinn_compiler_ = true; diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 59374b2a392..e0a1343b934 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1616,6 +1616,7 @@ void AnalysisPredictor::PrepareArgument() { // mixed precison. argument_->SetModelPrecision(static_cast(model_precision_)); argument_->SetMixedBlackList(config_.mixed_black_list_); + argument_->SetMixedWhiteList(config_.mixed_white_list_); argument_->SetEnableGPUMixed(config_.enable_gpu_mixed_); argument_->SetMixedPrecisionMode(static_cast( paddle::ConvertPrecision(config_.mixed_precision_mode_))); @@ -3097,7 +3098,8 @@ void ConvertToMixedPrecision(const std::string &model_file, PrecisionType mixed_precision, paddle_infer::PlaceType backend, bool keep_io_types, - std::unordered_set black_list) { + std::unordered_set black_list, + std::unordered_set white_list) { auto phi_backend = paddle::ConvertBackend(backend); auto phi_precision = paddle::ConvertPrecision(mixed_precision); paddle::inference::analysis::ConvertToMixedPrecision(model_file, @@ -3107,7 +3109,8 @@ void ConvertToMixedPrecision(const std::string &model_file, phi_precision, phi_backend, keep_io_types, - black_list); + black_list, + white_list); } } // namespace paddle_infer diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index def9c4f22e7..f1d193d0640 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -1147,6 +1147,14 @@ struct PD_INFER_DECL AnalysisConfig { void Exp_DisableMixedPrecisionOps( const std::unordered_set& black_list); + /// + /// \brief Set a list of operators that do support mixed precision. This + /// interface is in the experimental stage and may change in the future. Note + /// that the whitelist must be the same as the model conversion whitelist. + /// + void Exp_EnableMixedPrecisionOps( + const std::unordered_set& white_list); + void SetApplyOptim(bool value) { apply_optim_ = value; } void SetSkipLoadParams(bool value) { skip_load_params_ = value; } @@ -1179,6 +1187,7 @@ struct PD_INFER_DECL AnalysisConfig { // Mixed precision related. Precision mixed_precision_mode_{Precision::kFloat32}; std::unordered_set mixed_black_list_; + std::unordered_set mixed_white_list_; bool enable_low_precision_io_{false}; // GPU related. diff --git a/paddle/fluid/inference/api/paddle_inference_api.h b/paddle/fluid/inference/api/paddle_inference_api.h index 857e93e6f60..18b8b6dfd43 100644 --- a/paddle/fluid/inference/api/paddle_inference_api.h +++ b/paddle/fluid/inference/api/paddle_inference_api.h @@ -245,7 +245,8 @@ PD_INFER_DECL void ConvertToMixedPrecision( PrecisionType mixed_precision, PlaceType backend, bool keep_io_types = true, - std::unordered_set black_list = {}); + std::unordered_set black_list = {}, + std::unordered_set white_list = {}); namespace services { /// diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 09d1197d35b..31d044f8c0b 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -268,11 +268,11 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { "conv_elementwise_add_fuse_pass", // #endif // "transpose_flatten_concat_fuse_pass", // - "identity_op_clean_pass", // "conv2d_fusion_layout_transfer_pass", // "transfer_layout_elim_pass", "auto_mixed_precision_pass", // - "inplace_op_var_pass", // should be the last pass. + "identity_op_clean_pass", // should be after auto_mixed_precision_pass. + "inplace_op_var_pass", // should be the last pass. }); use_gpu_ = true; diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index e0a10df2bcc..16131ad12ea 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -544,7 +544,8 @@ void BindInferenceApi(py::module *m) { py::arg("mixed_precision"), py::arg("backend"), py::arg("keep_io_types") = true, - py::arg("black_list") = std::unordered_set()); + py::arg("black_list") = std::unordered_set(), + py::arg("white_list") = std::unordered_set()); } namespace { @@ -777,6 +778,8 @@ void BindAnalysisConfig(py::module *m) { .def("exp_enable_use_cutlass", &AnalysisConfig::Exp_EnableUseCutlass) .def("exp_disable_mixed_precision_ops", &AnalysisConfig::Exp_DisableMixedPrecisionOps) + .def("exp_enable_mixed_precision_ops", + &AnalysisConfig::Exp_EnableMixedPrecisionOps) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) .def("set_exec_stream", [](AnalysisConfig &self, phi::CUDAStream &stream) { diff --git a/python/paddle/inference/wrapper.py b/python/paddle/inference/wrapper.py index 7f922f3ee0d..7163bdee25e 100644 --- a/python/paddle/inference/wrapper.py +++ b/python/paddle/inference/wrapper.py @@ -78,7 +78,8 @@ def convert_to_mixed_precision( mixed_precision: PrecisionType, backend: PlaceType, keep_io_types: bool = True, - black_list: Set = set(), + black_list: Set[str] = set(), + **kwargs, ): ''' Convert a fp32 model to mixed precision model. @@ -92,6 +93,8 @@ def convert_to_mixed_precision( backend: The backend, e.g. PlaceType.GPU. keep_io_types: Whether the model input and output dtype remains unchanged. black_list: Operators that do not convert precision. + kwargs: Supported keys including 'white_list'. + - white_list: Operators that do convert precision. ''' mixed_model_dirname = os.path.dirname(mixed_model_file) # Support mixed_params_file is empty, because some models don't have params, but convert_to_mixed_precision will call @@ -104,6 +107,7 @@ def convert_to_mixed_precision( ) if not os.path.exists(mixed_params_dirname): os.makedirs(mixed_params_dirname) + white_list = kwargs.get('white_list', set()) convert_to_mixed_precision_bind( model_file, params_file, @@ -113,6 +117,7 @@ def convert_to_mixed_precision( backend, keep_io_types, black_list, + white_list, ) diff --git a/test/ir/inference/test_identity_clean_pass.py b/test/ir/inference/test_identity_clean_pass.py index dbab616410a..d484c2ced7f 100644 --- a/test/ir/inference/test_identity_clean_pass.py +++ b/test/ir/inference/test_identity_clean_pass.py @@ -155,5 +155,60 @@ class TestIdentityScaleCleanPass_V2(PassAutoScanTest): self.run_and_statis(max_examples=25, passes=["identity_op_clean_pass"]) +class TestIdentityCastCleanPass(PassAutoScanTest): + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_gpu=True) + yield config, ['relu', 'relu'], (1e-2, 1e-2) + + def sample_program_config(self, draw): + n = draw(st.integers(min_value=1, max_value=4)) + c = draw(st.integers(min_value=1, max_value=20)) + h = draw(st.integers(min_value=1, max_value=20)) + w = draw(st.integers(min_value=1, max_value=20)) + + relu_op_1 = OpConfig( + "relu", + inputs={"X": ["relu_op_1_in"]}, + outputs={"Out": ["relu_op_1_out"]}, + ) + cast_op_1 = OpConfig( + "cast", + inputs={"X": ["relu_op_1_out"]}, + outputs={"Out": ["cast_op_1_out"]}, + in_dtype=5, + out_dtype=5, + ) + relu_op_2 = OpConfig( + "relu", + inputs={"X": ["cast_op_1_out"]}, + outputs={"Out": ["relu_op_2_out"]}, + ) + cast_op_2 = OpConfig( + "cast", + inputs={"X": ["relu_op_2_out"]}, + outputs={"Out": ["cast_op_2_out"]}, + in_dtype=5, + out_dtype=4, + ) + cast_op_3 = OpConfig( + "cast", + inputs={"X": ["cast_op_2_out"]}, + outputs={"Out": ["cast_op_3_out"]}, + in_dtype=4, + out_dtype=5, + ) + + program_config = ProgramConfig( + ops=[relu_op_1, cast_op_1, relu_op_2, cast_op_2, cast_op_3], + weights={}, + inputs={"relu_op_1_in": TensorConfig(shape=[n, c, h, w])}, + outputs=["cast_op_3_out"], + ) + return program_config + + def test(self): + self.run_and_statis(max_examples=25, passes=["identity_op_clean_pass"]) + + if __name__ == "__main__": unittest.main() -- GitLab