未验证 提交 f6b23d6d 编写于 作者: J jameszhang 提交者: GitHub

use default XPU stream for computing (#49806)

* revert to use default XPU stream for computing

XPUContext now has a null stream by default. If you want to use a separate stream
 (e.g. in async collective communication), you should create a dedicated XPUContext
and invoke its XPUContext::CreateStream()

* minor
上级 77376727
......@@ -136,6 +136,8 @@ void ProcessGroupBKCL::CreateBKCLEnvCache(const Place& place,
BKCLContext_t bkcl_comm;
BKCLCHECK(bkcl_init_rank(&bkcl_comm, GetRank(), GetSize(), &bkcl_id));
comm_ctx->SetBkclContext(bkcl_comm);
// comm context creates a separate XPU stream for communication
comm_ctx->CreateStream();
place_to_calc_ctx_[place_key] = calc_ctx;
place_to_comm_ctx_[place_key] = std::move(comm_ctx);
......
......@@ -183,6 +183,7 @@ class XPUDeviceContext : public phi::XPUContext {
virtual ~XPUDeviceContext();
Eigen::DefaultDevice* eigen_device() const { return nullptr; }
xpuStream stream() const { return XPUContext::x_context()->xpu_stream; }
void CreateStream() { XPUContext::CreateStream(); }
};
template <>
......
......@@ -61,11 +61,13 @@ struct XPUContext::Impl {
~Impl() {
if (owned_ && context_ != nullptr) {
backends::xpu::XPUDeviceGuard guard(place_.GetDeviceId());
// manually destroy XPUStream here until xpu::api integrates this work
// into Context dtor
xpu_wait(context_->xpu_stream);
xpu_stream_destroy(context_->xpu_stream);
context_->xpu_stream = nullptr;
if (context_->xpu_stream) {
// manually destroy XPUStream here until xpu::api integrates this work
// into Context dtor
xpu_stream_destroy(context_->xpu_stream);
context_->xpu_stream = nullptr;
}
xpu::destroy_context(context_);
context_ = nullptr;
}
......@@ -73,11 +75,7 @@ struct XPUContext::Impl {
const Place& GetPlace() const { return place_; }
XPUStream stream() const {
auto s = context_->xpu_stream;
PD_CHECK(s != nullptr, "the xpu stream is nullptr.");
return s;
}
XPUStream stream() const { return context_->xpu_stream; }
xpu::Context* GetXContext() const {
PD_CHECK(context_ != nullptr, "the xpu context is nullptr.");
......@@ -103,13 +101,20 @@ struct XPUContext::Impl {
context_ = xpu::create_context();
xpu_version_ = backends::xpu::get_xpu_version(place_.device);
SetL3Cache();
PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_create(&context_->xpu_stream));
}
void SetXContext(xpu::Context* context) { context_ = context; }
void SetBkclContext(xpu::BKCLContext_t context) { bkcl_context_ = context; }
void CreateStream() {
if (context_->xpu_stream) {
VLOG(3) << "xpu stream is already created for current context";
return;
}
PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_create(&context_->xpu_stream));
}
bool owned_{false};
Place place_;
backends::xpu::XPUVersion xpu_version_;
......@@ -153,6 +158,8 @@ void XPUContext::SetBkclContext(xpu::BKCLContext_t context) {
impl_->SetBkclContext(context);
}
void XPUContext::CreateStream() { impl_->CreateStream(); }
void XPUContext::Init() { impl_->Init(); }
} // namespace phi
......@@ -46,6 +46,7 @@ class XPUContext : public DeviceContext,
// Return bkcl context.
xpu::BKCLContext_t bkcl_context() const;
void SetBkclContext(xpu::BKCLContext_t context);
void CreateStream();
// Wait for all operations completion in the stream.
void Wait() const override;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册