提交 90ed96c3 编写于 作者: D dingminghui 提交者: jackzhang235

feat(mlu): support multiqueue

上级 1435d179
......@@ -21,5 +21,11 @@ namespace lite {
thread_local xdnn::Context* Context<TargetType::kXPU>::_tls_raw_ctx{nullptr};
#endif
#ifdef LITE_WITH_MLU
int Context<TargetType::kMLU>::next_queue_id_{0};
std::map<int, int> Context<TargetType::kMLU>::queue_id_map_;
std::mutex Context<TargetType::kMLU>::map_mutex_;
#endif
} // namespace lite
} // namespace paddle
......@@ -26,6 +26,7 @@
#ifdef LITE_WITH_MLU
#include <cnml.h>
#include <cnrt.h>
#include <mutex> // NOLINT
#include "lite/backends/mlu/mlu_utils.h"
#endif
#ifdef LITE_WITH_XPU
......@@ -230,11 +231,11 @@ class Context<TargetType::kMLU> {
void InitOnce() {}
MLUContext& operator=(const MLUContext& ctx) {
this->Init(ctx.device_id_, ctx.exec_queue_id_, ctx.io_queue_id_);
this->Init(ctx.device_id_, ctx.exec_queue_id_);
return *this;
}
void Init(int dev_id, int exec_queue_id = 0, int io_queue_id = 0) {
void Init(int dev_id, int exec_queue_id = 0) {
CHECK_GT(devs.size(), 0UL)
<< "Env is not initialized or current target is not exit!";
if (dev_id >= static_cast<int>(devs.size())) {
......@@ -245,21 +246,19 @@ class Context<TargetType::kMLU> {
device_id_ = dev_id;
}
SetMluDevice(device_id_);
if (io_queue_id >= devs[dev_id].max_queue()) {
LOG(WARNING) << "data queue index exceeds the maximum queue number, "
"set to default qeueu(0)!";
io_queue_id = 0;
}
if (exec_queue_id >= devs[dev_id].max_queue()) {
LOG(WARNING) << "exec queue index exceeds the maximum queue number, "
"set to default qeueu(0)!";
exec_queue_id = 0;
// get queue id from map
std::unique_lock<std::mutex> lk(map_mutex_);
if (queue_id_map_.find(exec_queue_id) == queue_id_map_.end()) {
queue_id_map_[exec_queue_id] =
next_queue_id_++ % devs[dev_id].max_queue();
}
io_queue_ = devs[dev_id].io_queues()[io_queue_id];
exec_queue_ = devs[dev_id].exec_queues()[exec_queue_id];
exec_queue_id_ = queue_id_map_[exec_queue_id];
VLOG(4) << "pick mlu queue id: " << exec_queue_id_;
lk.unlock();
exec_queue_id_ = exec_queue_id;
io_queue_id_ = io_queue_id;
io_queue_ = devs[dev_id].io_queues()[exec_queue_id_];
exec_queue_ = devs[dev_id].exec_queues()[exec_queue_id_];
}
void CopySharedTo(MLUContext* ctx) { ctx->forward_param_ = forward_param_; }
......@@ -287,10 +286,12 @@ class Context<TargetType::kMLU> {
std::string name() const { return "MLUContext"; }
private:
static int next_queue_id_;
static std::map<int, int> queue_id_map_;
static std::mutex map_mutex_;
int device_id_;
// overall information
int exec_queue_id_;
int io_queue_id_;
cnrtQueue_t io_queue_;
cnrtQueue_t exec_queue_;
......@@ -444,7 +445,7 @@ class ContextScheduler {
case TARGET(kMLU): {
int dev_id = TargetWrapper<TargetType::kMLU>::GetCurDevice();
auto& context = ctx->As<MLUContext>();
context.Init(dev_id);
context.Init(dev_id, exec_stream_id);
kernel_contexts_[TargetType::kMLU].As<MLUContext>().CopySharedTo(
&context);
LOG(INFO) << "New Context for MLU";
......
......@@ -26,6 +26,8 @@ namespace paddle {
namespace lite {
namespace mir {
static thread_local int g_stream_id = 0;
Node* MLUPostprocessPass::InsertCastBefore(const std::string& op_type,
const std::string& cast_arg_name,
SSAGraph* graph,
......@@ -97,8 +99,8 @@ Node* MLUPostprocessPass::InsertCastBefore(const std::string& op_type,
// we pick the kernel
cast_inst->AsStmt(op_type, std::move(selected_kernels), cast_op);
auto& stmt = cast_inst->AsStmt();
stmt.picked_kernel().SetContext(
ContextScheduler::Global().NewContext(stmt.picked_kernel().target()));
stmt.picked_kernel().SetContext(ContextScheduler::Global().NewContext(
stmt.picked_kernel().target(), g_stream_id));
break;
}
}
......@@ -182,8 +184,8 @@ Node* MLUPostprocessPass::InsertCastAfter(const std::string& op_type,
// we pick the kernel
cast_inst->AsStmt(op_type, std::move(selected_kernels), cast_op);
auto& stmt = cast_inst->AsStmt();
stmt.picked_kernel().SetContext(
ContextScheduler::Global().NewContext(stmt.picked_kernel().target()));
stmt.picked_kernel().SetContext(ContextScheduler::Global().NewContext(
stmt.picked_kernel().target(), g_stream_id));
break;
}
}
......@@ -620,6 +622,7 @@ void MLUPostprocessPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
}
#endif
g_stream_id = static_cast<int>(reinterpret_cast<int64_t>(graph.get()));
// insert io_copy, layout and precision cast of subgraph's inputs and outputs
for (auto& node : graph->mutable_nodes()) {
if (node.IsStmt() && node.AsStmt().op_type() == "subgraph") {
......
......@@ -44,6 +44,10 @@ class RuntimeContextAssignPass : public StmtPass {
inst.picked_kernel().SetContext(ContextScheduler::Global().NewContext(
inst.picked_kernel().target()));
}
#elif LITE_WITH_MLU
inst.picked_kernel().SetContext(ContextScheduler::Global().NewContext(
inst.picked_kernel().target(),
static_cast<int>(reinterpret_cast<int64_t>(graph.get()))));
#else
int stream_id = inst.stream_id_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册