提交 34e53189 编写于 作者: Y Yi Zhu 提交者: Will Zhang

add min_register_num and max_register_num (#344)

* add min_register_num and max_register_num

* fix bug

* update register_num for nodes

* update loss register num

* updt register num interval

* fix interfaces and bug

* fix bug

* fix bug

* fix bug

* fix bug
上级 dfef4ad8
......@@ -45,10 +45,10 @@ void BoxingTaskNode::EnrollAllRegstAndBindRelatedEdge() {
}
for (TaskEdge* edge : out_edges()) {
std::string name = "boxing_out_" + edge->edge_id_str();
auto regst_desc = NewProducedRegstDesc(name);
auto regst_desc = NewProducedRegstDesc(name, 1, kMaxRegisterNum);
BindProducedRegstAndOutEdge(regst_desc, edge);
}
NewProducedRegstDesc("middle");
NewProducedRegstDesc("middle", 1);
}
void BoxingTaskNode::FwInitChain2SortedEdgesMaps(
......
......@@ -5,7 +5,7 @@
namespace oneflow {
void CopyTaskNode::BuildExecAndEnrollLbn2Regsts(TaskGraph*) {
auto out_regst = NewProducedRegstDesc("copy_out");
auto out_regst = NewProducedRegstDesc("copy_out", 1, kMaxRegisterNum);
BindProducedRegstAndOutEdge(out_regst, SoleOutEdge());
std::shared_ptr<RegstDesc> in_regst = GetRelatedRegst(SoleInEdge());
ConsumeRegstDesc("copy_in", in_regst);
......
......@@ -9,14 +9,14 @@ void DataCompTaskNode::FwBuildExecAndEnrollLbn2Regsts(TaskGraph*) {
mut_exec_gph().UpdateSourceAndSink();
// Enroll Produced Regsts
if (!out_edges().empty()) {
auto out_regst = NewProducedRegstDesc("out");
auto out_regst = NewProducedRegstDesc("out", 1, kMaxRegisterNum);
BindProducedRegstAndOutEdge(out_regst, SoleOutEdge());
}
NewProducedRegstDesc("activation");
NewProducedRegstDesc("data_tmp");
NewProducedRegstDesc("model_tmp");
NewProducedRegstDesc("model");
NewProducedRegstDesc("loss");
NewProducedRegstDesc("activation", 1, kMaxRegisterNum);
NewProducedRegstDesc("data_tmp", 1, kMaxRegisterNum);
NewProducedRegstDesc("model_tmp", 1);
NewProducedRegstDesc("model", 3, kMaxRegisterNum);
NewProducedRegstDesc("loss", 1, kMaxRegisterNum);
// Enroll Lbn
FwSetExecNodeFromInRegst(extern_in_lbn2consumer);
FwEnrollLbn2OutRegst(lbn2producer);
......@@ -153,12 +153,12 @@ void DataCompTaskNode::FwEnrollLbn2ModelAndTmpRegsts() {
void DataCompTaskNode::BpBuildExecAndEnrollLbn2Regsts(TaskGraph*) {
BpBuildExecGraph();
// New produced registers
auto in_diff_regst = NewProducedRegstDesc("in_diff");
auto in_diff_regst = NewProducedRegstDesc("in_diff", 1, kMaxRegisterNum);
if (!out_edges().empty()) {
BindProducedRegstAndOutEdge(in_diff_regst, SoleOutEdge());
}
NewProducedRegstDesc("model_diff");
NewProducedRegstDesc("activation_diff");
NewProducedRegstDesc("model_diff", 1, kMaxRegisterNum);
NewProducedRegstDesc("activation_diff", 1);
// Subscribe
ConsumeRegstDesc("activation",
GetFwNode()->GetProducedRegstDesc("activation"));
......
......@@ -10,7 +10,7 @@ void LossAccCompTaskNode::BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) {
BindProducedRegstAndOutEdge(loss_regst, SoleOutEdge());
return;
}
NewProducedRegstDesc("loss_acc");
NewProducedRegstDesc("loss_acc", 1, kMaxRegisterNum);
auto loss_regst = GetRelatedRegst(SoleInEdge());
auto loss_acc_regst = GetProducedRegstDesc("loss_acc");
ExecNode* exec_node = mut_exec_gph().NewNode();
......
......@@ -16,7 +16,7 @@ void MdDiffAccCompTaskNode::BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) {
return;
}
// comp task node
NewProducedRegstDesc("model_diff_acc");
NewProducedRegstDesc("model_diff_acc", 1, kMaxRegisterNum);
auto model_diff_acc_regst = GetProducedRegstDesc("model_diff_acc");
ExecNode* exec_node = mut_exec_gph().NewNode();
......
......@@ -25,7 +25,7 @@ void MdUpdtCompTaskNode::BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) {
ConsumeRegstDesc(ibn, model_diff_acc_regst);
}
exec_node->BindBnInOpAndRegst(exec_node->op()->SoleObn(), model_regst);
auto data_tmp_regst = NewProducedRegstDesc("data_tmp");
auto data_tmp_regst = NewProducedRegstDesc("data_tmp", 1);
for (const std::string& dtbn : exec_node->op()->data_tmp_bns()) {
const std::string& lbn = exec_node->op()->Lbn4BnInOp(dtbn);
data_tmp_regst->EnrollLbn(lbn);
......
......@@ -116,10 +116,14 @@ void TaskNode::BindProducedRegstAndOutEdge(std::weak_ptr<RegstDesc> regst,
}
std::shared_ptr<RegstDesc> TaskNode::NewProducedRegstDesc(
const std::string& regst_desc_name) {
const std::string& regst_desc_name, int32_t min_register_num,
int32_t max_register_num) {
auto regst_desc = std::make_shared<RegstDesc>();
regst_desc->SetProducer(this);
regst_desc->set_regst_desc_id(IDMgr::Singleton()->NewRegstDescId());
CHECK_LE(min_register_num, max_register_num);
regst_desc->set_min_register_num(min_register_num);
regst_desc->set_max_register_num(max_register_num);
CHECK(produced_regst_descs_.emplace(regst_desc_name, regst_desc).second);
return regst_desc;
}
......
......@@ -94,7 +94,14 @@ class TaskNode : public Node<TaskNode, TaskEdge> {
void BindProducedRegstAndOutEdge(std::weak_ptr<RegstDesc>, const TaskEdge*);
std::shared_ptr<RegstDesc> NewProducedRegstDesc(
const std::string& regst_desc_name);
const std::string& regst_desc_name, int32_t min_register_num,
int32_t max_register_num);
std::shared_ptr<RegstDesc> NewProducedRegstDesc(
const std::string& regst_desc_name, int32_t register_num) {
return NewProducedRegstDesc(regst_desc_name, register_num, register_num);
}
void ConsumeRegstDesc(const std::string& regst_desc_name,
std::shared_ptr<RegstDesc> regst_desc);
......
......@@ -114,7 +114,9 @@ void RegstDesc::ToProto(RegstDescProto* ret) const {
pair.second->ToProto(&(pb_pair.second));
ret->mutable_lbn2blob_desc()->insert(pb_pair);
}
ret->set_register_num(register_num_);
ret->set_register_num(min_register_num_);
ret->set_min_register_num(min_register_num_);
ret->set_max_register_num(max_register_num_);
*(ret->mutable_mem_case()) = InferMemCase();
}
......
......@@ -6,6 +6,8 @@
namespace oneflow {
const int32_t kMaxRegisterNum = std::numeric_limits<int32_t>::max();
class TaskNode;
class RegstDesc final {
......@@ -17,6 +19,13 @@ class RegstDesc final {
// regst_desc_id
int64_t regst_desc_id() const { return regst_desc_id_; }
void set_regst_desc_id(int64_t val) { regst_desc_id_ = val; }
//
int32_t min_register_num() const { return min_register_num_; }
void set_min_register_num(int32_t val) { min_register_num_ = val; }
int32_t max_register_num() const { return max_register_num_; }
void set_max_register_num(int32_t val) { max_register_num_ = val; }
// Producer
const TaskNode* GetProducer() const { return producer_; }
void SetProducer(const TaskNode* task_node) { producer_ = task_node; }
......@@ -44,6 +53,8 @@ class RegstDesc final {
int64_t regst_desc_id_;
const TaskNode* producer_;
HashSet<const TaskNode*> consumers_;
int32_t min_register_num_;
int32_t max_register_num_;
HashMap<std::string, std::unique_ptr<BlobDesc>> lbn2blob_desc_;
int64_t register_num_;
......
......@@ -9,6 +9,8 @@ message RegstDescProto {
int64 producer_task_id = 2;
repeated int64 consumer_task_id = 3;
map<string, BlobDescProto> lbn2blob_desc = 4;
int64 register_num = 5;
MemoryCase mem_case = 6;
int32 min_register_num = 5;
int32 max_register_num = 6;
int32 register_num = 7;
MemoryCase mem_case = 8;
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册