From 490e746269b11fefac4e23f7fad2a096f94b7c31 Mon Sep 17 00:00:00 2001 From: luotao1 Date: Tue, 23 Apr 2019 16:55:19 +0800 Subject: [PATCH] fix runtime_context_cache bug when gpu model has an op runs only on cpu test=develop --- paddle/fluid/framework/operator.cc | 11 +++++ .../inference/api/paddle_pass_builder.cc | 42 ++++++++++--------- .../tests/api/analyzer_pyramid_dnn_tester.cc | 5 --- 3 files changed, 34 insertions(+), 24 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 1723a9a78..78410c0d0 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1095,6 +1095,17 @@ Scope* OperatorWithKernel::PrepareData( if (!new_scope) { new_scope = &scope.NewScope(); } + // For inference, if a gpu model has an op which could only run on CPU, + // each result of different input will be the same with the first one. + // The reason is that if a gpu tensor is the input of a cpu kernel, + // we will create a new cpu tensor in new scope. + // However, if enable_cache_runtime_context, we get the cpu tensor each + // time, not the gpu tensor. + // Thus, we set pre_scope_ = nullptr to trigger `new RuntimeContext()` in + // RunImpl(). + if (enable_cache_runtime_context) { + pre_scope_ = nullptr; + } auto* trans_var = new_scope->Var(var_name); input_vars[i] = trans_var; diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 30966772e..2fba560ac 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -98,7 +98,10 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { "conv_elementwise_add_fuse_pass", // #endif // "transpose_flatten_concat_fuse_pass", + // following two passes should be located in the last, since they will + // work on all fused ops. "expected_kernel_cache_pass", // + "runtime_context_cache_pass" }); use_gpu_ = true; @@ -115,25 +118,26 @@ void GpuPassStrategy::EnableMkldnnQuantizer() { CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) { // NOTE the large fusions should be located in the front, so that they will // not be damaged by smaller ones. - passes_.assign({ - "infer_clean_graph_pass", // - "attention_lstm_fuse_pass", // - "seqconv_eltadd_relu_fuse_pass", // - // "seqpool_concat_fuse_pass", // - // "embedding_fc_lstm_fuse_pass", // - "fc_lstm_fuse_pass", // - "mul_lstm_fuse_pass", // - "fc_gru_fuse_pass", // - "mul_gru_fuse_pass", // - "seq_concat_fc_fuse_pass", // - "fc_fuse_pass", // - "repeated_fc_relu_fuse_pass", // - "squared_mat_sub_fuse_pass", // - "conv_bn_fuse_pass", // - "conv_eltwiseadd_bn_fuse_pass", // - "is_test_pass", // - "expected_kernel_cache_pass", // - }); + passes_.assign({"infer_clean_graph_pass", // + "attention_lstm_fuse_pass", // + "seqconv_eltadd_relu_fuse_pass", // + // "seqpool_concat_fuse_pass", // + // "embedding_fc_lstm_fuse_pass", // + "fc_lstm_fuse_pass", // + "mul_lstm_fuse_pass", // + "fc_gru_fuse_pass", // + "mul_gru_fuse_pass", // + "seq_concat_fc_fuse_pass", // + "fc_fuse_pass", // + "repeated_fc_relu_fuse_pass", // + "squared_mat_sub_fuse_pass", // + "conv_bn_fuse_pass", // + "conv_eltwiseadd_bn_fuse_pass", // + "is_test_pass", // + // following two passes should be located in the last, since + // they will work on all fused ops. + "expected_kernel_cache_pass", // + "runtime_context_cache_pass"}); use_gpu_ = false; } diff --git a/paddle/fluid/inference/tests/api/analyzer_pyramid_dnn_tester.cc b/paddle/fluid/inference/tests/api/analyzer_pyramid_dnn_tester.cc index 1bb06aa21..9443b0806 100644 --- a/paddle/fluid/inference/tests/api/analyzer_pyramid_dnn_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_pyramid_dnn_tester.cc @@ -110,11 +110,6 @@ void SetConfig(AnalysisConfig *cfg) { if (FLAGS_zero_copy) { cfg->SwitchUseFeedFetchOps(false); } - // Enable runtime_context_cache_pass, disabled by default since it doesn't - // cover all the cases. - // See detail: https://github.com/PaddlePaddle/Paddle/issues/16609 - // https://github.com/PaddlePaddle/Paddle/issues/16841 - cfg->pass_builder()->AppendPass("runtime_context_cache_pass"); } void SetInput(std::vector> *inputs) { -- GitLab