提交 f74b651b 编写于 作者: C chengtbf 提交者: Will Zhang

refine fw(recurrent/normal) comp task node (#511)

* fw task node for recurrent out edge

* recurrent task node 2 out regst

* set min regst num of rec_ht_regst

* todo is recurrent out edge

* refine name

* implement is recurrent out edge for comp task node

* reduce code for last impl

* fix for sub class fuc

* change bind out regst in fw task node by SunPeiWen

* fix for review

* unexpected_run for virtual func

* fix bug by zyeric fault

* change interface: delete isrecurrentoutedge add get succ/pred chain node on edge

* rec_ht

* fix name

* fix name

* refine fw(recurrent/nomal) compute task node

* fix bug of bind in regst

* delete node

* fix for review

* abstract new func for succ/pred chain node on edge

* fix max regst in model parallel

* fix back


Former-commit-id: 5970b3d3
上级 1b90b3a7
......@@ -3,11 +3,36 @@
namespace oneflow {
namespace {
const ChainNode* ChainNodeOnEdge(
TaskEdge* edge, TaskNode* (TaskEdge::*GetNode)() const,
const std::unordered_set<TaskEdge*>& (TaskNode::*GetEdges)() const) {
CompTaskNode* target_node = nullptr;
do {
TaskNode* tmp_node = (edge->*GetNode)();
edge = *((tmp_node->*GetEdges)().begin());
target_node = dynamic_cast<CompTaskNode*>(tmp_node);
} while (!target_node && edge);
if (target_node) { return target_node->chain_node(); }
return nullptr;
}
} // namespace
void CompTaskNode::ToProto(TaskProto* task_proto) {
TaskNode::ToProto(task_proto);
*(task_proto->mutable_parallel_ctx()) = parallel_ctx_;
}
const ChainNode* CompTaskNode::SuccChainNodeOnEdge(TaskEdge* edge) {
return ChainNodeOnEdge(edge, &TaskEdge::dst_node, &TaskNode::out_edges);
}
const ChainNode* CompTaskNode::PredChainNodeOnEdge(TaskEdge* edge) {
return ChainNodeOnEdge(edge, &TaskEdge::src_node, &TaskNode::in_edges);
}
void SortByParallelId(std::vector<CompTaskNode*>* node_vec) {
std::sort(node_vec->begin(), node_vec->end(),
[](const CompTaskNode* lhs, const CompTaskNode* rhs) {
......
......@@ -28,6 +28,9 @@ class CompTaskNode : public TaskNode {
void set_chain_node(const ChainNode* val) { chain_node_ = val; }
protected:
const ChainNode* SuccChainNodeOnEdge(TaskEdge* edge);
const ChainNode* PredChainNodeOnEdge(TaskEdge* edge);
private:
ParallelContext parallel_ctx_;
const ChainNode* chain_node_;
......
......@@ -4,16 +4,18 @@
namespace oneflow {
void ForwardCompTaskNode::ProduceAllRegstsAndBindEdges() {
std::shared_ptr<RegstDesc> activation_regst = ProduceRegst("activation");
std::shared_ptr<RegstDesc> data_tmp_regst = ProduceRegst("data_tmp");
std::shared_ptr<RegstDesc> out_regst = ProduceRegst("out");
for (TaskEdge* edge : out_edges()) {
TaskNode* dst_node = edge->dst_node();
if (IsBackwardTaskType(dst_node->GetTaskType())) {
edge->AddRegst("activation", activation_regst);
edge->AddRegst("data_tmp", data_tmp_regst);
if (SuccChainNodeOnEdge(edge) == chain_node()) {
VirtualAddRegstOnRecurrentOutEdge(edge);
} else {
edge->AddRegst("out", out_regst);
if (IsBackwardTaskType(dst_node->GetTaskType())) {
edge->AddRegst("activation", ProduceRegst("activation"));
edge->AddRegst("data_tmp", ProduceRegst("data_tmp"));
}
}
edge->AddRegst("out", out_regst);
}
}
......@@ -39,22 +41,6 @@ void ForwardCompTaskNode::BuildExecGphAndRegst() {
});
}
void ForwardCompTaskNode::BuildOutRegst() {
std::shared_ptr<RegstDesc> out_regst = GetProducedRegst("out");
mut_exec_gph().ForEachNode([&](ExecNode* cur_node) {
HashSet<std::string> found_lbns;
for (ExecEdge* out_edge : cur_node->out_edges()) {
CHECK(found_lbns.insert(out_edge->lbn()).second);
}
for (const std::string& obn : cur_node->op()->output_bns()) {
const std::string& lbn = cur_node->op()->Lbn4BnInOp(obn);
if (found_lbns.find(lbn) != found_lbns.end()) { continue; }
out_regst->AddLbn(lbn);
cur_node->BindBnInOpAndRegst(obn, out_regst);
}
});
}
void ForwardCompTaskNode::BuildActivationRegst() {
std::shared_ptr<RegstDesc> activation_regst = GetProducedRegst("activation");
mut_exec_gph().ForEachEdge([&](const ExecEdge* edge) {
......
......@@ -17,11 +17,12 @@ class ForwardCompTaskNode : public CompTaskNode {
void LockRegsts() override;
protected:
virtual void VirtualConsumeInRegst(TaskEdge* edge) { UNEXPECTED_RUN(); };
virtual void BuildExecGphStructAndBindInRegst() { UNEXPECTED_RUN(); };
virtual void VirtualAddRegstOnRecurrentOutEdge(TaskEdge* edge) {}
virtual void VirtualConsumeInRegst(TaskEdge* edge) { UNEXPECTED_RUN(); }
virtual void BuildExecGphStructAndBindInRegst() { UNEXPECTED_RUN(); }
virtual void BuildOutRegst() { UNEXPECTED_RUN(); }
private:
void BuildOutRegst();
void BuildActivationRegst();
void BuildModelAndTmpRegsts();
void FixRegisterNumRange() override;
......
......@@ -36,6 +36,22 @@ void NormalForwardCompTaskNode::BuildExecGphStructAndBindInRegst() {
});
}
void NormalForwardCompTaskNode::BuildOutRegst() {
std::shared_ptr<RegstDesc> out_regst = GetProducedRegst("out");
mut_exec_gph().ForEachNode([&](ExecNode* cur_node) {
HashSet<std::string> found_lbns;
for (ExecEdge* out_edge : cur_node->out_edges()) {
CHECK(found_lbns.insert(out_edge->lbn()).second);
}
for (const std::string& obn : cur_node->op()->output_bns()) {
const std::string& lbn = cur_node->op()->Lbn4BnInOp(obn);
if (found_lbns.find(lbn) != found_lbns.end()) { continue; }
out_regst->AddLbn(lbn);
cur_node->BindBnInOpAndRegst(obn, out_regst);
}
});
}
bool NormalForwardCompTaskNode::IsReadyForBuild() {
return GetConsumedRegst("in")->IsLocked();
}
......
......@@ -17,6 +17,7 @@ class NormalForwardCompTaskNode final : public ForwardCompTaskNode {
private:
void VirtualConsumeInRegst(TaskEdge* edge) override;
void BuildExecGphStructAndBindInRegst() override;
void BuildOutRegst() override;
};
} // namespace oneflow
......
......@@ -4,15 +4,32 @@
namespace oneflow {
void RecurrentForwardCompTaskNode::VirtualAddRegstOnRecurrentOutEdge(
TaskEdge* edge) {
int32_t max_regst_num = -1;
if (parallel_ctx()->policy() == kDataParallel) {
max_regst_num = 1;
} else if (parallel_ctx()->policy() == kModelParallel) {
max_regst_num = kMaxRegisterNum;
} else {
UNEXPECTED_RUN();
}
edge->AddRegst("rec_out", ProduceRegst("rec_out", 1, max_regst_num));
}
void RecurrentForwardCompTaskNode::VirtualConsumeInRegst(TaskEdge* edge) {
std::shared_ptr<const Operator> op = chain_node()->SoleOp();
std::shared_ptr<RegstDesc> regst = edge->GetSoleRegst();
if (regst->GetBlobDesc(op->Lbn4BnInOp("in"))) {
const HashSet<std::string>& lbns =
PredChainNodeOnEdge(edge)->data_output_lbns();
if (lbns.find(op->Lbn4BnInOp("in")) != lbns.end()) {
ConsumeRegst("in", regst);
} else if (regst->GetBlobDesc(op->Lbn4BnInOp("h0"))) {
} else if (lbns.find(op->Lbn4BnInOp("h0")) != lbns.end()) {
ConsumeRegst("h0", regst);
} else if (regst->GetBlobDesc(op->Lbn4BnInOp("ht_1"))) {
ConsumeRegst("ht_1", regst);
} else if (lbns.find(op->Lbn4BnInOp("rec_in")) != lbns.end()) {
if (parallel_ctx()->policy() == kModelParallel) {
ConsumeRegst("rec_in", regst);
}
} else {
UNEXPECTED_RUN();
}
......@@ -24,11 +41,30 @@ void RecurrentForwardCompTaskNode::BuildExecGphStructAndBindInRegst() {
ExecNode* exec_node = mut_exec_gph().NewNode();
exec_node->mut_op() = op;
exec_node->BindBnInOpAndRegst("in", GetConsumedRegst("in"));
exec_node->BindBnInOpAndRegst("ht_1", GetConsumedRegst("ht_1"));
if (parallel_ctx()->policy() == kModelParallel) {
exec_node->BindBnInOpAndRegst("rec_in", GetConsumedRegst("rec_in"));
} else if (parallel_ctx()->policy() == kDataParallel) {
exec_node->BindBnInOpAndRegst("rec_in", GetProducedRegst("rec_out"));
} else {
UNEXPECTED_RUN();
}
std::shared_ptr<RegstDesc> h0_regst = GetConsumedRegst("h0");
if (h0_regst) { exec_node->BindBnInOpAndRegst("h0", h0_regst); }
}
void RecurrentForwardCompTaskNode::BuildOutRegst() {
std::shared_ptr<RegstDesc> out_regst = GetProducedRegst("out");
std::shared_ptr<RegstDesc> rec_out_regst = GetProducedRegst("rec_out");
CHECK(out_regst && rec_out_regst);
ExecNode* exec_node = mut_exec_gph().SoleNode();
const std::string& out_lbn = exec_node->op()->Lbn4BnInOp("out");
const std::string& rec_out_lbn = exec_node->op()->Lbn4BnInOp("rec_out");
out_regst->AddLbn(out_lbn);
rec_out_regst->AddLbn(rec_out_lbn);
exec_node->BindBnInOpAndRegst("out", out_regst);
exec_node->BindBnInOpAndRegst("rec_out", rec_out_regst);
}
bool RecurrentForwardCompTaskNode::IsReadyForBuild() {
std::shared_ptr<RegstDesc> regst = GetConsumedRegst("h0");
if (GetConsumedRegst("in")->IsLocked() && (!regst || regst->IsLocked())) {
......
......@@ -15,8 +15,10 @@ class RecurrentForwardCompTaskNode final : public ForwardCompTaskNode {
bool IsReadyForBuild() override;
private:
void VirtualAddRegstOnRecurrentOutEdge(TaskEdge* edge) override;
void VirtualConsumeInRegst(TaskEdge* edge) override;
void BuildExecGphStructAndBindInRegst() override;
void BuildOutRegst() override;
};
} // namespace oneflow
......
......@@ -34,15 +34,15 @@ void RecurrentOp::InitFromOpConf() {
CHECK(op_conf().has_recurrent_conf());
const RecurrentOpConf& conf = op_conf().recurrent_conf();
EnrollInputBn("in");
EnrollInputBn("ht_1");
EnrollInputBn("rec_in");
if (!conf.init_hidden().empty()) {
CHECK(!conf.has_init_hidden_initializer());
EnrollInputBn("h0");
} else {
EnrollModelBn("h0");
}
EnrollOutputBn("ht");
EnrollOutputBn("rec_ht");
EnrollOutputBn("out");
EnrollOutputBn("rec_out");
if (conf.rnn_type_case() == RecurrentOpConf::kBasicRnnCell) {
EnrollDataTmpBn("in_ip_op_out");
......@@ -90,12 +90,12 @@ void RecurrentOp::InferBlobDescs(
BalancedSplitter splitter(hidden_size, parallel_ctx->parallel_num());
hidden_size = splitter.At(parallel_ctx->parallel_id()).size();
}
// ht
BlobDesc ht_blob_desc = *in_blob_desc;
ht_blob_desc.mut_shape() = Shape({data_num, hidden_size});
*GetBlobDesc4BnInOp("ht") = ht_blob_desc;
// recurrent_ht
*GetBlobDesc4BnInOp("rec_ht") = ht_blob_desc;
// out
BlobDesc out_blob_desc = *in_blob_desc;
out_blob_desc.mut_shape() = Shape({data_num, hidden_size});
*GetBlobDesc4BnInOp("out") = out_blob_desc;
// recurrent_out
*GetBlobDesc4BnInOp("rec_out") = out_blob_desc;
if (op_conf().recurrent_conf().rnn_type_case()
== RecurrentOpConf::kBasicRnnCell) {
......@@ -109,8 +109,8 @@ void RecurrentOp::InferBlobDescs(
}
std::string RecurrentOp::ibn2lbn(const std::string& input_bn) const {
if (input_bn == "ht_1") {
return obn2lbn("ht");
if (input_bn == "rec_in") {
return obn2lbn("rec_out");
} else if (input_bn == "h0") {
return op_conf().recurrent_conf().init_hidden();
} else if (input_bn == "in") {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册