未验证 提交 bcec0dce 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] fix SyncDefaultStream for process_group_custom (#51618)

* [CustomDevice] fix SyncDefaultStream for process_group_custom

* update
上级 1e232e27
...@@ -38,8 +38,8 @@ void SyncDefaultStream( ...@@ -38,8 +38,8 @@ void SyncDefaultStream(
for (size_t i = 0; i < places.size(); ++i) { for (size_t i = 0; i < places.size(); ++i) {
auto* default_ctx = static_cast<platform::CustomDeviceContext*>( auto* default_ctx = static_cast<platform::CustomDeviceContext*>(
platform::DeviceContextPool::Instance().Get(places[i])); platform::DeviceContextPool::Instance().Get(places[i]));
cclEvents[i].Record(*dev_ctx[i]); cclEvents[i].Record(*default_ctx);
cclEvents[i].Block(*default_ctx); cclEvents[i].Block(*dev_ctx[i]);
} }
} }
...@@ -74,8 +74,7 @@ void ProcessGroupCustom::CustomTask::SynchronizeStreams() { ...@@ -74,8 +74,7 @@ void ProcessGroupCustom::CustomTask::SynchronizeStreams() {
auto* default_ctx = static_cast<platform::CustomDeviceContext*>( auto* default_ctx = static_cast<platform::CustomDeviceContext*>(
platform::DeviceContextPool::Instance().Get(places_[i])); platform::DeviceContextPool::Instance().Get(places_[i]));
phi::DeviceGuard guard(default_ctx->GetPlace()); phi::DeviceGuard guard(default_ctx->GetPlace());
phi::stream::Stream stream(default_ctx->GetPlace(), default_ctx->stream()); control_events_[i].Block(*default_ctx);
stream.WaitEvent(control_events_[i].GetCustomEvent());
} }
} }
...@@ -118,7 +117,7 @@ ProcessGroupCustom::ProcessGroupCustom( ...@@ -118,7 +117,7 @@ ProcessGroupCustom::ProcessGroupCustom(
int rank, int rank,
int size, int size,
int gid) int gid)
: ProcessGroupWithoutStream(rank, size, gid), : ProcessGroupWithStream(rank, size, gid),
store_(store), store_(store),
device_type_(device_type) {} device_type_(device_type) {}
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
#include "paddle/fluid/distributed/collective/custom_ccl_tools.h" #include "paddle/fluid/distributed/collective/custom_ccl_tools.h"
#include "paddle/fluid/distributed/collective/process_group.h" #include "paddle/fluid/distributed/collective/process_group.h"
#include "paddle/fluid/distributed/collective/process_group_without_stream.h" #include "paddle/fluid/distributed/collective/process_group_with_stream.h"
#include "paddle/fluid/platform/device/npu/npu_stream.h" #include "paddle/fluid/platform/device/npu/npu_stream.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -36,7 +36,7 @@ namespace distributed { ...@@ -36,7 +36,7 @@ namespace distributed {
using Place = paddle::platform::Place; using Place = paddle::platform::Place;
using CustomDeviceContext = paddle::platform::CustomDeviceContext; using CustomDeviceContext = paddle::platform::CustomDeviceContext;
class ProcessGroupCustom : public ProcessGroupWithoutStream { class ProcessGroupCustom : public ProcessGroupWithStream {
public: public:
class CustomTask : public ProcessGroup::Task, class CustomTask : public ProcessGroup::Task,
public std::enable_shared_from_this<CustomTask> { public std::enable_shared_from_this<CustomTask> {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册