提交 90ed96c3 编写于 作者: D dingminghui 提交者: jackzhang235

feat(mlu): support multiqueue

上级 1435d179
...@@ -21,5 +21,11 @@ namespace lite { ...@@ -21,5 +21,11 @@ namespace lite {
thread_local xdnn::Context* Context<TargetType::kXPU>::_tls_raw_ctx{nullptr}; thread_local xdnn::Context* Context<TargetType::kXPU>::_tls_raw_ctx{nullptr};
#endif #endif
#ifdef LITE_WITH_MLU
int Context<TargetType::kMLU>::next_queue_id_{0};
std::map<int, int> Context<TargetType::kMLU>::queue_id_map_;
std::mutex Context<TargetType::kMLU>::map_mutex_;
#endif
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#ifdef LITE_WITH_MLU #ifdef LITE_WITH_MLU
#include <cnml.h> #include <cnml.h>
#include <cnrt.h> #include <cnrt.h>
#include <mutex> // NOLINT
#include "lite/backends/mlu/mlu_utils.h" #include "lite/backends/mlu/mlu_utils.h"
#endif #endif
#ifdef LITE_WITH_XPU #ifdef LITE_WITH_XPU
...@@ -230,11 +231,11 @@ class Context<TargetType::kMLU> { ...@@ -230,11 +231,11 @@ class Context<TargetType::kMLU> {
void InitOnce() {} void InitOnce() {}
MLUContext& operator=(const MLUContext& ctx) { MLUContext& operator=(const MLUContext& ctx) {
this->Init(ctx.device_id_, ctx.exec_queue_id_, ctx.io_queue_id_); this->Init(ctx.device_id_, ctx.exec_queue_id_);
return *this; return *this;
} }
void Init(int dev_id, int exec_queue_id = 0, int io_queue_id = 0) { void Init(int dev_id, int exec_queue_id = 0) {
CHECK_GT(devs.size(), 0UL) CHECK_GT(devs.size(), 0UL)
<< "Env is not initialized or current target is not exit!"; << "Env is not initialized or current target is not exit!";
if (dev_id >= static_cast<int>(devs.size())) { if (dev_id >= static_cast<int>(devs.size())) {
...@@ -245,21 +246,19 @@ class Context<TargetType::kMLU> { ...@@ -245,21 +246,19 @@ class Context<TargetType::kMLU> {
device_id_ = dev_id; device_id_ = dev_id;
} }
SetMluDevice(device_id_); SetMluDevice(device_id_);
if (io_queue_id >= devs[dev_id].max_queue()) {
LOG(WARNING) << "data queue index exceeds the maximum queue number, " // get queue id from map
"set to default qeueu(0)!"; std::unique_lock<std::mutex> lk(map_mutex_);
io_queue_id = 0; if (queue_id_map_.find(exec_queue_id) == queue_id_map_.end()) {
} queue_id_map_[exec_queue_id] =
if (exec_queue_id >= devs[dev_id].max_queue()) { next_queue_id_++ % devs[dev_id].max_queue();
LOG(WARNING) << "exec queue index exceeds the maximum queue number, "
"set to default qeueu(0)!";
exec_queue_id = 0;
} }
io_queue_ = devs[dev_id].io_queues()[io_queue_id]; exec_queue_id_ = queue_id_map_[exec_queue_id];
exec_queue_ = devs[dev_id].exec_queues()[exec_queue_id]; VLOG(4) << "pick mlu queue id: " << exec_queue_id_;
lk.unlock();
exec_queue_id_ = exec_queue_id; io_queue_ = devs[dev_id].io_queues()[exec_queue_id_];
io_queue_id_ = io_queue_id; exec_queue_ = devs[dev_id].exec_queues()[exec_queue_id_];
} }
void CopySharedTo(MLUContext* ctx) { ctx->forward_param_ = forward_param_; } void CopySharedTo(MLUContext* ctx) { ctx->forward_param_ = forward_param_; }
...@@ -287,10 +286,12 @@ class Context<TargetType::kMLU> { ...@@ -287,10 +286,12 @@ class Context<TargetType::kMLU> {
std::string name() const { return "MLUContext"; } std::string name() const { return "MLUContext"; }
private: private:
static int next_queue_id_;
static std::map<int, int> queue_id_map_;
static std::mutex map_mutex_;
int device_id_; int device_id_;
// overall information // overall information
int exec_queue_id_; int exec_queue_id_;
int io_queue_id_;
cnrtQueue_t io_queue_; cnrtQueue_t io_queue_;
cnrtQueue_t exec_queue_; cnrtQueue_t exec_queue_;
...@@ -444,7 +445,7 @@ class ContextScheduler { ...@@ -444,7 +445,7 @@ class ContextScheduler {
case TARGET(kMLU): { case TARGET(kMLU): {
int dev_id = TargetWrapper<TargetType::kMLU>::GetCurDevice(); int dev_id = TargetWrapper<TargetType::kMLU>::GetCurDevice();
auto& context = ctx->As<MLUContext>(); auto& context = ctx->As<MLUContext>();
context.Init(dev_id); context.Init(dev_id, exec_stream_id);
kernel_contexts_[TargetType::kMLU].As<MLUContext>().CopySharedTo( kernel_contexts_[TargetType::kMLU].As<MLUContext>().CopySharedTo(
&context); &context);
LOG(INFO) << "New Context for MLU"; LOG(INFO) << "New Context for MLU";
......
...@@ -26,6 +26,8 @@ namespace paddle { ...@@ -26,6 +26,8 @@ namespace paddle {
namespace lite { namespace lite {
namespace mir { namespace mir {
static thread_local int g_stream_id = 0;
Node* MLUPostprocessPass::InsertCastBefore(const std::string& op_type, Node* MLUPostprocessPass::InsertCastBefore(const std::string& op_type,
const std::string& cast_arg_name, const std::string& cast_arg_name,
SSAGraph* graph, SSAGraph* graph,
...@@ -97,8 +99,8 @@ Node* MLUPostprocessPass::InsertCastBefore(const std::string& op_type, ...@@ -97,8 +99,8 @@ Node* MLUPostprocessPass::InsertCastBefore(const std::string& op_type,
// we pick the kernel // we pick the kernel
cast_inst->AsStmt(op_type, std::move(selected_kernels), cast_op); cast_inst->AsStmt(op_type, std::move(selected_kernels), cast_op);
auto& stmt = cast_inst->AsStmt(); auto& stmt = cast_inst->AsStmt();
stmt.picked_kernel().SetContext( stmt.picked_kernel().SetContext(ContextScheduler::Global().NewContext(
ContextScheduler::Global().NewContext(stmt.picked_kernel().target())); stmt.picked_kernel().target(), g_stream_id));
break; break;
} }
} }
...@@ -182,8 +184,8 @@ Node* MLUPostprocessPass::InsertCastAfter(const std::string& op_type, ...@@ -182,8 +184,8 @@ Node* MLUPostprocessPass::InsertCastAfter(const std::string& op_type,
// we pick the kernel // we pick the kernel
cast_inst->AsStmt(op_type, std::move(selected_kernels), cast_op); cast_inst->AsStmt(op_type, std::move(selected_kernels), cast_op);
auto& stmt = cast_inst->AsStmt(); auto& stmt = cast_inst->AsStmt();
stmt.picked_kernel().SetContext( stmt.picked_kernel().SetContext(ContextScheduler::Global().NewContext(
ContextScheduler::Global().NewContext(stmt.picked_kernel().target())); stmt.picked_kernel().target(), g_stream_id));
break; break;
} }
} }
...@@ -620,6 +622,7 @@ void MLUPostprocessPass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -620,6 +622,7 @@ void MLUPostprocessPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
} }
#endif #endif
g_stream_id = static_cast<int>(reinterpret_cast<int64_t>(graph.get()));
// insert io_copy, layout and precision cast of subgraph's inputs and outputs // insert io_copy, layout and precision cast of subgraph's inputs and outputs
for (auto& node : graph->mutable_nodes()) { for (auto& node : graph->mutable_nodes()) {
if (node.IsStmt() && node.AsStmt().op_type() == "subgraph") { if (node.IsStmt() && node.AsStmt().op_type() == "subgraph") {
......
...@@ -44,6 +44,10 @@ class RuntimeContextAssignPass : public StmtPass { ...@@ -44,6 +44,10 @@ class RuntimeContextAssignPass : public StmtPass {
inst.picked_kernel().SetContext(ContextScheduler::Global().NewContext( inst.picked_kernel().SetContext(ContextScheduler::Global().NewContext(
inst.picked_kernel().target())); inst.picked_kernel().target()));
} }
#elif LITE_WITH_MLU
inst.picked_kernel().SetContext(ContextScheduler::Global().NewContext(
inst.picked_kernel().target(),
static_cast<int>(reinterpret_cast<int64_t>(graph.get()))));
#else #else
int stream_id = inst.stream_id_; int stream_id = inst.stream_id_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册