提交 26a83ed1 编写于 作者: S sneaxiy

hack event

上级 5a9214d8
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
DECLARE_bool(benchmark); DECLARE_bool(benchmark);
DECLARE_bool(nccl_blocking_wait); DECLARE_bool(nccl_blocking_wait);
DECLARE_bool(use_stream_safe_cuda_allocator); DECLARE_bool(use_stream_safe_cuda_allocator);
DECLARE_bool(enable_process_group_event_record);
// set this flag to `true` and recompile to enable dynamic checks // set this flag to `true` and recompile to enable dynamic checks
constexpr bool FLAGS_enable_nccl_dynamic_check = false; constexpr bool FLAGS_enable_nccl_dynamic_check = false;
...@@ -94,7 +95,96 @@ ProcessGroupNCCL::ProcessGroupNCCL( ...@@ -94,7 +95,96 @@ ProcessGroupNCCL::ProcessGroupNCCL(
int rank, int rank,
int size, int size,
int gid) int gid)
: ProcessGroupWithStream(rank, size, gid), store_(store) {} : ProcessGroupWithStream(rank, size, gid), store_(store) {
events_.resize(phi::backends::gpu::GetGPUDeviceCount());
}
ProcessGroupNCCL::~ProcessGroupNCCL() {
for (const auto& e : events_) {
for (const auto& p : e.events) {
cudaEventDestroy(p.first);
cudaEventDestroy(p.second);
}
}
}
std::vector<double> ProcessGroupNCCL::GetEventTimeAndRelease() {
auto dev_id = phi::backends::gpu::GetCurrentDeviceId();
auto& e = events_[dev_id];
std::vector<double> times;
times.reserve(e.length);
for (size_t i = 0; i < e.length; ++i) {
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventSynchronize(e.events[i].second));
float ms = 0.0f;
PADDLE_ENFORCE_GPU_SUCCESS(
cudaEventElapsedTime(&ms, e.events[i].first, e.events[i].second));
times.push_back(ms);
}
e.length = 0;
return times;
}
gpuEvent_t ProcessGroupNCCL::RecordStartEventOnCalcStream() {
if (!FLAGS_enable_process_group_event_record) {
return nullptr;
}
auto dev_id = phi::backends::gpu::GetCurrentDeviceId();
auto stream = static_cast<phi::GPUContext*>(
GetDeviceContext(phi::GPUPlace(dev_id), true))
->stream();
return RecordStartEvent(stream);
}
void ProcessGroupNCCL::RecordEndEventOnCalcStream(gpuEvent_t event) {
if (event == nullptr) {
return;
}
auto dev_id = phi::backends::gpu::GetCurrentDeviceId();
auto stream = static_cast<phi::GPUContext*>(
GetDeviceContext(phi::GPUPlace(dev_id), true))
->stream();
RecordEndEvent(event, stream);
}
gpuEvent_t ProcessGroupNCCL::RecordStartEvent(gpuStream_t stream) {
if (!FLAGS_enable_process_group_event_record) {
return nullptr;
}
if (s_group_call_counter > 0) {
return nullptr;
}
auto dev_id = phi::backends::gpu::GetCurrentDeviceId();
gpuEvent_t start_event, end_event;
auto& e = events_[dev_id];
if (e.events.size() <= e.length) {
VLOG(10) << "Create new events when cached event pair number is "
<< e.events.size() << " , and used event pair number is "
<< e.length;
e.events.resize(e.events.size() + 1);
auto& p = e.events[e.length++];
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventCreate(&p.first));
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventCreate(&p.second));
start_event = p.first;
end_event = p.second;
} else {
start_event = e.events[e.length].first;
end_event = e.events[e.length].second;
++e.length;
}
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(start_event, stream));
return end_event;
}
void ProcessGroupNCCL::RecordEndEvent(gpuEvent_t event, gpuStream_t stream) {
if (event != nullptr) {
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(event, stream));
}
}
void ProcessGroupNCCL::GroupStart() { void ProcessGroupNCCL::GroupStart() {
NCCL_CHECK(phi::dynload::ncclGroupStart()); NCCL_CHECK(phi::dynload::ncclGroupStart());
...@@ -228,6 +318,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce( ...@@ -228,6 +318,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
<< ", sync_op: " << sync_op << ", sync_op: " << sync_op
<< ", use_calc_stream: " << use_calc_stream; << ", use_calc_stream: " << use_calc_stream;
auto event = RecordStartEvent(stream);
NCCL_CHECK( NCCL_CHECK(
phi::dynload::ncclAllReduce(in_tensor.data(), phi::dynload::ncclAllReduce(in_tensor.data(),
out_tensor->data(), out_tensor->data(),
...@@ -236,6 +327,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce( ...@@ -236,6 +327,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
ToNCCLRedType(opts.reduce_op), ToNCCLRedType(opts.reduce_op),
comm, comm,
stream)); stream));
RecordEndEvent(event, stream);
}, },
in_tensor, in_tensor,
CommType::ALLREDUCE, CommType::ALLREDUCE,
...@@ -310,6 +402,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll( ...@@ -310,6 +402,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
<< ", sync_op: " << sync_op << ", sync_op: " << sync_op
<< ", use_calc_stream: " << use_calc_stream; << ", use_calc_stream: " << use_calc_stream;
auto event = RecordStartEvent(stream);
GroupStart(); GroupStart();
for (auto i = 0; i < size_; i++) { for (auto i = 0; i < size_; i++) {
in_numel = in_size_each_rank[i] * in_row_size; in_numel = in_size_each_rank[i] * in_row_size;
...@@ -335,6 +428,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll( ...@@ -335,6 +428,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
out_offset += out_numel; out_offset += out_numel;
} }
GroupEnd(); GroupEnd();
RecordEndEvent(event, stream);
}, },
in_tensor, in_tensor,
CommType::ALLTOALL, CommType::ALLTOALL,
...@@ -396,6 +490,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast( ...@@ -396,6 +490,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
<< ", nranks: " << size_ << ", sync_op: " << sync_op << ", nranks: " << size_ << ", sync_op: " << sync_op
<< ", use_calc_stream: " << use_calc_stream; << ", use_calc_stream: " << use_calc_stream;
auto event = RecordStartEvent(stream);
NCCL_CHECK( NCCL_CHECK(
phi::dynload::ncclBroadcast(in_tensor.data(), phi::dynload::ncclBroadcast(in_tensor.data(),
out_tensor->data(), out_tensor->data(),
...@@ -404,6 +499,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast( ...@@ -404,6 +499,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
root, root,
comm, comm,
stream)); stream));
RecordEndEvent(event, stream);
}, },
in_tensor, in_tensor,
CommType::BROADCAST, CommType::BROADCAST,
...@@ -444,6 +540,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce( ...@@ -444,6 +540,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce(
<< ", nranks: " << size_ << ", sync_op: " << sync_op << ", nranks: " << size_ << ", sync_op: " << sync_op
<< ", use_calc_stream: " << use_calc_stream; << ", use_calc_stream: " << use_calc_stream;
auto event = RecordStartEvent(stream);
NCCL_CHECK( NCCL_CHECK(
phi::dynload::ncclReduce(in_tensor.data(), phi::dynload::ncclReduce(in_tensor.data(),
out_tensor->data(), out_tensor->data(),
...@@ -453,6 +550,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce( ...@@ -453,6 +550,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce(
opts.root_rank, opts.root_rank,
comm, comm,
stream)); stream));
RecordEndEvent(event, stream);
}, },
in_tensor, in_tensor,
CommType::REDUCE, CommType::REDUCE,
...@@ -492,6 +590,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::ReduceScatter( ...@@ -492,6 +590,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::ReduceScatter(
<< ", sync_op: " << sync_op << ", sync_op: " << sync_op
<< ", use_calc_stream: " << use_calc_stream; << ", use_calc_stream: " << use_calc_stream;
auto event = RecordStartEvent(stream);
NCCL_CHECK(phi::dynload::ncclReduceScatter( NCCL_CHECK(phi::dynload::ncclReduceScatter(
in_tensor.data(), in_tensor.data(),
out_tensor->data(), out_tensor->data(),
...@@ -500,6 +599,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::ReduceScatter( ...@@ -500,6 +599,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::ReduceScatter(
ToNCCLRedType(opts.reduce_op), ToNCCLRedType(opts.reduce_op),
comm, comm,
stream)); stream));
RecordEndEvent(event, stream);
}, },
in_tensor, in_tensor,
CommType::REDUCE_SCATTER, CommType::REDUCE_SCATTER,
...@@ -543,6 +643,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter( ...@@ -543,6 +643,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
if (rank_ == opts.root_rank) { if (rank_ == opts.root_rank) {
int64_t offset = 0; int64_t offset = 0;
phi::DenseTensor partial_tensor; phi::DenseTensor partial_tensor;
auto event = RecordStartEvent(stream);
GroupStart(); GroupStart();
for (auto i = 0; i < size_; i++) { for (auto i = 0; i < size_; i++) {
partial_tensor = GetPartialTensor(in_tensor, offset, numel); partial_tensor = GetPartialTensor(in_tensor, offset, numel);
...@@ -563,7 +664,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter( ...@@ -563,7 +664,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
comm, comm,
stream)); stream));
GroupEnd(); GroupEnd();
RecordEndEvent(event, stream);
} else { } else {
auto event = RecordStartEvent(stream);
NCCL_CHECK( NCCL_CHECK(
phi::dynload::ncclRecv(out_tensor->data(), phi::dynload::ncclRecv(out_tensor->data(),
numel, numel,
...@@ -571,6 +674,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter( ...@@ -571,6 +674,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
opts.root_rank, opts.root_rank,
comm, comm,
stream)); stream));
RecordEndEvent(event, stream);
} }
}, },
in_tensor, in_tensor,
...@@ -627,6 +731,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Gather( ...@@ -627,6 +731,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Gather(
<< ", nranks: " << size_ << ", sync_op: " << sync_op << ", nranks: " << size_ << ", sync_op: " << sync_op
<< ", use_calc_stream: " << use_calc_stream; << ", use_calc_stream: " << use_calc_stream;
auto event = RecordStartEvent(stream);
GroupStart(); GroupStart();
// root receive from all devices // root receive from all devices
if (rank_ == opts.root_rank) { if (rank_ == opts.root_rank) {
...@@ -649,6 +754,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Gather( ...@@ -649,6 +754,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Gather(
comm, comm,
stream)); stream));
GroupEnd(); GroupEnd();
RecordEndEvent(event, stream);
}; };
return Collective( return Collective(
gather_func, in_tensor, CommType::GATHER, sync_op, use_calc_stream); gather_func, in_tensor, CommType::GATHER, sync_op, use_calc_stream);
...@@ -688,12 +794,14 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv( ...@@ -688,12 +794,14 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
<< ", sync_op: " << sync_op << ", sync_op: " << sync_op
<< ", use_calc_stream: " << use_calc_stream; << ", use_calc_stream: " << use_calc_stream;
auto event = RecordStartEvent(stream);
NCCL_CHECK(phi::dynload::ncclRecv(tensor->data(), NCCL_CHECK(phi::dynload::ncclRecv(tensor->data(),
tensor->numel(), tensor->numel(),
phi::ToNCCLDataType(tensor->dtype()), phi::ToNCCLDataType(tensor->dtype()),
rank_in_group, rank_in_group,
comm, comm,
stream)); stream));
RecordEndEvent(event, stream);
}, },
src_rank, src_rank,
*tensor, *tensor,
...@@ -735,6 +843,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send( ...@@ -735,6 +843,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send(
<< ", sync_op: " << sync_op << ", sync_op: " << sync_op
<< ", use_calc_stream: " << use_calc_stream; << ", use_calc_stream: " << use_calc_stream;
auto event = RecordStartEvent(stream);
NCCL_CHECK(phi::dynload::ncclSend( NCCL_CHECK(phi::dynload::ncclSend(
tensor_maybe_partial.data(), tensor_maybe_partial.data(),
tensor_maybe_partial.numel(), tensor_maybe_partial.numel(),
...@@ -742,6 +851,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send( ...@@ -742,6 +851,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send(
rank_in_group, rank_in_group,
comm, comm,
stream)); stream));
RecordEndEvent(event, stream);
}, },
dst_rank, dst_rank,
tensor_maybe_partial, tensor_maybe_partial,
......
...@@ -77,6 +77,8 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream { ...@@ -77,6 +77,8 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream {
int size, int size,
int gid); int gid);
~ProcessGroupNCCL();
std::string GetBackendName() const override { return "NCCL"; } std::string GetBackendName() const override { return "NCCL"; }
phi::DeviceContext* GetDeviceContext(const Place& place) const override; phi::DeviceContext* GetDeviceContext(const Place& place) const override;
...@@ -169,6 +171,8 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream { ...@@ -169,6 +171,8 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream {
ncclComm_t NCCLComm(const Place& place) const; ncclComm_t NCCLComm(const Place& place) const;
std::vector<double> GetEventTimeAndRelease();
private: private:
std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(const Place& place, std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(const Place& place,
int rank, int rank,
...@@ -203,6 +207,16 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream { ...@@ -203,6 +207,16 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream {
bool sync_op, bool sync_op,
bool use_calc_stream); bool use_calc_stream);
public:
gpuEvent_t RecordStartEventOnCalcStream();
void RecordEndEventOnCalcStream(gpuEvent_t event);
private:
gpuEvent_t RecordStartEvent(gpuStream_t stream);
void RecordEndEvent(gpuEvent_t event, gpuStream_t stream);
private: private:
std::shared_ptr<phi::distributed::Store> store_; std::shared_ptr<phi::distributed::Store> store_;
...@@ -212,6 +226,13 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream { ...@@ -212,6 +226,13 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream {
std::unordered_map<std::string, std::unique_ptr<phi::GPUContext>> std::unordered_map<std::string, std::unique_ptr<phi::GPUContext>>
place_to_comm_ctx_; place_to_comm_ctx_;
struct Events {
std::vector<std::pair<gpuEvent_t, gpuEvent_t>> events;
size_t length{0};
};
std::vector<Events> events_;
// TODO(sunyilun): attrs below will be removed later // TODO(sunyilun): attrs below will be removed later
std::mutex mutex_; std::mutex mutex_;
static uint64_t s_group_call_counter; static uint64_t s_group_call_counter;
......
...@@ -391,6 +391,7 @@ void BindCudaStream(py::module *m_ptr) { ...@@ -391,6 +391,7 @@ void BindCudaStream(py::module *m_ptr) {
event.synchronize() event.synchronize()
)DOC") )DOC")
.def("elapsed_time", &paddle::platform::CudaEvent::ElapsedTime)
#endif #endif
.def( .def(
"__init__", "__init__",
......
...@@ -1224,7 +1224,24 @@ void BindDistributed(py::module *m) { ...@@ -1224,7 +1224,24 @@ void BindDistributed(py::module *m) {
py::arg("src"), py::arg("src"),
py::arg("num"), py::arg("num"),
py::arg("id"), py::arg("id"),
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>())
.def("_record_start_event_on_calc_stream",
[](distributed::ProcessGroup &self) -> uintptr_t {
PADDLE_THROW(phi::errors::Unimplemented(
"Unsupported _record_start_event_on_calc_stream method."));
})
.def("_record_end_event_on_calc_stream",
[](distributed::ProcessGroupNCCL &self, uintptr_t event) {
PADDLE_THROW(phi::errors::Unimplemented(
"Unsupported _record_end_event_on_calc_stream method."));
})
.def(
"_get_event_time_and_release",
[](distributed::ProcessGroup &self, bool accumulate) {
PADDLE_THROW(phi::errors::Unimplemented(
"Unsupported _get_event_time_and_release method."));
},
py::arg("accumulate") = true);
#if defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL)
py::class_<distributed::ProcessGroupNCCL, py::class_<distributed::ProcessGroupNCCL,
...@@ -1238,8 +1255,45 @@ void BindDistributed(py::module *m) { ...@@ -1238,8 +1255,45 @@ void BindDistributed(py::module *m) {
py::arg("group_id") = 0, py::arg("group_id") = 0,
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def_static("group_start", distributed::ProcessGroupNCCL::GroupStart) .def_static("group_start", distributed::ProcessGroupNCCL::GroupStart)
.def_static("group_end", distributed::ProcessGroupNCCL::GroupEnd); .def_static("group_end", distributed::ProcessGroupNCCL::GroupEnd)
.def(
"_record_start_event_on_calc_stream",
[](distributed::ProcessGroupNCCL &self) -> uintptr_t {
return reinterpret_cast<uintptr_t>(
self.RecordStartEventOnCalcStream());
},
py::call_guard<py::gil_scoped_release>())
.def(
"_record_end_event_on_calc_stream",
[](distributed::ProcessGroupNCCL &self, uintptr_t event) {
return self.RecordEndEventOnCalcStream(
reinterpret_cast<gpuEvent_t>(event));
},
py::call_guard<py::gil_scoped_release>())
.def(
"_get_event_time_and_release",
[](distributed::ProcessGroupNCCL &self,
bool accumulate) -> py::object {
std::vector<double> times;
double total_ms = 0.0;
{
py::gil_scoped_release release;
times = self.GetEventTimeAndRelease();
if (accumulate) {
total_ms = std::accumulate(times.begin(), times.end(), 0.0);
}
}
if (accumulate) {
return py::cast(total_ms);
} else {
py::list obj(times.size());
for (size_t i = 0; i < times.size(); ++i) {
obj[i] = py::cast(times[i]);
}
return obj;
}
},
py::arg("accumulate") = true);
#endif #endif
#if defined(PADDLE_WITH_MPI) #if defined(PADDLE_WITH_MPI)
......
...@@ -205,6 +205,13 @@ class CudaEvent { ...@@ -205,6 +205,13 @@ class CudaEvent {
} }
gpuEvent_t GetRawCudaEvent() { return event_; } gpuEvent_t GetRawCudaEvent() { return event_; }
float ElapsedTime(const CudaEvent &end_event) const {
float ms;
PADDLE_ENFORCE_GPU_SUCCESS(
cudaEventElapsedTime(&ms, event_, end_event.event_));
return ms;
}
private: private:
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
unsigned int flags_ = hipEventDefault; unsigned int flags_ = hipEventDefault;
......
...@@ -62,6 +62,14 @@ bool Event::Query() const { return device_->QueryEvent(this); } ...@@ -62,6 +62,14 @@ bool Event::Query() const { return device_->QueryEvent(this); }
void Event::Synchronize() const { device_->SynchronizeEvent(this); } void Event::Synchronize() const { device_->SynchronizeEvent(this); }
double Event::ElapsedTime(const Event& end_event) const {
auto s_event = static_cast<gpuEvent_t>(event_);
auto e_event = static_cast<gpuEvent_t>(end_event.event_);
float ms;
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventElapsedTime(&ms, s_event, e_event));
return ms;
}
const Place& Event::GetPlace() const { return place_; } const Place& Event::GetPlace() const { return place_; }
} // namespace event } // namespace event
......
...@@ -47,6 +47,7 @@ class Event { ...@@ -47,6 +47,7 @@ class Event {
void Record(const stream::Stream* stream); void Record(const stream::Stream* stream);
bool Query() const; bool Query() const;
void Synchronize() const; void Synchronize() const;
double ElapsedTime(const Event& end_event) const;
const Place& GetPlace() const; const Place& GetPlace() const;
private: private:
......
...@@ -1247,3 +1247,7 @@ PADDLE_DEFINE_EXPORTED_bool(use_shm_cache, ...@@ -1247,3 +1247,7 @@ PADDLE_DEFINE_EXPORTED_bool(use_shm_cache,
PADDLE_DEFINE_EXPORTED_string(tensor_operants_mode, PADDLE_DEFINE_EXPORTED_string(tensor_operants_mode,
"eager", "eager",
"Tensor operants mode"); "Tensor operants mode");
PADDLE_DEFINE_EXPORTED_bool(enable_process_group_event_record,
false,
"Whether to enable process group event record.");
...@@ -77,6 +77,9 @@ class Group: ...@@ -77,6 +77,9 @@ class Group:
else: else:
return -1 return -1
def _get_event_time_and_release(self, accumulate=True):
return self._pg._get_event_time_and_release(accumulate)
def __repr__(self): def __repr__(self):
debug_str = "rank: {}, nranks: {}, id: {}, ranks: ".format( debug_str = "rank: {}, nranks: {}, id: {}, ranks: ".format(
self.rank, self.nranks, self.id self.rank, self.nranks, self.id
......
...@@ -278,6 +278,7 @@ def batch_send_recv_on_calc_stream(p2p_op_list): ...@@ -278,6 +278,7 @@ def batch_send_recv_on_calc_stream(p2p_op_list):
return return
group = _get_global_group() if group is None else group group = _get_global_group() if group is None else group
backend = group.backend backend = group.backend
event = group.process_group._record_start_event_on_calc_stream()
with _with_batch_p2p_guard(backend): with _with_batch_p2p_guard(backend):
for p2p_op in p2p_op_list: for p2p_op in p2p_op_list:
op = p2p_op.op op = p2p_op.op
...@@ -287,6 +288,7 @@ def batch_send_recv_on_calc_stream(p2p_op_list): ...@@ -287,6 +288,7 @@ def batch_send_recv_on_calc_stream(p2p_op_list):
nranks = p2p_op.nranks nranks = p2p_op.nranks
rank_id = p2p_op.rank_id rank_id = p2p_op.rank_id
op(tensor, comm_group, peer, nranks, rank_id) op(tensor, comm_group, peer, nranks, rank_id)
group.process_group._record_end_event_on_calc_stream(event)
def _process_p2p_tuple_or_tensor( def _process_p2p_tuple_or_tensor(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册