未验证 提交 d9cd9898 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #17048 from luotao1/fix_runtime_cache_bug

fix runtime_context_cache bug when gpu model has an op runs only on cpu
......@@ -1095,6 +1095,17 @@ Scope* OperatorWithKernel::PrepareData(
if (!new_scope) {
new_scope = &scope.NewScope();
}
// For inference, if a gpu model has an op which could only run on CPU,
// each result of different input will be the same with the first one.
// The reason is that if a gpu tensor is the input of a cpu kernel,
// we will create a new cpu tensor in new scope.
// However, if enable_cache_runtime_context, we get the cpu tensor each
// time, not the gpu tensor.
// Thus, we set pre_scope_ = nullptr to trigger `new RuntimeContext()` in
// RunImpl().
if (enable_cache_runtime_context) {
pre_scope_ = nullptr;
}
auto* trans_var = new_scope->Var(var_name);
input_vars[i] = trans_var;
......
......@@ -98,7 +98,10 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"conv_elementwise_add_fuse_pass", //
#endif //
"transpose_flatten_concat_fuse_pass",
// following two passes should be located in the last, since they will
// work on all fused ops.
"expected_kernel_cache_pass", //
"runtime_context_cache_pass"
});
use_gpu_ = true;
......@@ -115,8 +118,7 @@ void GpuPassStrategy::EnableMkldnnQuantizer() {
CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) {
// NOTE the large fusions should be located in the front, so that they will
// not be damaged by smaller ones.
passes_.assign({
"infer_clean_graph_pass", //
passes_.assign({"infer_clean_graph_pass", //
"attention_lstm_fuse_pass", //
"seqconv_eltadd_relu_fuse_pass", //
// "seqpool_concat_fuse_pass", //
......@@ -132,8 +134,10 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) {
"conv_bn_fuse_pass", //
"conv_eltwiseadd_bn_fuse_pass", //
"is_test_pass", //
// following two passes should be located in the last, since
// they will work on all fused ops.
"expected_kernel_cache_pass", //
});
"runtime_context_cache_pass"});
use_gpu_ = false;
}
......
......@@ -110,11 +110,6 @@ void SetConfig(AnalysisConfig *cfg) {
if (FLAGS_zero_copy) {
cfg->SwitchUseFeedFetchOps(false);
}
// Enable runtime_context_cache_pass, disabled by default since it doesn't
// cover all the cases.
// See detail: https://github.com/PaddlePaddle/Paddle/issues/16609
// https://github.com/PaddlePaddle/Paddle/issues/16841
cfg->pass_builder()->AppendPass("runtime_context_cache_pass");
}
void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册