未验证 提交 ae00f428 编写于 作者: W Wen Sun 提交者: GitHub

Support both use_calc_stream and sync_op in send recv APIs (#46023)

上级 92e1f64b
......@@ -134,24 +134,56 @@ class ProcessGroup {
"ProcessGroup%s does not support send", GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Send(
std::vector<phi::DenseTensor>&, int, bool) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support send with sync_op flag",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Recv(
std::vector<phi::DenseTensor>& tensors, int) { // NOLINT
std::vector<phi::DenseTensor>&, int) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support receive", GetBackendName()));
"ProcessGroup%s does not support recv", GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Send_Partial(phi::DenseTensor&,
int,
int,
int) { // NOLINT
virtual std::shared_ptr<ProcessGroup::Task> Recv(
std::vector<phi::DenseTensor>&, int, bool) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support send", GetBackendName()));
"ProcessGroup%s does not support recv with sync_op flag",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Send_Partial(
phi::DenseTensor&, // NOLINT
int,
int,
int) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support send_partial", GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Send_Partial(
phi::DenseTensor&, int, int, int, bool) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support send_partial with sync_op flag",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Recv_Partial(
phi::DenseTensor& tensors, int, int, int) { // NOLINT
phi::DenseTensor&, // NOLINT
int,
int,
int) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support receive", GetBackendName()));
"ProcessGroup%s does not support recv_partial", GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Recv_Partial(
phi::DenseTensor&, int, int, int, bool) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support recv_partial with sync_op flag",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> AllGather(
......
......@@ -51,6 +51,17 @@ std::shared_ptr<ProcessGroupNCCL::NCCLTask> ProcessGroupNCCL::CreateTask(
places, rank, comm_type, inputs);
}
std::shared_ptr<ProcessGroupNCCL::NCCLTask> ProcessGroupNCCL::CreateTask(
const std::vector<Place>& places,
int rank,
CommType comm_type,
const std::vector<phi::DenseTensor>& inputs,
bool is_sync,
bool use_calc_stream) {
return std::make_shared<ProcessGroupNCCL::NCCLTask>(
places, rank, comm_type, inputs, is_sync, use_calc_stream);
}
ProcessGroupNCCL::NCCLTask::NCCLTask(
const std::vector<Place>& places,
int rank,
......@@ -264,10 +275,12 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
auto& nccl_comms = places_to_ncclcomm_[key];
SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]);
if (!use_calc_stream) {
SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]);
}
auto task = std::make_shared<ProcessGroupNCCL::NCCLTask>(
places, rank_, comm_type, inputs, sync_op, use_calc_stream);
auto task =
CreateTask(places, rank_, comm_type, inputs, sync_op, use_calc_stream);
platform::CUDADeviceGuard cuda_guard;
......@@ -406,6 +419,78 @@ void ProcessGroupNCCL::Collective(const phi::DenseTensor* in,
cuda_guard.SetDevice(places[0]);
}
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
std::vector<phi::DenseTensor>& tensors,
Fn fn,
int dst_rank,
CommType op_type,
bool sync_op,
bool use_calc_stream) {
const auto& places = GetPlaceList(tensors);
const auto& key = GetKeyFromPlaces(places);
{
std::lock_guard<std::mutex> lock(mutex_);
if (places_to_ncclcomm_.find(key) == places_to_ncclcomm_.end()) {
CreateNCCLManagerCache(key, places);
}
}
auto& nccl_comms = places_to_ncclcomm_[key];
if (!use_calc_stream) {
SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]);
}
auto task =
CreateTask(places, rank_, op_type, tensors, sync_op, use_calc_stream);
platform::CUDADeviceGuard cuda_guard;
if (FLAGS_use_stream_safe_cuda_allocator) {
for (size_t i = 0; i < tensors.size(); ++i) {
cuda_guard.SetDevice(places[i]);
gpuStream_t nccl_stream;
if (use_calc_stream) {
nccl_stream =
static_cast<phi::GPUContext*>(
platform::DeviceContextPool::Instance().Get(places[i]))
->stream();
} else {
nccl_stream = places_to_ctx_[key][i]->stream();
}
memory::RecordStream(tensors[i].Holder(), nccl_stream);
}
}
{
platform::NCCLGroupGuard nccl_guard;
for (size_t i = 0; i < tensors.size(); ++i) {
cuda_guard.SetDevice(places[i]);
gpuStream_t nccl_stream;
if (use_calc_stream) {
nccl_stream =
static_cast<phi::GPUContext*>(
platform::DeviceContextPool::Instance().Get(places[i]))
->stream();
} else {
nccl_stream = places_to_ctx_[key][i]->stream();
}
fn(tensors[i], nccl_comms[i]->GetNcclComm(), nccl_stream, dst_rank);
}
}
if (!use_calc_stream) {
for (size_t i = 0; i < tensors.size(); ++i) {
cuda_guard.SetDevice(places[i]);
task->control_events_[i].Record(*places_to_ctx_[key][i]);
}
}
return task;
}
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
std::vector<phi::DenseTensor>& tensors,
......@@ -617,6 +702,34 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send(
return task;
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send(
std::vector<phi::DenseTensor>& tensors,
int dst_rank,
bool sync_op,
bool use_calc_stream) {
CheckTensorsInDifferentDevices(tensors, static_cast<size_t>(GetSize()));
auto task = PointToPoint(
tensors,
[&](phi::DenseTensor& input,
ncclComm_t comm,
const gpuStream_t& stream,
int dst_rank) {
return platform::dynload::ncclSend(
input.data(),
input.numel(),
platform::ToNCCLDataType(input.dtype()),
dst_rank,
comm,
stream);
},
dst_rank,
CommType::SEND,
sync_op,
use_calc_stream);
return task;
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
std::vector<phi::DenseTensor>& tensors, int src_rank) {
CheckTensorsInDifferentDevices(tensors, static_cast<size_t>(GetSize()));
......@@ -640,6 +753,34 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
return task;
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
std::vector<phi::DenseTensor>& tensors,
int src_rank,
bool sync_op,
bool use_calc_stream) {
CheckTensorsInDifferentDevices(tensors, static_cast<size_t>(GetSize()));
auto task = PointToPoint(
tensors,
[&](phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream,
int src_rank) {
return platform::dynload::ncclRecv(
output.data(),
output.numel(),
platform::ToNCCLDataType(output.dtype()),
src_rank,
comm,
stream);
},
src_rank,
CommType::RECV,
sync_op,
use_calc_stream);
return task;
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send_Partial(
phi::DenseTensor& tensors, int dst_rank, int offset, int length) {
// CheckTensorsInDifferentDevices(tensors, static_cast<size_t>(GetSize()));
......@@ -647,10 +788,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send_Partial(
phi::DenseTensor flatten_tensor;
flatten_tensor.ShareDataWith(tensors).Resize({tensors.numel()});
phi::DenseTensor shared_input = flatten_tensor.Slice(offset, offset + length);
std::vector<phi::DenseTensor> shared_tensors;
shared_tensors.push_back(shared_input);
std::vector<phi::DenseTensor> shared_tensors{
flatten_tensor.Slice(offset, offset + length)};
auto task = PointToPoint(
shared_tensors,
......@@ -671,16 +810,49 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send_Partial(
return task;
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send_Partial(
phi::DenseTensor& tensors,
int dst_rank,
int offset,
int length,
bool sync_op,
bool use_calc_stream) {
phi::DenseTensor flatten_tensor;
flatten_tensor.ShareDataWith(tensors).Resize({tensors.numel()});
std::vector<phi::DenseTensor> shared_tensors{
flatten_tensor.Slice(offset, offset + length)};
auto task = PointToPoint(
shared_tensors,
[&](phi::DenseTensor& input,
ncclComm_t comm,
const gpuStream_t& stream,
int dst_rank) {
return platform::dynload::ncclSend(
input.data(),
input.numel(),
platform::ToNCCLDataType(input.dtype()),
dst_rank,
comm,
stream);
},
dst_rank,
CommType::SEND,
sync_op,
use_calc_stream);
return task;
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv_Partial(
phi::DenseTensor& tensors, int src_rank, int offset, int length) {
// phi::DenseTensor shared_input = tensors.Slice(offset, offset+length);
phi::DenseTensor flatten_tensor;
flatten_tensor.ShareDataWith(tensors).Resize({tensors.numel()});
phi::DenseTensor shared_input = flatten_tensor.Slice(offset, offset + length);
std::vector<phi::DenseTensor> shared_tensors;
shared_tensors.push_back(shared_input);
std::vector<phi::DenseTensor> shared_tensors{
flatten_tensor.Slice(offset, offset + length)};
auto task = PointToPoint(
shared_tensors,
......@@ -701,6 +873,40 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv_Partial(
return task;
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv_Partial(
phi::DenseTensor& tensors,
int src_rank,
int offset,
int length,
bool sync_op,
bool use_calc_stream) {
phi::DenseTensor flatten_tensor;
flatten_tensor.ShareDataWith(tensors).Resize({tensors.numel()});
std::vector<phi::DenseTensor> shared_tensors{
flatten_tensor.Slice(offset, offset + length)};
auto task = PointToPoint(
shared_tensors,
[&](phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream,
int src_rank) {
return platform::dynload::ncclRecv(
output.data(),
output.numel(),
platform::ToNCCLDataType(output.dtype()),
src_rank,
comm,
stream);
},
src_rank,
CommType::RECV,
sync_op,
use_calc_stream);
return task;
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) {
......
......@@ -60,7 +60,7 @@ class ProcessGroupNCCL : public ProcessGroupStream {
int rank,
CommType comm_type,
const std::vector<phi::DenseTensor>& inputs,
bool is_sync,
bool sync_op,
bool use_calc_stream);
bool IsCompleted();
......@@ -122,19 +122,47 @@ class ProcessGroupNCCL : public ProcessGroupStream {
std::shared_ptr<ProcessGroup::Task> Send(
std::vector<phi::DenseTensor>& tensors, int dst_rank) override;
std::shared_ptr<ProcessGroup::Task> Send(
std::vector<phi::DenseTensor>& tensors,
int dst_rank,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Recv(
std::vector<phi::DenseTensor>& tensors, int src_rank) override;
std::shared_ptr<ProcessGroup::Task> Recv(
std::vector<phi::DenseTensor>& tensors,
int src_rank,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Send_Partial(phi::DenseTensor& tensors,
int dst_rank,
int offset,
int length) override;
std::shared_ptr<ProcessGroup::Task> Send_Partial(
phi::DenseTensor& tensors,
int dst_rank,
int offset,
int length,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Recv_Partial(phi::DenseTensor& tensors,
int src_rank,
int offset,
int length) override;
std::shared_ptr<ProcessGroup::Task> Recv_Partial(
phi::DenseTensor& tensors,
int src_rank,
int offset,
int length,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) override;
......@@ -180,9 +208,17 @@ class ProcessGroupNCCL : public ProcessGroupStream {
virtual std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(
std::vector<Place> places,
int rank,
CommType opType,
CommType op_type,
const std::vector<phi::DenseTensor>& inputs);
virtual std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(
const std::vector<Place>& places,
int rank,
CommType op_type,
const std::vector<phi::DenseTensor>& inputs,
bool sync_op,
bool use_calc_stream);
protected:
std::shared_ptr<Store> store_;
std::shared_ptr<NCCLCommManager> nccl_comm_;
......@@ -233,6 +269,15 @@ class ProcessGroupNCCL : public ProcessGroupStream {
int dst_rank,
CommType op_type);
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> PointToPoint(
std::vector<phi::DenseTensor>& tensors, // NOLINT
Fn fn,
int dst_rank,
CommType op_type,
bool sync_op,
bool use_calc_stream);
void CreateNCCLManagerCache(const std::string& places_key,
const std::vector<Place>& places);
......
......@@ -45,5 +45,89 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllReduce(
"ProcessGroup%s does not support do allreduce", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Send(
std::vector<phi::DenseTensor>& tensors, int dst_rank, bool sync_op) {
return Send(tensors,
dst_rank,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Send(
std::vector<phi::DenseTensor>& tensors,
int dst_rank,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do send", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Send_Partial(
phi::DenseTensor& tensors,
int dst_rank,
int offset,
int length,
bool sync_op) {
return Send_Partial(tensors,
dst_rank,
offset,
length,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Send_Partial(
phi::DenseTensor& tensors,
int dst_rank,
int offset,
int length,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do send_partial", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Recv(
std::vector<phi::DenseTensor>& tensors, int src_rank, bool sync_op) {
return Recv(tensors,
src_rank,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Recv(
std::vector<phi::DenseTensor>& tensors,
int src_rank,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do recv", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Recv_Partial(
phi::DenseTensor& tensors,
int src_rank,
int offset,
int length,
bool sync_op) {
return Recv_Partial(tensors,
src_rank,
offset,
length,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Recv_Partial(
phi::DenseTensor& tensors,
int src_rank,
int offset,
int length,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do recv_partial", GetBackendName()));
}
} // namespace distributed
} // namespace paddle
......@@ -66,6 +66,58 @@ class ProcessGroupStream : public ProcessGroup {
const AllreduceOptions& options,
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Send(
std::vector<phi::DenseTensor>& tensors, // NOLINT
int dst_rank,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> Send(
std::vector<phi::DenseTensor>& tensors, // NOLINT
int dst_rank,
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Send_Partial(
phi::DenseTensor& tensors, // NOLINT
int dst_rank,
int offset,
int length,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> Send_Partial(
phi::DenseTensor& tensors, // NOLINT
int dst_rank,
int offset,
int length,
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Recv(
std::vector<phi::DenseTensor>& tensors, // NOLINT
int src_rank,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> Recv(
std::vector<phi::DenseTensor>& tensors, // NOLINT
int src_rank,
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Recv_Partial(
phi::DenseTensor& tensors, // NOLINT
int src_rank,
int offset,
int length,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> Recv_Partial(
phi::DenseTensor& tensors, // NOLINT
int src_rank,
int offset,
int length,
bool sync_op,
bool use_calc_stream);
};
} // namespace distributed
......
......@@ -196,6 +196,23 @@ void BindDistributed(py::module *m) {
py::arg("dst"),
py::call_guard<py::gil_scoped_release>())
.def(
"send",
[](distributed::ProcessGroup &self,
py::handle py_tensor,
int dst,
bool sync_op) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.Send(tensors, dst, sync_op);
},
py::arg("tensor"),
py::arg("dst"),
py::arg("sync_op"),
py::call_guard<py::gil_scoped_release>())
.def(
"send_partial",
[](distributed::ProcessGroup &self,
......@@ -217,6 +234,30 @@ void BindDistributed(py::module *m) {
py::arg("id"),
py::call_guard<py::gil_scoped_release>())
.def(
"send_partial",
[](distributed::ProcessGroup &self,
py::handle py_tensor,
int dst_rank,
int nranks,
int rank_id,
bool sync_op) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
int numel = (*dense).numel();
int send_numel = numel / nranks;
int offset = send_numel * rank_id;
return self.Send_Partial(
*dense, dst_rank, offset, send_numel, sync_op);
},
py::arg("tensor"),
py::arg("dst"),
py::arg("num"),
py::arg("id"),
py::arg("sync_op"),
py::call_guard<py::gil_scoped_release>())
.def(
"recv",
[](distributed::ProcessGroup &self,
......@@ -232,6 +273,23 @@ void BindDistributed(py::module *m) {
py::arg("src"),
py::call_guard<py::gil_scoped_release>())
.def(
"recv",
[](distributed::ProcessGroup &self,
py::handle py_tensor,
int src,
bool sync_op) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.Recv(tensors, src, sync_op);
},
py::arg("tensor"),
py::arg("src"),
py::arg("sync_op"),
py::call_guard<py::gil_scoped_release>())
.def(
"recv_partial",
[](distributed::ProcessGroup &self,
......@@ -253,6 +311,30 @@ void BindDistributed(py::module *m) {
py::arg("id"),
py::call_guard<py::gil_scoped_release>())
.def(
"recv_partial",
[](distributed::ProcessGroup &self,
py::handle py_tensor,
int src_rank,
int nranks,
int rank_id,
bool sync_op) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
int numel = (*dense).numel();
int recv_numel = numel / nranks;
int offset = recv_numel * rank_id;
return self.Recv_Partial(
*dense, src_rank, offset, recv_numel, sync_op);
},
py::arg("tensor"),
py::arg("src"),
py::arg("num"),
py::arg("id"),
py::arg("sync_op"),
py::call_guard<py::gil_scoped_release>())
.def(
"all_gather",
[](distributed::ProcessGroup &self,
......@@ -427,6 +509,94 @@ void BindDistributed(py::module *m) {
},
py::arg("tensor"),
py::arg("op"),
py::call_guard<py::gil_scoped_release>())
.def(
"send_on_calc_stream",
[](distributed::ProcessGroupStream &self,
py::handle py_tensor,
int dst) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.Send(tensors,
dst,
/*sync_op*/ true,
/*use_calc_stream*/ true);
},
py::arg("tensor"),
py::arg("dst"),
py::call_guard<py::gil_scoped_release>())
.def(
"send_partial_on_calc_stream",
[](distributed::ProcessGroupStream &self,
py::handle py_tensor,
int dst_rank,
int nranks,
int rank_id) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
int numel = (*dense).numel();
int send_numel = numel / nranks;
int offset = send_numel * rank_id;
return self.Send_Partial(*dense,
dst_rank,
offset,
send_numel,
/*sync_op*/ true,
/*use_calc_stream*/ true);
},
py::arg("tensor"),
py::arg("dst"),
py::arg("num"),
py::arg("id"),
py::call_guard<py::gil_scoped_release>())
.def(
"recv_on_calc_stream",
[](distributed::ProcessGroupStream &self,
py::handle py_tensor,
int src) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.Recv(tensors,
src,
/*sync_op*/ true,
/*use_calc_stream*/ true);
},
py::arg("tensor"),
py::arg("src"),
py::call_guard<py::gil_scoped_release>())
.def(
"recv_partial_on_calc_stream",
[](distributed::ProcessGroupStream &self,
py::handle py_tensor,
int src_rank,
int nranks,
int rank_id) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
int numel = (*dense).numel();
int recv_numel = numel / nranks;
int offset = recv_numel * rank_id;
return self.Recv_Partial(*dense,
src_rank,
offset,
recv_numel,
/*sync_op*/ true,
/*use_calc_stream*/ true);
},
py::arg("tensor"),
py::arg("src"),
py::arg("num"),
py::arg("id"),
py::call_guard<py::gil_scoped_release>());
#if defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL)
......
......@@ -13,5 +13,7 @@
# limitations under the License.
from .all_reduce import all_reduce
from .send import send
from .recv import recv
__all__ = ["all_reduce"]
__all__ = ["all_reduce", "send", "recv"]
......@@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.distributed.collective as collective
import paddle.fluid.framework as framework
from ...collective import _get_default_group, _get_reduce_op, ReduceOp
def _all_reduce_in_dygraph(tensor, op, group, sync_op, use_calc_stream):
op_type = _get_reduce_op(op, "all_reduce")
group = _get_default_group() if group is None else group
op_type = collective._get_reduce_op(op, "all_reduce")
group = collective._get_default_group() if group is None else group
if use_calc_stream:
return group.process_group.allreduce_on_calc_stream(tensor, op_type)
......@@ -30,7 +30,7 @@ def _all_reduce_in_dygraph(tensor, op, group, sync_op, use_calc_stream):
def all_reduce(tensor,
op=ReduceOp.SUM,
op=collective.ReduceOp.SUM,
group=None,
sync_op=True,
use_calc_stream=False):
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.distributed.collective as collective
import paddle.fluid.framework as framework
def _recv_in_dygraph(tensor, src, group, sync_op, use_calc_stream):
group = collective._get_default_group() if group is None else group
if use_calc_stream:
return group.process_group.recv_on_calc_stream(tensor, src)
task = group.process_group.recv(tensor, src, sync_op)
if sync_op:
task.wait()
return task
def recv(tensor, src=0, group=None, sync_op=True, use_calc_stream=False):
"""
Receive a tensor from the source device.
Args:
tensor (Tensor): The tensor to receive. Support float16, float32, float64, int32, int64, int8, uint8 or bool as its data type.
src (int, optional): Rank of the source device. If none is given, use `0` as default.
group (Group, optional): Communicate in which group. If none is given, use the global group as default.
sync_op (bool, optional): Indicate whether the communication is sync or not. If none is given, use true as default.
use_calc_stream (bool, optional): Indicate whether the communication is done on calculation stream. If none is given, use false as default. This
option is designed for high performance demand, be careful to turn it on except you are clearly know its meaning.
Returns:
Return a task object.
Warning:
This API only supports the dygraph mode now.
Examples:
.. code-block:: python
# required: distributed
import paddle
import paddle.distributed as dist
dist.init_parallel_env()
local_rank = dist.get_rank()
if local_rank == 0:
data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
task = dist.stream.send(data, dst=1, sync_op=False)
else:
data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
task = dist.stream.recv(data, src=0, sync_op=False)
task.wait()
out = data.numpy()
# [[4, 5, 6], [4, 5, 6]
"""
if group is not None and not group.is_member():
raise RuntimeError(
"The group should not be None and all ranks which invoke this operation should be the member of this group."
)
if not sync_op and use_calc_stream:
raise RuntimeError(
"use_calc_stream can only be True in sync op behavior.")
if framework.in_dygraph_mode():
return _recv_in_dygraph(tensor, src, group, sync_op, use_calc_stream)
raise RuntimeError(
"paddle.distributed.stream.recv is only supported in dygraph mode now.")
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.distributed.collective as collective
import paddle.fluid.framework as framework
def _send_in_dygraph(tensor, dst, group, sync_op, use_calc_stream):
group = collective._get_default_group() if group is None else group
if use_calc_stream:
return group.process_group.send_on_calc_stream(tensor, dst)
task = group.process_group.send(tensor, dst, sync_op)
if sync_op:
task.wait()
return task
def send(tensor, dst=0, group=None, sync_op=True, use_calc_stream=False):
"""
Send a tensor to the destination device.
Args:
tensor (Tensor): The tensor to send. Support float16, float32, float64, int32, int64, int8, uint8 or bool as its data type.
dst (int, optional): Rank of the destination device. If none is given, use `0` as default.
group (Group, optional): Communicate in which group. If none is given, use the global group as default.
sync_op (bool, optional): Indicate whether the communication is sync or not. If none is given, use true as default.
use_calc_stream (bool, optional): Indicate whether the communication is done on calculation stream. If none is given, use false as default. This
option is designed for high performance demand, be careful to turn it on except you are clearly know its meaning.
Returns:
Return a task object.
Warning:
This API only supports the dygraph mode now.
Examples:
.. code-block:: python
# required: distributed
import paddle
import paddle.distributed as dist
dist.init_parallel_env()
local_rank = dist.get_rank()
if local_rank == 0:
data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
task = dist.stream.send(data, dst=1, sync_op=False)
else:
data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
task = dist.stream.recv(data, src=0, sync_op=False)
task.wait()
out = data.numpy()
# [[4, 5, 6], [4, 5, 6]
"""
if group is not None and not group.is_member():
raise RuntimeError(
"The group should not be None and all ranks which invoke this operation should be the member of this group."
)
if not sync_op and use_calc_stream:
raise RuntimeError(
"use_calc_stream can only be True in sync op behavior.")
if framework.in_dygraph_mode():
return _send_in_dygraph(tensor, dst, group, sync_op, use_calc_stream)
raise RuntimeError(
"paddle.distributed.stream.send is only supported in dygraph mode now.")
......@@ -268,17 +268,26 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
py_test_modules(
test_eager_dist_api MODULES test_eager_dist_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_eager_dist_api PROPERTIES TIMEOUT "120" LABELS
"RUN_TYPE=DIST")
test_communication_stream_allreduce_api MODULES
test_communication_stream_allreduce_api ENVS
"PYTHONPATH=..:${PADDLE_BINARY_DIR}/python;http_proxy=;https_proxy=")
set_tests_properties(test_communication_stream_allreduce_api
PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST")
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
py_test_modules(
test_new_group_api MODULES test_new_group_api ENVS
test_communication_stream_sendrecv_api MODULES
test_communication_stream_sendrecv_api ENVS
"PYTHONPATH=..:${PADDLE_BINARY_DIR}/python;http_proxy=;https_proxy=")
set_tests_properties(test_communication_stream_sendrecv_api
PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST")
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
py_test_modules(
test_eager_dist_api MODULES test_eager_dist_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_new_group_api PROPERTIES TIMEOUT "120" LABELS
"RUN_TYPE=DIST")
set_tests_properties(test_eager_dist_api PROPERTIES TIMEOUT "120" LABELS
"RUN_TYPE=DIST")
endif()
if((WITH_GPU
OR WITH_ROCM
......@@ -298,11 +307,10 @@ if((WITH_GPU
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
py_test_modules(
test_communication_stream_allreduce_api MODULES
test_communication_stream_allreduce_api ENVS
"PYTHONPATH=..:${PADDLE_BINARY_DIR}/python;http_proxy=;https_proxy=")
set_tests_properties(test_communication_stream_allreduce_api
PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST")
test_new_group_api MODULES test_new_group_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_new_group_api PROPERTIES TIMEOUT "120" LABELS
"RUN_TYPE=DIST")
endif()
if((WITH_ROCM OR WITH_GPU) AND (LINUX))
bash_test_modules(
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
import paddle
import paddle.distributed as dist
import paddle.fluid as fluid
import test_collective_api_base as test_collective_base
import test_communication_api_base as test_base
class StreamSendRecvTestCase():
def __init__(self):
self._sync_op = eval(os.getenv("sync_op"))
self._use_calc_stream = eval(os.getenv("use_calc_stream"))
self._backend = os.getenv("backend")
self._shape = eval(os.getenv("shape"))
self._dtype = os.getenv("dtype")
self._seeds = eval(os.getenv("seeds"))
if self._backend not in ["nccl", "gloo"]:
raise NotImplementedError(
"Only support nccl and gloo as the backend for now.")
os.environ["PADDLE_DISTRI_BACKEND"] = self._backend
def run_test_case(self):
dist.init_parallel_env()
test_data_list = []
for seed in self._seeds:
test_data_list.append(
test_collective_base.create_test_data(shape=self._shape,
dtype=self._dtype,
seed=seed))
rank = dist.get_rank()
tensor = paddle.to_tensor(test_data_list[rank])
if rank == 0:
task = dist.stream.send(tensor,
dst=1,
sync_op=self._sync_op,
use_calc_stream=self._use_calc_stream)
else:
task = dist.stream.recv(tensor,
src=0,
sync_op=self._sync_op,
use_calc_stream=self._use_calc_stream)
if not self._sync_op:
task.wait()
result = test_data_list[0]
assert np.allclose(tensor, result, rtol=1e-05, atol=1e-05)
if __name__ == "__main__":
StreamSendRecvTestCase().run_test_case()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import test_communication_api_base as test_base
class TestCommunicationStreamSendRecvAPI(test_base.CommunicationTestDistBase):
def setUp(self):
super(TestCommunicationStreamSendRecvAPI, self).setUp(num_of_devices=2,
timeout=120)
self._default_envs = {
"backend": "nccl",
"shape": "(100, 200)",
"dtype": "float32",
"seeds": str(self._seeds)
}
self._changeable_envs = {
"sync_op": ["True", "False"],
"use_calc_stream": ["True", "False"]
}
def test_sendrecv_stream(self):
envs_list = test_base.gen_product_envs_list(self._default_envs,
self._changeable_envs)
for envs in envs_list:
if eval(envs["use_calc_stream"]) and not eval(envs["sync_op"]):
continue
self.run_test_case("communication_stream_sendrecv_api_dygraph.py",
user_defined_envs=envs)
def tearDown(self):
super(TestCommunicationStreamSendRecvAPI, self).tearDown()
if __name__ == '__main__':
unittest.main()
......@@ -32,8 +32,9 @@ test_collective_split_col_linear,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_
test_collective_split_embedding_none_divisible,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_split_row_linear,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_wait,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_communication_stream_allreduce_api,linux,gpu;rocm,120,DIST,,2,,PYTHONPATH=..;http_proxy=;https_proxy=,
test_communication_stream_sendrecv_api,linux,gpu;rocm,120,DIST,,2,,PYTHONPATH=..;http_proxy=;https_proxy=,
test_eager_dist_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_new_group_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_gen_nccl_id_op,,gpu;rocm;ASCEND;ASCEND_CL,,DIST,../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_communication_stream_allreduce_api,linux,gpu;rocm,120,DIST,,2,,PYTHONPATH=..;http_proxy=;https_proxy=,
test_new_group_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_world_size_and_rank,linux,rocm;gpu,120,DIST,test_world_size_and_rank.sh,2,,http_proxy=;https_proxy=,
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册