From 4a4ffe9abb163ce9746469453b214106e3bf01ed Mon Sep 17 00:00:00 2001 From: houj04 <35131887+houj04@users.noreply.github.com> Date: Thu, 18 May 2023 23:13:18 +0800 Subject: [PATCH] [XPU] fix bug on XPUPlace and AllGather (#53926) --- paddle/fluid/distributed/collective/process_group_bkcl.cc | 2 +- python/paddle/__init__.py | 2 +- python/paddle/framework/__init__.py | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/distributed/collective/process_group_bkcl.cc b/paddle/fluid/distributed/collective/process_group_bkcl.cc index 99323d9d45a..095a04d2243 100644 --- a/paddle/fluid/distributed/collective/process_group_bkcl.cc +++ b/paddle/fluid/distributed/collective/process_group_bkcl.cc @@ -381,7 +381,7 @@ std::shared_ptr ProcessGroupBKCL::AllGather( phi::AllocationType::XPU); return Collective( out_tensor, - in_tensor, + in_tensor_maybe_partial, [&](phi::DenseTensor* output, const phi::DenseTensor& input, BKCLContext_t comm, diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index f878a6b8d7f..ff5f4c86507 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -349,6 +349,7 @@ from .framework import IPUPlace # noqa: F401 from .framework import CUDAPlace # noqa: F401 from .framework import CUDAPinnedPlace # noqa: F401 from .framework import CustomPlace # noqa: F401 +from .framework import XPUPlace # noqa: F401 from .autograd import grad # noqa: F401 from .autograd import no_grad # noqa: F401 @@ -380,7 +381,6 @@ from .device import is_compiled_with_cinn # noqa: F401 from .device import is_compiled_with_cuda # noqa: F401 from .device import is_compiled_with_rocm # noqa: F401 from .device import is_compiled_with_custom_device # noqa: F401 -from .device import XPUPlace # noqa: F401 # high-level api from .hapi import Model # noqa: F401 diff --git a/python/paddle/framework/__init__.py b/python/paddle/framework/__init__.py index 66456d48253..0ae39627d9a 100755 --- a/python/paddle/framework/__init__.py +++ b/python/paddle/framework/__init__.py @@ -25,6 +25,7 @@ from ..fluid.core import IPUPlace # noqa: F401 from ..fluid.core import CUDAPlace # noqa: F401 from ..fluid.core import CUDAPinnedPlace # noqa: F401 from ..fluid.core import CustomPlace # noqa: F401 +from ..fluid.core import XPUPlace # noqa: F401 from ..fluid import core # noqa: F401 from ..fluid.dygraph import base, to_variable -- GitLab