From 4f6038231115338f4009546a585db3302ed8a7f3 Mon Sep 17 00:00:00 2001 From: Niu Chong Date: Sun, 24 Dec 2017 17:07:25 +0800 Subject: [PATCH] Add BasicRnnBackwardActor and Fix typos in BasicRnnForwardActor (#466) * feat: add BasicRnnForwardComputeActor * feat: add RnnSourceComputeActor * fix: fix the bug of recurrent edge in BasicRnnForwardComputeActor 1. fix: not send EROD msg through actor recurrent edge 2. fix: not for BasicRnnForwardComputeActor, HandlerNormal() judge out_consume/out_produce by recurrent_flag * fix: fix the bug of PR comment * fix: fix bugs of BasicRnnForwardComputeActor * fix: fix some typos * fix: fix bugs and release model in HandlerNormal Former-commit-id: 8b9110ad9a8e7025c8300ce2ab1d2786c68433ca --- .../basic_rnn_backward_compute_actor.cpp | 287 ++++++++++++++++++ .../actor/basic_rnn_backward_compute_actor.h | 63 ++++ .../actor/basic_rnn_forward_compute_actor.cpp | 7 +- oneflow/core/job/task.proto | 1 + 4 files changed, 355 insertions(+), 3 deletions(-) create mode 100644 oneflow/core/actor/basic_rnn_backward_compute_actor.cpp create mode 100644 oneflow/core/actor/basic_rnn_backward_compute_actor.h diff --git a/oneflow/core/actor/basic_rnn_backward_compute_actor.cpp b/oneflow/core/actor/basic_rnn_backward_compute_actor.cpp new file mode 100644 index 0000000000..46eaf0744e --- /dev/null +++ b/oneflow/core/actor/basic_rnn_backward_compute_actor.cpp @@ -0,0 +1,287 @@ +#include "oneflow/core/actor/basic_rnn_backward_compute_actor.h" + +namespace oneflow { + +void BasicRnnBackwardCompActor::VirtualCompActorInit( + const TaskProto& task_proto) { + in_regst_desc_id_ = RegstDescId4Name("in"); + out_regst_desc_id_ = RegstDescId4Name("out"); + initial_hidden_regst_desc_id_ = RegstDescId4Name("initial_hidden"); + out_diff_regst_desc_id_ = RegstDescId4Name("out_diff"); + rec_acc_diff_regst_desc_id_ = RegstDescId4Name("rec_acc_diff"); + model_regst_desc_id_ = RegstDescId4Name("model"); + activation_regst_desc_id_ = RegstDescId4Name("activation"); + + is_out_diff_eord_ = false; + is_insert_to_back_ = true; + DecreaseRemainingEordCnt(); // no 'rec_acc_diff', else will cause deadlock + OF_SET_MSG_HANDLER(&BasicRnnBackwardCompActor::HandlerNormal); +} + +int BasicRnnBackwardCompActor::HandlerNormal(const ActorMsg& msg) { + if (msg.msg_type() == ActorMsgType::kEordMsg) { + if (msg.eord_regst_desc_id() == out_diff_regst_desc_id_) { + is_out_diff_eord_ = true; + } + DecreaseRemainingEordCnt(); + } else if (msg.msg_type() == ActorMsgType::kRegstMsg) { + Regst* cur_regst = msg.regst(); + if (TryUpdtStateAsProducedRegst(cur_regst) != 0) { + int64_t cur_regst_desc_id = cur_regst->regst_desc_id(); + const PieceStatus& cur_pst = cur_regst->piece_status(); + int64_t cur_pid = cur_pst.piece_id(); + int64_t cur_col_id = cur_pst.col_id(); + int64_t cur_model_vid = cur_regst->model_version_id(); + + if (cur_regst_desc_id == in_regst_desc_id_) { + pid2in_regsts_[cur_pid].push(cur_regst); // insert or push + } else if (cur_regst_desc_id == out_regst_desc_id_) { + pid2out_regsts_[cur_pid].push_back(cur_regst); // insert or pushback + if (cur_col_id == 0) { + model_vid2cnt_[cur_model_vid] += 1; // insert or add + model_vid2status_[cur_model_vid] = false; // insert or set + + if ((cur_model_vid > 0) + && (model_vid2status_.find(cur_model_vid - 1) + != model_vid2status_.end())) { + model_vid2status_.at(cur_model_vid - 1) = true; + if (model_vid2cnt_.find(cur_model_vid - 1) + == model_vid2cnt_.end()) { + RelModelByJudgingStatus(cur_model_vid - 1); + } + } + + if (cur_pid == GetLastPieceIdForModelVersionId(cur_model_vid)) { + model_vid2status_.at(cur_model_vid) = true; + } + if (cur_pid == RuntimeCtx::Singleton()->total_piece_num() - 1) { + model_vid2status_.at(cur_model_vid) = true; + } + } + } else if (cur_regst_desc_id == initial_hidden_regst_desc_id_) { + CHECK(pid2init_hid_regsts_.emplace(cur_pid, cur_regst).second); + } else if (cur_regst_desc_id == out_diff_regst_desc_id_) { + auto it = pid2out_diff_regsts_.find(cur_pid); + if (it == pid2out_diff_regsts_.end()) { + if (cur_col_id == 0) { + is_insert_to_back_ = true; + } else if (cur_pst.IsLastCol()) { + is_insert_to_back_ = false; + } else { + // do nothing + } + } + if (is_insert_to_back_) { + pid2out_diff_regsts_[cur_pid].push_back(cur_regst); // insert or push + } else { + pid2out_diff_regsts_[cur_pid].push_front(cur_regst); + } + } else if (cur_regst_desc_id == rec_acc_diff_regst_desc_id_) { + CHECK_EQ(-1, cur_regst->recurrent_flag()); + CHECK(pid2rec_acc_diff_regsts_.emplace(cur_pid, cur_regst).second); + } else if (cur_regst_desc_id == model_regst_desc_id_) { + CHECK(model_vid2model_regst_.emplace(cur_model_vid, cur_regst).second); + } + } + ActUntilFail(); + } else { + UNEXPECTED_RUN(); + } + return TrySwitchToZombieOrFinish(); +} + +bool BasicRnnBackwardCompActor::CheckModel_In_OutDiff_Activation( + Regst* out_regst) const { + const PieceStatus& cur_pst = out_regst->piece_status(); + int64_t cur_pid = cur_pst.piece_id(); + int64_t cur_model_vid = out_regst->model_version_id(); + + auto model_it = model_vid2model_regst_.find(cur_model_vid); + if (model_it == model_vid2model_regst_.end()) { return false; } + + auto in_it = pid2in_regsts_.find(cur_pid); + if (in_it == pid2in_regsts_.end()) { return false; } + if (cur_pst.IsLastCol()) { + if (in_it->second.top()->piece_status() != cur_pst) { return false; } + } else { + CHECK(in_it->second.top()->piece_status() == cur_pst); + } + + auto out_diff_it = pid2out_diff_regsts_.find(cur_pid); + if (out_diff_it == pid2out_diff_regsts_.end()) { return false; } + if (cur_pst.IsLastCol()) { + if (out_diff_it->second.back()->piece_status() != cur_pst) { return false; } + } else { + CHECK(out_diff_it->second.back()->piece_status() == cur_pst); + } + + auto act_it = pid2activation_regsts_.find(cur_pid); + if (act_it == pid2activation_regsts_.end()) { return false; } + if (cur_pst.IsLastCol()) { + if (act_it->second.top()->piece_status() != cur_pst) { return false; } + } else { + CHECK(act_it->second.top()->piece_status() == cur_pst); + } + + return true; +} + +void BasicRnnBackwardCompActor::FillReadableWithIn_OutDiff_Model_Activation( + Regst* out_regst) { + int64_t cur_pid = out_regst->piece_status().piece_id(); + int64_t cur_model_vid = out_regst->model_version_id(); + readable_regsts_.emplace(in_regst_desc_id_, pid2in_regsts_.at(cur_pid).top()); + readable_regsts_.emplace(out_diff_regst_desc_id_, + pid2out_diff_regsts_.at(cur_pid).back()); + readable_regsts_.emplace(model_regst_desc_id_, + model_vid2model_regst_.at(cur_model_vid)); + readable_regsts_.emplace(activation_regst_desc_id_, + pid2activation_regsts_.at(cur_pid).top()); +} + +bool BasicRnnBackwardCompActor::IsReadReady() { + if (pid2in_regsts_.empty() || pid2out_regsts_.empty() + || pid2out_diff_regsts_.empty() || model_vid2model_regst_.empty() + || pid2activation_regsts_.empty()) { + return false; + } + for (const auto& kv : pid2out_regsts_) { + Regst* out_regst = kv.second.back(); + const PieceStatus& cur_pst = out_regst->piece_status(); + int64_t cur_pid = cur_pst.piece_id(); + + if (!CheckModel_In_OutDiff_Activation(out_regst)) { continue; } + + readable_regsts_.clear(); + if (cur_pst.col_id() == 0) { + auto init_hid_it = pid2init_hid_regsts_.find(cur_pid); + if (init_hid_it == pid2init_hid_regsts_.end()) { continue; } + readable_regsts_.emplace(initial_hidden_regst_desc_id_, + init_hid_it->second); + } else { + readable_regsts_.emplace(out_regst_desc_id_, + *(pid2out_regsts_.at(cur_pid).end() - 2)); + } + if (!cur_pst.IsLastCol()) { + auto rec_acc_it = pid2rec_acc_diff_regsts_.find(cur_pid); + if (rec_acc_it == pid2rec_acc_diff_regsts_.end()) { continue; } + CHECK(rec_acc_it->second->piece_status().IsNextColOf( + out_regst->piece_status())); + readable_regsts_.emplace(rec_acc_diff_regst_desc_id_, rec_acc_it->second); + } else { + CHECK_EQ(kv.second.size(), pid2out_regsts_.at(cur_pid).size()); + CHECK_EQ(kv.second.size(), pid2activation_regsts_.at(cur_pid).size()); + } + FillReadableWithIn_OutDiff_Model_Activation(out_regst); + return true; + } + return false; +} + +bool BasicRnnBackwardCompActor::IsReadAlwaysUnReadyFromNow() { + return is_out_diff_eord_ && pid2out_diff_regsts_.empty(); +} + +void BasicRnnBackwardCompActor::RelModelByJudgingStatus(int64_t model_vid) { + if (model_vid2status_.at(model_vid)) { + AsyncSendRegstMsgToProducer(model_vid2model_regst_.at(model_vid)); + model_vid2model_regst_.erase(model_vid); + model_vid2status_.erase(model_vid); + } +} + +void BasicRnnBackwardCompActor::UpdtModelStatusAfterAct() { + Regst* out_diff_regst = readable_regsts_.at(out_diff_regst_desc_id_); + const PieceStatus& cur_pst = out_diff_regst->piece_status(); + int64_t cur_col_id = cur_pst.col_id(); + Regst* model_regst = readable_regsts_.at(model_regst_desc_id_); + int64_t cur_model_vid = model_regst->model_version_id(); + + if (cur_col_id == 0) { + model_vid2cnt_.at(cur_model_vid) -= 1; + if (model_vid2cnt_.at(cur_model_vid) == 0) { + model_vid2cnt_.erase(cur_model_vid); + RelModelByJudgingStatus(cur_model_vid); + } + } +} + +void BasicRnnBackwardCompActor::Act() { + AsyncLaunchKernel( + GenDefaultKernelCtx(), + [this](int64_t regst_desc_id) -> Regst* { return nullptr; }); + AsyncSendRegstMsgToConsumer([](Regst* regst) { + regst->set_is_forward(false); + return true; + }); + + Regst* out_diff_regst = readable_regsts_.at(out_diff_regst_desc_id_); + const PieceStatus& cur_pst = out_diff_regst->piece_status(); + int64_t cur_pid = cur_pst.piece_id(); + int64_t cur_col_id = cur_pst.col_id(); + Regst* model_regst = readable_regsts_.at(model_regst_desc_id_); + + UpdtModelStatusAfterAct(); + +#define ERASE_ELES_IN_HASHMAP_WHEN_COL0(hash_map) \ + do { \ + if (cur_col_id == 0) { \ + CHECK(hash_map.at(cur_pid).empty()); \ + hash_map.erase(cur_pid); \ + } \ + } while (0) + + // update out_regst + // the out_regst inserted to readable_regsts_ is not back(), but 'back()-1' + CHECK(pid2out_regsts_.at(cur_pid).back()->piece_status() == cur_pst); + AsyncSendRegstMsgToProducer(pid2out_regsts_.at(cur_pid).back()); + pid2out_regsts_.at(cur_pid).pop_back(); + ERASE_ELES_IN_HASHMAP_WHEN_COL0(pid2out_regsts_); + + for (auto& kv : readable_regsts_) { + if (kv.first == model_regst_desc_id_) { continue; } + if (kv.first == out_regst_desc_id_) { continue; } + AsyncSendRegstMsgToProducer(kv.second); + + if (kv.first == in_regst_desc_id_) { + pid2in_regsts_.at(cur_pid).pop(); + ERASE_ELES_IN_HASHMAP_WHEN_COL0(pid2in_regsts_); + } else if (kv.first == out_diff_regst_desc_id_) { + pid2out_diff_regsts_.at(cur_pid).pop_back(); + if (pid2out_diff_regsts_.at(cur_pid).empty()) { + pid2out_diff_regsts_.erase(cur_pid); + } + } else if (kv.first == initial_hidden_regst_desc_id_) { + CHECK_EQ(0, cur_col_id); + pid2init_hid_regsts_.erase(cur_pid); + } else if (kv.first == rec_acc_diff_regst_desc_id_) { + CHECK(!cur_pst.IsLastCol()); + pid2rec_acc_diff_regsts_.erase(cur_pid); + } else if (kv.first == activation_regst_desc_id_) { + pid2activation_regsts_.at(cur_pid).pop(); + ERASE_ELES_IN_HASHMAP_WHEN_COL0(pid2activation_regsts_); +#undef ERASE_ELES_IN_HASHMAP_WHEN_COL0 + } else { + UNEXPECTED_RUN(); + } + } +} + +void BasicRnnBackwardCompActor::AsyncReturnAllReadableRegst() { + CHECK(pid2in_regsts_.empty()); + CHECK(pid2out_regsts_.empty()); + CHECK(pid2out_diff_regsts_.empty()); + CHECK(pid2init_hid_regsts_.empty()); + CHECK(pid2rec_acc_diff_regsts_.empty()); + CHECK(pid2activation_regsts_.empty()); + CHECK(model_vid2cnt_.empty()); + CHECK(model_vid2status_.empty()); + for (auto& kv : model_vid2model_regst_) { + AsyncSendRegstMsgToProducer(kv.second); + } + model_vid2model_regst_.clear(); +} + +REGISTER_ACTOR(TaskType::kBasicRnnBackward, BasicRnnBackwardCompActor); + +} // namespace oneflow diff --git a/oneflow/core/actor/basic_rnn_backward_compute_actor.h b/oneflow/core/actor/basic_rnn_backward_compute_actor.h new file mode 100644 index 0000000000..7c2d54a6fb --- /dev/null +++ b/oneflow/core/actor/basic_rnn_backward_compute_actor.h @@ -0,0 +1,63 @@ +#ifndef ONEFLOW_CORE_ACTOR_BASIC_RNN_BACKWARD_COMPUTE_ACTOR_H_ +#define ONEFLOW_CORE_ACTOR_BASIC_RNN_BACKWARD_COMPUTE_ACTOR_H_ + +#include +#include "oneflow/core/actor/compute_actor.h" + +namespace oneflow { + +class BasicRnnBackwardCompActor final : public CompActor { + public: + OF_DISALLOW_COPY_AND_MOVE(BasicRnnBackwardCompActor); + BasicRnnBackwardCompActor() = default; + ~BasicRnnBackwardCompActor() = default; + + void VirtualCompActorInit(const TaskProto&) override; + + private: + int HandlerNormal(const ActorMsg&) override; + bool IsReadReady() override; + bool IsReadAlwaysUnReadyFromNow() override; + void AsyncReturnAllReadableRegst() override; + void Act() override; + + bool CheckModel_In_OutDiff_Activation(Regst*) const; + void FillReadableWithIn_OutDiff_Model_Activation(Regst*); + void UpdtModelStatusAfterAct(); + void RelModelByJudgingStatus(int64_t); // Rel for Release + + int64_t in_regst_desc_id_; + HashMap> pid2in_regsts_; + + int64_t out_regst_desc_id_; + HashMap> pid2out_regsts_; + + int64_t initial_hidden_regst_desc_id_; + HashMap pid2init_hid_regsts_; + + int64_t out_diff_regst_desc_id_; + // regst in deque is ascending by col_id + HashMap> pid2out_diff_regsts_; + bool is_insert_to_back_; + + int64_t rec_acc_diff_regst_desc_id_; // recurrent accumulate diff + HashMap pid2rec_acc_diff_regsts_; + + int64_t model_regst_desc_id_; + HashMap model_vid2model_regst_; + HashMap model_vid2cnt_; + // the only way to release a model regst is through model_vid2status_ + // except the last several unused model regsts + // + std::map model_vid2status_; + + int64_t activation_regst_desc_id_; + HashMap> pid2activation_regsts_; + + bool is_out_diff_eord_ = false; + HashMap readable_regsts_; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_ACTOR_BASIC_RNN_BACKWARD_COMPUTE_ACTOR_H_ diff --git a/oneflow/core/actor/basic_rnn_forward_compute_actor.cpp b/oneflow/core/actor/basic_rnn_forward_compute_actor.cpp index 24deb3b8ed..dba95ed0af 100644 --- a/oneflow/core/actor/basic_rnn_forward_compute_actor.cpp +++ b/oneflow/core/actor/basic_rnn_forward_compute_actor.cpp @@ -137,17 +137,18 @@ void BasicRnnForwardCompActor::UpdtInAndModelStates() { int64_t model_vid = model_regst->model_version_id(); model_regst2cnt_.at(model_regst) -= 1; + int64_t last_pid = GetLastPieceIdForModelVersionId(model_vid); if (model_regst2cnt_.at(model_regst) == 0) { model_regst2cnt_.erase(model_regst); if (latest_model_regst_ != model_regst - || cur_pid == GetLastPieceIdForModelVersionId(model_vid) - || models_to_be_released_.find(model_regst) + || cur_pid == last_pid + || models_to_be_released_.find(model_regst) != models_to_be_released_.end()) { AsyncSendRegstMsgToProducer(model_regst); if (model_regst == latest_model_regst_) { latest_model_regst_ = nullptr; } - if (models_to_be_released_.find(model_regst) + if (models_to_be_released_.find(model_regst) != models_to_be_released_.end()) { models_to_be_released_.erase(model_regst); } diff --git a/oneflow/core/job/task.proto b/oneflow/core/job/task.proto index 9057101075..14d1ffb4b1 100644 --- a/oneflow/core/job/task.proto +++ b/oneflow/core/job/task.proto @@ -22,6 +22,7 @@ enum TaskType { kPrint = 13; kRnnSource = 14; kBasicRnnForward = 15; + kBasicRnnBackward = 16; }; message TaskProto { -- GitLab