提交 3589c342 编写于 作者: W willzhang4a58

refine task_node::build_exec

上级 b55817ac
...@@ -8,28 +8,28 @@ namespace oneflow { ...@@ -8,28 +8,28 @@ namespace oneflow {
namespace { namespace {
void FwCompleteBoxOpConfDataData(BoxingOpConf* conf) { void FwCompleteBoxOpConfDataData(BoxingOpConf* conf) {
conf->mutable_concat_box()->set_axis(0); conf->mutable_concat_box()->set_type(BoxConcatConf::kData);
conf->mutable_split_box()->set_axis(0); conf->mutable_data_split_box();
} }
void FwCompleteBoxOpConfDataModel(BoxingOpConf* conf) { void FwCompleteBoxOpConfDataModel(BoxingOpConf* conf) {
conf->mutable_concat_box()->set_axis(0); conf->mutable_concat_box()->set_type(BoxConcatConf::kData);
conf->mutable_clone_box(); conf->mutable_clone_box();
} }
void FwCompleteBoxOpConfModelData(BoxingOpConf* conf) { void FwCompleteBoxOpConfModelData(BoxingOpConf* conf) {
conf->mutable_concat_box()->set_axis(1); conf->mutable_concat_box()->set_type(BoxConcatConf::kModel);
conf->mutable_split_box()->set_axis(0); conf->mutable_data_split_box();
} }
void FwCompleteBoxOpConfModelModel(BoxingOpConf* conf) { void FwCompleteBoxOpConfModelModel(BoxingOpConf* conf) {
conf->mutable_concat_box()->set_axis(1); conf->mutable_concat_box()->set_type(BoxConcatConf::kModel);
conf->mutable_clone_box(); conf->mutable_clone_box();
} }
} // namespace } // namespace
void BoxingTaskNode::FwBuildExecAndProducedRegsts(TaskGraph* gph) { void BoxingTaskNode::FwBuildExecAndEnrollLbn2Regsts(TaskGraph* gph) {
EnrollAllRegstAndBindRelatedEdge(); EnrollAllRegstAndBindRelatedEdge();
FwVirtualBuild(); FwVirtualBuild();
} }
...@@ -37,11 +37,11 @@ void BoxingTaskNode::FwBuildExecAndProducedRegsts(TaskGraph* gph) { ...@@ -37,11 +37,11 @@ void BoxingTaskNode::FwBuildExecAndProducedRegsts(TaskGraph* gph) {
void BoxingTaskNode::EnrollAllRegstAndBindRelatedEdge() { void BoxingTaskNode::EnrollAllRegstAndBindRelatedEdge() {
for (TaskEdge* edge : out_edges()) { for (TaskEdge* edge : out_edges()) {
std::string name = "boxing_out_" + edge->edge_id_str(); std::string name = "boxing_out_" + edge->edge_id_str();
auto regst_desc = of_make_unique<DisContigRegstDesc> (); auto regst_desc = RegstDescMgr::Singleton().CreateRegisterDesc();
BindProducedRegstAndOutEdge(regst_desc.get(), edge); BindProducedRegstAndOutEdge(regst_desc.get(), edge);
EnrollProducedRegstDesc(name, std::move(regst_desc)); EnrollProducedRegstDesc(name, std::move(regst_desc));
} }
auto regst_desc = of_make_unique<DisContigRegstDesc> (); auto regst_desc = RegstDescMgr::Singleton().CreateRegisterDesc();
EnrollProducedRegstDesc("middle", std::move(regst_desc)); EnrollProducedRegstDesc("middle", std::move(regst_desc));
} }
...@@ -133,7 +133,7 @@ void BoxingTaskNode::FwBuildChainSortedEdgesPair( ...@@ -133,7 +133,7 @@ void BoxingTaskNode::FwBuildChainSortedEdgesPair(
const std::string& ibn = node->op()->input_bns().at(i); const std::string& ibn = node->op()->input_bns().at(i);
std::string lbn = node->op()->ibn2lbn(ibn); std::string lbn = node->op()->ibn2lbn(ibn);
Shape* ptr = in_regst->GetMutShapePtr(lbn); Shape* ptr = in_regst->GetMutShapePtr(lbn);
node->op()->SetShapePtr(ibn, ptr); node->BindBnInOpAndShapePtr(ibn, ptr);
node->BindBnInOpAndRegst(ibn, in_regst); node->BindBnInOpAndRegst(ibn, in_regst);
} }
// obn // obn
...@@ -142,16 +142,22 @@ void BoxingTaskNode::FwBuildChainSortedEdgesPair( ...@@ -142,16 +142,22 @@ void BoxingTaskNode::FwBuildChainSortedEdgesPair(
const std::string& obn = node->op()->output_bns().at(i); const std::string& obn = node->op()->output_bns().at(i);
std::string lbn = node->op()->obn2lbn(obn); std::string lbn = node->op()->obn2lbn(obn);
Shape* ptr = out_regst->EnrollLbn(lbn); Shape* ptr = out_regst->EnrollLbn(lbn);
node->op()->SetShapePtr(obn, ptr); node->BindBnInOpAndShapePtr(obn, ptr);
node->BindBnInOpAndRegst(obn, out_regst); node->BindBnInOpAndRegst(obn, out_regst);
} }
// dtbn // dtbn
for (const std::string& dtbn : node->op()->data_tmp_bns()) { for (const std::string& dtbn : node->op()->data_tmp_bns()) {
std::string lbn = node->op()->dtbn2lbn(dtbn); std::string lbn = node->op()->dtbn2lbn(dtbn);
Shape* ptr = middle_regst->EnrollLbn(lbn); Shape* ptr = middle_regst->EnrollLbn(lbn);
node->op()->SetShapePtr(dtbn, ptr); node->BindBnInOpAndShapePtr(dtbn, ptr);
node->BindBnInOpAndRegst(dtbn, middle_regst);
} }
node->op()->InferShape4ObAndDtbFromIb(); }
}
void BoxingTaskNode::FwInferShape4LbnInProducedRegsts(TaskGraph*) {
for (const auto& exec_node : exec_gph().nodes()) {
exec_node->op()->InferShape4ObAndDtbFromIb(exec_node->BnInOp2ShapePtr());
} }
} }
...@@ -165,10 +171,9 @@ inline RegstDesc* GetBpRegstFromFwRegst(RegstDesc* fw_regst) { ...@@ -165,10 +171,9 @@ inline RegstDesc* GetBpRegstFromFwRegst(RegstDesc* fw_regst) {
} }
void BoxingTaskNode::BpBuildExecAndProducedRegsts(TaskGraph*) { void BoxingTaskNode::BpBuildExecAndEnrollLbn2Regsts(TaskGraph*) {
EnrollAllRegstAndBindRelatedEdge(); EnrollAllRegstAndBindRelatedEdge();
const ExecGraph& fw_exec_gph = GetFwNode()->exec_gph(); const ExecGraph& fw_exec_gph = GetFwNode()->exec_gph();
HashMap<const ExecNode*, ExecNode*> fw_node2bp_node;
for (const std::unique_ptr<ExecNode>& fw_node: fw_exec_gph.nodes()) { for (const std::unique_ptr<ExecNode>& fw_node: fw_exec_gph.nodes()) {
ExecNode* bp_node = mut_exec_gph().NewNode(); ExecNode* bp_node = mut_exec_gph().NewNode();
bp_node->mut_op() = fw_node->op(); bp_node->mut_op() = fw_node->op();
...@@ -178,8 +183,7 @@ void BoxingTaskNode::BpBuildExecAndProducedRegsts(TaskGraph*) { ...@@ -178,8 +183,7 @@ void BoxingTaskNode::BpBuildExecAndProducedRegsts(TaskGraph*) {
std::string lbn = fw_node->op()->ibn2lbn(ibn); std::string lbn = fw_node->op()->ibn2lbn(ibn);
RegstDesc* in_regst = fw_node->GetRegstFromBnInOp(ibn); RegstDesc* in_regst = fw_node->GetRegstFromBnInOp(ibn);
RegstDesc* in_diff_regst = GetBpRegstFromFwRegst(in_regst); RegstDesc* in_diff_regst = GetBpRegstFromFwRegst(in_regst);
Shape* in_diff_shape_ptr = in_diff_regst->EnrollLbn(lbn); in_diff_regst->EnrollLbn(lbn);
*in_diff_shape_ptr = in_regst->GetShape(lbn);
bp_node->BindBnInOpAndRegst(idbn, in_diff_regst); bp_node->BindBnInOpAndRegst(idbn, in_diff_regst);
} }
// out_diff // out_diff
...@@ -195,12 +199,22 @@ void BoxingTaskNode::BpBuildExecAndProducedRegsts(TaskGraph*) { ...@@ -195,12 +199,22 @@ void BoxingTaskNode::BpBuildExecAndProducedRegsts(TaskGraph*) {
std::string lbn = fw_node->op()->dtbn2lbn(dtbn); std::string lbn = fw_node->op()->dtbn2lbn(dtbn);
RegstDesc* fw_middle_regst = GetFwNode()->GetProducedRegstDesc("middle"); RegstDesc* fw_middle_regst = GetFwNode()->GetProducedRegstDesc("middle");
RegstDesc* bp_middle_regst = GetProducedRegstDesc("middle"); RegstDesc* bp_middle_regst = GetProducedRegstDesc("middle");
Shape* ptr = bp_middle_regst->EnrollLbn(lbn); bp_middle_regst->EnrollLbn(lbn);
*ptr = fw_middle_regst->GetShape(lbn);
bp_node->BindBnInOpAndRegst(dtbn, bp_middle_regst); bp_node->BindBnInOpAndRegst(dtbn, bp_middle_regst);
} }
} }
mut_exec_gph().UpdateSourceAndSink(); mut_exec_gph().UpdateSourceAndSink();
} }
void BoxingTaskNode::BpInferShape4LbnInProducedRegsts(TaskGraph*) {
for (TaskEdge* fw_in_edge : GetFwNode()->in_edges()) {
RegstDesc* in_regst = GetRelatedRegst(fw_in_edge);
RegstDesc* in_diff_regst = GetBpRegstFromFwRegst(in_regst);
in_diff_regst->CopyShapeFrom(in_regst);
}
RegstDesc* fw_middle_regst = GetFwNode()->GetProducedRegstDesc("middle");
RegstDesc* bp_middle_regst = GetProducedRegstDesc("middle");
bp_middle_regst->CopyShapeFrom(fw_middle_regst);
}
} // namespace oneflow } // namespace oneflow
...@@ -12,7 +12,7 @@ class BoxingTaskNode : public TaskNode { ...@@ -12,7 +12,7 @@ class BoxingTaskNode : public TaskNode {
virtual ~BoxingTaskNode() = default; virtual ~BoxingTaskNode() = default;
std::string VisualStr() const override { std::string VisualStr() const override {
return TaskNode::VisualStr() + "Boxing_" + node_id_str(); return TaskNode::VisualStr() + "Boxing";
} }
protected: protected:
...@@ -39,8 +39,10 @@ class BoxingTaskNode : public TaskNode { ...@@ -39,8 +39,10 @@ class BoxingTaskNode : public TaskNode {
virtual void FwVirtualBuild() = 0; virtual void FwVirtualBuild() = 0;
private: private:
void FwBuildExecAndProducedRegsts(TaskGraph*) override; void FwBuildExecAndEnrollLbn2Regsts(TaskGraph*) override;
void BpBuildExecAndProducedRegsts(TaskGraph*) override; void FwInferShape4LbnInProducedRegsts(TaskGraph*) override;
void BpBuildExecAndEnrollLbn2Regsts(TaskGraph*) override;
void BpInferShape4LbnInProducedRegsts(TaskGraph*) override;
void EnrollAllRegstAndBindRelatedEdge(); void EnrollAllRegstAndBindRelatedEdge();
......
...@@ -4,11 +4,11 @@ ...@@ -4,11 +4,11 @@
namespace oneflow { namespace oneflow {
void CommNetTaskNode::BuildExecAndProducedRegstsForNetCopy(TaskGraph* gph){ void CommNetTaskNode::CommNetBuildExecAndEnrollLbn2Regsts() {
auto out_regst = of_make_unique<DisContigRegstDesc> (); auto out_regst = RegstDescMgr::Singleton().CreateRegisterDesc();
BindProducedRegstAndOutEdge(out_regst.get(), SoleOutEdge()); BindProducedRegstAndOutEdge(out_regst.get(), SoleOutEdge());
RegstDesc* in_regst = GetRelatedRegst(SoleInEdge()); RegstDesc* in_regst = GetRelatedRegst(SoleInEdge());
out_regst->CopyLbn2ShapeMap(in_regst); out_regst->CopyLbnFrom(in_regst);
OperatorConf op_conf; OperatorConf op_conf;
op_conf.set_name("comm_net_" + NewUniqueId()); op_conf.set_name("comm_net_" + NewUniqueId());
...@@ -22,15 +22,29 @@ void CommNetTaskNode::BuildExecAndProducedRegstsForNetCopy(TaskGraph* gph){ ...@@ -22,15 +22,29 @@ void CommNetTaskNode::BuildExecAndProducedRegstsForNetCopy(TaskGraph* gph){
node->BindBnInOpAndRegst(node->op()->SoleObn(), out_regst.get()); node->BindBnInOpAndRegst(node->op()->SoleObn(), out_regst.get());
mut_exec_gph().UpdateSourceAndSink(); mut_exec_gph().UpdateSourceAndSink();
EnrollProducedRegstDesc("comm_net", std::move(out_regst)); EnrollProducedRegstDesc("out", std::move(out_regst));
} }
void CommNetTaskNode::FwBuildExecAndProducedRegsts(TaskGraph* gph) { void CommNetTaskNode::CommNetInferShape4LbnInProducedRegsts() {
BuildExecAndProducedRegstsForNetCopy(gph); RegstDesc* in_regst = GetRelatedRegst(SoleInEdge());
RegstDesc* out_regst = GetRelatedRegst(SoleOutEdge());
out_regst->CopyShapeFrom(in_regst);
}
void FwBuildExecAndEnrollLbn2Regsts(TaskGraph*) override {
return CommNetBuildExecAndEnrollLbn2Regsts();
}
void FwInferShape4LbnInProducedRegsts(TaskGraph*) override {
return CommNetInferShape4LbnInProducedRegsts();
}
void BpBuildExecAndEnrollLbn2Regsts(TaskGraph*) override {
return CommNetBuildExecAndEnrollLbn2Regsts();
} }
void CommNetTaskNode::BpBuildExecAndProducedRegsts(TaskGraph* gph) { void BpInferShape4LbnInProducedRegsts(TaskGraph*) override {
BuildExecAndProducedRegstsForNetCopy(gph); return CommNetInferShape4LbnInProducedRegsts();
} }
} // namespace oneflow } // namespace oneflow
...@@ -29,13 +29,17 @@ class CommNetTaskNode final : public TaskNode { ...@@ -29,13 +29,17 @@ class CommNetTaskNode final : public TaskNode {
} }
std::string VisualStr() const override { std::string VisualStr() const override {
return TaskNode::VisualStr() + "CommNet_" + node_id_str(); return TaskNode::VisualStr() + "CommNet";
} }
private: private:
void BuildExecAndProducedRegstsForNetCopy(TaskGraph*); void CommNetBuildExecAndEnrollLbn2Regsts(TaskGraph*);
void FwBuildExecAndProducedRegsts(TaskGraph*) override; void CommNetInferShape4LbnInProducedRegsts();
void BpBuildExecAndProducedRegsts(TaskGraph*) override;
void FwBuildExecAndEnrollLbn2Regsts(TaskGraph*) override;
void FwInferShape4LbnInProducedRegsts(TaskGraph*) override;
void BpBuildExecAndEnrollLbn2Regsts(TaskGraph*) override;
void BpInferShape4LbnInProducedRegsts(TaskGraph*) override;
std::unique_ptr<TaskNode> CreateSameTypeNode() const override { std::unique_ptr<TaskNode> CreateSameTypeNode() const override {
return of_make_unique<CommNetTaskNode> (); return of_make_unique<CommNetTaskNode> ();
......
...@@ -5,27 +5,52 @@ ...@@ -5,27 +5,52 @@
namespace oneflow { namespace oneflow {
void CompTaskNode::FwBuildExecAndProducedRegsts(TaskGraph* gph) { std::string CompTaskNode::VisualStr() const override {
(this->*(gph->Func4FwBuildExecAndProducedRegsts()))(gph); std::stringstream ss;
ss << TaskNode::VisualStr()
<< "Compute" << ":"
<< stage_node()->machine_id_str() << ":"
<< thrd_loc_id_str() << "\\n"
<< chain_node()->VisualStr();
return ss.str();
} }
void CompTaskNode::DataFwBuildExecAndProducedRegsts(TaskGraph*) { void CompTaskNode::DataFwBuildExecAndEnrollLbn2Regsts(TaskGraph*) {
Lbn2NodeBnMap lbn2producer; Lbn2NodeBnMap lbn2producer;
Lbn2NodeBnMap extern_in_lbn2consumer; Lbn2NodeBnMap extern_in_lbn2consumer;
FwBuildFromUserOps(&lbn2producer, &extern_in_lbn2consumer); FwBuildFromUserOps(&lbn2producer, &extern_in_lbn2consumer);
mut_exec_gph().UpdateSourceAndSink(); mut_exec_gph().UpdateSourceAndSink();
// data regst // produced regsts
auto data_regst = of_make_unique<DisContigRegstDesc> (); auto out_regst = RegstDescMgr::Singleton().CreateRegisterDesc();
BindProducedRegstAndOutEdge(data_regst.get(), SoleOutEdge()); auto activation_regst = RegstDescMgr::Singleton().CreateRegisterDesc();
EnrollProducedRegstDesc("data", std::move(data_regst)); auto data_tmp_regst = RegstDescMgr::Singleton().CreateRegisterDesc();
FwSetDataRegstDesc(lbn2producer, extern_in_lbn2consumer); auto model_tmp_regst = RegstDescMgr::Singleton().CreateRegisterDesc();
// model_tmp regst // Bind Out Edge
auto model_tmp_regst = of_make_unique<DisContigRegstDesc> (); BindProducedRegstAndOutEdge(out_regst.get(), SoleOutEdge());
// EnrollProducedRegstDesc
EnrollProducedRegstDesc("out", std::move(out_regst));
EnrollProducedRegstDesc("activation", std::move(activation_regst));
EnrollProducedRegstDesc("data_tmp", std::move(data_tmp_regst));
EnrollProducedRegstDesc("model_tmp", std::move(model_tmp_regst)); EnrollProducedRegstDesc("model_tmp", std::move(model_tmp_regst));
FwSetModelTmpRegstDesc(); // Enroll Lbn
FwSetExecNodeFromInRegst(extern_in_lbn2consumer);
FwEnrollLbn2OutRegst(lbn2producer);
FwEnrollLbn2ActivationRegst();
FwEnrollLbn2TmpRegsts();
} }
void CompTaskNode::MdUpdtFwBuildExecAndProducedRegsts(TaskGraph* gph) { void CompTaskNode::DataFwInferShape4LbnInProducedRegsts() {
for (const ExecNode& node : exec_gph()) {
node.op()->InferShape4ObAndDtbFromIb(node.BnInOp2ShapePtr());
node.op()->InferShape4ModelTmpBlob(node.BnInOp2ShapePtr(),
chain_node()->parallel_desc()->policy(),
parallel_id());
}
}
void CompTaskNode::MdUpdtFwBuildExecAndEnrollLbn2Regsts(TaskGraph* gph) {
TODO();
/*
if (IsFaker()) { if (IsFaker()) {
CompTaskNode* mccoy = gph->faker2mccoy().at(this); CompTaskNode* mccoy = gph->faker2mccoy().at(this);
RegstDesc* regst = mccoy->GetProducedRegstDesc("model_diff"); RegstDesc* regst = mccoy->GetProducedRegstDesc("model_diff");
...@@ -39,9 +64,16 @@ void CompTaskNode::MdUpdtFwBuildExecAndProducedRegsts(TaskGraph* gph) { ...@@ -39,9 +64,16 @@ void CompTaskNode::MdUpdtFwBuildExecAndProducedRegsts(TaskGraph* gph) {
mut_exec_gph().UpdateSourceAndSink(); mut_exec_gph().UpdateSourceAndSink();
// PostProcessing in ModelUpdateTaskGraph will complete the work // PostProcessing in ModelUpdateTaskGraph will complete the work
// which should be implemented in this function // which should be implemented in this function
*/
}
void CompTaskNode::MdUpdtFwInferShape4LbnInProducedRegsts(TaskGraph* gph) {
TODO();
} }
void CompTaskNode::MdLoadFwBuildExecAndProducedRegsts(TaskGraph* gph) { void CompTaskNode::MdLoadFwBuildExecAndEnrollLbn2Regsts(TaskGraph* gph) {
TODO();
/*
if (IsFaker()) { if (IsFaker()) {
CompTaskNode* update_task = gph->faker2mccoy().at(this); CompTaskNode* update_task = gph->faker2mccoy().at(this);
ExecNode* exec_node = update_task->exec_gph().SoleNode(); ExecNode* exec_node = update_task->exec_gph().SoleNode();
...@@ -58,9 +90,16 @@ void CompTaskNode::MdLoadFwBuildExecAndProducedRegsts(TaskGraph* gph) { ...@@ -58,9 +90,16 @@ void CompTaskNode::MdLoadFwBuildExecAndProducedRegsts(TaskGraph* gph) {
exec_node->op()->SetShapePtr(exec_node->op()->SoleObn(), shape_ptr); exec_node->op()->SetShapePtr(exec_node->op()->SoleObn(), shape_ptr);
exec_node->op()->InferShape4ObAndDtbFromIb(); exec_node->op()->InferShape4ObAndDtbFromIb();
EnrollProducedRegstDesc("model_regst", std::move(model_regst)); EnrollProducedRegstDesc("model_regst", std::move(model_regst));
*/
} }
void CompTaskNode::MdSaveFwBuildExecAndProducedRegsts(TaskGraph* gph) { void CompTaskNode::MdLoadFwInferShape4LbnInProducedRegsts(TaskGraph* gph) {
TODO();
}
void CompTaskNode::MdSaveFwBuildExecAndEnrollLbn2Regsts(TaskGraph* gph) {
TODO();
/*
if (IsFaker()) { if (IsFaker()) {
CompTaskNode* update_task = gph->faker2mccoy().at(this); CompTaskNode* update_task = gph->faker2mccoy().at(this);
RegstDesc* model_regst = update_task->GetProducedRegstDesc("model"); RegstDesc* model_regst = update_task->GetProducedRegstDesc("model");
...@@ -72,6 +111,11 @@ void CompTaskNode::MdSaveFwBuildExecAndProducedRegsts(TaskGraph* gph) { ...@@ -72,6 +111,11 @@ void CompTaskNode::MdSaveFwBuildExecAndProducedRegsts(TaskGraph* gph) {
mut_exec_gph().UpdateSourceAndSink(); mut_exec_gph().UpdateSourceAndSink();
const std::string& ibn = exec_node->op()->SoleIbn(); const std::string& ibn = exec_node->op()->SoleIbn();
exec_node->BindBnInOpAndRegst(ibn, GetRelatedRegst(SoleInEdge())); exec_node->BindBnInOpAndRegst(ibn, GetRelatedRegst(SoleInEdge()));
*/
}
void CompTaskNode::MdSaveFwInferShape4LbnInProducedRegsts(TaskGraph* gph) {
TODO();
} }
void CompTaskNode::FwBuildFromUserOps( void CompTaskNode::FwBuildFromUserOps(
...@@ -103,80 +147,97 @@ void CompTaskNode::FwBuildFromUserOps( ...@@ -103,80 +147,97 @@ void CompTaskNode::FwBuildFromUserOps(
} }
} }
void CompTaskNode::FwSetDataRegstDesc( void CompTaskNode::FwSetExecNodeFromInRegst(
const Lbn2NodeBnMap& lbn2producer,
const Lbn2NodeBnMap& extern_in_lbn2consumer) { const Lbn2NodeBnMap& extern_in_lbn2consumer) {
RegstDesc* in_regst = GetRelatedRegst(SoleInEdge()); RegstDesc* in_regst = GetRelatedRegst(SoleInEdge());
RegstDesc* out_regst = GetRelatedRegst(SoleOutEdge());
// blob on exec_edge
for (const std::unique_ptr<ExecEdge>& edge : exec_gph().edges()) {
Shape* ptr = out_regst->EnrollLbn(edge->lbn());
edge->src_node()->BindBnInOpAndRegst(edge->src_bn(), out_regst);
edge->src_node()->op()->SetShapePtr(edge->src_bn(), ptr);
edge->dst_node()->BindBnInOpAndRegst(edge->dst_bn(), out_regst);
edge->dst_node()->op()->SetShapePtr(edge->dst_bn(), ptr);
}
// extern in blobs
for (const auto& pair : extern_in_lbn2consumer) { for (const auto& pair : extern_in_lbn2consumer) {
const std::string& lbn = pair.first; const std::string& lbn = pair.first;
Shape* ptr = in_regst->GetMutShapePtr(lbn); Shape* ptr = in_regst->GetMutShapePtr(lbn);
ExecNode* node = pair.second.first; ExecNode* node = pair.second.first;
const std::string& ibn = pair.second.second; const std::string& ibn = pair.second.second;
node->op()->SetShapePtr(ibn, ptr); node->BindBnInOpAndShapePtr(ibn, ptr);
node->BindBnInOpAndRegst(ibn, in_regst); node->BindBnInOpAndRegst(ibn, in_regst);
} }
// extern out blobs }
void CompTaskNode::FwEnrollLbn2OutRegst(const Lbn2NodeBnMap& lbn2producer) {
RegstDesc* out_regst = GetRelatedRegst(SoleOutEdge());
for (const std::string& lbn : chain_node()->output_lbns()) { for (const std::string& lbn : chain_node()->output_lbns()) {
const std::pair<ExecNode*, std::string>& producer = lbn2producer.at(lbn); const std::pair<ExecNode*, std::string>& producer = lbn2producer.at(lbn);
ExecNode* node = producer.first; ExecNode* node = producer.first;
const std::string& obn = producer.second; const std::string& obn = producer.second;
Shape* ptr = out_regst->EnrollLbn(lbn); Shape* ptr = out_regst->EnrollLbn(lbn);
node->op()->SetShapePtr(obn, ptr); node->BindBnInOpAndShapePtr(obn, ptr);
node->BindBnInOpAndRegst(obn, out_regst); node->BindBnInOpAndRegst(obn, out_regst);
} }
// data tmp blobs }
for (const std::unique_ptr<ExecNode>& node : exec_gph().nodes()) {
for (const std::string& dtbn : node->op()->data_tmp_bns()) { void CompTaskNode::FwEnrollLbn2ActivationRegst() {
std::string lbn = node->op()->dtbn2lbn(dtbn); RegstDesc* activation_regst = GetProducedRegstDesc("activation");
Shape* ptr = out_regst->EnrollLbn(lbn); for (const std::unique_ptr<ExecEdge>& edge : exec_gph().edges()) {
node->op()->SetShapePtr(dtbn, ptr); Shape* ptr = activation_regst->EnrollLbn(edge->lbn());
node->BindBnInOpAndRegst(dtbn, out_regst); edge->src_node()->BindBnInOpAndRegst(edge->src_bn(), activation_regst);
} edge->src_node()->BindBnInOpAndShapePtr(edge->src_bn(), ptr);
} edge->dst_node()->BindBnInOpAndRegst(edge->dst_bn(), activation_regst);
// Inference Shape edge->dst_node()->BindBnInOpAndShapePtr(edge->dst_bn(), ptr);
for (const ExecNode& node : exec_gph()) {
node.op()->InferShape4ObAndDtbFromIb();
} }
} }
void CompTaskNode::FwSetModelTmpRegstDesc() { void CompTaskNode::FwEnrollLbn2TmpRegsts() {
RegstDesc* data_tmp_regst = GetProducedRegstDesc("data_tmp");
RegstDesc* model_tmp_regst = GetProducedRegstDesc("model_tmp"); RegstDesc* model_tmp_regst = GetProducedRegstDesc("model_tmp");
for (const std::unique_ptr<ExecNode>& node : exec_gph().nodes()) { for (const std::unique_ptr<ExecNode>& node : exec_gph().nodes()) {
for (const std::string& dtbn : node->op()->data_tmp_bns()) {
std::string lbn = node->op()->dtbn2lbn(dtbn);
Shape* ptr = data_tmp_regst->EnrollLbn(lbn);
node->BindBnInOpAndShapePtr(dtbn, ptr);
node->BindBnInOpAndRegst(dtbn, out_regst);
}
for (const std::string& mtbn : node->op()->model_tmp_bns()) { for (const std::string& mtbn : node->op()->model_tmp_bns()) {
std::string lbn = node->op()->mtbn2lbn(mtbn); std::string lbn = node->op()->mtbn2lbn(mtbn);
Shape* ptr = model_tmp_regst->EnrollLbn(lbn); Shape* ptr = model_tmp_regst->EnrollLbn(lbn);
node->op()->SetShapePtr(mtbn, ptr); node->BindBnInOpAndShapePtr(mtbn, ptr);
node->BindBnInOpAndRegst(mtbn, model_tmp_regst); node->BindBnInOpAndRegst(mtbn, model_tmp_regst);
} }
node->op()->InferShape4ModelTmpBlob(chain_node()->parallel_desc()->policy(),
parallel_id());
} }
} }
void CompTaskNode::BpBuildExecAndProducedRegsts(TaskGraph*) { void CompTaskNode::BpBuildExecAndEnrollLbn2Regsts(TaskGraph*) {
const ExecGraph& fw_gph = GetFwNode()->exec_gph(); const ExecGraph& fw_gph = GetFwNode()->exec_gph();
HashMap<const ExecNode*, ExecNode*> fw_node2bp_node; HashMap<const ExecNode*, ExecNode*> fw_node2bp_node;
HashMap<ExecEdge*, const ExecEdge*> bp_edge2fw_edge; HashMap<ExecEdge*, const ExecEdge*> bp_edge2fw_edge;
BpBuildExecGraph(fw_gph, &fw_node2bp_node, &bp_edge2fw_edge); BpBuildExecGraph(fw_gph, &fw_node2bp_node, &bp_edge2fw_edge);
// // Produced registers
auto data_diff_regst = of_make_unique<DisContigRegstDesc> (); auto in_diff_regst = RegstDescMgr::Singleton().CreateRegisterDesc();
BindProducedRegstAndOutEdge(data_diff_regst.get(), SoleOutEdge()); auto model_diff_regst = RegstDescMgr::Singleton().CreateRegisterDesc();
EnrollProducedRegstDesc("data_diff", std::move(data_diff_regst)); auto activation_diff_regst = RegstDescMgr::Singleton().CreateRegisterDesc();
BpSetDataDiffRegst(fw_node2bp_node, bp_edge2fw_edge); // Bind out edge
// BindProducedRegstAndOutEdge(in_diff_regst.get(), SoleOutEdge());
auto model_diff_regst = of_make_unique<ContigRegstDesc> (); // Enroll registers
EnrollProducedRegstDesc("in_diff", std::move(in_diff_regst));
EnrollProducedRegstDesc("model_diff", std::move(model_diff_regst)); EnrollProducedRegstDesc("model_diff", std::move(model_diff_regst));
BpSetModelDiffRegst(); EnrollProducedRegstDesc("activation_diff", std::move(activation_diff_regst));
// Enroll Lbn
BpEnrollLbn2ProducedRegst(fw_node2bp_node, bp_edge2fw_edge);
}
void CompTaskNode::BpInferShape4LbnInProducedRegsts(TaskGraph*) {
// in_diff_regst
RegstDesc* in_diff_regst = GetRelatedRegst(SoleOutEdge());
RegstDesc* in_regst = GetRelatedRegst(GetFwNode()->SoleInEdge());
in_diff_regst->CopyShapeFrom(in_regst);
// model_diff_regst
RegstDesc* model_diff_regst = GetProducedRegstDesc("model_diff");
for (const std::unique_ptr<ExecNode>& cur_node : exec_gph().nodes()) {
cur_node->op()->InferShape4ModelDiffBlob(
cur_node->BnInOp2ShapePtr(),
chain_node()->parallel_desc()->policy(),
parallel_id());
}
// activation_diff_regst
RegstDesc* activation_diff_regst = GetProducedRegstDesc("activation_diff");
RegstDesc* activation_regst = GetFwNode()->GetProducedRegstDesc("activation");
activation_diff_regst->CopyShapeFrom(activation_regst);
} }
void CompTaskNode::BpBuildExecGraph( void CompTaskNode::BpBuildExecGraph(
...@@ -200,21 +261,23 @@ void CompTaskNode::BpBuildExecGraph( ...@@ -200,21 +261,23 @@ void CompTaskNode::BpBuildExecGraph(
} }
} }
void CompTaskNode::BpSetDataDiffRegst( void CompTaskNode::BpEnrollLbn2ProducedRegst(
const HashMap<const ExecNode*, ExecNode*>& fw_node2bp_node, const HashMap<const ExecNode*, ExecNode*>& fw_node2bp_node,
const HashMap<ExecEdge*, const ExecEdge*>& bp_edge2fw_edge) { const HashMap<ExecEdge*, const ExecEdge*>& bp_edge2fw_edge) {
// Regsts // Regsts
RegstDesc* in_diff_regst = GetRelatedRegst(SoleOutEdge()); RegstDesc* in_diff_regst = GetRelatedRegst(SoleOutEdge());
RegstDesc* out_diff_regst = GetRelatedRegst(SoleInEdge()); RegstDesc* out_diff_regst = GetRelatedRegst(SoleInEdge());
RegstDesc* in_regst = GetRelatedRegst(GetFwNode()->SoleInEdge()); RegstDesc* in_regst = GetRelatedRegst(GetFwNode()->SoleInEdge());
RegstDesc* out_regst = GetRelatedRegst(GetFwNode()->SoleOutEdge()); RegstDesc* activation_regst = GetFwNode()->GetProducedRegstDesc("activation");
RegstDesc* data_tmp_regst = GetFwNode()->GetProducedRegstDesc("data_tmp");
RegstDesc* model_tmp_regst = GetFwNode()->GetProducedRegstDesc("model_tmp"); RegstDesc* model_tmp_regst = GetFwNode()->GetProducedRegstDesc("model_tmp");
RegstDesc* activation_diff_regst = GetProducedRegstDesc("activation_diff");
RegstDesc* model_diff_regst = GetProducedRegstDesc("model_diff");
// blobs on edge // blobs on edge
for (const std::unique_ptr<ExecEdge>& edge : exec_gph().edges()) { for (const std::unique_ptr<ExecEdge>& edge : exec_gph().edges()) {
Shape* ptr = in_diff_regst->EnrollLbn(edge->lbn()); activation_diff_regst->EnrollLbn(edge->lbn());
*ptr = out_regst->GetShape(edge->lbn()); edge->src_node()->BindBnInOpAndRegst(edge->src_bn(), activation_diff_regst);
edge->src_node()->BindBnInOpAndRegst(edge->src_bn(), in_diff_regst); edge->dst_node()->BindBnInOpAndRegst(edge->dst_bn(), activation_diff_regst);
edge->dst_node()->BindBnInOpAndRegst(edge->dst_bn(), in_diff_regst);
} }
// extern out_diff blobs // extern out_diff blobs
for (const std::unique_ptr<ExecNode>& bp_node : exec_gph().nodes()) { for (const std::unique_ptr<ExecNode>& bp_node : exec_gph().nodes()) {
...@@ -237,35 +300,27 @@ void CompTaskNode::BpSetDataDiffRegst( ...@@ -237,35 +300,27 @@ void CompTaskNode::BpSetDataDiffRegst(
for (const std::string& idbn : bp_node->op()->input_diff_bns()) { for (const std::string& idbn : bp_node->op()->input_diff_bns()) {
if (found_bns.find(idbn) != found_bns.end()) { continue; } if (found_bns.find(idbn) != found_bns.end()) { continue; }
std::string lbn = bp_node->op()->idbn2lbn(idbn); std::string lbn = bp_node->op()->idbn2lbn(idbn);
Shape* ptr = in_diff_regst->EnrollLbn(lbn); in_diff_regst->EnrollLbn(lbn);
*ptr = in_regst->GetShape(lbn);
bp_node->BindBnInOpAndRegst(idbn, in_diff_regst); bp_node->BindBnInOpAndRegst(idbn, in_diff_regst);
bp_node->BindBnInOpAndRegst(GenUnDiffBn(idbn), in_regst);
} }
} }
// tmp blobs // tmp blobs and model_diff blobs
for (const std::unique_ptr<ExecNode>& node : exec_gph().nodes()) { for (const std::unique_ptr<ExecNode>& node : exec_gph().nodes()) {
for (const std::string& dtbn : node->op()->data_tmp_bns()) { for (const std::string& dtbn : node->op()->data_tmp_bns()) {
std::string lbn = node->op()->dtbn2lbn(dtbn); std::string lbn = node->op()->dtbn2lbn(dtbn);
node->BindBnInOpAndRegst(dtbn, out_regst); node->BindBnInOpAndRegst(dtbn, data_tmp_regst);
} }
for (const std::string& mtbn : node->op()->model_tmp_bns()) { for (const std::string& mtbn : node->op()->model_tmp_bns()) {
std::string lbn = node->op()->mtbn2lbn(mtbn); std::string lbn = node->op()->mtbn2lbn(mtbn);
node->BindBnInOpAndRegst(mtbn, model_tmp_regst); node->BindBnInOpAndRegst(mtbn, model_tmp_regst);
} }
}
}
void CompTaskNode::BpSetModelDiffRegst() {
RegstDesc* model_diff_regst = GetProducedRegstDesc("model_diff");
for (const std::unique_ptr<ExecNode>& cur_node : exec_gph().nodes()) {
for (const std::string& mdbn : cur_node->op()->model_diff_bns()) { for (const std::string& mdbn : cur_node->op()->model_diff_bns()) {
std::string lbn = cur_node->op()->mdbn2lbn(mdbn); std::string lbn = cur_node->op()->mdbn2lbn(mdbn);
Shape* ptr = model_diff_regst->EnrollLbn(lbn); Shape* ptr = model_diff_regst->EnrollLbn(lbn);
cur_node->op()->SetShapePtr(mdbn, ptr); cur_node->BindBnInOpAndShapePtr(mdbn, ptr);
cur_node->BindBnInOpAndRegst(mdbn, model_diff_regst); cur_node->BindBnInOpAndRegst(mdbn, model_diff_regst);
} }
cur_node->op()->InferShape4ModelDiffBlob(chain_node()->parallel_desc()->policy(),
parallel_id());
} }
} }
......
...@@ -12,55 +12,60 @@ class CompTaskNode : public TaskNode { ...@@ -12,55 +12,60 @@ class CompTaskNode : public TaskNode {
CompTaskNode() = default; CompTaskNode() = default;
virtual ~CompTaskNode() = default; virtual ~CompTaskNode() = default;
// Getters and Setters
uint64_t parallel_id() const { return parallel_id_; } uint64_t parallel_id() const { return parallel_id_; }
void set_parallel_id(uint64_t parallel_id) { parallel_id_ = parallel_id; } void set_parallel_id(uint64_t parallel_id) { parallel_id_ = parallel_id; }
bool IsLossNode() const { return chain_node()->IsLossNode(); } bool IsLossNode() const { return chain_node()->IsLossNode(); }
bool IsFaker() const { return chain_node()->IsFaker(); } bool IsFaker() const { return chain_node()->IsFaker(); }
std::string VisualStr() const override;
void DataFwBuildExecAndProducedRegsts(TaskGraph*); // Build Exec and Set Produced Regsts
void MdUpdtFwBuildExecAndProducedRegsts(TaskGraph*); void DataFwBuildExecAndEnrollLbn2Regsts(TaskGraph*);
void MdLoadFwBuildExecAndProducedRegsts(TaskGraph*); void DataFwInferShape4LbnInProducedRegsts(TaskGraph*);
void MdSaveFwBuildExecAndProducedRegsts(TaskGraph*);
std::string VisualStr() const override { void MdUpdtFwBuildExecAndEnrollLbn2Regsts(TaskGraph*);
std::stringstream ss; void MdUpdtFwInferShape4LbnInProducedRegsts(TaskGraph*);
ss << TaskNode::VisualStr()
<< "Compute_" << node_id_str() << ":" void MdLoadFwBuildExecAndEnrollLbn2Regsts(TaskGraph*);
<< stage_node()->machine_id_str() << ":" void MdLoadFwInferShape4LbnInProducedRegsts(TaskGraph*);
<< thrd_loc_id_str() << "\\n"
<< chain_node()->VisualStr(); void MdSaveFwBuildExecAndEnrollLbn2Regsts(TaskGraph*);
return ss.str(); void MdSaveFwInferShape4LbnInProducedRegsts(TaskGraph*);
}
protected: protected:
virtual void InitWithFwNode(TaskNode* fw_node) override { virtual void InitWithFwNode(TaskNode* fw_node) override {
TaskNode::InitWithFwNode(fw_node); TaskNode::InitWithFwNode(fw_node);
parallel_id_ = of_dynamic_cast<CompTaskNode*> (fw_node)->parallel_id_;
} }
private: private:
using Lbn2NodeBnMap = using Lbn2NodeBnMap =
HashMap<std::string, std::pair<ExecNode*, std::string>>; HashMap<std::string, std::pair<ExecNode*, std::string>>;
void FwBuildExecAndProducedRegsts(TaskGraph*) override; void FwBuildExecAndEnrollLbn2Regsts(TaskGraph* gph) override {
(this->*(gph->Func4FwBuildExecAndEnrollLbn2Regsts()))(gph);
}
void FwInferShape4LbnInProducedRegsts(TaskGraph* gph) override {
(this->*(gph->Func4FwInferShape4LbnInProducedRegsts()))(gph);
}
void FwBuildFromUserOps( void FwBuildFromUserOps(
Lbn2NodeBnMap* lbn2producer, Lbn2NodeBnMap* lbn2producer,
Lbn2NodeBnMap* extern_in_lbn2consumer); Lbn2NodeBnMap* extern_in_lbn2consumer);
void FwSetDataRegstDesc( void FwSetExecNodeFromInRegst(
const Lbn2NodeBnMap& lbn2producer,
const Lbn2NodeBnMap& extern_in_lbn2consumer); const Lbn2NodeBnMap& extern_in_lbn2consumer);
void FwSetModelTmpRegstDesc(); void FwEnrollLbn2OutRegst(const Lbn2NodeBnMap& lbn2producer);
void FwEnrollLbn2ActivationRegst();
void FwEnrollLbn2TmpRegsts();
void BpBuildExecAndProducedRegsts(TaskGraph*) override; void BpBuildExecAndEnrollLbn2Regsts(TaskGraph*) override;
void BpInferShape4LbnInProducedRegsts(TaskGraph*) override;
void BpBuildExecGraph( void BpBuildExecGraph(
const ExecGraph& fw_gph, const ExecGraph& fw_gph,
HashMap<const ExecNode*, ExecNode*>* fw_node2bp_node, HashMap<const ExecNode*, ExecNode*>* fw_node2bp_node,
HashMap<ExecEdge*, const ExecEdge*>* bp_edge2fw_edge); HashMap<ExecEdge*, const ExecEdge*>* bp_edge2fw_edge);
void BpSetDataDiffRegst( void BpEnrollLbn2ProducedRegst(
const HashMap<const ExecNode*, ExecNode*>& fw_node2bp_node, const HashMap<const ExecNode*, ExecNode*>& fw_node2bp_node,
const HashMap<ExecEdge*, const ExecEdge*>& bp_edge2fw_edge); const HashMap<ExecEdge*, const ExecEdge*>& bp_edge2fw_edge);
void BpSetModelDiffRegst();
uint64_t parallel_id_; uint64_t parallel_id_;
......
...@@ -4,52 +4,6 @@ ...@@ -4,52 +4,6 @@
namespace oneflow { namespace oneflow {
void CopyHDTaskNode::BuildExecAndProducedRegstsForCopy(TaskGraph* gph){
auto out_regst = of_make_unique<DisContigRegstDesc> ();
BindProducedRegstAndOutEdge(out_regst.get(), SoleOutEdge());
OperatorConf op_conf;
op_conf.set_name("copy_" + NewUniqueId());
CopyOpConf* copy_conf = op_conf.mutable_copy_conf();
copy_conf->set_copy_type(
IsH2D() ? CopyOpConf::H2D : CopyOpConf::D2H);
for(std::string lbn : CopiedLbns()){
copy_conf->add_copied_lbns(lbn);
}
RegstDesc* in_regst = GetRelatedRegst(SoleInEdge());
bool get_kalllbn = false;
if(copy_conf->copied_lbns_size() == 1
&& copy_conf->copied_lbns(0) == RegstDesc::kAllLbn){
out_regst->CopyLbn2ShapeMap(in_regst);
get_kalllbn = true;
}
ExecNode* node = mut_exec_gph().NewNode();
node->mut_op() = OpMgr::Singleton().ConstructOp(op_conf);
for(std::string ibn : node->op()->input_bns()){
std::string lbn = node->op()->ibn2lbn(ibn);
Shape* shape_ptr = in_regst->GetMutShapePtr(lbn);
node->op()->SetShapePtr(ibn, shape_ptr);
node->BindBnInOpAndRegst(ibn, in_regst);
}
for(std::string obn : node->op()->output_bns()){
std::string lbn = node->op()->obn2lbn(obn);
Shape* shape_ptr = nullptr;
if(!get_kalllbn){
shape_ptr = out_regst->EnrollLbn(lbn);
} else {
shape_ptr = out_regst->GetMutShapePtr(lbn);
}
node->op()->SetShapePtr(obn, shape_ptr);
node->BindBnInOpAndRegst(obn, out_regst.get());
}
if(!get_kalllbn){
node->op()->InferShape4ObAndDtbFromIb();
}
mut_exec_gph().UpdateSourceAndSink();
EnrollProducedRegstDesc("copy", std::move(out_regst));
}
void CopyHDTaskNode::SetFwInCopy() { void CopyHDTaskNode::SetFwInCopy() {
CHECK(IsFwNode()); CHECK(IsFwNode());
is_fw_in_copy_ = true; is_fw_in_copy_ = true;
...@@ -60,21 +14,52 @@ void CopyHDTaskNode::SetFwOutCopy() { ...@@ -60,21 +14,52 @@ void CopyHDTaskNode::SetFwOutCopy() {
is_fw_in_copy_ = false; is_fw_in_copy_ = false;
} }
const std::vector<std::string>& CopyHDTaskNode::CopiedLbns() const {
return IsFwInCopy() ? chain_node()->input_lbns() : chain_node()->output_lbns();
}
void CopyHDTaskNode::InitWithFwNode(TaskNode* fw_node) { void CopyHDTaskNode::InitWithFwNode(TaskNode* fw_node) {
TaskNode::InitWithFwNode(fw_node); TaskNode::InitWithFwNode(fw_node);
is_fw_in_copy_ = of_dynamic_cast<CopyHDTaskNode*>(fw_node)->is_fw_in_copy_; is_fw_in_copy_ = of_dynamic_cast<CopyHDTaskNode*>(fw_node)->is_fw_in_copy_;
} }
void CopyHDTaskNode::FwBuildExecAndProducedRegsts(TaskGraph* gph) { void CopyHDTaskNode::FwBuildExecAndEnrollLbn2Regsts(TaskGraph*) {
BuildExecAndProducedRegstsForCopy(gph); return CopyHdBuildExecAndEnrollLbn2Regsts();
} }
void CopyHDTaskNode::BpBuildExecAndProducedRegsts(TaskGraph* gph) { void CopyHDTaskNode::FwInferShape4LbnInProducedRegsts(TaskGraph*) {
BuildExecAndProducedRegstsForCopy(gph); return CopyHdInferShape4LbnInProducedRegsts();
}
void CopyHDTaskNode::BpBuildExecAndEnrollLbn2Regsts(TaskGraph*) {
return CopyHdBuildExecAndEnrollLbn2Regsts();
}
void CopyHDTaskNode::BpInferShape4LbnInProducedRegsts(TaskGraph*) {
return CopyHdInferShape4LbnInProducedRegsts();
}
void CopyHDTaskNode::CopyHdBuildExecAndEnrollLbn2Regsts(){
auto out_regst = RegstDescMgr::Singleton().CreateRegisterDesc();
BindProducedRegstAndOutEdge(out_regst.get(), SoleOutEdge());
RegstDesc* in_regst = GetRelatedRegst(SoleInEdge());
out_regst->CopyLbnFrom(in_regst);
OperatorConf op_conf;
op_conf.set_name("copy_hd_" + NewUniqueId());
CopyHdOpConf* copy_hd_conf = op_conf.mutable_copy_hd_conf();
copy_hd_conf->set_type(IsH2D() ? CopyHdOpConf::H2D : CopyHdOpConf::D2H);
ExecNode* node = mut_exec_gph().NewNode();
node->mut_op() = OpMgr::Singleton().ConstructOp(op_conf);
node->BindBnInOpAndRegst(node->op()->SoleIbn(), in_regst);
node->BindBnInOpAndRegst(node->op()->SoleObn(), out_regst.get());
mut_exec_gph().UpdateSourceAndSink();
EnrollProducedRegstDesc("out", std::move(out_regst));
}
void CopyHDTaskNode::CopyHdInferShape4LbnInProducedRegsts() {
RegstDesc* in_regst = GetRelatedRegst(SoleInEdge());
RegstDesc* out_regst = GetRelatedRegst(SoleOutEdge());
out_regst->CopyShapeFrom(in_regst);
} }
} // namespace oneflow } // namespace oneflow
...@@ -23,22 +23,24 @@ class CopyHDTaskNode final : public TaskNode { ...@@ -23,22 +23,24 @@ class CopyHDTaskNode final : public TaskNode {
void SetFwInCopy(); void SetFwInCopy();
void SetFwOutCopy(); void SetFwOutCopy();
const std::vector<std::string>& CopiedLbns() const;
std::string VisualStr() const override { std::string VisualStr() const override {
return TaskNode::VisualStr() + "CopyHD_" + node_id_str(); return TaskNode::VisualStr() + "CopyHD";
} }
private: private:
void InitWithFwNode(TaskNode* fw_node) override; void InitWithFwNode(TaskNode* fw_node) override;
void BuildExecAndProducedRegstsForCopy(TaskGraph*);
void FwBuildExecAndProducedRegsts(TaskGraph*) override;
void BpBuildExecAndProducedRegsts(TaskGraph*) override;
std::unique_ptr<TaskNode> CreateSameTypeNode() const override { std::unique_ptr<TaskNode> CreateSameTypeNode() const override {
return of_make_unique<CopyHDTaskNode> (); return of_make_unique<CopyHDTaskNode> ();
} }
void FwBuildExecAndEnrollLbn2Regsts(TaskGraph*) override;
void FwInferShape4LbnInProducedRegsts(TaskGraph*) override;
void BpBuildExecAndEnrollLbn2Regsts(TaskGraph*) override;
void BpInferShape4LbnInProducedRegsts(TaskGraph*) override;
void CopyHdBuildExecAndEnrollLbn2Regsts();
void CopyHdInferShape4LbnInProducedRegsts();
bool is_fw_in_copy_; bool is_fw_in_copy_;
}; };
......
...@@ -44,17 +44,25 @@ class ExecNode final : public Node<ExecNode, ExecEdge> { ...@@ -44,17 +44,25 @@ class ExecNode final : public Node<ExecNode, ExecEdge> {
std::shared_ptr<const Operator>& mut_op() { return op_; } std::shared_ptr<const Operator>& mut_op() { return op_; }
void BindBnInOpAndRegst(const std::string& bn_in_op, RegstDesc* regst) { void BindBnInOpAndRegst(const std::string& bn_in_op, RegstDesc* regst) {
CHECK(bn_in_op2regst.emplace(bn_in_op, regst).second); CHECK(bn_in_op2regst_.emplace(bn_in_op, regst).second);
} }
RegstDesc* GetRegstFromBnInOp(const std::string& bn_in_op) { RegstDesc* GetRegstFromBnInOp(const std::string& bn_in_op) {
return bn_in_op2regst.at(bn_in_op); return bn_in_op2regst_.at(bn_in_op);
}
void BindBnInOpAndShapePtr(const std::string& bn_in_op, Shape* shape_ptr) {
CHECK(bn_in_op2shape_ptr_.emplace(bn_in_op, shape_ptr).second);
}
const HashMap<std::string, Shape*>& BnInOp2ShapePtr() const {
return bn_in_op2shape_ptr_;
} }
std::string VisualStr() const { TODO(); } std::string VisualStr() const { TODO(); }
private: private:
std::shared_ptr<const Operator> op_; std::shared_ptr<const Operator> op_;
HashMap<std::string, RegstDesc*> bn_in_op2regst; HashMap<std::string, RegstDesc*> bn_in_op2regst_;
HashMap<std::string, Shape*> bn_in_op2shape_ptr_;
}; };
......
...@@ -16,9 +16,15 @@ inline void TaskConnect(TaskNode* src_node, ...@@ -16,9 +16,15 @@ inline void TaskConnect(TaskNode* src_node,
} }
void TaskGraph::BuildExecAndProducedRegsts() { void TaskGraph::BuildExecAndEnrollLbn2Regsts() {
for (TaskNode& node : *this) { for (TaskNode& node : *this) {
node.BuildExecAndProducedRegsts(this); node.BuildExecAndEnrollLbn2Regsts(this);
}
}
void TaskGraph::InferShape4LbnInProducedRegsts() {
for (TaskNode& node : *this) {
node.InferShape4LbnInProducedRegsts(this);
} }
} }
......
...@@ -17,18 +17,21 @@ class TaskGraph : public Graph<TaskNode, TaskEdge> { ...@@ -17,18 +17,21 @@ class TaskGraph : public Graph<TaskNode, TaskEdge> {
OF_DISALLOW_COPY_AND_MOVE(TaskGraph); OF_DISALLOW_COPY_AND_MOVE(TaskGraph);
virtual ~TaskGraph() = default; virtual ~TaskGraph() = default;
// Getters
const StageGraph* stage_gph() const { return stage_gph_.get(); } const StageGraph* stage_gph() const { return stage_gph_.get(); }
const ChainGraph* chain_gph() const { return stage_gph_->chain_gph(); } const ChainGraph* chain_gph() const { return stage_gph_->chain_gph(); }
const HashMap<CompTaskNode*, CompTaskNode*>& faker2mccoy() { const HashMap<CompTaskNode*, CompTaskNode*>& faker2mccoy() {
return faker2mccoy_; return faker2mccoy_;
} }
std::vector<CompTaskNode*> SortedCompTasksInChain(const ChainNode*) const;
void BuildExecAndProducedRegsts(); // Build Exec And Set Produced Registers
void BuildExecAndEnrollLbn2Regsts();
void InferShape4LbnInProducedRegsts();
typedef void (CompTaskNode::*CompTaskNodeMemFunc)(TaskGraph*); using CompTaskNodeMemFunc = void (CompTaskNode::*)(TaskGraph*);
virtual CompTaskNodeMemFunc Func4FwBuildExecAndProducedRegsts() const = 0; virtual CompTaskNodeMemFunc Func4FwBuildExecAndEnrollLbn2Regsts() const = 0;
virtual CompTaskNodeMemFunc Func4FwInferShape4LbnInProducedRegsts() const = 0;
std::vector<CompTaskNode*> SortedCompTasksInChain(const ChainNode*) const;
protected: protected:
TaskGraph() = default; TaskGraph() = default;
......
...@@ -40,11 +40,19 @@ std::unique_ptr<TaskNode> TaskNode::BuildAndConnectBpNode() { ...@@ -40,11 +40,19 @@ std::unique_ptr<TaskNode> TaskNode::BuildAndConnectBpNode() {
return bp_node; return bp_node;
} }
void TaskNode::BuildExecAndProducedRegsts(TaskGraph* gph) { void TaskNode::BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) {
if (IsFwNode()) { if (IsFwNode()) {
FwBuildExecAndProducedRegsts(gph); FwBuildExecAndEnrollLbn2Regsts(gph);
} else { } else {
BpBuildExecAndProducedRegsts(gph); BpBuildExecAndEnrollLbn2Regsts(gph);
}
}
void TaskNode::InferShape4LbnInProducedRegsts(TaskGraph* gph) {
if (IsFwNode()) {
FwInferShape4LbnInProducedRegsts();
} else {
BpInferShape4LbnInProducedRegsts();
} }
} }
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#include "task/task.pb.h" #include "task/task.pb.h"
#include "graph/stage_graph.h" #include "graph/stage_graph.h"
#include "graph/exec_graph.h" #include "graph/exec_graph.h"
#include "register/register_desc.h" #include "register/register_desc_manager.h"
namespace oneflow { namespace oneflow {
...@@ -39,7 +39,8 @@ class TaskNode : public Node<TaskNode, TaskEdge> { ...@@ -39,7 +39,8 @@ class TaskNode : public Node<TaskNode, TaskEdge> {
std::unique_ptr<TaskNode> BuildAndConnectBpNode(); std::unique_ptr<TaskNode> BuildAndConnectBpNode();
// //
void BuildExecAndProducedRegsts(TaskGraph*); void BuildExecAndEnrollLbn2Regsts(TaskGraph*);
void InferShape4LbnInProducedRegsts(TaskGraph*);
RegstDesc* GetProducedRegstDesc(const std::string& regst_desc_name); RegstDesc* GetProducedRegstDesc(const std::string& regst_desc_name);
// //
...@@ -67,8 +68,10 @@ class TaskNode : public Node<TaskNode, TaskEdge> { ...@@ -67,8 +68,10 @@ class TaskNode : public Node<TaskNode, TaskEdge> {
void EnrollProducedRegstDesc(const std::string& regst_desc_name, void EnrollProducedRegstDesc(const std::string& regst_desc_name,
std::unique_ptr<RegstDesc>&& regst_desc); std::unique_ptr<RegstDesc>&& regst_desc);
virtual void FwBuildExecAndProducedRegsts(TaskGraph*) { UNEXPECTED_RUN(); } virtual void FwBuildExecAndEnrollLbn2Regsts(TaskGraph*) { UNEXPECTED_RUN(); }
virtual void BpBuildExecAndProducedRegsts(TaskGraph*) { UNEXPECTED_RUN(); } virtual void BpBuildExecAndEnrollLbn2Regsts(TaskGraph*) { UNEXPECTED_RUN(); }
virtual void FwInferShape4LbnInProducedRegsts(TaskGraph*) { UNEXPECTED_RUN(); }
virtual void BpInferShape4LbnInProducedRegsts(TaskGraph*) { UNEXPECTED_RUN(); }
private: private:
// In task_gph level // In task_gph level
......
...@@ -13,7 +13,7 @@ class CommNetOp final : public SysOperator { ...@@ -13,7 +13,7 @@ class CommNetOp final : public SysOperator {
~CommNetOp() = default; ~CommNetOp() = default;
void InitFromOpConf(const OperatorConf& op_conf) override; void InitFromOpConf(const OperatorConf& op_conf) override;
void InferShape4ObAndDtbFromIb() const override { TODO(); } void InferShape4ObAndDtbFromIb() const override { UNEXPECTED_RUN(); }
std::string GetValueFromPbOpConf(const std::string& k) const override; std::string GetValueFromPbOpConf(const std::string& k) const override;
std::string normal_ibn2lbn(const std::string& input_bn) const override; std::string normal_ibn2lbn(const std::string& input_bn) const override;
......
#include "operator/copy_hd_op.h"
#include "operator/operator_manager.h"
namespace oneflow {
void CopyHdOp::InitFromOpConf(const OperatorConf& op_conf) {
CHECK(op_conf.has_copy_hd_conf());
mut_op_conf() = op_conf;
EnrollInputBn("in");
EnrollOutputBn("out");
}
std::string CopyHdOp::GetValueFromPbOpConf(const std::string& k) const {
return GetValueFromPbMessage(op_conf().copy_hd_conf(), k);
}
REGISTER_OP(OperatorConf::kCopyConf, CopyHdOp);
} // namespace oneflow
#ifndef ONEFLOW_OPERATOR_COPY_OP_H_ #ifndef ONEFLOW_OPERATOR_COPY_HD_OP_H_
#define ONEFLOW_OPERATOR_COPY_OP_H_ #define ONEFLOW_OPERATOR_COPY_HD_OP_H_
#include "operator/operator.h" #include "operator/operator.h"
namespace oneflow { namespace oneflow {
class CopyOp final : public SysOperator { class CopyHdOp final : public SysOperator {
public: public:
OF_DISALLOW_COPY_AND_MOVE(CopyOp); OF_DISALLOW_COPY_AND_MOVE(CopyHdOp);
CopyOp() = default; CopyHdOp() = default;
~CopyOp() = default; ~CopyHdOp() = default;
void InitFromOpConf(const OperatorConf& op_conf) override; void InitFromOpConf(const OperatorConf& op_conf) override;
void InitFromOperatorProto(const OperatorProto& operatorproto) override;
OperatorProto ToOperatorProto() override;
void InferShape4ObAndDtbFromIb() const override;
std::string GetValueFromPbOpConf(const std::string& k) const override; std::string GetValueFromPbOpConf(const std::string& k) const override;
void InferShape4ObAndDtbFromIb() const override { UNEXPECTED_RUN(); }
std::string normal_ibn2lbn(const std::string& input_bn) const override { std::string normal_ibn2lbn(const std::string& input_bn) const override {
return ibn2lbn_.at(input_bn); return RegstDesc::kAllLbn;
} }
std::string obn2lbn(const std::string& output_bn) const override { std::string obn2lbn(const std::string& output_bn) const override {
return obn2lbn_.at(output_bn); return RegstDesc::kAllLbn;
} }
private: private:
HashMap<std::string, std::string> ibn2lbn_;
HashMap<std::string, std::string> obn2lbn_;
}; };
} // namespace oneflow } // namespace oneflow
#endif // ONEFLOW_OPERATOR_COPY_OP_H_ #endif // ONEFLOW_OPERATOR_COPY_HD_OP_H_
#include "operator/copy_op.h"
#include "operator/operator_manager.h"
namespace oneflow {
void CopyOp::InitFromOpConf(const OperatorConf& op_conf) {
CHECK(op_conf.has_copy_conf());
mut_op_conf() = op_conf;
for (int64_t i = 0; i < op_conf.copy_conf().copied_lbns_size(); ++i) {
std::string ibn = "in_" + std::to_string(i);
EnrollInputBn(ibn);
CHECK(ibn2lbn_.emplace(ibn, op_conf.copy_conf().copied_lbns(i)).second);
std::string obn = "out_" + std::to_string(i);
EnrollOutputBn(obn);
CHECK(obn2lbn_.emplace(obn, op_conf.copy_conf().copied_lbns(i)).second);
}
}
std::string CopyOp::GetValueFromPbOpConf(const std::string& k) const {
return GetValueFromPbMessage(op_conf().copy_conf(), k);
}
void CopyOp::InitFromOperatorProto(const OperatorProto& operatorproto) {
CHECK(operatorproto.has_copy_op());
Operator::InitFromOperatorProto(operatorproto);
ibn2lbn_ = PbMap2HashMap(operatorproto.copy_op().ibn2lbn());
obn2lbn_ = PbMap2HashMap(operatorproto.copy_op().obn2lbn());
}
OperatorProto CopyOp::ToOperatorProto() {
OperatorProto operatorproto = Operator::ToOperatorProto();
CopyOpProto copyopproto;
*(copyopproto.mutable_ibn2lbn()) = HashMap2PbMap(ibn2lbn_);
*(copyopproto.mutable_obn2lbn()) = HashMap2PbMap(obn2lbn_);
*(operatorproto.mutable_copy_op()) = copyopproto;
return operatorproto;
}
void CopyOp::InferShape4ObAndDtbFromIb() const {
CHECK_EQ(output_bns().size(), input_bns().size());
for(size_t i = 0;i < output_bns().size();++ i){
std::string obn = output_bns().at(i);
std::string ibn = input_bns().at(i);
*GetShapePtr(obn) = *GetShapePtr(ibn);
}
}
REGISTER_OP(OperatorConf::kCopyConf, CopyOp);
} // namespace oneflow
...@@ -170,13 +170,12 @@ message CommNetOpConf { ...@@ -170,13 +170,12 @@ message CommNetOpConf {
CommNetType comm_net_type = 1; CommNetType comm_net_type = 1;
} }
message CopyOpConf { message CopyHdOpConf {
enum CopyType { enum CopyHdType {
H2D = 0; H2D = 0;
D2H = 1; D2H = 1;
} }
CopyType copy_type = 1; CopyHdType type = 1;
repeated string copied_lbns = 2;
} }
...@@ -186,11 +185,14 @@ message CloneOpConf { ...@@ -186,11 +185,14 @@ message CloneOpConf {
} }
message BoxConcatConf { message BoxConcatConf {
int32 axis = 1; enum ConcatType {
kData = 0;
kModel = 1;
};
ConcatType type = 1;
} }
message BoxSplitConf { message BoxDataSplitConf {
int32 axis = 1;
} }
message BoxCloneConf { message BoxCloneConf {
...@@ -202,7 +204,7 @@ message BoxingOpConf { ...@@ -202,7 +204,7 @@ message BoxingOpConf {
uint32 out_num = 3; uint32 out_num = 3;
BoxConcatConf concat_box = 4; BoxConcatConf concat_box = 4;
oneof out_box { oneof out_box {
BoxSplitConf split_box = 5; BoxDataSplitConf data_split_box = 5;
BoxCloneConf clone_box = 6; BoxCloneConf clone_box = 6;
} }
} }
...@@ -226,7 +228,7 @@ message OperatorConf { ...@@ -226,7 +228,7 @@ message OperatorConf {
ReluOpConf relu_conf = 104; ReluOpConf relu_conf = 104;
SoftmaxOpConf softmax_conf = 105; SoftmaxOpConf softmax_conf = 105;
MultinomialLogisticLossOpConf multinomial_logistic_loss_conf = 106; MultinomialLogisticLossOpConf multinomial_logistic_loss_conf = 106;
CopyOpConf copy_conf = 107; CopyHdOpConf copy_hd_conf = 107;
CloneOpConf clone_conf = 108; CloneOpConf clone_conf = 108;
BoxingOpConf boxing_conf = 109; BoxingOpConf boxing_conf = 109;
ModelUpdateOpConf model_update_conf = 110; ModelUpdateOpConf model_update_conf = 110;
......
syntax = "proto3"; syntax = "proto3";
package oneflow; package oneflow;
import "operator/op_conf.proto";
message CopyOpProto { import "operator/op_conf.proto";
map<string, string> ibn2lbn = 1;
map<string, string> obn2lbn = 2;
}
message OperatorProto { message OperatorProto {
OperatorConf user_conf = 1; OperatorConf user_conf = 1;
...@@ -22,6 +18,5 @@ message OperatorProto { ...@@ -22,6 +18,5 @@ message OperatorProto {
repeated string model_tmp_bns = 11; repeated string model_tmp_bns = 11;
oneof specified_op_proto { oneof specified_op_proto {
CopyOpProto copy_op = 100;
} }
} }
...@@ -7,14 +7,22 @@ RegstDesc::RegstDesc() { ...@@ -7,14 +7,22 @@ RegstDesc::RegstDesc() {
producer_ = nullptr; producer_ = nullptr;
} }
void RegstDesc::CopyLbn2ShapeMap(const RegstDesc* rhs) { void RegstDesc::CopyLbnFrom(const RegstDesc* rhs) {
lbn2shape_.clear();
for (const auto& pair : rhs->lbn2shape_) { for (const auto& pair : rhs->lbn2shape_) {
const std::string& lbn = pair.first; const std::string& lbn = pair.first;
auto shape = of_make_unique<Shape> (*(pair.second)); auto shape = of_make_unique<Shape> ();
CHECK(lbn2shape_.emplace(lbn, std::move(shape)).second); CHECK(lbn2shape_.emplace(lbn, std::move(shape)).second);
} }
} }
void RegstDesc::CopyShapeFrom(const RegstDesc* rhs) {
for (const auto& pair : lbn2shape_) {
const std::string& lbn = pair.first;
*(lbn2shape_.at(lbn)) = rhs->GetShape(lbn);
}
}
Shape* RegstDesc::EnrollLbn(const std::string& lbn) { Shape* RegstDesc::EnrollLbn(const std::string& lbn) {
Shape* raw_ptr = new Shape; Shape* raw_ptr = new Shape;
std::unique_ptr<Shape> uptr(raw_ptr); std::unique_ptr<Shape> uptr(raw_ptr);
......
...@@ -11,11 +11,11 @@ namespace oneflow { ...@@ -11,11 +11,11 @@ namespace oneflow {
class TaskNode; class TaskNode;
class RegstDesc { class RegstDesc final {
public: public:
OF_DISALLOW_COPY_AND_MOVE(RegstDesc); OF_DISALLOW_COPY_AND_MOVE(RegstDesc);
RegstDesc(); RegstDesc();
virtual ~RegstDesc() = default; ~RegstDesc() = default;
// regst_desc_id // regst_desc_id
uint64_t regst_desc_id() const { return regst_desc_id_; } uint64_t regst_desc_id() const { return regst_desc_id_; }
...@@ -25,7 +25,8 @@ class RegstDesc { ...@@ -25,7 +25,8 @@ class RegstDesc {
void SetProducer(const TaskNode* task_node) { producer_ = task_node; } void SetProducer(const TaskNode* task_node) { producer_ = task_node; }
// Lbn and Shape // Lbn and Shape
void CopyLbn2ShapeMap(const RegstDesc*); void CopyLbnFrom(const RegstDesc*);
void CopyShapeFrom(const RegstDesc*);
Shape* EnrollLbn(const std::string& lbn); Shape* EnrollLbn(const std::string& lbn);
const Shape& GetShape(const std::string& lbn); const Shape& GetShape(const std::string& lbn);
Shape* GetMutShapePtr(const std::string& lbn); Shape* GetMutShapePtr(const std::string& lbn);
...@@ -40,24 +41,6 @@ class RegstDesc { ...@@ -40,24 +41,6 @@ class RegstDesc {
}; };
class ContigRegstDesc final : public RegstDesc {
public:
OF_DISALLOW_COPY_AND_MOVE(ContigRegstDesc);
ContigRegstDesc() = default;
~ContigRegstDesc() = default;
private:
};
class DisContigRegstDesc final : public RegstDesc {
public:
OF_DISALLOW_COPY_AND_MOVE(DisContigRegstDesc);
DisContigRegstDesc() = default;
~DisContigRegstDesc() = default;
};
} // namespace oneflow } // namespace oneflow
#endif // ONEFLOW_REGISTER_REGISTER_DESC_H_ #endif // ONEFLOW_REGISTER_REGISTER_DESC_H_
#ifndef ONEFLOW_REGISTER_REGISTER_DESC_MANAGER_H_
#define ONEFLOW_REGISTER_REGISTER_DESC_MANAGER_H_
#include "register/register_desc.h"
namespace oneflow {
class RegstDescMgr final {
public:
OF_DISALLOW_COPY_AND_MOVE(RegstDescMgr);
RegstDescMgr() = default;
~RegstDescMgr() = default;
std::unique_ptr<RegstDesc> CreateRegisterDesc() {
return of_make_unique<RegstDesc> ();
}
private:
};
} // namespace oneflow
#endif // ONEFLOW_REGISTER_REGISTER_DESC_MANAGER_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册