未验证 提交 4a4ffe9a 编写于 作者: H houj04 提交者: GitHub

[XPU] fix bug on XPUPlace and AllGather (#53926)

上级 3ad67b9a
...@@ -381,7 +381,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllGather( ...@@ -381,7 +381,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllGather(
phi::AllocationType::XPU); phi::AllocationType::XPU);
return Collective( return Collective(
out_tensor, out_tensor,
in_tensor, in_tensor_maybe_partial,
[&](phi::DenseTensor* output, [&](phi::DenseTensor* output,
const phi::DenseTensor& input, const phi::DenseTensor& input,
BKCLContext_t comm, BKCLContext_t comm,
......
...@@ -349,6 +349,7 @@ from .framework import IPUPlace # noqa: F401 ...@@ -349,6 +349,7 @@ from .framework import IPUPlace # noqa: F401
from .framework import CUDAPlace # noqa: F401 from .framework import CUDAPlace # noqa: F401
from .framework import CUDAPinnedPlace # noqa: F401 from .framework import CUDAPinnedPlace # noqa: F401
from .framework import CustomPlace # noqa: F401 from .framework import CustomPlace # noqa: F401
from .framework import XPUPlace # noqa: F401
from .autograd import grad # noqa: F401 from .autograd import grad # noqa: F401
from .autograd import no_grad # noqa: F401 from .autograd import no_grad # noqa: F401
...@@ -380,7 +381,6 @@ from .device import is_compiled_with_cinn # noqa: F401 ...@@ -380,7 +381,6 @@ from .device import is_compiled_with_cinn # noqa: F401
from .device import is_compiled_with_cuda # noqa: F401 from .device import is_compiled_with_cuda # noqa: F401
from .device import is_compiled_with_rocm # noqa: F401 from .device import is_compiled_with_rocm # noqa: F401
from .device import is_compiled_with_custom_device # noqa: F401 from .device import is_compiled_with_custom_device # noqa: F401
from .device import XPUPlace # noqa: F401
# high-level api # high-level api
from .hapi import Model # noqa: F401 from .hapi import Model # noqa: F401
......
...@@ -25,6 +25,7 @@ from ..fluid.core import IPUPlace # noqa: F401 ...@@ -25,6 +25,7 @@ from ..fluid.core import IPUPlace # noqa: F401
from ..fluid.core import CUDAPlace # noqa: F401 from ..fluid.core import CUDAPlace # noqa: F401
from ..fluid.core import CUDAPinnedPlace # noqa: F401 from ..fluid.core import CUDAPinnedPlace # noqa: F401
from ..fluid.core import CustomPlace # noqa: F401 from ..fluid.core import CustomPlace # noqa: F401
from ..fluid.core import XPUPlace # noqa: F401
from ..fluid import core # noqa: F401 from ..fluid import core # noqa: F401
from ..fluid.dygraph import base, to_variable from ..fluid.dygraph import base, to_variable
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册