未验证 提交 f6b23d6d 编写于 作者: J jameszhang 提交者: GitHub

use default XPU stream for computing (#49806)

* revert to use default XPU stream for computing

XPUContext now has a null stream by default. If you want to use a separate stream
 (e.g. in async collective communication), you should create a dedicated XPUContext
and invoke its XPUContext::CreateStream()

* minor
上级 77376727
...@@ -136,6 +136,8 @@ void ProcessGroupBKCL::CreateBKCLEnvCache(const Place& place, ...@@ -136,6 +136,8 @@ void ProcessGroupBKCL::CreateBKCLEnvCache(const Place& place,
BKCLContext_t bkcl_comm; BKCLContext_t bkcl_comm;
BKCLCHECK(bkcl_init_rank(&bkcl_comm, GetRank(), GetSize(), &bkcl_id)); BKCLCHECK(bkcl_init_rank(&bkcl_comm, GetRank(), GetSize(), &bkcl_id));
comm_ctx->SetBkclContext(bkcl_comm); comm_ctx->SetBkclContext(bkcl_comm);
// comm context creates a separate XPU stream for communication
comm_ctx->CreateStream();
place_to_calc_ctx_[place_key] = calc_ctx; place_to_calc_ctx_[place_key] = calc_ctx;
place_to_comm_ctx_[place_key] = std::move(comm_ctx); place_to_comm_ctx_[place_key] = std::move(comm_ctx);
......
...@@ -183,6 +183,7 @@ class XPUDeviceContext : public phi::XPUContext { ...@@ -183,6 +183,7 @@ class XPUDeviceContext : public phi::XPUContext {
virtual ~XPUDeviceContext(); virtual ~XPUDeviceContext();
Eigen::DefaultDevice* eigen_device() const { return nullptr; } Eigen::DefaultDevice* eigen_device() const { return nullptr; }
xpuStream stream() const { return XPUContext::x_context()->xpu_stream; } xpuStream stream() const { return XPUContext::x_context()->xpu_stream; }
void CreateStream() { XPUContext::CreateStream(); }
}; };
template <> template <>
......
...@@ -61,11 +61,13 @@ struct XPUContext::Impl { ...@@ -61,11 +61,13 @@ struct XPUContext::Impl {
~Impl() { ~Impl() {
if (owned_ && context_ != nullptr) { if (owned_ && context_ != nullptr) {
backends::xpu::XPUDeviceGuard guard(place_.GetDeviceId()); backends::xpu::XPUDeviceGuard guard(place_.GetDeviceId());
// manually destroy XPUStream here until xpu::api integrates this work
// into Context dtor
xpu_wait(context_->xpu_stream); xpu_wait(context_->xpu_stream);
xpu_stream_destroy(context_->xpu_stream); if (context_->xpu_stream) {
context_->xpu_stream = nullptr; // manually destroy XPUStream here until xpu::api integrates this work
// into Context dtor
xpu_stream_destroy(context_->xpu_stream);
context_->xpu_stream = nullptr;
}
xpu::destroy_context(context_); xpu::destroy_context(context_);
context_ = nullptr; context_ = nullptr;
} }
...@@ -73,11 +75,7 @@ struct XPUContext::Impl { ...@@ -73,11 +75,7 @@ struct XPUContext::Impl {
const Place& GetPlace() const { return place_; } const Place& GetPlace() const { return place_; }
XPUStream stream() const { XPUStream stream() const { return context_->xpu_stream; }
auto s = context_->xpu_stream;
PD_CHECK(s != nullptr, "the xpu stream is nullptr.");
return s;
}
xpu::Context* GetXContext() const { xpu::Context* GetXContext() const {
PD_CHECK(context_ != nullptr, "the xpu context is nullptr."); PD_CHECK(context_ != nullptr, "the xpu context is nullptr.");
...@@ -103,13 +101,20 @@ struct XPUContext::Impl { ...@@ -103,13 +101,20 @@ struct XPUContext::Impl {
context_ = xpu::create_context(); context_ = xpu::create_context();
xpu_version_ = backends::xpu::get_xpu_version(place_.device); xpu_version_ = backends::xpu::get_xpu_version(place_.device);
SetL3Cache(); SetL3Cache();
PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_create(&context_->xpu_stream));
} }
void SetXContext(xpu::Context* context) { context_ = context; } void SetXContext(xpu::Context* context) { context_ = context; }
void SetBkclContext(xpu::BKCLContext_t context) { bkcl_context_ = context; } void SetBkclContext(xpu::BKCLContext_t context) { bkcl_context_ = context; }
void CreateStream() {
if (context_->xpu_stream) {
VLOG(3) << "xpu stream is already created for current context";
return;
}
PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_create(&context_->xpu_stream));
}
bool owned_{false}; bool owned_{false};
Place place_; Place place_;
backends::xpu::XPUVersion xpu_version_; backends::xpu::XPUVersion xpu_version_;
...@@ -153,6 +158,8 @@ void XPUContext::SetBkclContext(xpu::BKCLContext_t context) { ...@@ -153,6 +158,8 @@ void XPUContext::SetBkclContext(xpu::BKCLContext_t context) {
impl_->SetBkclContext(context); impl_->SetBkclContext(context);
} }
void XPUContext::CreateStream() { impl_->CreateStream(); }
void XPUContext::Init() { impl_->Init(); } void XPUContext::Init() { impl_->Init(); }
} // namespace phi } // namespace phi
...@@ -46,6 +46,7 @@ class XPUContext : public DeviceContext, ...@@ -46,6 +46,7 @@ class XPUContext : public DeviceContext,
// Return bkcl context. // Return bkcl context.
xpu::BKCLContext_t bkcl_context() const; xpu::BKCLContext_t bkcl_context() const;
void SetBkclContext(xpu::BKCLContext_t context); void SetBkclContext(xpu::BKCLContext_t context);
void CreateStream();
// Wait for all operations completion in the stream. // Wait for all operations completion in the stream.
void Wait() const override; void Wait() const override;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册