diff --git a/paddle/fluid/distributed/collective/process_group_bkcl.cc b/paddle/fluid/distributed/collective/process_group_bkcl.cc index 99323d9d45ad58e2de8a96b48b44be5d510299fb..095a04d2243ab18099ba58564b0e9dd5cfc9031b 100644 --- a/paddle/fluid/distributed/collective/process_group_bkcl.cc +++ b/paddle/fluid/distributed/collective/process_group_bkcl.cc @@ -381,7 +381,7 @@ std::shared_ptr ProcessGroupBKCL::AllGather( phi::AllocationType::XPU); return Collective( out_tensor, - in_tensor, + in_tensor_maybe_partial, [&](phi::DenseTensor* output, const phi::DenseTensor& input, BKCLContext_t comm, diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index f878a6b8d7f7eaf2847071ad4ecd8e440cfd92dc..ff5f4c865079583d1f0f2e9c8010cbbe3aecf7d2 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -349,6 +349,7 @@ from .framework import IPUPlace # noqa: F401 from .framework import CUDAPlace # noqa: F401 from .framework import CUDAPinnedPlace # noqa: F401 from .framework import CustomPlace # noqa: F401 +from .framework import XPUPlace # noqa: F401 from .autograd import grad # noqa: F401 from .autograd import no_grad # noqa: F401 @@ -380,7 +381,6 @@ from .device import is_compiled_with_cinn # noqa: F401 from .device import is_compiled_with_cuda # noqa: F401 from .device import is_compiled_with_rocm # noqa: F401 from .device import is_compiled_with_custom_device # noqa: F401 -from .device import XPUPlace # noqa: F401 # high-level api from .hapi import Model # noqa: F401 diff --git a/python/paddle/framework/__init__.py b/python/paddle/framework/__init__.py index 66456d48253f8eb948ef1298b30ddbe70813978c..0ae39627d9a5d8c13024e6eba7add32a85a5377a 100755 --- a/python/paddle/framework/__init__.py +++ b/python/paddle/framework/__init__.py @@ -25,6 +25,7 @@ from ..fluid.core import IPUPlace # noqa: F401 from ..fluid.core import CUDAPlace # noqa: F401 from ..fluid.core import CUDAPinnedPlace # noqa: F401 from ..fluid.core import CustomPlace # noqa: F401 +from ..fluid.core import XPUPlace # noqa: F401 from ..fluid import core # noqa: F401 from ..fluid.dygraph import base, to_variable