未验证 提交 93c58390 编写于 作者: S ShenLiang 提交者: GitHub

[Distributed] Opt nccl connection by lazy initialization (#55005)

上级 51c414b6
......@@ -21,6 +21,7 @@
#include <vector>
#include "paddle/fluid/distributed/collective/types.h"
#include "paddle/fluid/distributed/collective/utils.h"
#include "paddle/fluid/eager/api/utils/tensor_utils.h" // NOTE: this header is required somewhere
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
......@@ -34,22 +35,6 @@ namespace distributed {
constexpr int kIgnoreId = -1;
enum class CommType : std::uint8_t {
BROADCAST = 0,
ALLREDUCE = 1,
ALLREDUCE_SPARSE = 2, // TODO(shenliang03): to support sparse in allreduce
REDUCE = 3,
ALLGATHER = 4,
GATHER = 5,
SCATTER = 6,
REDUCE_SCATTER = 7,
ALLTOALL = 8,
SEND = 9,
RECV = 10,
BARRIER = 11,
UNKNOWN = 100,
};
class ProcessGroup {
public:
class Task {
......@@ -405,68 +390,57 @@ class ProcessGroup {
// legacy APIs
// TODO(liyurui): This API will be moved later
virtual std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>& /* input tensors */, // NOLINT
std::vector<phi::DenseTensor>& /* output tensors */, // NOLINT
const AllreduceOptions& = AllreduceOptions()) {
PADDLE_THROW(phi::errors::InvalidArgument(
"ProcessGroup%s does not support allreduce", GetBackendName()));
std::vector<phi::DenseTensor>& inputs, // NOLINT
std::vector<phi::DenseTensor>& outputs, // NOLINT
const AllreduceOptions& options = AllreduceOptions()) {
return AllReduce(outputs.data(), inputs.front(), options, false);
}
virtual std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>& /* input tensors */, // NOLINT
std::vector<phi::DenseTensor>& /* output tensors */, // NOLINT
const AllreduceOptions&,
bool) {
PADDLE_THROW(phi::errors::InvalidArgument(
"ProcessGroup%s does not support allreduce with sync_op flag",
GetBackendName()));
std::vector<phi::DenseTensor>& inputs, // NOLINT
std::vector<phi::DenseTensor>& outputs, // NOLINT
const AllreduceOptions& options,
bool sync_op) {
return AllReduce(outputs.data(), inputs.front(), options, sync_op);
}
// TODO(sunyilun): methods below will be removed later
virtual std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<phi::DenseTensor>& /* input tensors */, // NOLINT
std::vector<phi::DenseTensor>& /* output tensors */, // NOLINT
const BroadcastOptions& = BroadcastOptions()) {
PADDLE_THROW(phi::errors::InvalidArgument(
"ProcessGroup%s does not support broadcast", GetBackendName()));
std::vector<phi::DenseTensor>& inputs, // NOLINT
std::vector<phi::DenseTensor>& outputs, // NOLINT
const BroadcastOptions& options = BroadcastOptions()) {
return Broadcast(outputs.data(), inputs.front(), options, false);
}
virtual std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<phi::DenseTensor>& /* input tensors */, // NOLINT
std::vector<phi::DenseTensor>& /* output tensors */, // NOLINT
const BroadcastOptions&,
bool) {
PADDLE_THROW(phi::errors::InvalidArgument(
"ProcessGroup%s does not support broadcast with sync_op flag",
GetBackendName()));
std::vector<phi::DenseTensor>& inputs, // NOLINT
std::vector<phi::DenseTensor>& outputs, // NOLINT
const BroadcastOptions& options,
bool sync_op) {
return Broadcast(outputs.data(), inputs.front(), options, sync_op);
}
virtual std::shared_ptr<ProcessGroup::Task> Send(
std::vector<phi::DenseTensor>&, int) { // NOLINT
PADDLE_THROW(phi::errors::InvalidArgument(
"ProcessGroup%s does not support send", GetBackendName()));
std::vector<phi::DenseTensor>& tensors, int dst_rank) { // NOLINT
return Send(tensors.front(), dst_rank, false);
}
virtual std::shared_ptr<ProcessGroup::Task> Recv(
std::vector<phi::DenseTensor>&, int) { // NOLINT
PADDLE_THROW(phi::errors::InvalidArgument(
"ProcessGroup%s does not support recv", GetBackendName()));
std::vector<phi::DenseTensor>& tensors, int src_rank) { // NOLINT
return Recv(&tensors.front(), src_rank, false);
}
virtual std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&) { // NOLINT
PADDLE_THROW(phi::errors::InvalidArgument(
"ProcessGroup%s does not support all_gather", GetBackendName()));
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors) { // NOLINT
return AllGather(out_tensors.data(), in_tensors.front(), false);
}
virtual std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&, // NOLINT
bool) {
PADDLE_THROW(phi::errors::InvalidArgument(
"ProcessGroup%s does not support all_gather with sync_op flag",
GetBackendName()));
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
bool sync_op) {
return AllGather(out_tensors.data(), in_tensors.front(), sync_op);
}
virtual std::shared_ptr<ProcessGroup::Task> AllToAll(
......@@ -477,19 +451,17 @@ class ProcessGroup {
}
virtual std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>& ins, // NOLINT
std::vector<phi::DenseTensor>& outs, // NOLINT
const ReduceOptions& opts) {
PADDLE_THROW(phi::errors::InvalidArgument(
"ProcessGroup%s does not support reduce", GetBackendName()));
return Reduce(outs.data(), ins.front(), opts, false);
}
virtual std::shared_ptr<ProcessGroup::Task> Scatter(
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&, // NOLINT
const ScatterOptions&) {
PADDLE_THROW(phi::errors::InvalidArgument(
"ProcessGroup%s does not support scatter", GetBackendName()));
std::vector<phi::DenseTensor>& ins, // NOLINT
std::vector<phi::DenseTensor>& outs, // NOLINT
const ScatterOptions& opts) {
return Scatter(outs.data(), ins.front(), opts, false);
}
protected:
......
......@@ -16,7 +16,6 @@
#include "paddle/fluid/distributed/collective/bkcl_tools.h"
#include "paddle/fluid/distributed/collective/common.h"
#include "paddle/fluid/distributed/collective/utils.h"
#include "paddle/fluid/platform/device/xpu/bkcl_helper.h"
#include "paddle/fluid/platform/device/xpu/xpu_info.h"
#include "paddle/phi/api/lib/utils/allocator.h"
......
......@@ -16,7 +16,6 @@
#include "paddle/fluid/distributed/collective/common.h"
#include "paddle/fluid/distributed/collective/custom_ccl_tools.h"
#include "paddle/fluid/distributed/collective/utils.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
......
......@@ -169,42 +169,6 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream {
ncclComm_t NCCLComm(const Place& place) const;
// TODO(liyurui): This API will be moved later
std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const AllreduceOptions& = AllreduceOptions()) override;
// TODO(sunyilun): methods below will be removed later
std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const BroadcastOptions& = BroadcastOptions()) override;
std::shared_ptr<ProcessGroup::Task> Send(
std::vector<phi::DenseTensor>& tensors, int dst_rank) override;
std::shared_ptr<ProcessGroup::Task> Recv(
std::vector<phi::DenseTensor>& tensors, int src_rank) override;
std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) override;
std::shared_ptr<ProcessGroup::Task> AllToAll(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) override;
std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<phi::DenseTensor>& tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ReduceOptions& opts) override;
std::shared_ptr<ProcessGroup::Task> Scatter(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ScatterOptions& opts) override;
private:
std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(const Place& place,
int rank,
......@@ -212,42 +176,32 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream {
bool sync_op,
bool use_calc_stream);
void BroadcastUniqueNCCLID(ncclUniqueId* nccl_id);
void BroadcastUniqueNCCLID(ncclUniqueId* nccl_id,
bool is_p2p_op = false,
const std::string& p2p_key = "",
int p2p_rank = 0);
void CreateNCCLEnvCache(const Place& place, const std::string& place_key);
void CreateNCCLEnvCache(const Place& place,
const std::string& place_key,
CommType comm_type,
int p2p_rank = 0);
void SyncCalcStream(const Place& place);
void SyncCalcStream(const Place& place, const std::string& place_key);
std::shared_ptr<ProcessGroup::Task> RunFnInNCCLEnv(
std::shared_ptr<ProcessGroup::Task> Collective(
std::function<void(ncclComm_t, gpuStream_t)> fn,
const phi::DenseTensor& tensor,
CommType comm_type,
bool sync_op,
bool use_calc_stream);
// TODO(sunyilun): methods below will be removed later
std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(
std::vector<Place> places,
int rank,
CommType op_type,
const std::vector<phi::DenseTensor>& inputs);
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> Collective(
std::vector<phi::DenseTensor>& inputs, // NOLINT
std::vector<phi::DenseTensor>& outputs, // NOLINT
Fn fn,
CommType op_type);
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> PointToPoint(
std::vector<phi::DenseTensor>& tensors, // NOLINT
Fn fn,
int dst_rank,
CommType op_type);
void CreateNCCLManagerCache(const std::string& places_key,
const std::vector<Place>& places);
std::shared_ptr<ProcessGroup::Task> Point2Point(
std::function<void(ncclComm_t, gpuStream_t, int)> fn,
int peer,
const phi::DenseTensor& tensor,
CommType comm_type,
bool sync_op,
bool use_calc_stream);
private:
std::shared_ptr<phi::distributed::Store> store_;
......@@ -260,7 +214,7 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream {
// TODO(sunyilun): attrs below will be removed later
std::mutex mutex_;
std::unordered_map<std::string, std::vector<phi::GPUContext*>> places_to_ctx_;
static uint64_t s_group_call_counter;
};
} // namespace distributed
......
......@@ -13,7 +13,6 @@
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace paddle {
......@@ -28,5 +27,29 @@ inline phi::DenseTensor GetPartialTensor(const phi::DenseTensor& tensor,
return tensor_flattened.Slice(offset, offset + numel);
}
enum class CommType : std::uint8_t {
BROADCAST = 0,
ALLREDUCE = 1,
ALLREDUCE_SPARSE = 2, // TODO(shenliang03): to support sparse in allreduce
REDUCE = 3,
ALLGATHER = 4,
GATHER = 5,
SCATTER = 6,
REDUCE_SCATTER = 7,
ALLTOALL = 8,
SEND = 9,
RECV = 10,
BARRIER = 11,
UNKNOWN = 100,
};
inline bool IsP2POP(CommType comm_type, bool is_batch_p2p = false) {
if (is_batch_p2p) {
return false;
} else {
return comm_type == CommType::SEND || comm_type == CommType::RECV;
}
}
} // namespace distributed
} // namespace paddle
......@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/global_gather_op.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/distributed/collective/process_group_nccl.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif
......@@ -221,7 +222,7 @@ struct GlobalGatherProcessGroupFunctor<phi::GPUContext, T> {
out->mutable_data<T>(out_dims, place);
for (auto i = 0; i < n_expert; ++i) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
distributed::ProcessGroupNCCL::GroupStart();
for (auto j = 0; j < nranks; ++j) {
int idx = i + j * n_expert;
if (cpu_global_count_data[idx]) {
......@@ -241,7 +242,7 @@ struct GlobalGatherProcessGroupFunctor<phi::GPUContext, T> {
/*sync_op*/ true);
}
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
distributed::ProcessGroupNCCL::GroupEnd();
}
#ifdef PADDLE_WITH_CUDA
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/global_scatter_op.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/distributed/collective/process_group_nccl.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif
......@@ -219,7 +220,7 @@ struct GlobalScatterProcessGroupFunctor<phi::GPUContext, T> {
out->mutable_data<T>(out_dims, place);
for (auto i = 0; i < n_expert; ++i) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
distributed::ProcessGroupNCCL::GroupStart();
for (auto j = 0; j < nranks; ++j) {
int idx = i + j * n_expert;
if (cpu_local_count_data[idx]) {
......@@ -239,7 +240,7 @@ struct GlobalScatterProcessGroupFunctor<phi::GPUContext, T> {
recv_ptr += cpu_global_count_data[idx];
}
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
distributed::ProcessGroupNCCL::GroupEnd();
}
#ifdef PADDLE_WITH_CUDA
......
......@@ -267,8 +267,8 @@ void BindDistributed(py::module *m) {
in_tensor.impl());
auto in_dense = *p_in_tensor;
auto *dev_ctx = self.GetDeviceContext(in_tensor.place());
auto task = self.AllGather(out_dense, in_dense, sync_op);
auto *dev_ctx = self.GetDeviceContext(in_tensor.place());
SplitTensor(*dev_ctx, *out_dense, &out_tensor_list);
task->UpdateWaitChain(*dev_ctx);
return task;
......@@ -322,8 +322,6 @@ void BindDistributed(py::module *m) {
auto in_dense = *p_in_tensor;
// in_tensor_list should not be empty
auto *dev_ctx =
self.GetDeviceContext(in_tensor_list.back().place());
int world_size = self.GetSize();
auto task =
self.AllToAll(out_dense,
......@@ -331,6 +329,8 @@ void BindDistributed(py::module *m) {
GetDefaultSplitSizes(*out_dense, world_size),
GetDefaultSplitSizes(in_dense, world_size),
sync_op);
auto *dev_ctx =
self.GetDeviceContext(in_tensor_list.back().place());
SplitTensor(*dev_ctx, *out_dense, &out_tensor_list);
task->UpdateWaitChain(*dev_ctx);
return task;
......@@ -544,11 +544,11 @@ void BindDistributed(py::module *m) {
in_tensor.impl());
auto in_dense = *p_in_tensor;
auto *dev_ctx =
self.GetDeviceContext(in_tensor.place(), use_calc_stream);
distributed::GatherOptions gather_opts{dst};
auto task = self.Gather(
out_dense, in_dense, gather_opts, sync_op, use_calc_stream);
auto *dev_ctx =
self.GetDeviceContext(in_tensor.place(), use_calc_stream);
SplitTensor(*dev_ctx, *out_dense, &out_tensor_list);
if (!use_calc_stream &&
dev_ctx->GetPlace() != platform::CPUPlace()) {
......@@ -584,8 +584,7 @@ void BindDistributed(py::module *m) {
opts.reduce_op = op;
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.AllReduce(tensors, tensors, opts);
return self.AllReduce(dense.get(), *dense, opts, false);
},
py::arg("tensor"),
py::arg("op") = distributed::ReduceOp::SUM,
......@@ -601,8 +600,7 @@ void BindDistributed(py::module *m) {
opts.source_rank = source_rank;
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.Broadcast(tensors, tensors, opts);
return self.Broadcast(dense.get(), *dense, opts, false);
},
py::arg("tensor"),
py::arg("source_rank"),
......@@ -616,8 +614,7 @@ void BindDistributed(py::module *m) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.Send(tensors, dst);
return self.Send(*dense, dst, false);
},
py::arg("tensor"),
py::arg("dst"),
......@@ -631,8 +628,7 @@ void BindDistributed(py::module *m) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.Recv(tensors, src);
return self.Recv(dense.get(), src, false);
},
py::arg("tensor"),
py::arg("src"),
......@@ -649,9 +645,7 @@ void BindDistributed(py::module *m) {
in_tensor.impl());
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> in_tensors = {*in_dense};
std::vector<phi::DenseTensor> out_tensors = {*out_dense};
return self.AllGather(in_tensors, out_tensors);
return self.AllGather(out_dense.get(), *in_dense, false);
},
py::arg("in"),
py::arg("out"),
......@@ -697,9 +691,14 @@ void BindDistributed(py::module *m) {
in_tensor.impl());
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> in_tensors = {*in_dense};
std::vector<phi::DenseTensor> out_tensors = {*out_dense};
return self.AllToAll(in_tensors, out_tensors);
int world_size = self.GetSize();
return self.AllToAll(
out_dense.get(),
*in_dense,
GetDefaultSplitSizes(*out_dense, world_size),
GetDefaultSplitSizes(*in_dense, world_size),
false);
},
py::arg("in"),
py::arg("out"),
......@@ -743,8 +742,7 @@ void BindDistributed(py::module *m) {
opts.root_rank = dst;
auto dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.Reduce(tensors, tensors, opts);
return self.Reduce(dense.get(), *dense, opts, false);
},
py::arg("tensor"),
py::arg("dst"),
......@@ -765,9 +763,7 @@ void BindDistributed(py::module *m) {
in_tensor.impl());
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> in_tensors = {*in_dense};
std::vector<phi::DenseTensor> out_tensors = {*out_dense};
return self.Scatter(in_tensors, out_tensors, opts);
return self.Scatter(out_dense.get(), *in_dense, opts, false);
},
py::arg("in"),
py::arg("out"),
......@@ -790,12 +786,11 @@ void BindDistributed(py::module *m) {
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto in_dense = *p_in_tensor;
auto *dev_ctx = self.GetDeviceContext(in_tensor.place(), true);
auto task = self.AllGather(out_dense,
in_dense,
/*sync_op*/ true,
/*use_calc_stream*/ true);
auto *dev_ctx = self.GetDeviceContext(in_tensor.place(), true);
SplitTensor(*dev_ctx, *out_dense, &out_tensor_list);
return task;
},
......@@ -902,8 +897,6 @@ void BindDistributed(py::module *m) {
auto in_dense = *p_in_tensor;
// in_tensor_list should not be empty
auto *dev_ctx = self.GetDeviceContext(
in_tensor_list.back().place(), /*use_calc_stream*/ true);
int world_size = self.GetSize();
auto task =
self.AllToAll(out_dense,
......@@ -912,6 +905,8 @@ void BindDistributed(py::module *m) {
GetDefaultSplitSizes(in_dense, world_size),
/*sync_op*/ true,
/*use_calc_stream*/ true);
auto *dev_ctx = self.GetDeviceContext(
in_tensor_list.back().place(), /*use_calc_stream*/ true);
SplitTensor(*dev_ctx, *out_dense, &out_tensor_list);
return task;
},
......
......@@ -146,8 +146,25 @@ void InitTensorWithNumpyValue(TensorObject* self,
if (platform::is_cpu_place(place)) {
SetTensorFromPyArray<platform::CPUPlace>(impl_ptr, array, place, zero_copy);
} else if (platform::is_xpu_place(place)) {
#if defined(PADDLE_WITH_XPU)
phi::backends::xpu::SetXPUDeviceId(place.device);
VLOG(4) << "CurrentDeviceId: "
<< phi::backends::xpu::GetXPUCurrentDeviceId() << " from "
<< static_cast<int>(place.device);
#else
PADDLE_THROW(paddle::platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with XPU if use XPUPlace."));
#endif
SetTensorFromPyArray<platform::XPUPlace>(impl_ptr, array, place, zero_copy);
} else if (platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
phi::backends::gpu::SetDeviceId(place.device);
VLOG(4) << "CurrentDeviceId: " << phi::backends::gpu::GetCurrentDeviceId()
<< " from " << static_cast<int>(place.device);
#else
PADDLE_THROW(paddle::platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU if use CUDAPlace."));
#endif
SetTensorFromPyArray<platform::CUDAPlace>(
impl_ptr, array, place, zero_copy);
} else if (platform::is_cuda_pinned_place(place)) {
......@@ -156,6 +173,15 @@ void InitTensorWithNumpyValue(TensorObject* self,
} else if (platform::is_npu_place(place)) {
SetTensorFromPyArray<platform::NPUPlace>(impl_ptr, array, place, zero_copy);
} else if (platform::is_custom_place(place)) {
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
phi::DeviceManager::SetDevice(place);
VLOG(4) << "CurrentDeviceId: "
<< phi::DeviceManager::GetDevice(place.GetDeviceType()) << " from "
<< static_cast<int>(place.device);
#else
PADDLE_THROW(paddle::platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with CUSTOM_DEVICE if use CustomPlace."));
#endif
SetTensorFromPyArray<platform::CustomPlace>(
impl_ptr, array, place, zero_copy);
} else {
......
......@@ -236,12 +236,6 @@ def new_group(ranks=None, backend=None, timeout=_default_timeout):
# TODO: The method below is a new method for group management, will replace the previous
# three in the future.
_add_new_group(group)
# TODO(shenliang03): This is a temporary solution to solve the problem of
# hang caused by tcp
paddle.distributed.barrier(group=group)
if paddle.distributed.get_world_size() > 1:
paddle.distributed.barrier()
return group
if not backend:
......
......@@ -164,15 +164,25 @@ class HybridCommunicateGroup:
)
)
# create comm group for pipe parallel
self._pp_group, self._pp_comm_group = self._set_comm_group("pipe")
# NOTE(shenliang03): In pipeline parallel, we use batch_isend_irecv.
# if batch_isend_irecv is the first collective operation, all ranks of
# the pipeline group must participate in this call. In order to avoid
# this situation, we perform a collective communication in advance and
# create a communicator.
paddle.distributed.all_reduce(
paddle.zeros([1], dtype="int32"),
op=paddle.distributed.ReduceOp.SUM,
group=self._pp_comm_group,
)
# create comm group for data parallel
self._dp_group, self._dp_comm_group = self._set_comm_group("data")
# create comm group for model parallel
self._mp_group, self._mp_comm_group = self._set_comm_group("model")
# create comm group for pipe parallel
self._pp_group, self._pp_comm_group = self._set_comm_group("pipe")
# create comm group for sharding parallel
self._sharding_group, self._sharding_comm_group = self._set_comm_group(
"sharding"
......
......@@ -1115,8 +1115,6 @@ def init_parallel_env():
_set_group_map_backend(group, backend)
_add_new_group(group)
parallel_helper._set_parallel_ctx(True)
paddle.distributed.barrier(group=group)
return group
node_num = {i.split(":")[0] for i in parallel_env.trainer_endpoints}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册