提交 53f1e024 编写于 作者: 石晓伟 提交者: kolinwei

fix infer crashes caused by conv/pool upgrades, test=release/1.6 (#20969)

* fix infer crashes caused by conv/pool upgrades, test=release/1.6

* fix bug, test=release/1.6
上级 6f0b2b19
...@@ -215,5 +215,50 @@ bool OpCompatibleMap::ReadFromProto(const proto::OpCompatibleMap& desc) { ...@@ -215,5 +215,50 @@ bool OpCompatibleMap::ReadFromProto(const proto::OpCompatibleMap& desc) {
return true; return true;
} }
bool ProgOptimUnsupported(std::shared_ptr<framework::ProgramDesc> program) {
auto op_type_checker = [](const std::string& name) {
const std::vector<std::string> op_types({
"conv2d", "conv3d", "conv2d_transpose", "conv3d_transpose",
"depthwise_conv2d", "depthwise_conv2d_transpose", "pool2d", "pool3d",
});
return std::find(op_types.begin(), op_types.end(), name) != op_types.end();
};
auto checker = [](const framework::OpDesc& op) {
if (op.HasAttr("paddings") && op.HasAttr("strides")) {
auto paddings = boost::get<std::vector<int>>(op.GetAttr("paddings"));
auto strides = boost::get<std::vector<int>>(op.GetAttr("strides"));
if (paddings.size() != strides.size()) {
VLOG(3) << "== paddings size is not equal to strides size.";
return true;
}
}
if (op.HasAttr("data_format")) {
auto data_format = boost::get<std::string>(op.GetAttr("data_format"));
if (data_format == "NHWC" || data_format == "NDHWC") {
VLOG(3) << "== data_format is NHWC or NDHWC.";
return true;
}
}
if (op.HasAttr("padding_algorithm")) {
auto padding_algorithm =
boost::get<std::string>(op.GetAttr("padding_algorithm"));
if (padding_algorithm != "EXPLICIT") {
VLOG(3) << "== padding_algorithm is not EXPLICIT.";
return true;
}
}
return false;
};
for (size_t i = 0; i < program->Size(); i++) {
const auto& block = program->Block(i);
for (auto* op : block.AllOps()) {
if ((op_type_checker(op->Type())) && checker(*op)) {
return true;
}
}
}
return false;
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include <map> #include <map>
#include <memory>
#include <string> #include <string>
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
...@@ -70,5 +71,9 @@ class OpCompatibleMap { ...@@ -70,5 +71,9 @@ class OpCompatibleMap {
std::string default_required_version_; std::string default_required_version_;
}; };
// Determine if the model contains operators that the optimization cannot
// support.
bool ProgOptimUnsupported(std::shared_ptr<framework::ProgramDesc> program);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -145,7 +145,7 @@ bool AnalysisPredictor::PrepareProgram( ...@@ -145,7 +145,7 @@ bool AnalysisPredictor::PrepareProgram(
// still need to create other persistable variables. // still need to create other persistable variables.
// So in both case, create persistable variables at first. // So in both case, create persistable variables at first.
if (!CheckOperatorCompatible()) { if (!CheckOperatorCompatible()) {
LOG(WARNING) << "WARNING: Results may be DIFF! " LOG(WARNING) << "WARNING: Results may be incorrect! "
"Using same versions between model and lib."; "Using same versions between model and lib.";
} }
executor_->CreateVariables(*inference_program_, 0, true, sub_scope_); executor_->CreateVariables(*inference_program_, 0, true, sub_scope_);
...@@ -458,6 +458,14 @@ void AnalysisPredictor::PrepareArgument() { ...@@ -458,6 +458,14 @@ void AnalysisPredictor::PrepareArgument() {
// NOTE All the members in AnalysisConfig should be copied to Argument. // NOTE All the members in AnalysisConfig should be copied to Argument.
void AnalysisPredictor::OptimizeInferenceProgram() { void AnalysisPredictor::OptimizeInferenceProgram() {
if (ProgOptimUnsupported(inference_program_)) {
LOG(INFO) << "NOTICE: Your inference model contains parameters such "
"as asymmetric padding, and ir optimization is temporarily "
"not supported, "
"so it is turned off.";
config_.SwitchIrOptim(false);
argument_.SetEnableAnalysisOptim(false);
}
PrepareArgument(); PrepareArgument();
Analyzer().Run(&argument_); Analyzer().Run(&argument_);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册