提交 3589c342 编写于 作者: W willzhang4a58

refine task_node::build_exec

上级 b55817ac
......@@ -8,28 +8,28 @@ namespace oneflow {
namespace {
void FwCompleteBoxOpConfDataData(BoxingOpConf* conf) {
conf->mutable_concat_box()->set_axis(0);
conf->mutable_split_box()->set_axis(0);
conf->mutable_concat_box()->set_type(BoxConcatConf::kData);
conf->mutable_data_split_box();
}
void FwCompleteBoxOpConfDataModel(BoxingOpConf* conf) {
conf->mutable_concat_box()->set_axis(0);
conf->mutable_concat_box()->set_type(BoxConcatConf::kData);
conf->mutable_clone_box();
}
void FwCompleteBoxOpConfModelData(BoxingOpConf* conf) {
conf->mutable_concat_box()->set_axis(1);
conf->mutable_split_box()->set_axis(0);
conf->mutable_concat_box()->set_type(BoxConcatConf::kModel);
conf->mutable_data_split_box();
}
void FwCompleteBoxOpConfModelModel(BoxingOpConf* conf) {
conf->mutable_concat_box()->set_axis(1);
conf->mutable_concat_box()->set_type(BoxConcatConf::kModel);
conf->mutable_clone_box();
}
} // namespace
void BoxingTaskNode::FwBuildExecAndProducedRegsts(TaskGraph* gph) {
void BoxingTaskNode::FwBuildExecAndEnrollLbn2Regsts(TaskGraph* gph) {
EnrollAllRegstAndBindRelatedEdge();
FwVirtualBuild();
}
......@@ -37,11 +37,11 @@ void BoxingTaskNode::FwBuildExecAndProducedRegsts(TaskGraph* gph) {
void BoxingTaskNode::EnrollAllRegstAndBindRelatedEdge() {
for (TaskEdge* edge : out_edges()) {
std::string name = "boxing_out_" + edge->edge_id_str();
auto regst_desc = of_make_unique<DisContigRegstDesc> ();
auto regst_desc = RegstDescMgr::Singleton().CreateRegisterDesc();
BindProducedRegstAndOutEdge(regst_desc.get(), edge);
EnrollProducedRegstDesc(name, std::move(regst_desc));
}
auto regst_desc = of_make_unique<DisContigRegstDesc> ();
auto regst_desc = RegstDescMgr::Singleton().CreateRegisterDesc();
EnrollProducedRegstDesc("middle", std::move(regst_desc));
}
......@@ -133,7 +133,7 @@ void BoxingTaskNode::FwBuildChainSortedEdgesPair(
const std::string& ibn = node->op()->input_bns().at(i);
std::string lbn = node->op()->ibn2lbn(ibn);
Shape* ptr = in_regst->GetMutShapePtr(lbn);
node->op()->SetShapePtr(ibn, ptr);
node->BindBnInOpAndShapePtr(ibn, ptr);
node->BindBnInOpAndRegst(ibn, in_regst);
}
// obn
......@@ -142,16 +142,22 @@ void BoxingTaskNode::FwBuildChainSortedEdgesPair(
const std::string& obn = node->op()->output_bns().at(i);
std::string lbn = node->op()->obn2lbn(obn);
Shape* ptr = out_regst->EnrollLbn(lbn);
node->op()->SetShapePtr(obn, ptr);
node->BindBnInOpAndShapePtr(obn, ptr);
node->BindBnInOpAndRegst(obn, out_regst);
}
// dtbn
for (const std::string& dtbn : node->op()->data_tmp_bns()) {
std::string lbn = node->op()->dtbn2lbn(dtbn);
Shape* ptr = middle_regst->EnrollLbn(lbn);
node->op()->SetShapePtr(dtbn, ptr);
node->BindBnInOpAndShapePtr(dtbn, ptr);
node->BindBnInOpAndRegst(dtbn, middle_regst);
}
node->op()->InferShape4ObAndDtbFromIb();
}
}
void BoxingTaskNode::FwInferShape4LbnInProducedRegsts(TaskGraph*) {
for (const auto& exec_node : exec_gph().nodes()) {
exec_node->op()->InferShape4ObAndDtbFromIb(exec_node->BnInOp2ShapePtr());
}
}
......@@ -165,10 +171,9 @@ inline RegstDesc* GetBpRegstFromFwRegst(RegstDesc* fw_regst) {
}
void BoxingTaskNode::BpBuildExecAndProducedRegsts(TaskGraph*) {
void BoxingTaskNode::BpBuildExecAndEnrollLbn2Regsts(TaskGraph*) {
EnrollAllRegstAndBindRelatedEdge();
const ExecGraph& fw_exec_gph = GetFwNode()->exec_gph();
HashMap<const ExecNode*, ExecNode*> fw_node2bp_node;
for (const std::unique_ptr<ExecNode>& fw_node: fw_exec_gph.nodes()) {
ExecNode* bp_node = mut_exec_gph().NewNode();
bp_node->mut_op() = fw_node->op();
......@@ -178,8 +183,7 @@ void BoxingTaskNode::BpBuildExecAndProducedRegsts(TaskGraph*) {
std::string lbn = fw_node->op()->ibn2lbn(ibn);
RegstDesc* in_regst = fw_node->GetRegstFromBnInOp(ibn);
RegstDesc* in_diff_regst = GetBpRegstFromFwRegst(in_regst);
Shape* in_diff_shape_ptr = in_diff_regst->EnrollLbn(lbn);
*in_diff_shape_ptr = in_regst->GetShape(lbn);
in_diff_regst->EnrollLbn(lbn);
bp_node->BindBnInOpAndRegst(idbn, in_diff_regst);
}
// out_diff
......@@ -195,12 +199,22 @@ void BoxingTaskNode::BpBuildExecAndProducedRegsts(TaskGraph*) {
std::string lbn = fw_node->op()->dtbn2lbn(dtbn);
RegstDesc* fw_middle_regst = GetFwNode()->GetProducedRegstDesc("middle");
RegstDesc* bp_middle_regst = GetProducedRegstDesc("middle");
Shape* ptr = bp_middle_regst->EnrollLbn(lbn);
*ptr = fw_middle_regst->GetShape(lbn);
bp_middle_regst->EnrollLbn(lbn);
bp_node->BindBnInOpAndRegst(dtbn, bp_middle_regst);
}
}
mut_exec_gph().UpdateSourceAndSink();
}
void BoxingTaskNode::BpInferShape4LbnInProducedRegsts(TaskGraph*) {
for (TaskEdge* fw_in_edge : GetFwNode()->in_edges()) {
RegstDesc* in_regst = GetRelatedRegst(fw_in_edge);
RegstDesc* in_diff_regst = GetBpRegstFromFwRegst(in_regst);
in_diff_regst->CopyShapeFrom(in_regst);
}
RegstDesc* fw_middle_regst = GetFwNode()->GetProducedRegstDesc("middle");
RegstDesc* bp_middle_regst = GetProducedRegstDesc("middle");
bp_middle_regst->CopyShapeFrom(fw_middle_regst);
}
} // namespace oneflow
......@@ -12,7 +12,7 @@ class BoxingTaskNode : public TaskNode {
virtual ~BoxingTaskNode() = default;
std::string VisualStr() const override {
return TaskNode::VisualStr() + "Boxing_" + node_id_str();
return TaskNode::VisualStr() + "Boxing";
}
protected:
......@@ -39,8 +39,10 @@ class BoxingTaskNode : public TaskNode {
virtual void FwVirtualBuild() = 0;
private:
void FwBuildExecAndProducedRegsts(TaskGraph*) override;
void BpBuildExecAndProducedRegsts(TaskGraph*) override;
void FwBuildExecAndEnrollLbn2Regsts(TaskGraph*) override;
void FwInferShape4LbnInProducedRegsts(TaskGraph*) override;
void BpBuildExecAndEnrollLbn2Regsts(TaskGraph*) override;
void BpInferShape4LbnInProducedRegsts(TaskGraph*) override;
void EnrollAllRegstAndBindRelatedEdge();
......
......@@ -4,11 +4,11 @@
namespace oneflow {
void CommNetTaskNode::BuildExecAndProducedRegstsForNetCopy(TaskGraph* gph){
auto out_regst = of_make_unique<DisContigRegstDesc> ();
void CommNetTaskNode::CommNetBuildExecAndEnrollLbn2Regsts() {
auto out_regst = RegstDescMgr::Singleton().CreateRegisterDesc();
BindProducedRegstAndOutEdge(out_regst.get(), SoleOutEdge());
RegstDesc* in_regst = GetRelatedRegst(SoleInEdge());
out_regst->CopyLbn2ShapeMap(in_regst);
out_regst->CopyLbnFrom(in_regst);
OperatorConf op_conf;
op_conf.set_name("comm_net_" + NewUniqueId());
......@@ -22,15 +22,29 @@ void CommNetTaskNode::BuildExecAndProducedRegstsForNetCopy(TaskGraph* gph){
node->BindBnInOpAndRegst(node->op()->SoleObn(), out_regst.get());
mut_exec_gph().UpdateSourceAndSink();
EnrollProducedRegstDesc("comm_net", std::move(out_regst));
EnrollProducedRegstDesc("out", std::move(out_regst));
}
void CommNetTaskNode::FwBuildExecAndProducedRegsts(TaskGraph* gph) {
BuildExecAndProducedRegstsForNetCopy(gph);
void CommNetTaskNode::CommNetInferShape4LbnInProducedRegsts() {
RegstDesc* in_regst = GetRelatedRegst(SoleInEdge());
RegstDesc* out_regst = GetRelatedRegst(SoleOutEdge());
out_regst->CopyShapeFrom(in_regst);
}
void FwBuildExecAndEnrollLbn2Regsts(TaskGraph*) override {
return CommNetBuildExecAndEnrollLbn2Regsts();
}
void FwInferShape4LbnInProducedRegsts(TaskGraph*) override {
return CommNetInferShape4LbnInProducedRegsts();
}
void BpBuildExecAndEnrollLbn2Regsts(TaskGraph*) override {
return CommNetBuildExecAndEnrollLbn2Regsts();
}
void CommNetTaskNode::BpBuildExecAndProducedRegsts(TaskGraph* gph) {
BuildExecAndProducedRegstsForNetCopy(gph);
void BpInferShape4LbnInProducedRegsts(TaskGraph*) override {
return CommNetInferShape4LbnInProducedRegsts();
}
} // namespace oneflow
......@@ -29,13 +29,17 @@ class CommNetTaskNode final : public TaskNode {
}
std::string VisualStr() const override {
return TaskNode::VisualStr() + "CommNet_" + node_id_str();
return TaskNode::VisualStr() + "CommNet";
}
private:
void BuildExecAndProducedRegstsForNetCopy(TaskGraph*);
void FwBuildExecAndProducedRegsts(TaskGraph*) override;
void BpBuildExecAndProducedRegsts(TaskGraph*) override;
void CommNetBuildExecAndEnrollLbn2Regsts(TaskGraph*);
void CommNetInferShape4LbnInProducedRegsts();
void FwBuildExecAndEnrollLbn2Regsts(TaskGraph*) override;
void FwInferShape4LbnInProducedRegsts(TaskGraph*) override;
void BpBuildExecAndEnrollLbn2Regsts(TaskGraph*) override;
void BpInferShape4LbnInProducedRegsts(TaskGraph*) override;
std::unique_ptr<TaskNode> CreateSameTypeNode() const override {
return of_make_unique<CommNetTaskNode> ();
......
......@@ -5,27 +5,52 @@
namespace oneflow {
void CompTaskNode::FwBuildExecAndProducedRegsts(TaskGraph* gph) {
(this->*(gph->Func4FwBuildExecAndProducedRegsts()))(gph);
std::string CompTaskNode::VisualStr() const override {
std::stringstream ss;
ss << TaskNode::VisualStr()
<< "Compute" << ":"
<< stage_node()->machine_id_str() << ":"
<< thrd_loc_id_str() << "\\n"
<< chain_node()->VisualStr();
return ss.str();
}
void CompTaskNode::DataFwBuildExecAndProducedRegsts(TaskGraph*) {
void CompTaskNode::DataFwBuildExecAndEnrollLbn2Regsts(TaskGraph*) {
Lbn2NodeBnMap lbn2producer;
Lbn2NodeBnMap extern_in_lbn2consumer;
FwBuildFromUserOps(&lbn2producer, &extern_in_lbn2consumer);
mut_exec_gph().UpdateSourceAndSink();
// data regst
auto data_regst = of_make_unique<DisContigRegstDesc> ();
BindProducedRegstAndOutEdge(data_regst.get(), SoleOutEdge());
EnrollProducedRegstDesc("data", std::move(data_regst));
FwSetDataRegstDesc(lbn2producer, extern_in_lbn2consumer);
// model_tmp regst
auto model_tmp_regst = of_make_unique<DisContigRegstDesc> ();
// produced regsts
auto out_regst = RegstDescMgr::Singleton().CreateRegisterDesc();
auto activation_regst = RegstDescMgr::Singleton().CreateRegisterDesc();
auto data_tmp_regst = RegstDescMgr::Singleton().CreateRegisterDesc();
auto model_tmp_regst = RegstDescMgr::Singleton().CreateRegisterDesc();
// Bind Out Edge
BindProducedRegstAndOutEdge(out_regst.get(), SoleOutEdge());
// EnrollProducedRegstDesc
EnrollProducedRegstDesc("out", std::move(out_regst));
EnrollProducedRegstDesc("activation", std::move(activation_regst));
EnrollProducedRegstDesc("data_tmp", std::move(data_tmp_regst));
EnrollProducedRegstDesc("model_tmp", std::move(model_tmp_regst));
FwSetModelTmpRegstDesc();
// Enroll Lbn
FwSetExecNodeFromInRegst(extern_in_lbn2consumer);
FwEnrollLbn2OutRegst(lbn2producer);
FwEnrollLbn2ActivationRegst();
FwEnrollLbn2TmpRegsts();
}
void CompTaskNode::MdUpdtFwBuildExecAndProducedRegsts(TaskGraph* gph) {
void CompTaskNode::DataFwInferShape4LbnInProducedRegsts() {
for (const ExecNode& node : exec_gph()) {
node.op()->InferShape4ObAndDtbFromIb(node.BnInOp2ShapePtr());
node.op()->InferShape4ModelTmpBlob(node.BnInOp2ShapePtr(),
chain_node()->parallel_desc()->policy(),
parallel_id());
}
}
void CompTaskNode::MdUpdtFwBuildExecAndEnrollLbn2Regsts(TaskGraph* gph) {
TODO();
/*
if (IsFaker()) {
CompTaskNode* mccoy = gph->faker2mccoy().at(this);
RegstDesc* regst = mccoy->GetProducedRegstDesc("model_diff");
......@@ -39,9 +64,16 @@ void CompTaskNode::MdUpdtFwBuildExecAndProducedRegsts(TaskGraph* gph) {
mut_exec_gph().UpdateSourceAndSink();
// PostProcessing in ModelUpdateTaskGraph will complete the work
// which should be implemented in this function
*/
}
void CompTaskNode::MdUpdtFwInferShape4LbnInProducedRegsts(TaskGraph* gph) {
TODO();
}
void CompTaskNode::MdLoadFwBuildExecAndProducedRegsts(TaskGraph* gph) {
void CompTaskNode::MdLoadFwBuildExecAndEnrollLbn2Regsts(TaskGraph* gph) {
TODO();
/*
if (IsFaker()) {
CompTaskNode* update_task = gph->faker2mccoy().at(this);
ExecNode* exec_node = update_task->exec_gph().SoleNode();
......@@ -58,9 +90,16 @@ void CompTaskNode::MdLoadFwBuildExecAndProducedRegsts(TaskGraph* gph) {
exec_node->op()->SetShapePtr(exec_node->op()->SoleObn(), shape_ptr);
exec_node->op()->InferShape4ObAndDtbFromIb();
EnrollProducedRegstDesc("model_regst", std::move(model_regst));
*/
}
void CompTaskNode::MdSaveFwBuildExecAndProducedRegsts(TaskGraph* gph) {
void CompTaskNode::MdLoadFwInferShape4LbnInProducedRegsts(TaskGraph* gph) {
TODO();
}
void CompTaskNode::MdSaveFwBuildExecAndEnrollLbn2Regsts(TaskGraph* gph) {
TODO();
/*
if (IsFaker()) {
CompTaskNode* update_task = gph->faker2mccoy().at(this);
RegstDesc* model_regst = update_task->GetProducedRegstDesc("model");
......@@ -72,6 +111,11 @@ void CompTaskNode::MdSaveFwBuildExecAndProducedRegsts(TaskGraph* gph) {
mut_exec_gph().UpdateSourceAndSink();
const std::string& ibn = exec_node->op()->SoleIbn();
exec_node->BindBnInOpAndRegst(ibn, GetRelatedRegst(SoleInEdge()));
*/
}
void CompTaskNode::MdSaveFwInferShape4LbnInProducedRegsts(TaskGraph* gph) {
TODO();
}
void CompTaskNode::FwBuildFromUserOps(
......@@ -103,80 +147,97 @@ void CompTaskNode::FwBuildFromUserOps(
}
}
void CompTaskNode::FwSetDataRegstDesc(
const Lbn2NodeBnMap& lbn2producer,
void CompTaskNode::FwSetExecNodeFromInRegst(
const Lbn2NodeBnMap& extern_in_lbn2consumer) {
RegstDesc* in_regst = GetRelatedRegst(SoleInEdge());
RegstDesc* out_regst = GetRelatedRegst(SoleOutEdge());
// blob on exec_edge
for (const std::unique_ptr<ExecEdge>& edge : exec_gph().edges()) {
Shape* ptr = out_regst->EnrollLbn(edge->lbn());
edge->src_node()->BindBnInOpAndRegst(edge->src_bn(), out_regst);
edge->src_node()->op()->SetShapePtr(edge->src_bn(), ptr);
edge->dst_node()->BindBnInOpAndRegst(edge->dst_bn(), out_regst);
edge->dst_node()->op()->SetShapePtr(edge->dst_bn(), ptr);
}
// extern in blobs
for (const auto& pair : extern_in_lbn2consumer) {
const std::string& lbn = pair.first;
Shape* ptr = in_regst->GetMutShapePtr(lbn);
ExecNode* node = pair.second.first;
const std::string& ibn = pair.second.second;
node->op()->SetShapePtr(ibn, ptr);
node->BindBnInOpAndShapePtr(ibn, ptr);
node->BindBnInOpAndRegst(ibn, in_regst);
}
// extern out blobs
}
void CompTaskNode::FwEnrollLbn2OutRegst(const Lbn2NodeBnMap& lbn2producer) {
RegstDesc* out_regst = GetRelatedRegst(SoleOutEdge());
for (const std::string& lbn : chain_node()->output_lbns()) {
const std::pair<ExecNode*, std::string>& producer = lbn2producer.at(lbn);
ExecNode* node = producer.first;
const std::string& obn = producer.second;
Shape* ptr = out_regst->EnrollLbn(lbn);
node->op()->SetShapePtr(obn, ptr);
node->BindBnInOpAndShapePtr(obn, ptr);
node->BindBnInOpAndRegst(obn, out_regst);
}
// data tmp blobs
for (const std::unique_ptr<ExecNode>& node : exec_gph().nodes()) {
for (const std::string& dtbn : node->op()->data_tmp_bns()) {
std::string lbn = node->op()->dtbn2lbn(dtbn);
Shape* ptr = out_regst->EnrollLbn(lbn);
node->op()->SetShapePtr(dtbn, ptr);
node->BindBnInOpAndRegst(dtbn, out_regst);
}
}
// Inference Shape
for (const ExecNode& node : exec_gph()) {
node.op()->InferShape4ObAndDtbFromIb();
}
void CompTaskNode::FwEnrollLbn2ActivationRegst() {
RegstDesc* activation_regst = GetProducedRegstDesc("activation");
for (const std::unique_ptr<ExecEdge>& edge : exec_gph().edges()) {
Shape* ptr = activation_regst->EnrollLbn(edge->lbn());
edge->src_node()->BindBnInOpAndRegst(edge->src_bn(), activation_regst);
edge->src_node()->BindBnInOpAndShapePtr(edge->src_bn(), ptr);
edge->dst_node()->BindBnInOpAndRegst(edge->dst_bn(), activation_regst);
edge->dst_node()->BindBnInOpAndShapePtr(edge->dst_bn(), ptr);
}
}
void CompTaskNode::FwSetModelTmpRegstDesc() {
void CompTaskNode::FwEnrollLbn2TmpRegsts() {
RegstDesc* data_tmp_regst = GetProducedRegstDesc("data_tmp");
RegstDesc* model_tmp_regst = GetProducedRegstDesc("model_tmp");
for (const std::unique_ptr<ExecNode>& node : exec_gph().nodes()) {
for (const std::string& dtbn : node->op()->data_tmp_bns()) {
std::string lbn = node->op()->dtbn2lbn(dtbn);
Shape* ptr = data_tmp_regst->EnrollLbn(lbn);
node->BindBnInOpAndShapePtr(dtbn, ptr);
node->BindBnInOpAndRegst(dtbn, out_regst);
}
for (const std::string& mtbn : node->op()->model_tmp_bns()) {
std::string lbn = node->op()->mtbn2lbn(mtbn);
Shape* ptr = model_tmp_regst->EnrollLbn(lbn);
node->op()->SetShapePtr(mtbn, ptr);
node->BindBnInOpAndShapePtr(mtbn, ptr);
node->BindBnInOpAndRegst(mtbn, model_tmp_regst);
}
node->op()->InferShape4ModelTmpBlob(chain_node()->parallel_desc()->policy(),
parallel_id());
}
}
void CompTaskNode::BpBuildExecAndProducedRegsts(TaskGraph*) {
void CompTaskNode::BpBuildExecAndEnrollLbn2Regsts(TaskGraph*) {
const ExecGraph& fw_gph = GetFwNode()->exec_gph();
HashMap<const ExecNode*, ExecNode*> fw_node2bp_node;
HashMap<ExecEdge*, const ExecEdge*> bp_edge2fw_edge;
BpBuildExecGraph(fw_gph, &fw_node2bp_node, &bp_edge2fw_edge);
//
auto data_diff_regst = of_make_unique<DisContigRegstDesc> ();
BindProducedRegstAndOutEdge(data_diff_regst.get(), SoleOutEdge());
EnrollProducedRegstDesc("data_diff", std::move(data_diff_regst));
BpSetDataDiffRegst(fw_node2bp_node, bp_edge2fw_edge);
//
auto model_diff_regst = of_make_unique<ContigRegstDesc> ();
// Produced registers
auto in_diff_regst = RegstDescMgr::Singleton().CreateRegisterDesc();
auto model_diff_regst = RegstDescMgr::Singleton().CreateRegisterDesc();
auto activation_diff_regst = RegstDescMgr::Singleton().CreateRegisterDesc();
// Bind out edge
BindProducedRegstAndOutEdge(in_diff_regst.get(), SoleOutEdge());
// Enroll registers
EnrollProducedRegstDesc("in_diff", std::move(in_diff_regst));
EnrollProducedRegstDesc("model_diff", std::move(model_diff_regst));
BpSetModelDiffRegst();
EnrollProducedRegstDesc("activation_diff", std::move(activation_diff_regst));
// Enroll Lbn
BpEnrollLbn2ProducedRegst(fw_node2bp_node, bp_edge2fw_edge);
}
void CompTaskNode::BpInferShape4LbnInProducedRegsts(TaskGraph*) {
// in_diff_regst
RegstDesc* in_diff_regst = GetRelatedRegst(SoleOutEdge());
RegstDesc* in_regst = GetRelatedRegst(GetFwNode()->SoleInEdge());
in_diff_regst->CopyShapeFrom(in_regst);
// model_diff_regst
RegstDesc* model_diff_regst = GetProducedRegstDesc("model_diff");
for (const std::unique_ptr<ExecNode>& cur_node : exec_gph().nodes()) {
cur_node->op()->InferShape4ModelDiffBlob(
cur_node->BnInOp2ShapePtr(),
chain_node()->parallel_desc()->policy(),
parallel_id());
}
// activation_diff_regst
RegstDesc* activation_diff_regst = GetProducedRegstDesc("activation_diff");
RegstDesc* activation_regst = GetFwNode()->GetProducedRegstDesc("activation");
activation_diff_regst->CopyShapeFrom(activation_regst);
}
void CompTaskNode::BpBuildExecGraph(
......@@ -200,21 +261,23 @@ void CompTaskNode::BpBuildExecGraph(
}
}
void CompTaskNode::BpSetDataDiffRegst(
void CompTaskNode::BpEnrollLbn2ProducedRegst(
const HashMap<const ExecNode*, ExecNode*>& fw_node2bp_node,
const HashMap<ExecEdge*, const ExecEdge*>& bp_edge2fw_edge) {
// Regsts
RegstDesc* in_diff_regst = GetRelatedRegst(SoleOutEdge());
RegstDesc* out_diff_regst = GetRelatedRegst(SoleInEdge());
RegstDesc* in_regst = GetRelatedRegst(GetFwNode()->SoleInEdge());
RegstDesc* out_regst = GetRelatedRegst(GetFwNode()->SoleOutEdge());
RegstDesc* activation_regst = GetFwNode()->GetProducedRegstDesc("activation");
RegstDesc* data_tmp_regst = GetFwNode()->GetProducedRegstDesc("data_tmp");
RegstDesc* model_tmp_regst = GetFwNode()->GetProducedRegstDesc("model_tmp");
RegstDesc* activation_diff_regst = GetProducedRegstDesc("activation_diff");
RegstDesc* model_diff_regst = GetProducedRegstDesc("model_diff");
// blobs on edge
for (const std::unique_ptr<ExecEdge>& edge : exec_gph().edges()) {
Shape* ptr = in_diff_regst->EnrollLbn(edge->lbn());
*ptr = out_regst->GetShape(edge->lbn());
edge->src_node()->BindBnInOpAndRegst(edge->src_bn(), in_diff_regst);
edge->dst_node()->BindBnInOpAndRegst(edge->dst_bn(), in_diff_regst);
activation_diff_regst->EnrollLbn(edge->lbn());
edge->src_node()->BindBnInOpAndRegst(edge->src_bn(), activation_diff_regst);
edge->dst_node()->BindBnInOpAndRegst(edge->dst_bn(), activation_diff_regst);
}
// extern out_diff blobs
for (const std::unique_ptr<ExecNode>& bp_node : exec_gph().nodes()) {
......@@ -237,35 +300,27 @@ void CompTaskNode::BpSetDataDiffRegst(
for (const std::string& idbn : bp_node->op()->input_diff_bns()) {
if (found_bns.find(idbn) != found_bns.end()) { continue; }
std::string lbn = bp_node->op()->idbn2lbn(idbn);
Shape* ptr = in_diff_regst->EnrollLbn(lbn);
*ptr = in_regst->GetShape(lbn);
in_diff_regst->EnrollLbn(lbn);
bp_node->BindBnInOpAndRegst(idbn, in_diff_regst);
bp_node->BindBnInOpAndRegst(GenUnDiffBn(idbn), in_regst);
}
}
// tmp blobs
// tmp blobs and model_diff blobs
for (const std::unique_ptr<ExecNode>& node : exec_gph().nodes()) {
for (const std::string& dtbn : node->op()->data_tmp_bns()) {
std::string lbn = node->op()->dtbn2lbn(dtbn);
node->BindBnInOpAndRegst(dtbn, out_regst);
node->BindBnInOpAndRegst(dtbn, data_tmp_regst);
}
for (const std::string& mtbn : node->op()->model_tmp_bns()) {
std::string lbn = node->op()->mtbn2lbn(mtbn);
node->BindBnInOpAndRegst(mtbn, model_tmp_regst);
}
}
}
void CompTaskNode::BpSetModelDiffRegst() {
RegstDesc* model_diff_regst = GetProducedRegstDesc("model_diff");
for (const std::unique_ptr<ExecNode>& cur_node : exec_gph().nodes()) {
for (const std::string& mdbn : cur_node->op()->model_diff_bns()) {
std::string lbn = cur_node->op()->mdbn2lbn(mdbn);
Shape* ptr = model_diff_regst->EnrollLbn(lbn);
cur_node->op()->SetShapePtr(mdbn, ptr);
cur_node->BindBnInOpAndShapePtr(mdbn, ptr);
cur_node->BindBnInOpAndRegst(mdbn, model_diff_regst);
}
cur_node->op()->InferShape4ModelDiffBlob(chain_node()->parallel_desc()->policy(),
parallel_id());
}
}
......
......@@ -12,55 +12,60 @@ class CompTaskNode : public TaskNode {
CompTaskNode() = default;
virtual ~CompTaskNode() = default;
// Getters and Setters
uint64_t parallel_id() const { return parallel_id_; }
void set_parallel_id(uint64_t parallel_id) { parallel_id_ = parallel_id; }
bool IsLossNode() const { return chain_node()->IsLossNode(); }
bool IsFaker() const { return chain_node()->IsFaker(); }
std::string VisualStr() const override;
void DataFwBuildExecAndProducedRegsts(TaskGraph*);
void MdUpdtFwBuildExecAndProducedRegsts(TaskGraph*);
void MdLoadFwBuildExecAndProducedRegsts(TaskGraph*);
void MdSaveFwBuildExecAndProducedRegsts(TaskGraph*);
// Build Exec and Set Produced Regsts
void DataFwBuildExecAndEnrollLbn2Regsts(TaskGraph*);
void DataFwInferShape4LbnInProducedRegsts(TaskGraph*);
std::string VisualStr() const override {
std::stringstream ss;
ss << TaskNode::VisualStr()
<< "Compute_" << node_id_str() << ":"
<< stage_node()->machine_id_str() << ":"
<< thrd_loc_id_str() << "\\n"
<< chain_node()->VisualStr();
return ss.str();
}
void MdUpdtFwBuildExecAndEnrollLbn2Regsts(TaskGraph*);
void MdUpdtFwInferShape4LbnInProducedRegsts(TaskGraph*);
void MdLoadFwBuildExecAndEnrollLbn2Regsts(TaskGraph*);
void MdLoadFwInferShape4LbnInProducedRegsts(TaskGraph*);
void MdSaveFwBuildExecAndEnrollLbn2Regsts(TaskGraph*);
void MdSaveFwInferShape4LbnInProducedRegsts(TaskGraph*);
protected:
virtual void InitWithFwNode(TaskNode* fw_node) override {
TaskNode::InitWithFwNode(fw_node);
parallel_id_ = of_dynamic_cast<CompTaskNode*> (fw_node)->parallel_id_;
}
private:
using Lbn2NodeBnMap =
HashMap<std::string, std::pair<ExecNode*, std::string>>;
void FwBuildExecAndProducedRegsts(TaskGraph*) override;
void FwBuildExecAndEnrollLbn2Regsts(TaskGraph* gph) override {
(this->*(gph->Func4FwBuildExecAndEnrollLbn2Regsts()))(gph);
}
void FwInferShape4LbnInProducedRegsts(TaskGraph* gph) override {
(this->*(gph->Func4FwInferShape4LbnInProducedRegsts()))(gph);
}
void FwBuildFromUserOps(
Lbn2NodeBnMap* lbn2producer,
Lbn2NodeBnMap* extern_in_lbn2consumer);
void FwSetDataRegstDesc(
const Lbn2NodeBnMap& lbn2producer,
void FwSetExecNodeFromInRegst(
const Lbn2NodeBnMap& extern_in_lbn2consumer);
void FwSetModelTmpRegstDesc();
void FwEnrollLbn2OutRegst(const Lbn2NodeBnMap& lbn2producer);
void FwEnrollLbn2ActivationRegst();
void FwEnrollLbn2TmpRegsts();
void BpBuildExecAndProducedRegsts(TaskGraph*) override;
void BpBuildExecAndEnrollLbn2Regsts(TaskGraph*) override;
void BpInferShape4LbnInProducedRegsts(TaskGraph*) override;
void BpBuildExecGraph(
const ExecGraph& fw_gph,
HashMap<const ExecNode*, ExecNode*>* fw_node2bp_node,
HashMap<ExecEdge*, const ExecEdge*>* bp_edge2fw_edge);
void BpSetDataDiffRegst(
void BpEnrollLbn2ProducedRegst(
const HashMap<const ExecNode*, ExecNode*>& fw_node2bp_node,
const HashMap<ExecEdge*, const ExecEdge*>& bp_edge2fw_edge);
void BpSetModelDiffRegst();
uint64_t parallel_id_;
......
......@@ -4,52 +4,6 @@
namespace oneflow {
void CopyHDTaskNode::BuildExecAndProducedRegstsForCopy(TaskGraph* gph){
auto out_regst = of_make_unique<DisContigRegstDesc> ();
BindProducedRegstAndOutEdge(out_regst.get(), SoleOutEdge());
OperatorConf op_conf;
op_conf.set_name("copy_" + NewUniqueId());
CopyOpConf* copy_conf = op_conf.mutable_copy_conf();
copy_conf->set_copy_type(
IsH2D() ? CopyOpConf::H2D : CopyOpConf::D2H);
for(std::string lbn : CopiedLbns()){
copy_conf->add_copied_lbns(lbn);
}
RegstDesc* in_regst = GetRelatedRegst(SoleInEdge());
bool get_kalllbn = false;
if(copy_conf->copied_lbns_size() == 1
&& copy_conf->copied_lbns(0) == RegstDesc::kAllLbn){
out_regst->CopyLbn2ShapeMap(in_regst);
get_kalllbn = true;
}
ExecNode* node = mut_exec_gph().NewNode();
node->mut_op() = OpMgr::Singleton().ConstructOp(op_conf);
for(std::string ibn : node->op()->input_bns()){
std::string lbn = node->op()->ibn2lbn(ibn);
Shape* shape_ptr = in_regst->GetMutShapePtr(lbn);
node->op()->SetShapePtr(ibn, shape_ptr);
node->BindBnInOpAndRegst(ibn, in_regst);
}
for(std::string obn : node->op()->output_bns()){
std::string lbn = node->op()->obn2lbn(obn);
Shape* shape_ptr = nullptr;
if(!get_kalllbn){
shape_ptr = out_regst->EnrollLbn(lbn);
} else {
shape_ptr = out_regst->GetMutShapePtr(lbn);
}
node->op()->SetShapePtr(obn, shape_ptr);
node->BindBnInOpAndRegst(obn, out_regst.get());
}
if(!get_kalllbn){
node->op()->InferShape4ObAndDtbFromIb();
}
mut_exec_gph().UpdateSourceAndSink();
EnrollProducedRegstDesc("copy", std::move(out_regst));
}
void CopyHDTaskNode::SetFwInCopy() {
CHECK(IsFwNode());
is_fw_in_copy_ = true;
......@@ -60,21 +14,52 @@ void CopyHDTaskNode::SetFwOutCopy() {
is_fw_in_copy_ = false;
}
const std::vector<std::string>& CopyHDTaskNode::CopiedLbns() const {
return IsFwInCopy() ? chain_node()->input_lbns() : chain_node()->output_lbns();
}
void CopyHDTaskNode::InitWithFwNode(TaskNode* fw_node) {
TaskNode::InitWithFwNode(fw_node);
is_fw_in_copy_ = of_dynamic_cast<CopyHDTaskNode*>(fw_node)->is_fw_in_copy_;
}
void CopyHDTaskNode::FwBuildExecAndProducedRegsts(TaskGraph* gph) {
BuildExecAndProducedRegstsForCopy(gph);
void CopyHDTaskNode::FwBuildExecAndEnrollLbn2Regsts(TaskGraph*) {
return CopyHdBuildExecAndEnrollLbn2Regsts();
}
void CopyHDTaskNode::BpBuildExecAndProducedRegsts(TaskGraph* gph) {
BuildExecAndProducedRegstsForCopy(gph);
void CopyHDTaskNode::FwInferShape4LbnInProducedRegsts(TaskGraph*) {
return CopyHdInferShape4LbnInProducedRegsts();
}
void CopyHDTaskNode::BpBuildExecAndEnrollLbn2Regsts(TaskGraph*) {
return CopyHdBuildExecAndEnrollLbn2Regsts();
}
void CopyHDTaskNode::BpInferShape4LbnInProducedRegsts(TaskGraph*) {
return CopyHdInferShape4LbnInProducedRegsts();
}
void CopyHDTaskNode::CopyHdBuildExecAndEnrollLbn2Regsts(){
auto out_regst = RegstDescMgr::Singleton().CreateRegisterDesc();
BindProducedRegstAndOutEdge(out_regst.get(), SoleOutEdge());
RegstDesc* in_regst = GetRelatedRegst(SoleInEdge());
out_regst->CopyLbnFrom(in_regst);
OperatorConf op_conf;
op_conf.set_name("copy_hd_" + NewUniqueId());
CopyHdOpConf* copy_hd_conf = op_conf.mutable_copy_hd_conf();
copy_hd_conf->set_type(IsH2D() ? CopyHdOpConf::H2D : CopyHdOpConf::D2H);
ExecNode* node = mut_exec_gph().NewNode();
node->mut_op() = OpMgr::Singleton().ConstructOp(op_conf);
node->BindBnInOpAndRegst(node->op()->SoleIbn(), in_regst);
node->BindBnInOpAndRegst(node->op()->SoleObn(), out_regst.get());
mut_exec_gph().UpdateSourceAndSink();
EnrollProducedRegstDesc("out", std::move(out_regst));
}
void CopyHDTaskNode::CopyHdInferShape4LbnInProducedRegsts() {
RegstDesc* in_regst = GetRelatedRegst(SoleInEdge());
RegstDesc* out_regst = GetRelatedRegst(SoleOutEdge());
out_regst->CopyShapeFrom(in_regst);
}
} // namespace oneflow
......@@ -23,22 +23,24 @@ class CopyHDTaskNode final : public TaskNode {
void SetFwInCopy();
void SetFwOutCopy();
const std::vector<std::string>& CopiedLbns() const;
std::string VisualStr() const override {
return TaskNode::VisualStr() + "CopyHD_" + node_id_str();
return TaskNode::VisualStr() + "CopyHD";
}
private:
void InitWithFwNode(TaskNode* fw_node) override;
void BuildExecAndProducedRegstsForCopy(TaskGraph*);
void FwBuildExecAndProducedRegsts(TaskGraph*) override;
void BpBuildExecAndProducedRegsts(TaskGraph*) override;
std::unique_ptr<TaskNode> CreateSameTypeNode() const override {
return of_make_unique<CopyHDTaskNode> ();
}
void FwBuildExecAndEnrollLbn2Regsts(TaskGraph*) override;
void FwInferShape4LbnInProducedRegsts(TaskGraph*) override;
void BpBuildExecAndEnrollLbn2Regsts(TaskGraph*) override;
void BpInferShape4LbnInProducedRegsts(TaskGraph*) override;
void CopyHdBuildExecAndEnrollLbn2Regsts();
void CopyHdInferShape4LbnInProducedRegsts();
bool is_fw_in_copy_;
};
......
......@@ -44,17 +44,25 @@ class ExecNode final : public Node<ExecNode, ExecEdge> {
std::shared_ptr<const Operator>& mut_op() { return op_; }
void BindBnInOpAndRegst(const std::string& bn_in_op, RegstDesc* regst) {
CHECK(bn_in_op2regst.emplace(bn_in_op, regst).second);
CHECK(bn_in_op2regst_.emplace(bn_in_op, regst).second);
}
RegstDesc* GetRegstFromBnInOp(const std::string& bn_in_op) {
return bn_in_op2regst.at(bn_in_op);
return bn_in_op2regst_.at(bn_in_op);
}
void BindBnInOpAndShapePtr(const std::string& bn_in_op, Shape* shape_ptr) {
CHECK(bn_in_op2shape_ptr_.emplace(bn_in_op, shape_ptr).second);
}
const HashMap<std::string, Shape*>& BnInOp2ShapePtr() const {
return bn_in_op2shape_ptr_;
}
std::string VisualStr() const { TODO(); }
private:
std::shared_ptr<const Operator> op_;
HashMap<std::string, RegstDesc*> bn_in_op2regst;
HashMap<std::string, RegstDesc*> bn_in_op2regst_;
HashMap<std::string, Shape*> bn_in_op2shape_ptr_;
};
......
......@@ -16,9 +16,15 @@ inline void TaskConnect(TaskNode* src_node,
}
void TaskGraph::BuildExecAndProducedRegsts() {
void TaskGraph::BuildExecAndEnrollLbn2Regsts() {
for (TaskNode& node : *this) {
node.BuildExecAndProducedRegsts(this);
node.BuildExecAndEnrollLbn2Regsts(this);
}
}
void TaskGraph::InferShape4LbnInProducedRegsts() {
for (TaskNode& node : *this) {
node.InferShape4LbnInProducedRegsts(this);
}
}
......
......@@ -17,18 +17,21 @@ class TaskGraph : public Graph<TaskNode, TaskEdge> {
OF_DISALLOW_COPY_AND_MOVE(TaskGraph);
virtual ~TaskGraph() = default;
// Getters
const StageGraph* stage_gph() const { return stage_gph_.get(); }
const ChainGraph* chain_gph() const { return stage_gph_->chain_gph(); }
const HashMap<CompTaskNode*, CompTaskNode*>& faker2mccoy() {
return faker2mccoy_;
}
std::vector<CompTaskNode*> SortedCompTasksInChain(const ChainNode*) const;
void BuildExecAndProducedRegsts();
// Build Exec And Set Produced Registers
void BuildExecAndEnrollLbn2Regsts();
void InferShape4LbnInProducedRegsts();
typedef void (CompTaskNode::*CompTaskNodeMemFunc)(TaskGraph*);
virtual CompTaskNodeMemFunc Func4FwBuildExecAndProducedRegsts() const = 0;
std::vector<CompTaskNode*> SortedCompTasksInChain(const ChainNode*) const;
using CompTaskNodeMemFunc = void (CompTaskNode::*)(TaskGraph*);
virtual CompTaskNodeMemFunc Func4FwBuildExecAndEnrollLbn2Regsts() const = 0;
virtual CompTaskNodeMemFunc Func4FwInferShape4LbnInProducedRegsts() const = 0;
protected:
TaskGraph() = default;
......
......@@ -40,11 +40,19 @@ std::unique_ptr<TaskNode> TaskNode::BuildAndConnectBpNode() {
return bp_node;
}
void TaskNode::BuildExecAndProducedRegsts(TaskGraph* gph) {
void TaskNode::BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) {
if (IsFwNode()) {
FwBuildExecAndProducedRegsts(gph);
FwBuildExecAndEnrollLbn2Regsts(gph);
} else {
BpBuildExecAndProducedRegsts(gph);
BpBuildExecAndEnrollLbn2Regsts(gph);
}
}
void TaskNode::InferShape4LbnInProducedRegsts(TaskGraph* gph) {
if (IsFwNode()) {
FwInferShape4LbnInProducedRegsts();
} else {
BpInferShape4LbnInProducedRegsts();
}
}
......
......@@ -4,7 +4,7 @@
#include "task/task.pb.h"
#include "graph/stage_graph.h"
#include "graph/exec_graph.h"
#include "register/register_desc.h"
#include "register/register_desc_manager.h"
namespace oneflow {
......@@ -39,7 +39,8 @@ class TaskNode : public Node<TaskNode, TaskEdge> {
std::unique_ptr<TaskNode> BuildAndConnectBpNode();
//
void BuildExecAndProducedRegsts(TaskGraph*);
void BuildExecAndEnrollLbn2Regsts(TaskGraph*);
void InferShape4LbnInProducedRegsts(TaskGraph*);
RegstDesc* GetProducedRegstDesc(const std::string& regst_desc_name);
//
......@@ -67,8 +68,10 @@ class TaskNode : public Node<TaskNode, TaskEdge> {
void EnrollProducedRegstDesc(const std::string& regst_desc_name,
std::unique_ptr<RegstDesc>&& regst_desc);
virtual void FwBuildExecAndProducedRegsts(TaskGraph*) { UNEXPECTED_RUN(); }
virtual void BpBuildExecAndProducedRegsts(TaskGraph*) { UNEXPECTED_RUN(); }
virtual void FwBuildExecAndEnrollLbn2Regsts(TaskGraph*) { UNEXPECTED_RUN(); }
virtual void BpBuildExecAndEnrollLbn2Regsts(TaskGraph*) { UNEXPECTED_RUN(); }
virtual void FwInferShape4LbnInProducedRegsts(TaskGraph*) { UNEXPECTED_RUN(); }
virtual void BpInferShape4LbnInProducedRegsts(TaskGraph*) { UNEXPECTED_RUN(); }
private:
// In task_gph level
......
......@@ -13,7 +13,7 @@ class CommNetOp final : public SysOperator {
~CommNetOp() = default;
void InitFromOpConf(const OperatorConf& op_conf) override;
void InferShape4ObAndDtbFromIb() const override { TODO(); }
void InferShape4ObAndDtbFromIb() const override { UNEXPECTED_RUN(); }
std::string GetValueFromPbOpConf(const std::string& k) const override;
std::string normal_ibn2lbn(const std::string& input_bn) const override;
......
#include "operator/copy_hd_op.h"
#include "operator/operator_manager.h"
namespace oneflow {
void CopyHdOp::InitFromOpConf(const OperatorConf& op_conf) {
CHECK(op_conf.has_copy_hd_conf());
mut_op_conf() = op_conf;
EnrollInputBn("in");
EnrollOutputBn("out");
}
std::string CopyHdOp::GetValueFromPbOpConf(const std::string& k) const {
return GetValueFromPbMessage(op_conf().copy_hd_conf(), k);
}
REGISTER_OP(OperatorConf::kCopyConf, CopyHdOp);
} // namespace oneflow
#ifndef ONEFLOW_OPERATOR_COPY_OP_H_
#define ONEFLOW_OPERATOR_COPY_OP_H_
#ifndef ONEFLOW_OPERATOR_COPY_HD_OP_H_
#define ONEFLOW_OPERATOR_COPY_HD_OP_H_
#include "operator/operator.h"
namespace oneflow {
class CopyOp final : public SysOperator {
class CopyHdOp final : public SysOperator {
public:
OF_DISALLOW_COPY_AND_MOVE(CopyOp);
CopyOp() = default;
~CopyOp() = default;
OF_DISALLOW_COPY_AND_MOVE(CopyHdOp);
CopyHdOp() = default;
~CopyHdOp() = default;
void InitFromOpConf(const OperatorConf& op_conf) override;
void InitFromOperatorProto(const OperatorProto& operatorproto) override;
OperatorProto ToOperatorProto() override;
void InferShape4ObAndDtbFromIb() const override;
std::string GetValueFromPbOpConf(const std::string& k) const override;
void InferShape4ObAndDtbFromIb() const override { UNEXPECTED_RUN(); }
std::string normal_ibn2lbn(const std::string& input_bn) const override {
return ibn2lbn_.at(input_bn);
return RegstDesc::kAllLbn;
}
std::string obn2lbn(const std::string& output_bn) const override {
return obn2lbn_.at(output_bn);
return RegstDesc::kAllLbn;
}
private:
HashMap<std::string, std::string> ibn2lbn_;
HashMap<std::string, std::string> obn2lbn_;
};
} // namespace oneflow
#endif // ONEFLOW_OPERATOR_COPY_OP_H_
#endif // ONEFLOW_OPERATOR_COPY_HD_OP_H_
#include "operator/copy_op.h"
#include "operator/operator_manager.h"
namespace oneflow {
void CopyOp::InitFromOpConf(const OperatorConf& op_conf) {
CHECK(op_conf.has_copy_conf());
mut_op_conf() = op_conf;
for (int64_t i = 0; i < op_conf.copy_conf().copied_lbns_size(); ++i) {
std::string ibn = "in_" + std::to_string(i);
EnrollInputBn(ibn);
CHECK(ibn2lbn_.emplace(ibn, op_conf.copy_conf().copied_lbns(i)).second);
std::string obn = "out_" + std::to_string(i);
EnrollOutputBn(obn);
CHECK(obn2lbn_.emplace(obn, op_conf.copy_conf().copied_lbns(i)).second);
}
}
std::string CopyOp::GetValueFromPbOpConf(const std::string& k) const {
return GetValueFromPbMessage(op_conf().copy_conf(), k);
}
void CopyOp::InitFromOperatorProto(const OperatorProto& operatorproto) {
CHECK(operatorproto.has_copy_op());
Operator::InitFromOperatorProto(operatorproto);
ibn2lbn_ = PbMap2HashMap(operatorproto.copy_op().ibn2lbn());
obn2lbn_ = PbMap2HashMap(operatorproto.copy_op().obn2lbn());
}
OperatorProto CopyOp::ToOperatorProto() {
OperatorProto operatorproto = Operator::ToOperatorProto();
CopyOpProto copyopproto;
*(copyopproto.mutable_ibn2lbn()) = HashMap2PbMap(ibn2lbn_);
*(copyopproto.mutable_obn2lbn()) = HashMap2PbMap(obn2lbn_);
*(operatorproto.mutable_copy_op()) = copyopproto;
return operatorproto;
}
void CopyOp::InferShape4ObAndDtbFromIb() const {
CHECK_EQ(output_bns().size(), input_bns().size());
for(size_t i = 0;i < output_bns().size();++ i){
std::string obn = output_bns().at(i);
std::string ibn = input_bns().at(i);
*GetShapePtr(obn) = *GetShapePtr(ibn);
}
}
REGISTER_OP(OperatorConf::kCopyConf, CopyOp);
} // namespace oneflow
......@@ -170,13 +170,12 @@ message CommNetOpConf {
CommNetType comm_net_type = 1;
}
message CopyOpConf {
enum CopyType {
message CopyHdOpConf {
enum CopyHdType {
H2D = 0;
D2H = 1;
}
CopyType copy_type = 1;
repeated string copied_lbns = 2;
CopyHdType type = 1;
}
......@@ -186,11 +185,14 @@ message CloneOpConf {
}
message BoxConcatConf {
int32 axis = 1;
enum ConcatType {
kData = 0;
kModel = 1;
};
ConcatType type = 1;
}
message BoxSplitConf {
int32 axis = 1;
message BoxDataSplitConf {
}
message BoxCloneConf {
......@@ -202,7 +204,7 @@ message BoxingOpConf {
uint32 out_num = 3;
BoxConcatConf concat_box = 4;
oneof out_box {
BoxSplitConf split_box = 5;
BoxDataSplitConf data_split_box = 5;
BoxCloneConf clone_box = 6;
}
}
......@@ -226,7 +228,7 @@ message OperatorConf {
ReluOpConf relu_conf = 104;
SoftmaxOpConf softmax_conf = 105;
MultinomialLogisticLossOpConf multinomial_logistic_loss_conf = 106;
CopyOpConf copy_conf = 107;
CopyHdOpConf copy_hd_conf = 107;
CloneOpConf clone_conf = 108;
BoxingOpConf boxing_conf = 109;
ModelUpdateOpConf model_update_conf = 110;
......
syntax = "proto3";
package oneflow;
import "operator/op_conf.proto";
message CopyOpProto {
map<string, string> ibn2lbn = 1;
map<string, string> obn2lbn = 2;
}
import "operator/op_conf.proto";
message OperatorProto {
OperatorConf user_conf = 1;
......@@ -22,6 +18,5 @@ message OperatorProto {
repeated string model_tmp_bns = 11;
oneof specified_op_proto {
CopyOpProto copy_op = 100;
}
}
......@@ -7,14 +7,22 @@ RegstDesc::RegstDesc() {
producer_ = nullptr;
}
void RegstDesc::CopyLbn2ShapeMap(const RegstDesc* rhs) {
void RegstDesc::CopyLbnFrom(const RegstDesc* rhs) {
lbn2shape_.clear();
for (const auto& pair : rhs->lbn2shape_) {
const std::string& lbn = pair.first;
auto shape = of_make_unique<Shape> (*(pair.second));
auto shape = of_make_unique<Shape> ();
CHECK(lbn2shape_.emplace(lbn, std::move(shape)).second);
}
}
void RegstDesc::CopyShapeFrom(const RegstDesc* rhs) {
for (const auto& pair : lbn2shape_) {
const std::string& lbn = pair.first;
*(lbn2shape_.at(lbn)) = rhs->GetShape(lbn);
}
}
Shape* RegstDesc::EnrollLbn(const std::string& lbn) {
Shape* raw_ptr = new Shape;
std::unique_ptr<Shape> uptr(raw_ptr);
......
......@@ -11,11 +11,11 @@ namespace oneflow {
class TaskNode;
class RegstDesc {
class RegstDesc final {
public:
OF_DISALLOW_COPY_AND_MOVE(RegstDesc);
RegstDesc();
virtual ~RegstDesc() = default;
~RegstDesc() = default;
// regst_desc_id
uint64_t regst_desc_id() const { return regst_desc_id_; }
......@@ -25,7 +25,8 @@ class RegstDesc {
void SetProducer(const TaskNode* task_node) { producer_ = task_node; }
// Lbn and Shape
void CopyLbn2ShapeMap(const RegstDesc*);
void CopyLbnFrom(const RegstDesc*);
void CopyShapeFrom(const RegstDesc*);
Shape* EnrollLbn(const std::string& lbn);
const Shape& GetShape(const std::string& lbn);
Shape* GetMutShapePtr(const std::string& lbn);
......@@ -40,24 +41,6 @@ class RegstDesc {
};
class ContigRegstDesc final : public RegstDesc {
public:
OF_DISALLOW_COPY_AND_MOVE(ContigRegstDesc);
ContigRegstDesc() = default;
~ContigRegstDesc() = default;
private:
};
class DisContigRegstDesc final : public RegstDesc {
public:
OF_DISALLOW_COPY_AND_MOVE(DisContigRegstDesc);
DisContigRegstDesc() = default;
~DisContigRegstDesc() = default;
};
} // namespace oneflow
#endif // ONEFLOW_REGISTER_REGISTER_DESC_H_
#ifndef ONEFLOW_REGISTER_REGISTER_DESC_MANAGER_H_
#define ONEFLOW_REGISTER_REGISTER_DESC_MANAGER_H_
#include "register/register_desc.h"
namespace oneflow {
class RegstDescMgr final {
public:
OF_DISALLOW_COPY_AND_MOVE(RegstDescMgr);
RegstDescMgr() = default;
~RegstDescMgr() = default;
std::unique_ptr<RegstDesc> CreateRegisterDesc() {
return of_make_unique<RegstDesc> ();
}
private:
};
} // namespace oneflow
#endif // ONEFLOW_REGISTER_REGISTER_DESC_MANAGER_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册