提交 a260b163 编写于 作者: J Jingwu Chen 提交者: Will Zhang

boxing for both cnn & rnn (#513)

* boxing for both cnn & rnn

* add ColIdOrder

* order uncertain

* fix boxing

* last -- max

* boxing kernel context

* fix int32

* new delete previous_pid_cid_

* nullptr after delete

* fix bug

* judge when max_col_num > 1
上级 9094e8b9
...@@ -4,20 +4,50 @@ ...@@ -4,20 +4,50 @@
namespace oneflow { namespace oneflow {
void BoxingActor::VirtualActorInit(const TaskProto& task_proto) { void BoxingActor::VirtualActorInit(const TaskProto& task_proto) {
is_eord_ = false;
for (const auto& pair : task_proto.consumed_regst_desc_id()) { for (const auto& pair : task_proto.consumed_regst_desc_id()) {
readable_regst_[pair.second] = {}; readable_regst_[pair.second] = {};
} }
previous_pid_cid_ = new HashMap<int64_t, std::pair<int64_t, int32_t>>;
readable_regst_cnt_ = 0; readable_regst_cnt_ = 0;
col_id_order_ = ColIdOrder::kUnCertain;
is_eord_ = false;
OF_SET_MSG_HANDLER(&BoxingActor::HandlerNormal); OF_SET_MSG_HANDLER(&BoxingActor::HandlerNormal);
} }
void BoxingActor::TrySetColIdOrder(const Regst* cur_regst) {
int64_t regst_desc_id = cur_regst->regst_desc_id();
int64_t cur_pid = cur_regst->piece_id();
int32_t cur_cid = cur_regst->col_id();
if (previous_pid_cid_->find(regst_desc_id) == previous_pid_cid_->end()) {
(*previous_pid_cid_)[regst_desc_id] = std::make_pair(cur_pid, cur_cid);
return;
}
auto& pre_pid_cid = previous_pid_cid_->at(regst_desc_id);
if (pre_pid_cid.first != cur_pid) {
pre_pid_cid = std::make_pair(cur_pid, cur_cid);
return;
}
if (cur_cid == pre_pid_cid.second + 1) {
col_id_order_ = ColIdOrder::kAscending;
} else {
CHECK_EQ(cur_cid, pre_pid_cid.second - 1);
col_id_order_ = ColIdOrder::kDescending;
}
delete previous_pid_cid_;
previous_pid_cid_ = nullptr;
return;
}
int BoxingActor::HandlerNormal(const ActorMsg& msg) { int BoxingActor::HandlerNormal(const ActorMsg& msg) {
if (msg.msg_type() == ActorMsgType::kEordMsg) { if (msg.msg_type() == ActorMsgType::kEordMsg) {
is_eord_ = true; is_eord_ = true;
DecreaseRemainingEordCnt(); DecreaseRemainingEordCnt();
} else if (msg.msg_type() == ActorMsgType::kRegstMsg) { } else if (msg.msg_type() == ActorMsgType::kRegstMsg) {
if (TryUpdtStateAsProducedRegst(msg.regst()) != 0) { if (TryUpdtStateAsProducedRegst(msg.regst()) != 0) {
if (msg.regst()->packed_blob()->max_col_num() > 1
&& col_id_order_ == ColIdOrder::kUnCertain) {
TrySetColIdOrder(msg.regst());
}
std::queue<Regst*>& rq = readable_regst_.at(msg.regst()->regst_desc_id()); std::queue<Regst*>& rq = readable_regst_.at(msg.regst()->regst_desc_id());
if (rq.empty()) { readable_regst_cnt_ += 1; } if (rq.empty()) { readable_regst_cnt_ += 1; }
rq.push(msg.regst()); rq.push(msg.regst());
...@@ -30,7 +60,6 @@ int BoxingActor::HandlerNormal(const ActorMsg& msg) { ...@@ -30,7 +60,6 @@ int BoxingActor::HandlerNormal(const ActorMsg& msg) {
} }
void BoxingActor::Act() { void BoxingActor::Act() {
int64_t piece_id = readable_regst_.begin()->second.front()->piece_id();
AsyncLaunchKernel(GenDefaultKernelCtx(), AsyncLaunchKernel(GenDefaultKernelCtx(),
[this](int64_t regst_desc_id) -> Regst* { [this](int64_t regst_desc_id) -> Regst* {
Regst* regst = GetCurWriteableRegst(regst_desc_id); Regst* regst = GetCurWriteableRegst(regst_desc_id);
...@@ -41,10 +70,25 @@ void BoxingActor::Act() { ...@@ -41,10 +70,25 @@ void BoxingActor::Act() {
} }
}); });
AsyncSendRegstMsgToConsumer([&](Regst* regst) { AsyncSendRegstMsgToConsumer([&](Regst* regst) {
regst->set_piece_id(piece_id); regst->set_piece_id(regst->piece_id());
return true; return regst->col_id() <= regst->max_col_id();
}); });
int32_t cur_max_cid = 0;
int32_t cur_max_maxcid = 0;
for (const auto& pair : readable_regst_) {
cur_max_cid = std::max(cur_max_cid, pair.second.front()->col_id());
cur_max_maxcid =
std::max(cur_max_maxcid, pair.second.front()->max_col_id());
}
for (auto& pair : readable_regst_) { for (auto& pair : readable_regst_) {
if (col_id_order_ == ColIdOrder::kAscending) {
if (pair.second.front()->IsMaxCol() && cur_max_cid < cur_max_maxcid) {
continue;
}
} else if (col_id_order_ == ColIdOrder::kDescending) {
if (pair.second.front()->col_id() < cur_max_cid) { continue; }
} else { // do nothing
}
AsyncSendRegstMsgToProducer(pair.second.front()); AsyncSendRegstMsgToProducer(pair.second.front());
pair.second.pop(); pair.second.pop();
if (pair.second.empty()) { readable_regst_cnt_ -= 1; } if (pair.second.empty()) { readable_regst_cnt_ -= 1; }
......
...@@ -21,9 +21,15 @@ class BoxingActor final : public Actor { ...@@ -21,9 +21,15 @@ class BoxingActor final : public Actor {
bool IsReadAlwaysUnReadyFromNow() override; bool IsReadAlwaysUnReadyFromNow() override;
void AsyncReturnAllReadableRegst() override; void AsyncReturnAllReadableRegst() override;
bool is_eord_; void TrySetColIdOrder(const Regst*);
// <regst_desc_id, regst*>
HashMap<int64_t, std::queue<Regst*>> readable_regst_; HashMap<int64_t, std::queue<Regst*>> readable_regst_;
int64_t readable_regst_cnt_; // <regst_desc_id, <pid, cid>>
HashMap<int64_t, std::pair<int64_t, int32_t>>* previous_pid_cid_;
int8_t readable_regst_cnt_;
ColIdOrder col_id_order_;
bool is_eord_;
}; };
} // namespace oneflow } // namespace oneflow
......
...@@ -54,7 +54,6 @@ void RandomUniformInitializer( ...@@ -54,7 +54,6 @@ void RandomUniformInitializer(
blob->shape().elem_cnt(), static_cast<T>(initializer_conf.min()), blob->shape().elem_cnt(), static_cast<T>(initializer_conf.min()),
static_cast<T>(initializer_conf.max()), random_seed, blob->mut_dptr<T>()); static_cast<T>(initializer_conf.max()), random_seed, blob->mut_dptr<T>());
} }
template<typename T> template<typename T>
void RandomNormalInitializer( void RandomNormalInitializer(
const RandomNormalInitializerConf& initializer_conf, uint32_t random_seed, const RandomNormalInitializerConf& initializer_conf, uint32_t random_seed,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册