diff --git a/paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.h b/paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.h index 3d4dc9e2b6ecccddea4d63e45710c80d55ef2772..c071d9aed20bd40f5c1076d2dc5d3098a4e65495 100644 --- a/paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.h @@ -14,12 +14,16 @@ limitations under the License. */ #pragma once +#include #include "paddle/fluid/framework/ir/pass.h" namespace paddle { namespace framework { namespace ir { +/* + * Specifies which operators should use MKLDNN. + */ class MKLDNNPlacementPass : public Pass { protected: std::unique_ptr ApplyImpl( diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index 321deccf86718aad013c106b5a783161f96cbcb9..997f3575f457b67d4df5000705724b46cd8b951d 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -131,6 +131,15 @@ struct Argument { // Pass a set of op types to enable its mkldnn kernel DECL_ARGUMENT_FIELD(mkldnn_enabled_op_types, MKLDNNEnabledOpTypes, std::unordered_set); + + // A set of op types to enable their quantized kernels + DECL_ARGUMENT_FIELD(quantize_enabled_op_types, QuantizeEnabledOpTypes, + std::unordered_set); + + // A set of op IDs to exclude from enabling their quantized kernels + DECL_ARGUMENT_FIELD(quantize_excluded_op_ids, QuantizeExcludedOpIds, + std::unordered_set); + // Scales for variables to be quantized DECL_ARGUMENT_FIELD(quant_var_scales, QuantVarScales, VarQuantScale); diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index 8fd86b2cc56c4af50e735be2d660ec3db23e1547..1556caa46412c8a2dacd44f2187666c6a1fda6bf 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/inference/analysis/ir_pass_manager.h" #include #include +#include #include #include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/graph.h" @@ -60,6 +61,13 @@ void IRPassManager::CreatePasses(Argument *argument, pass->Set("mkldnn_enabled_op_types", new std::unordered_set( argument->mkldnn_enabled_op_types())); + } else if (pass_name == "cpu_quantize_placement_pass") { + pass->Set("quantize_enabled_op_types", + new std::unordered_set( + argument->quantize_enabled_op_types())); + pass->Set( + "quantize_excluded_op_ids", + new std::unordered_set(argument->quantize_excluded_op_ids())); } else if (pass_name == "cpu_quantize_pass") { pass->Set("quant_var_scales", new VarQuantScale(argument->quant_var_scales())); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 22c527cfc117a5e6ababf264744745e41e0bf71a..d413a418c88241a15808474f753a3900e0a5293e 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -91,6 +91,10 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { use_gpu_ = true; } +void GpuPassStrategy::EnableQuantizer() { + LOG(ERROR) << "GPU not support quantization yet"; +} + void PaddlePassBuilder::AppendAnalysisPass(const std::string &pass) { analysis_passes_.push_back(pass); } diff --git a/paddle/fluid/inference/api/paddle_pass_builder.h b/paddle/fluid/inference/api/paddle_pass_builder.h index 2524d89fcd1322e105ad2217347aa2380448f2bc..84645fef018ce41ee2cba7ae25d2b0c13e49dfc0 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.h +++ b/paddle/fluid/inference/api/paddle_pass_builder.h @@ -84,6 +84,10 @@ class PassStrategy : public PaddlePassBuilder { */ virtual void EnableMKLDNN() {} + /** Enable quantize optimization + */ + virtual void EnableQuantizer() {} + bool use_gpu() const { return use_gpu_; } virtual ~PassStrategy() = default; @@ -124,6 +128,16 @@ class CpuPassStrategy : public PassStrategy { use_mkldnn_ = false; #endif } + + void EnableQuantizer() override { + if (!use_quantizer_) { + passes_.push_back("cpu_quantize_placement_pass"); + } + use_quantizer_ = true; + } + + protected: + bool use_quantizer_{false}; }; /** The GPU passes strategy, it is used in AnalysisPredictor with GPU mode. @@ -138,6 +152,7 @@ class GpuPassStrategy : public PassStrategy { } void EnableMKLDNN() override; + void EnableQuantizer() override; virtual ~GpuPassStrategy() = default; };