未验证 提交 75685a8e 编写于 作者: J Juncheng 提交者: GitHub

Refactor ExecKernel::bn_in_op2regst_desc_id to bn_in_op2blob_info (#3744)

Co-authored-by: Noneflow-bot <69100618+oneflow-bot@users.noreply.github.com>
上级 06be620f
......@@ -58,7 +58,6 @@ void Actor::Init(const JobDesc* job_desc, const TaskProto& task_proto,
for (const ExecNodeProto& node : task_proto.exec_sequence().exec_node()) {
ExecKernel ek;
ek.kernel = ConstructKernel(job_desc_, node.kernel_conf(), device_ctx_.get());
ek.bn_in_op2regst_desc_id = PbMap2HashMap(node.bn_in_op2regst_desc_id());
exec_kernel_vec_.push_back(std::move(ek));
}
......@@ -105,6 +104,7 @@ void Actor::Init(const JobDesc* job_desc, const TaskProto& task_proto,
is_naive_consumed_eord_ = false;
TakeOverNaiveConsumed(task_proto.consumed_regst_desc_id());
TakeOverNaiveProduced(task_proto.produced_regst_desc());
InitBnInOp2BlobInfo(task_proto);
VirtualActorInit(task_proto);
}
......@@ -174,6 +174,42 @@ void Actor::TakeOverNaiveProduced(const PbMap<std::string, RegstDescProto>& prod
}
}
void Actor::InitBnInOp2BlobInfo(const TaskProto& task_proto) {
for (int64_t i = 0; i < exec_kernel_vec_.size(); ++i) {
ExecKernel& ek = exec_kernel_vec_.at(i);
const ExecNodeProto& node = task_proto.exec_sequence().exec_node(i);
for (auto& pair : node.kernel_conf().op_attribute().arg_signature().bn_in_op2lbi()) {
BlobInfo blob_info;
blob_info.lbi = pair.second;
const std::string& bn = pair.first;
auto regst_desc_id_it = node.bn_in_op2regst_desc_id().find(bn);
if (regst_desc_id_it != node.bn_in_op2regst_desc_id().end()) {
const int64_t regst_desc_id = regst_desc_id_it->second;
blob_info.regst_desc_id = regst_desc_id;
const RtRegstDesc& regst_desc =
Global<RegstMgr>::Get()->RegstDesc4RegstDescId(regst_desc_id);
blob_info.ordinal = regst_desc.GetOrdinalForLbi(blob_info.lbi);
if (naive_produced_rs_.HasRegstDescId(regst_desc_id)) {
blob_info.rs = &naive_produced_rs_;
} else if (inplace_produced_rs_.HasRegstDescId(regst_desc_id)) {
blob_info.rs = &inplace_produced_rs_;
} else if (naive_consumed_rs_.HasRegstDescId(regst_desc_id)) {
blob_info.rs = &naive_consumed_rs_;
} else if (inplace_consumed_rs_.HasRegstDescId(regst_desc_id)) {
blob_info.rs = &inplace_consumed_rs_;
} else {
blob_info.rs = nullptr;
}
} else {
blob_info.regst_desc_id = -1;
blob_info.ordinal = -1;
blob_info.rs = nullptr;
}
ek.bn_in_op2blob_info.emplace(bn, std::move(blob_info));
}
}
}
void Actor::ForEachProducedRegst(const std::function<void(Regst*)>& Handler) const {
for (const auto& pair : produced_regsts_) {
for (const auto& regst : pair.second) { Handler(regst.get()); }
......@@ -515,14 +551,22 @@ void Actor::AsyncLaunchKernel(const KernelCtx& kernel_ctx,
std::function<Regst*(int64_t)> Regst4RegstDescId) {
for (const ExecKernel& ek : exec_kernel_vec_) {
ek.kernel->Launch(kernel_ctx, [&](const std::string& bn_in_op) -> Blob* {
auto regst_desc_id_it = ek.bn_in_op2regst_desc_id.find(bn_in_op);
if (regst_desc_id_it == ek.bn_in_op2regst_desc_id.end()) { return nullptr; }
Regst* regst = GetNaiveOrInplaceCurWriteable(regst_desc_id_it->second);
if (regst == nullptr) { regst = GetNaiveOrInplaceCurReadable(regst_desc_id_it->second); }
if (regst == nullptr) { regst = Regst4RegstDescId(regst_desc_id_it->second); }
const auto blob_info_it = ek.bn_in_op2blob_info.find(bn_in_op);
if (blob_info_it == ek.bn_in_op2blob_info.cend()) { return nullptr; }
const BlobInfo& info = blob_info_it->second;
if (info.regst_desc_id == -1) { return nullptr; }
Regst* regst;
if (info.rs != nullptr) {
regst = info.rs->Front(info.regst_desc_id);
} else {
regst = Regst4RegstDescId(info.regst_desc_id);
}
if (regst == nullptr) { return nullptr; }
const LogicalBlobId& lbi = ek.kernel->BnInOp2Lbi(bn_in_op);
return regst->GetBlobByLbi(lbi);
if (info.ordinal >= 0) {
return regst->GetBlobByOrdinal(info.ordinal);
} else {
return regst->GetBlobByLbi(info.lbi);
}
});
}
}
......
......@@ -53,9 +53,15 @@ class Actor {
int64_t actor_id() const { return actor_id_; }
protected:
struct BlobInfo {
LogicalBlobId lbi;
int64_t regst_desc_id;
int64_t ordinal;
RegstSlot* rs;
};
struct ExecKernel {
std::unique_ptr<const Kernel> kernel;
HashMap<std::string, int64_t> bn_in_op2regst_desc_id;
HashMap<std::string, BlobInfo> bn_in_op2blob_info;
};
using MsgHandler = int (Actor::*)(const ActorMsg&);
enum class RegstNameType { kNaive = 0, kCustomized };
......@@ -182,6 +188,7 @@ class Actor {
const PbMap<std::string, RegstDescProto>& produced_ids);
void TakeOverNaiveConsumed(const PbMap<std::string, RegstDescIdSet>& consumed_ids);
void TakeOverNaiveProduced(const PbMap<std::string, RegstDescProto>& produced_ids);
void InitBnInOp2BlobInfo(const TaskProto& task_proto);
// Send Msgs
void AsyncSendNaiveProducedRegstMsgToConsumer();
......
......@@ -23,7 +23,7 @@ void CaseCompActor::VirtualCompActorInit(const TaskProto& task_proto) {
task_proto.exec_sequence().exec_node().Get(0).kernel_conf().op_attribute().output_bns_size();
FOR_RANGE(int64_t, i, 0, output_bns_size) {
const int64_t regst_desc_id =
exec_kernel_vec().at(0).bn_in_op2regst_desc_id.at(GenRepeatedBn("out", i));
exec_kernel_vec().at(0).bn_in_op2blob_info.at(GenRepeatedBn("out", i)).regst_desc_id;
CHECK(out_bn_id2regst_desc_id_.emplace(i, regst_desc_id).second);
}
TakeOverConsumedRegst(task_proto.consumed_regst_desc_id());
......
......@@ -23,7 +23,7 @@ void EsacCompActor::VirtualCompActorInit(const TaskProto& task_proto) {
task_proto.exec_sequence().exec_node().Get(0).kernel_conf().op_attribute().input_bns_size();
FOR_RANGE(int64_t, i, 0, input_bns_size) {
const int64_t regst_desc_id =
exec_kernel_vec().at(0).bn_in_op2regst_desc_id.at(GenRepeatedBn("in", i));
exec_kernel_vec().at(0).bn_in_op2blob_info.at(GenRepeatedBn("in", i)).regst_desc_id;
CHECK(regst_desc_id2in_bn_id_.emplace(regst_desc_id, i).second);
}
for (const auto& pair : task_proto.consumed_regst_desc_id()) {
......
......@@ -25,10 +25,10 @@ void InputWiseCompActor::Init(const TaskProto& task_proto) {
for (int64_t i = 0; i < input_bns.size(); ++i) {
CHECK(ibn2in_bn_id.emplace(input_bns.Get(i), i).second);
}
for (const auto& pair : exec_kernel_vec().at(0).bn_in_op2regst_desc_id) {
for (const auto& pair : exec_kernel_vec().at(0).bn_in_op2blob_info) {
auto it = ibn2in_bn_id.find(pair.first);
if (it != ibn2in_bn_id.end()) {
CHECK(regst_desc_id2in_bn_id_.emplace(pair.second, it->second).second);
CHECK(regst_desc_id2in_bn_id_.emplace(pair.second.regst_desc_id, it->second).second);
}
}
......
......@@ -22,7 +22,7 @@ void ReentrantLockCompActor::VirtualCompActorInit(const TaskProto& task_proto) {
const auto& kernel_conf = task_proto.exec_sequence().exec_node().Get(0).kernel_conf();
const auto& ibns = kernel_conf.op_attribute().input_bns();
for (const auto& ibn : ibns) {
int64_t regst_desc_id = exec_kernel_vec().at(0).bn_in_op2regst_desc_id.at(ibn);
int64_t regst_desc_id = exec_kernel_vec().at(0).bn_in_op2blob_info.at(ibn).regst_desc_id;
if (ibn == "start") { eord_regst_desc_id_ = regst_desc_id; }
CHECK(regst_desc_id2ibn_.emplace(regst_desc_id, ibn).second);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册