diff --git a/paddle/phi/api/lib/tensor_copy.cc b/paddle/phi/api/lib/tensor_copy.cc index 57e3c28d8cb1f9f5db19170084d93b1ca922508e..0fbdee257612c2c3ba3f6841eed51a5f5bab746f 100644 --- a/paddle/phi/api/lib/tensor_copy.cc +++ b/paddle/phi/api/lib/tensor_copy.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/phi/api/lib/tensor_copy.h" + +#include "paddle/phi/api/include/context_pool.h" #include "paddle/phi/api/lib/api_gen_utils.h" #include "paddle/phi/api/lib/kernel_dispatch.h" #include "paddle/phi/api/lib/utils/storage.h" @@ -24,18 +26,21 @@ limitations under the License. */ namespace paddle { namespace experimental { -void copy(const Tensor& src, Place place, bool blocking, Tensor* dst) { +void copy(const Tensor& src, const Place& place, bool blocking, Tensor* dst) { auto kernel_key_set = ParseKernelKeyByInputArgs(src); kernel_key_set.backend_set = kernel_key_set.backend_set | BackendSet(phi::TransToPhiBackend(place)); auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey(); - auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( - "copy", kernel_key); VLOG(6) << "copy API kernel key: " << kernel_key; + auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( + "copy", kernel_key); VLOG(6) << "copy API kernel: " << kernel; - auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend()); + auto target_place = phi::TransToPhiPlace(kernel_key.backend()); + auto& pool = paddle::experimental::DeviceContextPool::Instance(); + auto* dev_ctx = pool.GetMutable( + target_place.GetType() == place.GetType() ? place : target_place); auto dense_x = TensorToDenseTensor(src); diff --git a/paddle/phi/api/lib/tensor_copy.h b/paddle/phi/api/lib/tensor_copy.h index 3ce45853319ecf24b21a1305288bdd441f1c1e1c..4a50b78be85d0862e0aa85786f9e09919d00518f 100644 --- a/paddle/phi/api/lib/tensor_copy.h +++ b/paddle/phi/api/lib/tensor_copy.h @@ -19,7 +19,7 @@ limitations under the License. */ namespace paddle { namespace experimental { -void copy(const Tensor& src, Place place, bool blocking, Tensor* dst); +void copy(const Tensor& src, const Place& place, bool blocking, Tensor* dst); } // namespace experimental } // namespace paddle