From a260b163783de06cf74320df41c09a236744136c Mon Sep 17 00:00:00 2001 From: Jingwu Chen <1163643073@qq.com> Date: Sat, 20 Jan 2018 16:34:14 +0800 Subject: [PATCH] boxing for both cnn & rnn (#513) * boxing for both cnn & rnn * add ColIdOrder * order uncertain * fix boxing * last -- max * boxing kernel context * fix int32 * new delete previous_pid_cid_ * nullptr after delete * fix bug * judge when max_col_num > 1 --- oneflow/core/actor/boxing_actor.cpp | 52 ++++++++++++++++++++++++++--- oneflow/core/actor/boxing_actor.h | 10 ++++-- oneflow/core/kernel/kernel_util.cpp | 1 - 3 files changed, 56 insertions(+), 7 deletions(-) diff --git a/oneflow/core/actor/boxing_actor.cpp b/oneflow/core/actor/boxing_actor.cpp index a432ed79b3..27a5508e38 100644 --- a/oneflow/core/actor/boxing_actor.cpp +++ b/oneflow/core/actor/boxing_actor.cpp @@ -4,20 +4,50 @@ namespace oneflow { void BoxingActor::VirtualActorInit(const TaskProto& task_proto) { - is_eord_ = false; for (const auto& pair : task_proto.consumed_regst_desc_id()) { readable_regst_[pair.second] = {}; } + previous_pid_cid_ = new HashMap>; readable_regst_cnt_ = 0; + col_id_order_ = ColIdOrder::kUnCertain; + is_eord_ = false; OF_SET_MSG_HANDLER(&BoxingActor::HandlerNormal); } +void BoxingActor::TrySetColIdOrder(const Regst* cur_regst) { + int64_t regst_desc_id = cur_regst->regst_desc_id(); + int64_t cur_pid = cur_regst->piece_id(); + int32_t cur_cid = cur_regst->col_id(); + if (previous_pid_cid_->find(regst_desc_id) == previous_pid_cid_->end()) { + (*previous_pid_cid_)[regst_desc_id] = std::make_pair(cur_pid, cur_cid); + return; + } + auto& pre_pid_cid = previous_pid_cid_->at(regst_desc_id); + if (pre_pid_cid.first != cur_pid) { + pre_pid_cid = std::make_pair(cur_pid, cur_cid); + return; + } + if (cur_cid == pre_pid_cid.second + 1) { + col_id_order_ = ColIdOrder::kAscending; + } else { + CHECK_EQ(cur_cid, pre_pid_cid.second - 1); + col_id_order_ = ColIdOrder::kDescending; + } + delete previous_pid_cid_; + previous_pid_cid_ = nullptr; + return; +} + int BoxingActor::HandlerNormal(const ActorMsg& msg) { if (msg.msg_type() == ActorMsgType::kEordMsg) { is_eord_ = true; DecreaseRemainingEordCnt(); } else if (msg.msg_type() == ActorMsgType::kRegstMsg) { if (TryUpdtStateAsProducedRegst(msg.regst()) != 0) { + if (msg.regst()->packed_blob()->max_col_num() > 1 + && col_id_order_ == ColIdOrder::kUnCertain) { + TrySetColIdOrder(msg.regst()); + } std::queue& rq = readable_regst_.at(msg.regst()->regst_desc_id()); if (rq.empty()) { readable_regst_cnt_ += 1; } rq.push(msg.regst()); @@ -30,7 +60,6 @@ int BoxingActor::HandlerNormal(const ActorMsg& msg) { } void BoxingActor::Act() { - int64_t piece_id = readable_regst_.begin()->second.front()->piece_id(); AsyncLaunchKernel(GenDefaultKernelCtx(), [this](int64_t regst_desc_id) -> Regst* { Regst* regst = GetCurWriteableRegst(regst_desc_id); @@ -41,10 +70,25 @@ void BoxingActor::Act() { } }); AsyncSendRegstMsgToConsumer([&](Regst* regst) { - regst->set_piece_id(piece_id); - return true; + regst->set_piece_id(regst->piece_id()); + return regst->col_id() <= regst->max_col_id(); }); + int32_t cur_max_cid = 0; + int32_t cur_max_maxcid = 0; + for (const auto& pair : readable_regst_) { + cur_max_cid = std::max(cur_max_cid, pair.second.front()->col_id()); + cur_max_maxcid = + std::max(cur_max_maxcid, pair.second.front()->max_col_id()); + } for (auto& pair : readable_regst_) { + if (col_id_order_ == ColIdOrder::kAscending) { + if (pair.second.front()->IsMaxCol() && cur_max_cid < cur_max_maxcid) { + continue; + } + } else if (col_id_order_ == ColIdOrder::kDescending) { + if (pair.second.front()->col_id() < cur_max_cid) { continue; } + } else { // do nothing + } AsyncSendRegstMsgToProducer(pair.second.front()); pair.second.pop(); if (pair.second.empty()) { readable_regst_cnt_ -= 1; } diff --git a/oneflow/core/actor/boxing_actor.h b/oneflow/core/actor/boxing_actor.h index 80a82ef6da..8cc02d8d33 100644 --- a/oneflow/core/actor/boxing_actor.h +++ b/oneflow/core/actor/boxing_actor.h @@ -21,9 +21,15 @@ class BoxingActor final : public Actor { bool IsReadAlwaysUnReadyFromNow() override; void AsyncReturnAllReadableRegst() override; - bool is_eord_; + void TrySetColIdOrder(const Regst*); + + // HashMap> readable_regst_; - int64_t readable_regst_cnt_; + // > + HashMap>* previous_pid_cid_; + int8_t readable_regst_cnt_; + ColIdOrder col_id_order_; + bool is_eord_; }; } // namespace oneflow diff --git a/oneflow/core/kernel/kernel_util.cpp b/oneflow/core/kernel/kernel_util.cpp index 622b417e3b..de256d05e5 100644 --- a/oneflow/core/kernel/kernel_util.cpp +++ b/oneflow/core/kernel/kernel_util.cpp @@ -54,7 +54,6 @@ void RandomUniformInitializer( blob->shape().elem_cnt(), static_cast(initializer_conf.min()), static_cast(initializer_conf.max()), random_seed, blob->mut_dptr()); } - template void RandomNormalInitializer( const RandomNormalInitializerConf& initializer_conf, uint32_t random_seed, -- GitLab