未验证 提交 dd181238 编写于 作者: W wenbin 提交者: GitHub

fix inference prepare data bug (#33305)

* fix inference prepare data bug

* rename functions

* typo

* typo

* typo

* UT correct

* correct condition

* correct condition

* ci coverage

* morelines

* fix ci coverage
上级 6877b134
...@@ -1531,7 +1531,12 @@ Scope* OperatorWithKernel::PrepareData( ...@@ -1531,7 +1531,12 @@ Scope* OperatorWithKernel::PrepareData(
// the rest iterations to save the elapsed time. // the rest iterations to save the elapsed time.
// We do not support skipping PrepareData in while block, because the Op's // We do not support skipping PrepareData in while block, because the Op's
// input may be changed by subsequent Ops, which may cause an error. // input may be changed by subsequent Ops, which may cause an error.
if (pre_scope_ == &scope && new_scope == nullptr) {
// For inference, ops that behind conditional branch aren't supported well,
// so disable prepare optimization conservatively.
bool force_prepare_data = HasAttr("inference_force_prepare_data") &&
Attr<bool>("inference_force_prepare_data");
if (pre_scope_ == &scope && new_scope == nullptr && !force_prepare_data) {
need_prepare_data_ = false; need_prepare_data_ = false;
} }
......
...@@ -270,7 +270,46 @@ bool AnalysisPredictor::CreateExecutor() { ...@@ -270,7 +270,46 @@ bool AnalysisPredictor::CreateExecutor() {
executor_.reset(new paddle::framework::NaiveExecutor(place_)); executor_.reset(new paddle::framework::NaiveExecutor(place_));
return true; return true;
} }
static bool IsPrepareDataOptTargetOp(framework::OpDesc *op) {
// here is prepare data optimization related bad cases:
// let's assume an op behind conditional_block and if conditional_block
// chooses branch 1, the op need to call prepare data. else the op don't need
// to call prepare data. In running, if predictor chooses branch 2, then
// optimization takes effect, later issue is followed if predictor chooses
// branch 1, because the op lost chance to prepare data.
std::vector<std::string> op_type = {"conditional_block_infer",
"select_input"};
for (const auto &type : op_type) {
if (op->Type() == type) {
return true;
}
}
return false;
}
static void DisablePrepareDataOpt(
std::shared_ptr<framework::ProgramDesc> inference_program, int block,
bool pre_disable_opt) {
bool disable_opt = false;
auto &infer_block = inference_program->Block(block);
for (auto *op : infer_block.AllOps()) {
if (disable_opt || pre_disable_opt) {
op->SetAttr("inference_force_prepare_data", true);
}
if (op->HasAttr("sub_block")) {
int blockID = op->GetBlockAttrId("sub_block");
DisablePrepareDataOpt(inference_program, blockID,
disable_opt || pre_disable_opt);
}
// disable prepare data if unfriendly op is found
disable_opt = IsPrepareDataOptTargetOp(op);
}
}
bool AnalysisPredictor::PrepareExecutor() { bool AnalysisPredictor::PrepareExecutor() {
DisablePrepareDataOpt(inference_program_, 0, false);
executor_->Prepare(sub_scope_, *inference_program_, 0, executor_->Prepare(sub_scope_, *inference_program_, 0,
config_.use_feed_fetch_ops_); config_.use_feed_fetch_ops_);
......
...@@ -73,6 +73,8 @@ class ConditionalBlockInferOp : public ConditionalOp { ...@@ -73,6 +73,8 @@ class ConditionalBlockInferOp : public ConditionalOp {
framework::Executor exec(dev_place); framework::Executor exec(dev_place);
auto *block = Attr<framework::BlockDesc *>("sub_block"); auto *block = Attr<framework::BlockDesc *>("sub_block");
VLOG(3) << "Conditional block.idx = " << block->ID()
<< ", scope = " << &cur_scope;
exec.Run(*block->Program(), &cur_scope, block->ID(), false); exec.Run(*block->Program(), &cur_scope, block->ID(), false);
scope.DeleteScope(scopes->front()); scope.DeleteScope(scopes->front());
} }
......
...@@ -219,7 +219,7 @@ class DynamicShapeTensorRTSubgraphPassConvTest(InferencePassTest): ...@@ -219,7 +219,7 @@ class DynamicShapeTensorRTSubgraphPassConvTest(InferencePassTest):
}, { }, {
"conv2d_0.tmp_0": [16, 6, 16, 16], "conv2d_0.tmp_0": [16, 6, 16, 16],
"data": [16, 6, 16, 16], "data": [16, 6, 16, 16],
"depthwise_conv2d_0.tmp_0": [32, 6, 64, 64] "depthwise_conv2d_0.tmp_0": [16, 6, 16, 16]
}, False) }, False)
self.fetch_list = [conv_out] self.fetch_list = [conv_out]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册