diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index 9dd1dacc02c25474803ef3177d9cd967ee681714..2317bfdd7c0d5ee94e91e081da47177625f5bfd8 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -186,11 +186,10 @@ PreparedOp PrepareImpl(const NameVarMap& ins, << " | kernel key: " << pt_kernel_key << " | kernel: " << pt_kernel; - if (platform::is_cpu_place(expected_kernel_key.place_)) { - auto* cpu_ctx = pool.Get(paddle::platform::CPUPlace()); - return PreparedOp(op, ctx, expected_kernel_key, pt_kernel_signature, - pt_kernel, cpu_ctx); + if (expected_kernel_key.place_ != place) { + dev_ctx = pool.Get(expected_kernel_key.place_); } + // TODO(chenweihang): using CPUKernel when miss device kernel case return PreparedOp(op, ctx, expected_kernel_key, pt_kernel_signature, pt_kernel, dev_ctx);