From 53f1e024d4dcffe76d5b8b7a996eb7ed2ac802cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=9F=B3=E6=99=93=E4=BC=9F?= <39303645+Shixiaowei02@users.noreply.github.com> Date: Sat, 2 Nov 2019 07:28:55 +0800 Subject: [PATCH] fix infer crashes caused by conv/pool upgrades, test=release/1.6 (#20969) * fix infer crashes caused by conv/pool upgrades, test=release/1.6 * fix bug, test=release/1.6 --- paddle/fluid/framework/op_compatible_info.cc | 45 +++++++++++++++++++ paddle/fluid/framework/op_compatible_info.h | 5 +++ .../fluid/inference/api/analysis_predictor.cc | 10 ++++- 3 files changed, 59 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/op_compatible_info.cc b/paddle/fluid/framework/op_compatible_info.cc index 934f682811..d0702081d1 100644 --- a/paddle/fluid/framework/op_compatible_info.cc +++ b/paddle/fluid/framework/op_compatible_info.cc @@ -215,5 +215,50 @@ bool OpCompatibleMap::ReadFromProto(const proto::OpCompatibleMap& desc) { return true; } +bool ProgOptimUnsupported(std::shared_ptr program) { + auto op_type_checker = [](const std::string& name) { + const std::vector op_types({ + "conv2d", "conv3d", "conv2d_transpose", "conv3d_transpose", + "depthwise_conv2d", "depthwise_conv2d_transpose", "pool2d", "pool3d", + }); + return std::find(op_types.begin(), op_types.end(), name) != op_types.end(); + }; + auto checker = [](const framework::OpDesc& op) { + if (op.HasAttr("paddings") && op.HasAttr("strides")) { + auto paddings = boost::get>(op.GetAttr("paddings")); + auto strides = boost::get>(op.GetAttr("strides")); + if (paddings.size() != strides.size()) { + VLOG(3) << "== paddings size is not equal to strides size."; + return true; + } + } + if (op.HasAttr("data_format")) { + auto data_format = boost::get(op.GetAttr("data_format")); + if (data_format == "NHWC" || data_format == "NDHWC") { + VLOG(3) << "== data_format is NHWC or NDHWC."; + return true; + } + } + if (op.HasAttr("padding_algorithm")) { + auto padding_algorithm = + boost::get(op.GetAttr("padding_algorithm")); + if (padding_algorithm != "EXPLICIT") { + VLOG(3) << "== padding_algorithm is not EXPLICIT."; + return true; + } + } + return false; + }; + for (size_t i = 0; i < program->Size(); i++) { + const auto& block = program->Block(i); + for (auto* op : block.AllOps()) { + if ((op_type_checker(op->Type())) && checker(*op)) { + return true; + } + } + } + return false; +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/op_compatible_info.h b/paddle/fluid/framework/op_compatible_info.h index 08b5734b5b..e72d206b14 100644 --- a/paddle/fluid/framework/op_compatible_info.h +++ b/paddle/fluid/framework/op_compatible_info.h @@ -13,6 +13,7 @@ // limitations under the License. #include +#include #include #include "paddle/fluid/framework/program_desc.h" @@ -70,5 +71,9 @@ class OpCompatibleMap { std::string default_required_version_; }; +// Determine if the model contains operators that the optimization cannot +// support. +bool ProgOptimUnsupported(std::shared_ptr program); + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 7a62876972..4f632946f3 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -145,7 +145,7 @@ bool AnalysisPredictor::PrepareProgram( // still need to create other persistable variables. // So in both case, create persistable variables at first. if (!CheckOperatorCompatible()) { - LOG(WARNING) << "WARNING: Results may be DIFF! " + LOG(WARNING) << "WARNING: Results may be incorrect! " "Using same versions between model and lib."; } executor_->CreateVariables(*inference_program_, 0, true, sub_scope_); @@ -458,6 +458,14 @@ void AnalysisPredictor::PrepareArgument() { // NOTE All the members in AnalysisConfig should be copied to Argument. void AnalysisPredictor::OptimizeInferenceProgram() { + if (ProgOptimUnsupported(inference_program_)) { + LOG(INFO) << "NOTICE: Your inference model contains parameters such " + "as asymmetric padding, and ir optimization is temporarily " + "not supported, " + "so it is turned off."; + config_.SwitchIrOptim(false); + argument_.SetEnableAnalysisOptim(false); + } PrepareArgument(); Analyzer().Run(&argument_); -- GitLab