未验证 提交 3654d164 编写于 作者: J Jinhui Yuan 提交者: GitHub

fix wrong ActNum of ctrl regst produced by AccCompActor (#1136)

上级 5be84c50
......@@ -25,7 +25,9 @@ void AccumulateCompActor::Init(const TaskProto& task_proto, int32_t max_acc_cnt,
next_piece_id_ = 0;
}
int64_t AccumulateCompActor::ActNumForEachOutput() const { return max_acc_cnt_; }
int64_t AccumulateCompActor::ActNumForEachOutput(int64_t regst_desc_id) const {
return regst_desc_id == GetSoleProducedDataRegstDescId() ? max_acc_cnt_ : 1;
}
void AccumulateCompActor::Act() {
Regst* in_regst = GetNaiveSoleCurReadable();
......
......@@ -13,7 +13,7 @@ class AccumulateCompActor : public CompActor {
protected:
void Init(const TaskProto&, int32_t max_acc_cnt, ColIdOrder order);
int64_t ActNumForEachOutput() const override;
int64_t ActNumForEachOutput(int64_t regst_desc_id) const override;
private:
void Act() override;
......
......@@ -417,6 +417,11 @@ Regst* Actor::GetSoleProducedRegst(int64_t regst_desc_id) {
return it->second.front().get();
}
int64_t Actor::GetSoleProducedDataRegstDescId() const {
CHECK_EQ(produced_data_regsts_.size(), 1);
return produced_data_regsts_.begin()->first;
}
bool Actor::IsReadReady() {
return naive_readable_data_regst_.size() == naive_readable_data_regst_cnt_
&& IsCustomizedReadReady();
......@@ -443,7 +448,7 @@ int Actor::ProcessWriteableCtrlRegstMsg(const ActorMsg& msg) {
if (expected_act_id >= 0 && CheckOutputActId(regst->regst_desc_id())) {
CHECK_EQ(regst->act_id(), expected_act_id);
}
expected_act_id = regst->act_id() + ActNumForEachOutput();
expected_act_id = regst->act_id() + ActNumForEachOutput(regst->regst_desc_id());
writeable_it->second.push_back(regst);
return 0;
}
......@@ -524,7 +529,7 @@ int Actor::TryUpdtStateAsProducedRegst(Regst* regst) {
if (expected_act_id >= 0 && CheckOutputActId(regst->regst_desc_id())) {
CHECK_EQ(regst->act_id(), expected_act_id);
}
expected_act_id = regst->act_id() + ActNumForEachOutput();
expected_act_id = regst->act_id() + ActNumForEachOutput(regst->regst_desc_id());
writeable_it->second.push_back(regst);
return 0;
}
......
......@@ -85,7 +85,7 @@ class Actor {
virtual bool IsCustomizedReadAlwaysUnReadyFromNow() { return false; }
bool IsWriteReady();
virtual void AsyncReturnAllCustomizedReadableRegst() {}
virtual int64_t ActNumForEachOutput() const { return 1; }
virtual int64_t ActNumForEachOutput(int64_t regst_desc_id) const { return 1; }
virtual bool CheckOutputActId(int64_t regst_desc_id) const {
return true; // TODO(jiyuan): figure out the ActNumForEachOutput of the model regsts to MdSave
// area
......@@ -117,6 +117,7 @@ class Actor {
Regst* GetNaiveSoleCurReadable();
Regst* GetNaiveFirstCurReadable();
Regst* GetSoleProducedRegst(int64_t regst_desc_id);
int64_t GetSoleProducedDataRegstDescId() const;
void DecreaseActualWriteableProducedDataRegstDescNum(int64_t amount) {
actual_writeable_produced_data_regst_desc_num_ -= amount;
......
......@@ -29,7 +29,9 @@ void InputWiseCompActor::Init(const TaskProto& task_proto) {
OF_SET_MSG_HANDLER(&InputWiseCompActor::HandlerNormal);
}
int64_t InputWiseCompActor::ActNumForEachOutput() const { return regst_desc_id2in_bn_id_.size(); }
int64_t InputWiseCompActor::ActNumForEachOutput(int64_t regst_desc_id) const {
return regst_desc_id2in_bn_id_.size();
}
void InputWiseCompActor::NormalProcessCustomizedReadableRegstMsg(const ActorMsg& msg) {
Regst* regst = msg.regst();
......
......@@ -17,7 +17,7 @@ class InputWiseCompActor : public CompActor {
int64_t processed_regst_desc_id_cnt() const { return processed_regst_desc_id_cnt_; }
int64_t RegstDescNum() const { return readable_regsts_.size(); }
int64_t InBnId4RegstDescId(int64_t id) const { return regst_desc_id2in_bn_id_.at(id); }
int64_t ActNumForEachOutput() const override;
int64_t ActNumForEachOutput(int64_t regst_desc_id) const override;
bool EnableInplace() const {
return GetDeviceType() == DeviceType::kGPU && Global<JobDesc>::Get()->enable_mem_sharing();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册