From 91212104562f7075c01bce2c60e9a81b804b77e2 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Tue, 5 Apr 2022 22:41:15 +0800 Subject: [PATCH] Fix bug of data transform in inference executor (#41349) * fix bug of data transform in inference executor * fix bug --- paddle/fluid/framework/operator.cc | 10 ++++++++++ paddle/phi/kernels/gpu/arange_kernel.cu | 6 +++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 49248edd322..6af07caaf88 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -2176,6 +2176,16 @@ Scope* OperatorWithKernel::PreparePhiData( if (!new_scope) { new_scope = &scope.NewScope(); } + // For inference, if a gpu model has an op which could only run on CPU, + // each result of different input will be the same with the first one. + // The reason is that if a gpu tensor is the input of a cpu kernel, + // we will create a new cpu tensor in new scope. + // However, if enable_cache_runtime_context_, we get the cpu tensor each + // time, not the gpu tensor. Thus, we set pre_scope_ = nullptr + // to trigger `new RuntimeContext()` in RunImpl(). + if (enable_cache_runtime_context_) { + pre_scope_ = nullptr; + } // Create new var with the same name in transfer scopes auto* trans_var = new_scope->Var(name_vec[offset]); diff --git a/paddle/phi/kernels/gpu/arange_kernel.cu b/paddle/phi/kernels/gpu/arange_kernel.cu index 916f6aa5537..9ea0d7c5393 100644 --- a/paddle/phi/kernels/gpu/arange_kernel.cu +++ b/paddle/phi/kernels/gpu/arange_kernel.cu @@ -64,7 +64,7 @@ void ArangeKernel(const Context& dev_ctx, PD_REGISTER_KERNEL( arange, GPU, ALL_LAYOUT, phi::ArangeKernel, float, double, int64_t, int) { - kernel->InputAt(0).SetBackend(phi::Backend::CPU); - kernel->InputAt(1).SetBackend(phi::Backend::CPU); - kernel->InputAt(2).SetBackend(phi::Backend::CPU); + kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); } -- GitLab