提交 a260b163 编写于 作者: J Jingwu Chen 提交者: Will Zhang

boxing for both cnn & rnn (#513)

* boxing for both cnn & rnn

* add ColIdOrder

* order uncertain

* fix boxing

* last -- max

* boxing kernel context

* fix int32

* new delete previous_pid_cid_

* nullptr after delete

* fix bug

* judge when max_col_num > 1
上级 9094e8b9
......@@ -4,20 +4,50 @@
namespace oneflow {
void BoxingActor::VirtualActorInit(const TaskProto& task_proto) {
is_eord_ = false;
for (const auto& pair : task_proto.consumed_regst_desc_id()) {
readable_regst_[pair.second] = {};
}
previous_pid_cid_ = new HashMap<int64_t, std::pair<int64_t, int32_t>>;
readable_regst_cnt_ = 0;
col_id_order_ = ColIdOrder::kUnCertain;
is_eord_ = false;
OF_SET_MSG_HANDLER(&BoxingActor::HandlerNormal);
}
void BoxingActor::TrySetColIdOrder(const Regst* cur_regst) {
int64_t regst_desc_id = cur_regst->regst_desc_id();
int64_t cur_pid = cur_regst->piece_id();
int32_t cur_cid = cur_regst->col_id();
if (previous_pid_cid_->find(regst_desc_id) == previous_pid_cid_->end()) {
(*previous_pid_cid_)[regst_desc_id] = std::make_pair(cur_pid, cur_cid);
return;
}
auto& pre_pid_cid = previous_pid_cid_->at(regst_desc_id);
if (pre_pid_cid.first != cur_pid) {
pre_pid_cid = std::make_pair(cur_pid, cur_cid);
return;
}
if (cur_cid == pre_pid_cid.second + 1) {
col_id_order_ = ColIdOrder::kAscending;
} else {
CHECK_EQ(cur_cid, pre_pid_cid.second - 1);
col_id_order_ = ColIdOrder::kDescending;
}
delete previous_pid_cid_;
previous_pid_cid_ = nullptr;
return;
}
int BoxingActor::HandlerNormal(const ActorMsg& msg) {
if (msg.msg_type() == ActorMsgType::kEordMsg) {
is_eord_ = true;
DecreaseRemainingEordCnt();
} else if (msg.msg_type() == ActorMsgType::kRegstMsg) {
if (TryUpdtStateAsProducedRegst(msg.regst()) != 0) {
if (msg.regst()->packed_blob()->max_col_num() > 1
&& col_id_order_ == ColIdOrder::kUnCertain) {
TrySetColIdOrder(msg.regst());
}
std::queue<Regst*>& rq = readable_regst_.at(msg.regst()->regst_desc_id());
if (rq.empty()) { readable_regst_cnt_ += 1; }
rq.push(msg.regst());
......@@ -30,7 +60,6 @@ int BoxingActor::HandlerNormal(const ActorMsg& msg) {
}
void BoxingActor::Act() {
int64_t piece_id = readable_regst_.begin()->second.front()->piece_id();
AsyncLaunchKernel(GenDefaultKernelCtx(),
[this](int64_t regst_desc_id) -> Regst* {
Regst* regst = GetCurWriteableRegst(regst_desc_id);
......@@ -41,10 +70,25 @@ void BoxingActor::Act() {
}
});
AsyncSendRegstMsgToConsumer([&](Regst* regst) {
regst->set_piece_id(piece_id);
return true;
regst->set_piece_id(regst->piece_id());
return regst->col_id() <= regst->max_col_id();
});
int32_t cur_max_cid = 0;
int32_t cur_max_maxcid = 0;
for (const auto& pair : readable_regst_) {
cur_max_cid = std::max(cur_max_cid, pair.second.front()->col_id());
cur_max_maxcid =
std::max(cur_max_maxcid, pair.second.front()->max_col_id());
}
for (auto& pair : readable_regst_) {
if (col_id_order_ == ColIdOrder::kAscending) {
if (pair.second.front()->IsMaxCol() && cur_max_cid < cur_max_maxcid) {
continue;
}
} else if (col_id_order_ == ColIdOrder::kDescending) {
if (pair.second.front()->col_id() < cur_max_cid) { continue; }
} else { // do nothing
}
AsyncSendRegstMsgToProducer(pair.second.front());
pair.second.pop();
if (pair.second.empty()) { readable_regst_cnt_ -= 1; }
......
......@@ -21,9 +21,15 @@ class BoxingActor final : public Actor {
bool IsReadAlwaysUnReadyFromNow() override;
void AsyncReturnAllReadableRegst() override;
bool is_eord_;
void TrySetColIdOrder(const Regst*);
// <regst_desc_id, regst*>
HashMap<int64_t, std::queue<Regst*>> readable_regst_;
int64_t readable_regst_cnt_;
// <regst_desc_id, <pid, cid>>
HashMap<int64_t, std::pair<int64_t, int32_t>>* previous_pid_cid_;
int8_t readable_regst_cnt_;
ColIdOrder col_id_order_;
bool is_eord_;
};
} // namespace oneflow
......
......@@ -54,7 +54,6 @@ void RandomUniformInitializer(
blob->shape().elem_cnt(), static_cast<T>(initializer_conf.min()),
static_cast<T>(initializer_conf.max()), random_seed, blob->mut_dptr<T>());
}
template<typename T>
void RandomNormalInitializer(
const RandomNormalInitializerConf& initializer_conf, uint32_t random_seed,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册