diff --git a/paddle/fluid/framework/ir/fc_fuse_pass.cc b/paddle/fluid/framework/ir/fc_fuse_pass.cc index 7acbe227b3296eb4add60abb94c6998294cdd298..cdc6520ad6f4454fd06779f55e8d07bb3ca2b032 100644 --- a/paddle/fluid/framework/ir/fc_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_fuse_pass.cc @@ -90,17 +90,17 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const { desc.SetAttr("activation_type", activation_type); // This is to add padding for dimension 128 on concern of MKL performance - auto* scope = param_scope(); - auto* weight = scope->FindVar(w->Name())->GetMutable(); - auto* weight_data = weight->data(); - auto weight_dims = weight->dims(); - int weight_num = product(weight_dims); - int w_h = weight_dims[0]; - int w_w = weight_dims[1]; bool use_gpu = Has("use_gpu") ? Get("use_gpu") : false; bool use_fc_padding = Has("use_fc_padding") ? Get("use_fc_padding") : true; if (!use_gpu && use_fc_padding) { + auto* scope = param_scope(); + auto* weight = scope->FindVar(w->Name())->GetMutable(); + auto* weight_data = weight->data(); + auto weight_dims = weight->dims(); + int weight_num = product(weight_dims); + int w_h = weight_dims[0]; + int w_w = weight_dims[1]; if (w_h % 128 == 0 && w_w % 128 == 0) { auto* weight_data_tmp = new float[weight_num]; for (int i = 0; i < w_h; i++) { diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index 630e375dd7e8e85767e0146cdceba08a3ab3724e..1cf72f7001b2a56eb613340f8d3d71c1bbec03a6 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -172,6 +172,7 @@ struct Argument { // Passed from config. DECL_ARGUMENT_FIELD(use_gpu, UseGPU, bool); + DECL_ARGUMENT_FIELD(use_fc_padding, UseFcPadding, bool); DECL_ARGUMENT_FIELD(gpu_device_id, GPUDeviceId, int); DECL_ARGUMENT_FIELD(use_tensorrt, UseTensorRT, bool); DECL_ARGUMENT_FIELD(tensorrt_max_batch_size, TensorRtMaxBatchSize, int); diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index 65189b89375da755be5279e27e5dc71d27370c41..a4a2fdb2b687ff6c354a6fdae2221e5947766e6d 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -142,6 +142,14 @@ void IRPassManager::CreatePasses(Argument *argument, disable_logs_ = argument->disable_logs(); if (pass_name == "fc_fuse_pass") { pass->Set("use_gpu", new bool(argument->use_gpu())); + bool fc_mkldnn_pass = 0; + for (const std::string &pass_n : passes) { + if (pass_n == "fc_mkldnn_pass") { + fc_mkldnn_pass = 1; + } + } + bool use_fc_padding = !fc_mkldnn_pass && argument->use_fc_padding(); + pass->Set("use_fc_padding", new bool(use_fc_padding)); } pre_pass = pass_name; @@ -150,47 +158,12 @@ void IRPassManager::CreatePasses(Argument *argument, } } -bool IRPassManager::HasPass(const std::string &pass_type) { - if (passes_.empty()) return false; - auto it = std::find_if( - passes_.begin(), passes_.end(), - [&](std::unique_ptr &pass) { return pass->Type() == pass_type; }); - return it != passes_.end(); -} - -std::unique_ptr &IRPassManager::GetPass(const std::string &pass_type) { - PADDLE_ENFORCE_EQ(passes_.empty(), false, - platform::errors::PreconditionNotMet( - "The list of passes cannot be empty.")); - auto it = std::find_if(passes_.begin(), passes_.end(), - [&](const std::unique_ptr &pass) { - return pass->Type() == pass_type; - }); - PADDLE_ENFORCE_NE(it, passes_.end(), - platform::errors::PermissionDenied( - "You cannot get pass which was not added earlier.")); - return *it; -} - -// Some passes depend on each other. This method serves for exchanging -// information between them. -void IRPassManager::UpdatePasses() { - // Update padding settings for fc_fuse_pass. Skipp adding padding for - // MKL-DNN-based FC - bool use_fc_padding = !HasPass("fc_mkldnn_pass"); - if (HasPass("fc_fuse_pass")) { - auto &fc_fuse_pass = GetPass("fc_fuse_pass"); - fc_fuse_pass->Set("use_fc_padding", new bool(use_fc_padding)); - } -} - std::unique_ptr IRPassManager::Apply(std::unique_ptr graph) { if (passes_.empty()) { return graph; } PADDLE_ENFORCE_NOT_NULL(graph.get(), platform::errors::PreconditionNotMet( "Graph cannot be NULL.")); - UpdatePasses(); // Apply all the passes for (const auto &pass : passes_) { if (pass->Type() != "graph_viz_pass" && !disable_logs_) { diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.h b/paddle/fluid/inference/analysis/ir_pass_manager.h index 2366a0eaf094b68885399a061fdabf8f5eaba211..823dc8907ea532b510e8d643c361b98c863404e3 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.h +++ b/paddle/fluid/inference/analysis/ir_pass_manager.h @@ -54,9 +54,6 @@ class IRPassManager final { private: void CreatePasses(Argument *argument, const std::vector &passes); - bool HasPass(const std::string &pass_type); - std::unique_ptr &GetPass(const std::string &pass_type); - void UpdatePasses(); std::unique_ptr graph_; std::vector> passes_; diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index 3755d088bfe97427eaf6b120eacae03937cff7b9..75a05fa309fcb29cbf0d89294100366471a724c7 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -82,6 +82,12 @@ void AnalysisConfig::DisableGpu() { Update(); } +void AnalysisConfig::DisableFCPadding() { + use_fc_padding_ = false; + + Update(); +} + AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { #define CP_MEMBER(member__) member__ = other.member__; @@ -94,6 +100,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { prog_file_ = std::move(other.prog_file_); params_file_ = std::move(other.params_file_); + CP_MEMBER(use_fc_padding_); // GPU related. CP_MEMBER(use_gpu_); CP_MEMBER(use_cudnn_); @@ -354,6 +361,7 @@ std::string AnalysisConfig::SerializeInfoCache() { ss << params_file_; ss << use_gpu_; + ss << use_fc_padding_; ss << device_id_; ss << memory_pool_init_size_mb_; diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index cd12a8985ba9b467b0d649a74dd351b2420944c4..107e5ae7d818d24d664e4a02530b718f9460364e 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -381,6 +381,7 @@ bool AnalysisPredictor::GetFetch(std::vector *outputs, void AnalysisPredictor::PrepareArgument() { argument_.SetUseGPU(config_.use_gpu()); + argument_.SetUseFcPadding(config_.use_fc_padding()); argument_.SetGPUDeviceId(config_.gpu_device_id()); argument_.SetEnableAnalysisOptim(config_.enable_ir_optim_); argument_.SetEnableMemoryOptim(config_.enable_memory_optim()); diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index 292615a019ad648c3fcf5446453b99552451603f..260ec6562aa3521a751782c065c6c15ada774fa7 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -77,6 +77,14 @@ struct AnalysisConfig { */ const std::string& params_file() const { return params_file_; } + // Padding related. + /** Turn off Padding. + */ + void DisableFCPadding(); + /** A bool state telling whether padding is turned on. + */ + bool use_fc_padding() const { return use_fc_padding_; } + // GPU related. /** @@ -293,6 +301,9 @@ struct AnalysisConfig { bool use_cudnn_{false}; + // Padding related + bool use_fc_padding_{true}; + // TensorRT related. bool use_tensorrt_{false}; // For workspace_size, refer it from here: diff --git a/paddle/fluid/inference/tests/api/analyzer_bert_tester.cc b/paddle/fluid/inference/tests/api/analyzer_bert_tester.cc index 5035f9b358718c4b3da445f82863c5d66e2dfbe6..7232dbbe57069a7e8d780047aca090d8ef970505 100644 --- a/paddle/fluid/inference/tests/api/analyzer_bert_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_bert_tester.cc @@ -145,7 +145,10 @@ bool LoadInputData(std::vector> *inputs) { return true; } -void SetConfig(AnalysisConfig *config) { config->SetModel(FLAGS_infer_model); } +void SetConfig(AnalysisConfig *config) { + config->SetModel(FLAGS_infer_model); + config->DisableFCPadding(); +} void profile(bool use_mkldnn = false, bool use_ngraph = false) { AnalysisConfig config;