未验证 提交 2d383b81 编写于 作者: L LiYuRio 提交者: GitHub

Remove place for process group (#47857)

上级 e0be4b94
...@@ -83,15 +83,14 @@ class ProcessGroup { ...@@ -83,15 +83,14 @@ class ProcessGroup {
}; };
public: public:
explicit ProcessGroup(int rank, int size, int gid);
virtual ~ProcessGroup() = default;
// TODO(dev): This constructor will be removed later.
explicit ProcessGroup(int rank, explicit ProcessGroup(int rank,
int size, int size,
const platform::Place& place, const platform::Place& place,
int gid); int gid);
explicit ProcessGroup(int rank, int size, int gid);
virtual ~ProcessGroup() {}
int GetRank() const { return rank_; } int GetRank() const { return rank_; }
int GetSize() const { return size_; } int GetSize() const { return size_; }
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "paddle/fluid/platform/device/xpu/xpu_info.h" #include "paddle/fluid/platform/device/xpu/xpu_info.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/phi/core/errors.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
...@@ -68,11 +69,8 @@ void ProcessGroupBKCL::BKCLTask::Synchronize() { Wait(kWaitTimeout); } ...@@ -68,11 +69,8 @@ void ProcessGroupBKCL::BKCLTask::Synchronize() { Wait(kWaitTimeout); }
ProcessGroupBKCL::ProcessGroupBKCL(const std::shared_ptr<Store>& store, ProcessGroupBKCL::ProcessGroupBKCL(const std::shared_ptr<Store>& store,
int rank, int rank,
int size, int size,
const platform::Place& place,
int gid) int gid)
: ProcessGroupStream(rank, size, place, gid), store_(store) { : ProcessGroupStream(rank, size, gid), store_(store) {}
platform::SetXPUDeviceId(place_.device);
}
void ProcessGroupBKCL::GroupStart() { void ProcessGroupBKCL::GroupStart() {
PADDLE_ENFORCE_XPU_SUCCESS(bkcl_group_start()); PADDLE_ENFORCE_XPU_SUCCESS(bkcl_group_start());
...@@ -255,8 +253,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllGather( ...@@ -255,8 +253,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllGather(
std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Barrier( std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Barrier(
const BarrierOptions& opts) { const BarrierOptions& opts) {
PADDLE_ENFORCE_GE(opts.device_id,
0,
platform::errors::PreconditionNotMet(
"The barrier device id must greater or equal than 0."));
platform::XPUPlace place(opts.device_id);
auto allocator = std::unique_ptr<phi::Allocator>( auto allocator = std::unique_ptr<phi::Allocator>(
new paddle::experimental::DefaultAllocator(place_)); new paddle::experimental::DefaultAllocator(place));
phi::DenseTensorMeta meta(phi::DataType::FLOAT32, phi::DDim{1}); phi::DenseTensorMeta meta(phi::DataType::FLOAT32, phi::DDim{1});
phi::DenseTensor barrier_tensor{allocator.get(), meta}; phi::DenseTensor barrier_tensor{allocator.get(), meta};
......
...@@ -71,7 +71,6 @@ class ProcessGroupBKCL : public ProcessGroupStream { ...@@ -71,7 +71,6 @@ class ProcessGroupBKCL : public ProcessGroupStream {
ProcessGroupBKCL(const std::shared_ptr<Store>& store, ProcessGroupBKCL(const std::shared_ptr<Store>& store,
int rank, int rank,
int size, int size,
const platform::Place& place,
int gid); int gid);
std::string GetBackendName() const override { std::string GetBackendName() const override {
......
...@@ -98,15 +98,11 @@ bool ProcessGroupCustom::CustomTask::Wait(std::chrono::milliseconds timeout) { ...@@ -98,15 +98,11 @@ bool ProcessGroupCustom::CustomTask::Wait(std::chrono::milliseconds timeout) {
void ProcessGroupCustom::CustomTask::Synchronize() { Wait(kWaitTimeout); } void ProcessGroupCustom::CustomTask::Synchronize() { Wait(kWaitTimeout); }
ProcessGroupCustom::ProcessGroupCustom(const std::shared_ptr<Store>& store, ProcessGroupCustom::ProcessGroupCustom(const std::shared_ptr<Store>& store,
const std::string& device_type,
int rank, int rank,
int size, int size,
const platform::Place& place,
int gid) int gid)
: ProcessGroup(rank, size, place, gid), : ProcessGroup(rank, size, gid), store_(store), device_type_(device_type) {}
store_(store),
device_type_(place.GetDeviceType()) {
phi::DeviceManager::SetDevice(place_);
}
void ProcessGroupCustom::BroadcastUniqueCustomID( void ProcessGroupCustom::BroadcastUniqueCustomID(
std::vector<phi::ccl::CCLRootId>& ccl_ids) { // NOLINT std::vector<phi::ccl::CCLRootId>& ccl_ids) { // NOLINT
...@@ -379,7 +375,12 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast( ...@@ -379,7 +375,12 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Barrier( std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Barrier(
const BarrierOptions& opts) { const BarrierOptions& opts) {
// Only support single card single process // Only support single card single process
std::vector<phi::CustomPlace> places = {place_}; PADDLE_ENFORCE_GE(opts.device_id,
0,
platform::errors::PreconditionNotMet(
"The barrier device id must greater or equal than 0."));
platform::CustomPlace place(device_type_, opts.device_id);
std::vector<phi::CustomPlace> places = {place};
std::vector<phi::DenseTensor> barrierTensors; std::vector<phi::DenseTensor> barrierTensors;
barrierTensors.reserve(places.size()); barrierTensors.reserve(places.size());
......
...@@ -64,9 +64,9 @@ class ProcessGroupCustom : public ProcessGroup { ...@@ -64,9 +64,9 @@ class ProcessGroupCustom : public ProcessGroup {
}; };
ProcessGroupCustom(const std::shared_ptr<Store>& store, ProcessGroupCustom(const std::shared_ptr<Store>& store,
const std::string& device_type,
int rank, int rank,
int size, int size,
const platform::Place& place,
int gid); int gid);
std::string GetBackendName() const override { return "XCCL_" + device_type_; } std::string GetBackendName() const override { return "XCCL_" + device_type_; }
......
...@@ -180,10 +180,9 @@ ProcessGroupGloo::ProcessGroupGloo( ...@@ -180,10 +180,9 @@ ProcessGroupGloo::ProcessGroupGloo(
const std::shared_ptr<distributed::Store>& store, const std::shared_ptr<distributed::Store>& store,
int rank, int rank,
int world_size, int world_size,
const platform::Place& place,
int gid, int gid,
const std::shared_ptr<GlooOptions> options) const std::shared_ptr<GlooOptions> options)
: ProcessGroup(rank, world_size, place, gid), : ProcessGroup(rank, world_size, gid),
_tag(0), _tag(0),
_store(new GlooStore(store)) { _store(new GlooStore(store)) {
_context = std::make_shared<gloo::rendezvous::Context>(rank, world_size); _context = std::make_shared<gloo::rendezvous::Context>(rank, world_size);
......
...@@ -102,7 +102,6 @@ class ProcessGroupGloo : public ProcessGroup { ...@@ -102,7 +102,6 @@ class ProcessGroupGloo : public ProcessGroup {
const std::shared_ptr<paddle::distributed::Store>& store, const std::shared_ptr<paddle::distributed::Store>& store,
int rank, int rank,
int world_size, int world_size,
const platform::Place& place,
int gid, int gid,
std::shared_ptr<GlooOptions> options); std::shared_ptr<GlooOptions> options);
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "paddle/fluid/distributed/collective/Common.h" #include "paddle/fluid/distributed/collective/Common.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/phi/api/lib/utils/allocator.h"
DECLARE_bool(nccl_blocking_wait); DECLARE_bool(nccl_blocking_wait);
...@@ -81,11 +82,8 @@ void ProcessGroupNCCL::NCCLTask::Synchronize() { Wait(kWaitTimeout); } ...@@ -81,11 +82,8 @@ void ProcessGroupNCCL::NCCLTask::Synchronize() { Wait(kWaitTimeout); }
ProcessGroupNCCL::ProcessGroupNCCL(const std::shared_ptr<Store>& store, ProcessGroupNCCL::ProcessGroupNCCL(const std::shared_ptr<Store>& store,
int rank, int rank,
int size, int size,
const platform::Place& place,
int gid) int gid)
: ProcessGroupStream(rank, size, place, gid), store_(store) { : ProcessGroupStream(rank, size, gid), store_(store) {}
platform::SetDeviceId(place_.device);
}
void ProcessGroupNCCL::GroupStart() { void ProcessGroupNCCL::GroupStart() {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
...@@ -182,8 +180,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce( ...@@ -182,8 +180,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Barrier( std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Barrier(
const BarrierOptions& opts) { const BarrierOptions& opts) {
PADDLE_ENFORCE_GE(opts.device_id,
0,
platform::errors::PreconditionNotMet(
"The barrier device id must greater or equal than 0."));
platform::CUDAPlace place(opts.device_id);
auto allocator = std::unique_ptr<phi::Allocator>( auto allocator = std::unique_ptr<phi::Allocator>(
new paddle::experimental::DefaultAllocator(place_)); new paddle::experimental::DefaultAllocator(place));
phi::DenseTensorMeta meta(phi::DataType::FLOAT32, phi::DDim{1}); phi::DenseTensorMeta meta(phi::DataType::FLOAT32, phi::DDim{1});
phi::DenseTensor barrier_tensor{allocator.get(), meta}; phi::DenseTensor barrier_tensor{allocator.get(), meta};
......
...@@ -85,7 +85,6 @@ class ProcessGroupNCCL final : public ProcessGroupStream { ...@@ -85,7 +85,6 @@ class ProcessGroupNCCL final : public ProcessGroupStream {
ProcessGroupNCCL(const std::shared_ptr<Store>& store, ProcessGroupNCCL(const std::shared_ptr<Store>& store,
int rank, int rank,
int size, int size,
const platform::Place& place,
int gid); int gid);
std::string GetBackendName() const override { return "NCCL"; } std::string GetBackendName() const override { return "NCCL"; }
......
...@@ -17,11 +17,8 @@ ...@@ -17,11 +17,8 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
ProcessGroupStream::ProcessGroupStream(int rank, ProcessGroupStream::ProcessGroupStream(int rank, int size, int gid)
int size, : ProcessGroup(rank, size, gid) {}
const platform::Place& place,
int gid)
: ProcessGroup(rank, size, place, gid) {}
const phi::DeviceContext& ProcessGroupStream::GetDeviceContext( const phi::DeviceContext& ProcessGroupStream::GetDeviceContext(
const Place& place, bool use_calc_stream) const { const Place& place, bool use_calc_stream) const {
......
...@@ -55,7 +55,7 @@ class ProcessGroupStream : public ProcessGroup { ...@@ -55,7 +55,7 @@ class ProcessGroupStream : public ProcessGroup {
}; };
public: public:
ProcessGroupStream(int rank, int size, const platform::Place& place, int gid); ProcessGroupStream(int rank, int size, int gid);
virtual ~ProcessGroupStream() = default; virtual ~ProcessGroupStream() = default;
virtual const phi::DeviceContext& GetDeviceContext( virtual const phi::DeviceContext& GetDeviceContext(
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <chrono> #include <chrono>
#include <cstdint> #include <cstdint>
#include <vector> #include <vector>
#include "paddle/phi/common/place.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
...@@ -33,7 +34,7 @@ struct BroadcastOptions { ...@@ -33,7 +34,7 @@ struct BroadcastOptions {
}; };
struct BarrierOptions { struct BarrierOptions {
std::vector<int> place_ids; int8_t device_id;
}; };
struct ReduceOptions { struct ReduceOptions {
......
...@@ -110,7 +110,7 @@ void BindDistributed(py::module *m) { ...@@ -110,7 +110,7 @@ void BindDistributed(py::module *m) {
py::class_<distributed::BarrierOptions>(*m, "BarrierOptions") py::class_<distributed::BarrierOptions>(*m, "BarrierOptions")
.def(py::init<>()) .def(py::init<>())
.def_readwrite("place_ids", &distributed::BarrierOptions::place_ids); .def_readwrite("device_id", &distributed::BarrierOptions::device_id);
py::class_<distributed::ReduceOptions>(*m, "ReduceOptions") py::class_<distributed::ReduceOptions>(*m, "ReduceOptions")
.def(py::init<>()) .def(py::init<>())
...@@ -513,12 +513,12 @@ void BindDistributed(py::module *m) { ...@@ -513,12 +513,12 @@ void BindDistributed(py::module *m) {
.def( .def(
"barrier", "barrier",
[](distributed::ProcessGroup &self, std::vector<int> place_ids) { [](distributed::ProcessGroup &self, int8_t device_id) {
distributed::BarrierOptions opts; distributed::BarrierOptions opts;
opts.place_ids = place_ids; opts.device_id = device_id;
return self.Barrier(opts); return self.Barrier(opts);
}, },
py::arg("place_ids") = std::vector<int>{}, py::arg("device_id") = -1,
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
// TODO(liyurui): Interface below will be removed in the future. // TODO(liyurui): Interface below will be removed in the future.
...@@ -1214,12 +1214,10 @@ void BindDistributed(py::module *m) { ...@@ -1214,12 +1214,10 @@ void BindDistributed(py::module *m) {
.def(py::init<const std::shared_ptr<distributed::Store> &, .def(py::init<const std::shared_ptr<distributed::Store> &,
int, int,
int, int,
const platform::CUDAPlace &,
int>(), int>(),
py::arg("store"), py::arg("store"),
py::arg("rank"), py::arg("rank"),
py::arg("world_size"), py::arg("world_size"),
py::arg("place"),
py::arg("group_id") = 0, py::arg("group_id") = 0,
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
...@@ -1254,14 +1252,14 @@ void BindDistributed(py::module *m) { ...@@ -1254,14 +1252,14 @@ void BindDistributed(py::module *m) {
std::shared_ptr<distributed::ProcessGroupCustom>>( std::shared_ptr<distributed::ProcessGroupCustom>>(
*m, "ProcessGroupCustom", ProcessGroup) *m, "ProcessGroupCustom", ProcessGroup)
.def(py::init<const std::shared_ptr<distributed::Store> &, .def(py::init<const std::shared_ptr<distributed::Store> &,
const std::string &,
int, int,
int, int,
const platform::CustomPlace &,
int>(), int>(),
py::arg("store"), py::arg("store"),
py::arg("device_type"),
py::arg("rank"), py::arg("rank"),
py::arg("world_size"), py::arg("world_size"),
py::arg("place"),
py::arg("group_id") = 0, py::arg("group_id") = 0,
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
...@@ -1275,12 +1273,10 @@ void BindDistributed(py::module *m) { ...@@ -1275,12 +1273,10 @@ void BindDistributed(py::module *m) {
.def(py::init<const std::shared_ptr<distributed::Store> &, .def(py::init<const std::shared_ptr<distributed::Store> &,
int, int,
int, int,
const platform::XPUPlace &,
int>(), int>(),
py::arg("store"), py::arg("store"),
py::arg("rank"), py::arg("rank"),
py::arg("world_size"), py::arg("world_size"),
py::arg("place"),
py::arg("group_id") = 0, py::arg("group_id") = 0,
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
#endif #endif
...@@ -1303,14 +1299,12 @@ void BindDistributed(py::module *m) { ...@@ -1303,14 +1299,12 @@ void BindDistributed(py::module *m) {
.def(py::init<const std::shared_ptr<paddle::distributed::Store> &, .def(py::init<const std::shared_ptr<paddle::distributed::Store> &,
int, int,
int, int,
const platform::CPUPlace &,
int, int,
std::shared_ptr<GlooOptions> &>(), std::shared_ptr<GlooOptions> &>(),
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def(py::init([](const std::shared_ptr<paddle::distributed::Store> &store, .def(py::init([](const std::shared_ptr<paddle::distributed::Store> &store,
int rank, int rank,
int world_size, int world_size,
const platform::CPUPlace &place,
int gid) { int gid) {
auto opts = GlooOptions::create(); auto opts = GlooOptions::create();
char *ifname = getenv(GLOO_SOCKET_IFNAME_ENV.c_str()); char *ifname = getenv(GLOO_SOCKET_IFNAME_ENV.c_str());
...@@ -1321,12 +1315,11 @@ void BindDistributed(py::module *m) { ...@@ -1321,12 +1315,11 @@ void BindDistributed(py::module *m) {
opts->device = ProcessGroupGloo::createDefaultDevice(); opts->device = ProcessGroupGloo::createDefaultDevice();
} }
return std::make_shared<ProcessGroupGloo>( return std::make_shared<ProcessGroupGloo>(
store, rank, world_size, place, gid, opts); store, rank, world_size, gid, opts);
}), }),
py::arg("store"), py::arg("store"),
py::arg("rank"), py::arg("rank"),
py::arg("world_size"), py::arg("world_size"),
py::arg("place"),
py::arg("group_id") = 0, py::arg("group_id") = 0,
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def_static("create_default_device", .def_static("create_default_device",
......
...@@ -152,17 +152,15 @@ def _new_process_group_impl( ...@@ -152,17 +152,15 @@ def _new_process_group_impl(
genv = _get_global_env() genv = _get_global_env()
assert backend in _valid_backend_list, "Unsupported backend: %s." % backend assert backend in _valid_backend_list, "Unsupported backend: %s." % backend
if backend == "gloo": if backend == "gloo":
place = core.CPUPlace() pg = core.ProcessGroupGloo(store, rank, world_size, group_id)
pg = core.ProcessGroupGloo(store, rank, world_size, place, group_id)
elif backend == "nccl": elif backend == "nccl":
place = core.CUDAPlace(genv.device_id) pg = core.ProcessGroupNCCL(store, rank, world_size, group_id)
pg = core.ProcessGroupNCCL(store, rank, world_size, place, group_id)
elif backend == "xccl": elif backend == "xccl":
place = core.CustomPlace(genv.device_type, genv.device_id) pg = core.ProcessGroupCustom(
pg = core.ProcessGroupCustom(store, rank, world_size, place, group_id) store, genv.device_type, rank, world_size, group_id
)
elif backend == "bkcl": elif backend == "bkcl":
place = core.XPUPlace(genv.device_id) pg = core.ProcessGroupBKCL(store, rank, world_size, group_id)
pg = core.ProcessGroupBKCL(store, rank, world_size, place, group_id)
return pg return pg
...@@ -192,7 +190,12 @@ def barrier(group=None): ...@@ -192,7 +190,12 @@ def barrier(group=None):
if in_dygraph_mode(): if in_dygraph_mode():
group = _get_default_group() if group is None else group group = _get_default_group() if group is None else group
task = group.process_group.barrier() place = paddle.fluid.framework._current_expected_place()
if isinstance(place, paddle.fluid.core.CPUPlace):
task = group.process_group.barrier()
else:
device_id = place.get_device_id()
task = group.process_group.barrier(device_id)
task.wait() task.wait()
return return
......
...@@ -30,9 +30,9 @@ def init_process_group(strategy=None): ...@@ -30,9 +30,9 @@ def init_process_group(strategy=None):
store = paddle.fluid.core.TCPStore("127.0.0.1", 6173, is_master, nranks) store = paddle.fluid.core.TCPStore("127.0.0.1", 6173, is_master, nranks)
pg_group = core.ProcessGroupCustom( pg_group = core.ProcessGroupCustom(
store, store,
ParallelEnv().device_type,
rank, rank,
nranks, nranks,
paddle.CustomPlace(ParallelEnv().device_type, ParallelEnv().device_id),
) )
return pg_group return pg_group
...@@ -51,9 +51,8 @@ class TestProcessGroupFp32(unittest.TestCase): ...@@ -51,9 +51,8 @@ class TestProcessGroupFp32(unittest.TestCase):
def test_create_process_group_xccl(self): def test_create_process_group_xccl(self):
with _test_eager_guard(): with _test_eager_guard():
paddle.set_device( device_id = paddle.distributed.ParallelEnv().dev_id
'custom_cpu:%d' % paddle.distributed.ParallelEnv().dev_id paddle.set_device('custom_cpu:%d' % device_id)
)
pg = init_process_group() pg = init_process_group()
...@@ -119,11 +118,11 @@ class TestProcessGroupFp32(unittest.TestCase): ...@@ -119,11 +118,11 @@ class TestProcessGroupFp32(unittest.TestCase):
# test barrier # test barrier
# rank 0 # rank 0
if pg.rank() == 0: if pg.rank() == 0:
task = pg.barrier() task = pg.barrier(device_id)
task.wait() task.wait()
# rank 1 # rank 1
else: else:
task = pg.barrier() task = pg.barrier(device_id)
task.wait() task.wait()
print("test barrier api ok\n") print("test barrier api ok\n")
......
...@@ -42,8 +42,7 @@ class TestProcessGroupFp32(unittest.TestCase): ...@@ -42,8 +42,7 @@ class TestProcessGroupFp32(unittest.TestCase):
store = paddle.fluid.core.TCPStore( store = paddle.fluid.core.TCPStore(
"127.0.0.1", 6272, is_master, nranks, 30 "127.0.0.1", 6272, is_master, nranks, 30
) )
place = paddle.fluid.core.CPUPlace() pg = paddle.fluid.core.ProcessGroupGloo(store, rank, nranks)
pg = paddle.fluid.core.ProcessGroupGloo(store, rank, nranks, place)
# test allreduce sum # test allreduce sum
# rank 0 # rank 0
......
...@@ -44,9 +44,8 @@ class TestProcessGroupFp32(unittest.TestCase): ...@@ -44,9 +44,8 @@ class TestProcessGroupFp32(unittest.TestCase):
def test_create_process_group_nccl(self): def test_create_process_group_nccl(self):
with _test_eager_guard(): with _test_eager_guard():
paddle.set_device( device_id = paddle.distributed.ParallelEnv().dev_id
'gpu:%d' % paddle.distributed.ParallelEnv().dev_id paddle.set_device('gpu:%d' % device_id)
)
pg = init_process_group() pg = init_process_group()
print("rank:", pg.rank(), "size:", pg.size(), "name:", pg.name()) print("rank:", pg.rank(), "size:", pg.size(), "name:", pg.name())
...@@ -170,10 +169,10 @@ class TestProcessGroupFp32(unittest.TestCase): ...@@ -170,10 +169,10 @@ class TestProcessGroupFp32(unittest.TestCase):
# test barrier # test barrier
# rank 0 # rank 0
if pg.rank() == 0: if pg.rank() == 0:
dist.barrier() pg.barrier(device_id)
# rank 1 # rank 1
else: else:
task = pg.barrier() task = pg.barrier(device_id)
task.wait() task.wait()
print("test barrier api ok\n") print("test barrier api ok\n")
......
...@@ -20,7 +20,6 @@ import sys ...@@ -20,7 +20,6 @@ import sys
import paddle import paddle
from paddle.fluid.framework import _test_eager_guard from paddle.fluid.framework import _test_eager_guard
from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.fluid.dygraph.parallel import ParallelEnv
import paddle.distributed as dist
def init_process_group(strategy=None): def init_process_group(strategy=None):
...@@ -45,9 +44,8 @@ class TestProcessGroupFp32(unittest.TestCase): ...@@ -45,9 +44,8 @@ class TestProcessGroupFp32(unittest.TestCase):
def test_create_process_group_bkcl(self): def test_create_process_group_bkcl(self):
with _test_eager_guard(): with _test_eager_guard():
paddle.set_device( device_id = paddle.distributed.ParallelEnv().dev_id
'xpu:%d' % paddle.distributed.ParallelEnv().dev_id paddle.set_device('xpu:%d' % device_id)
)
pg = init_process_group() pg = init_process_group()
sys.stdout.write( sys.stdout.write(
...@@ -108,10 +106,10 @@ class TestProcessGroupFp32(unittest.TestCase): ...@@ -108,10 +106,10 @@ class TestProcessGroupFp32(unittest.TestCase):
# test barrier # test barrier
# rank 0 # rank 0
if pg.rank() == 0: if pg.rank() == 0:
dist.barrier() pg.barrier(device_id)
# rank 1 # rank 1
else: else:
task = pg.barrier() task = pg.barrier(device_id)
task.wait() task.wait()
sys.stdout.write("rank {}: test barrier api ok\n".format(pg.rank())) sys.stdout.write("rank {}: test barrier api ok\n".format(pg.rank()))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册