From 32022d4df0200284cb4ab2ee98cc9e380f31e62d Mon Sep 17 00:00:00 2001 From: willzhang4a58 Date: Fri, 18 Aug 2017 16:25:46 +0800 Subject: [PATCH] blob desc --- cmake/oneflow.cmake | 8 +-- oneflow/core/graph/boxing_task_node.cpp | 14 ++--- oneflow/core/graph/boxing_task_node.h | 6 +-- oneflow/core/graph/copy_task_node.cpp | 4 +- oneflow/core/graph/copy_task_node.h | 2 +- oneflow/core/graph/data_comp_task_node.cpp | 19 ++++--- oneflow/core/graph/data_comp_task_node.h | 6 +-- oneflow/core/graph/exec_graph.cpp | 6 +-- oneflow/core/graph/exec_graph.h | 2 +- .../graph/loss_accumulate_comp_task_node.cpp | 4 +- .../graph/loss_accumulate_comp_task_node.h | 2 +- .../core/graph/loss_record_comp_task_node.cpp | 3 +- .../core/graph/loss_record_comp_task_node.h | 2 +- .../model_diff_accumulate_comp_task_node.cpp | 4 +- .../model_diff_accumulate_comp_task_node.h | 2 +- .../core/graph/model_save_comp_task_node.cpp | 2 +- .../core/graph/model_save_comp_task_node.h | 2 +- .../graph/model_update_comp_task_node.cpp | 12 ++--- .../core/graph/model_update_comp_task_node.h | 2 +- oneflow/core/graph/task_graph.cpp | 7 ++- oneflow/core/graph/task_graph.h | 2 +- oneflow/core/graph/task_node.h | 2 +- oneflow/core/job/compiler.cpp | 2 +- oneflow/core/operator/boxing_op.cpp | 17 +++--- oneflow/core/operator/boxing_op.h | 4 +- oneflow/core/operator/boxing_op_test.cpp | 6 +-- oneflow/core/operator/clone_op.cpp | 8 +-- oneflow/core/operator/clone_op.h | 4 +- oneflow/core/operator/clone_op_test.cpp | 2 +- oneflow/core/operator/concat_op.cpp | 23 ++++---- oneflow/core/operator/concat_op.h | 4 +- oneflow/core/operator/concat_op_test.cpp | 2 +- oneflow/core/operator/convolution_op.cpp | 43 +++++++-------- oneflow/core/operator/convolution_op.h | 4 +- oneflow/core/operator/convolution_op_test.cpp | 4 +- oneflow/core/operator/data_loader_op.cpp | 8 +-- oneflow/core/operator/data_loader_op.h | 4 +- oneflow/core/operator/innerproduct_op.cpp | 20 ++++--- oneflow/core/operator/innerproduct_op.h | 4 +- .../core/operator/innerproduct_op_test.cpp | 4 +- oneflow/core/operator/model_update_op.h | 4 +- .../operator/momentum_model_update_op.cpp | 7 ++- .../core/operator/momentum_model_update_op.h | 4 +- .../operator/multinomial_logistic_loss_op.cpp | 8 +-- .../operator/multinomial_logistic_loss_op.h | 4 +- .../multinomial_logistic_loss_op_test.cpp | 2 +- oneflow/core/operator/operator.h | 10 ++-- oneflow/core/operator/pooling_op.cpp | 26 +++++---- oneflow/core/operator/pooling_op.h | 4 +- oneflow/core/operator/pooling_op_test.cpp | 2 +- oneflow/core/operator/record_op.h | 4 +- oneflow/core/operator/relu_op.cpp | 8 ++- oneflow/core/operator/relu_op.h | 4 +- oneflow/core/operator/relu_op_test.cpp | 2 +- .../core/operator/rmsprop_model_update_op.cpp | 7 ++- .../core/operator/rmsprop_model_update_op.h | 4 +- oneflow/core/operator/softmax_loss_op.cpp | 14 ++--- oneflow/core/operator/softmax_loss_op.h | 4 +- .../core/operator/softmax_loss_op_test.cpp | 2 +- oneflow/core/operator/softmax_op.cpp | 11 ++-- oneflow/core/operator/softmax_op.h | 4 +- oneflow/core/operator/softmax_op_test.cpp | 2 +- oneflow/core/register/blob.h | 2 +- oneflow/core/register/blob_desc.h | 30 +++++++++++ oneflow/core/register/blob_desc.proto | 8 +++ oneflow/core/register/register_desc.cpp | 53 ++++++++++--------- oneflow/core/register/register_desc.h | 21 ++++---- oneflow/core/register/register_desc.proto | 4 +- oneflow/core/register/register_manager.cpp | 13 ++--- .../core/register/runtime_register_desc.cpp | 5 +- oneflow/core/register/runtime_register_desc.h | 8 +-- 71 files changed, 296 insertions(+), 261 deletions(-) create mode 100644 oneflow/core/register/blob_desc.h create mode 100644 oneflow/core/register/blob_desc.proto diff --git a/cmake/oneflow.cmake b/cmake/oneflow.cmake index 5cebccd794..2de16af20f 100644 --- a/cmake/oneflow.cmake +++ b/cmake/oneflow.cmake @@ -108,6 +108,8 @@ foreach(cc ${of_main_cc}) endforeach() # build test -cuda_add_executable(oneflow_testexe ${of_all_test_cc}) -target_link_libraries(oneflow_testexe ${of_libs} ${oneflow_third_party_libs}) -add_test(NAME oneflow_test COMMAND oneflow_testexe) +if(BUILD_TESTING) + cuda_add_executable(oneflow_testexe ${of_all_test_cc}) + target_link_libraries(oneflow_testexe ${of_libs} ${oneflow_third_party_libs}) + add_test(NAME oneflow_test COMMAND oneflow_testexe) +endif() diff --git a/oneflow/core/graph/boxing_task_node.cpp b/oneflow/core/graph/boxing_task_node.cpp index 2d445165f2..af700d13fe 100644 --- a/oneflow/core/graph/boxing_task_node.cpp +++ b/oneflow/core/graph/boxing_task_node.cpp @@ -173,11 +173,11 @@ void BoxingTaskNode::FwBuildChainSortedEdgesPair( } } -void BoxingTaskNode::FwInferShapeOfBlobsInProducedRegsts(TaskGraph*) { +void BoxingTaskNode::FwInferBlobDescInProducedRegsts(TaskGraph*) { exec_gph().ConstForEachNode([this](const ExecNode* exec_node) { - exec_node->op()->InferShape4FwBlobs(exec_node->GetMutShapePtr4BnInOpFunc(), - chain_node()->parallel_desc()->policy(), - 0, 0); + exec_node->op()->InferBlobDesc4FwBlobs( + exec_node->GetBlobDesc4BnInOpFunc(), + chain_node()->parallel_desc()->policy(), 0, 0); }); } @@ -232,16 +232,16 @@ void BoxingTaskNode::BpBuildExecAndEnrollLbn2Regsts(TaskGraph*) { mut_exec_gph().UpdateSourceAndSink(); } -void BoxingTaskNode::BpInferShapeOfBlobsInProducedRegsts(TaskGraph*) { +void BoxingTaskNode::BpInferBlobDescInProducedRegsts(TaskGraph*) { for (TaskEdge* fw_in_edge : GetFwNode()->in_edges()) { auto in_regst = GetRelatedRegst(fw_in_edge); if (auto in_diff_regst = GetBpRegstFromFwRegst(in_regst)) { - in_diff_regst->CopyShapeFrom(in_regst.get()); + in_diff_regst->CopyBlobDescFrom(in_regst.get()); } } auto fw_middle_regst = GetFwNode()->GetProducedRegstDesc("middle"); auto bp_middle_regst = GetProducedRegstDesc("middle"); - bp_middle_regst->CopyShapeFrom(fw_middle_regst.get()); + bp_middle_regst->CopyBlobDescFrom(fw_middle_regst.get()); } } // namespace oneflow diff --git a/oneflow/core/graph/boxing_task_node.h b/oneflow/core/graph/boxing_task_node.h index 5101ee2443..ecb50731cb 100644 --- a/oneflow/core/graph/boxing_task_node.h +++ b/oneflow/core/graph/boxing_task_node.h @@ -41,12 +41,12 @@ class BoxingTaskNode : public TaskNode { private: OVERRIDE_IF_FW_BP_FOR_FUNC(BuildExecAndEnrollLbn2Regsts); - OVERRIDE_IF_FW_BP_FOR_FUNC(InferShapeOfBlobsInProducedRegsts); + OVERRIDE_IF_FW_BP_FOR_FUNC(InferBlobDescInProducedRegsts); void FwBuildExecAndEnrollLbn2Regsts(TaskGraph*); - void FwInferShapeOfBlobsInProducedRegsts(TaskGraph*); + void FwInferBlobDescInProducedRegsts(TaskGraph*); void BpBuildExecAndEnrollLbn2Regsts(TaskGraph*); - void BpInferShapeOfBlobsInProducedRegsts(TaskGraph*); + void BpInferBlobDescInProducedRegsts(TaskGraph*); void EnrollAllRegstAndBindRelatedEdge(); TaskType task_type() const override { return kBoxingTask; } diff --git a/oneflow/core/graph/copy_task_node.cpp b/oneflow/core/graph/copy_task_node.cpp index 0d8d3d8792..39ba616de8 100644 --- a/oneflow/core/graph/copy_task_node.cpp +++ b/oneflow/core/graph/copy_task_node.cpp @@ -25,10 +25,10 @@ void CopyTaskNode::BuildExecAndEnrollLbn2Regsts(TaskGraph*) { mut_exec_gph().UpdateSourceAndSink(); } -void CopyTaskNode::InferShapeOfBlobsInProducedRegsts(TaskGraph*) { +void CopyTaskNode::InferBlobDescInProducedRegsts(TaskGraph*) { std::shared_ptr in_regst = GetRelatedRegst(SoleInEdge()); std::shared_ptr out_regst = GetRelatedRegst(SoleOutEdge()); - out_regst->CopyShapeFrom(in_regst.get()); + out_regst->CopyBlobDescFrom(in_regst.get()); } void CopyHDTaskNode::SetFwInCopy() { diff --git a/oneflow/core/graph/copy_task_node.h b/oneflow/core/graph/copy_task_node.h index 7c473c2c8f..ff3904a5a5 100644 --- a/oneflow/core/graph/copy_task_node.h +++ b/oneflow/core/graph/copy_task_node.h @@ -16,7 +16,7 @@ class CopyTaskNode : public TaskNode { private: void BuildExecAndEnrollLbn2Regsts(TaskGraph*) override; - void InferShapeOfBlobsInProducedRegsts(TaskGraph*) override; + void InferBlobDescInProducedRegsts(TaskGraph*) override; }; class CopyHDTaskNode final : public CopyTaskNode { diff --git a/oneflow/core/graph/data_comp_task_node.cpp b/oneflow/core/graph/data_comp_task_node.cpp index bd5485eb3f..f705ed0300 100644 --- a/oneflow/core/graph/data_comp_task_node.cpp +++ b/oneflow/core/graph/data_comp_task_node.cpp @@ -24,17 +24,16 @@ void DataCompTaskNode::FwBuildExecAndEnrollLbn2Regsts(TaskGraph*) { FwEnrollLbn2ModelAndTmpRegsts(); // model model_tmp data_tmp } -void DataCompTaskNode::FwInferShapeOfBlobsInProducedRegsts(TaskGraph*) { +void DataCompTaskNode::FwInferBlobDescInProducedRegsts(TaskGraph*) { exec_gph().ConstTopoForEachNode([this](const ExecNode* node) { - node->op()->InferShape4FwBlobs( - node->GetMutShapePtr4BnInOpFunc(), - chain_node()->parallel_desc()->policy(), parallel_id(), - chain_node()->parallel_desc()->parallel_num()); + node->op()->InferBlobDesc4FwBlobs( + node->GetBlobDesc4BnInOpFunc(), chain_node()->parallel_desc()->policy(), + parallel_id(), chain_node()->parallel_desc()->parallel_num()); }); if (IsLossNode()) { auto out_regst = GetRelatedRegst(SoleOutEdge()); auto in_regst = GetRelatedRegst(SoleInEdge()); - out_regst->CopyShapeFrom(in_regst.get()); + out_regst->CopyBlobDescFrom(in_regst.get()); } } @@ -173,20 +172,20 @@ void DataCompTaskNode::BpBuildExecAndEnrollLbn2Regsts(TaskGraph*) { BpEnrollLbn2ProducedRegst(); } -void DataCompTaskNode::BpInferShapeOfBlobsInProducedRegsts(TaskGraph*) { +void DataCompTaskNode::BpInferBlobDescInProducedRegsts(TaskGraph*) { // in_diff_regst auto in_diff_regst = GetProducedRegstDesc("in_diff"); auto in_regst = GetRelatedRegst(GetFwNode()->SoleInEdge()); - in_diff_regst->CopyShapeFrom(in_regst.get()); + in_diff_regst->CopyBlobDescFrom(in_regst.get()); // model_diff_regst if (auto md_diff_regst = GetProducedRegstDesc("model_diff")) { - md_diff_regst->CopyShapeFrom( + md_diff_regst->CopyBlobDescFrom( GetFwNode()->GetConsumedRegstDesc("model").get()); } // activation_diff_regst if (auto acti_diff_regst = GetProducedRegstDesc("activation_diff")) { auto acti_regst = GetFwNode()->GetProducedRegstDesc("activation"); - acti_diff_regst->CopyShapeFrom(acti_regst.get()); + acti_diff_regst->CopyBlobDescFrom(acti_regst.get()); } } diff --git a/oneflow/core/graph/data_comp_task_node.h b/oneflow/core/graph/data_comp_task_node.h index 2dc336bb51..eb5e35055d 100644 --- a/oneflow/core/graph/data_comp_task_node.h +++ b/oneflow/core/graph/data_comp_task_node.h @@ -41,12 +41,12 @@ class DataCompTaskNode final : public CompTaskNode { private: OVERRIDE_IF_FW_BP_FOR_FUNC(BuildExecAndEnrollLbn2Regsts); - OVERRIDE_IF_FW_BP_FOR_FUNC(InferShapeOfBlobsInProducedRegsts); + OVERRIDE_IF_FW_BP_FOR_FUNC(InferBlobDescInProducedRegsts); using Lbn2NodeBnMap = HashMap>; void FwBuildExecAndEnrollLbn2Regsts(TaskGraph* gph); - void FwInferShapeOfBlobsInProducedRegsts(TaskGraph* gph); + void FwInferBlobDescInProducedRegsts(TaskGraph* gph); void FwBuildFromUserOps(Lbn2NodeBnMap* lbn2producer, Lbn2NodeBnMap* extern_in_lbn2consumer); void FwSetExecNodeFromInRegst(const Lbn2NodeBnMap& extern_in_lbn2consumer); @@ -56,7 +56,7 @@ class DataCompTaskNode final : public CompTaskNode { void FwEnrollLbn2ActivationRegst(); void FwEnrollLbn2ModelAndTmpRegsts(); void BpBuildExecAndEnrollLbn2Regsts(TaskGraph*); - void BpInferShapeOfBlobsInProducedRegsts(TaskGraph*); + void BpInferBlobDescInProducedRegsts(TaskGraph*); void BpBuildExecGraph(); void BpEnrollLbn2ProducedRegst(); void BpEnrollLbn2ActivationDiffRegst(); diff --git a/oneflow/core/graph/exec_graph.cpp b/oneflow/core/graph/exec_graph.cpp index 7b1affb9df..b49859119d 100644 --- a/oneflow/core/graph/exec_graph.cpp +++ b/oneflow/core/graph/exec_graph.cpp @@ -4,14 +4,14 @@ namespace oneflow { void ExecEdge::set_lbn(const std::string& lbn) { lbn_ = lbn; } -std::function ExecNode::GetMutShapePtr4BnInOpFunc() +std::function ExecNode::GetBlobDesc4BnInOpFunc() const { - return [this](const std::string& bn_in_op) -> Shape* { + return [this](const std::string& bn_in_op) -> BlobDesc* { auto it = this->bn_in_op2regst_.find(bn_in_op); if (it == this->bn_in_op2regst_.end()) { return nullptr; } std::shared_ptr regst = it->second.lock(); const std::string& lbn = this->op()->Lbn4BnInOp(bn_in_op); - return regst->GetMutShapePtr(lbn); + return regst->GetMutBlobDesc(lbn); }; } diff --git a/oneflow/core/graph/exec_graph.h b/oneflow/core/graph/exec_graph.h index 28ec35c062..abaaf79920 100644 --- a/oneflow/core/graph/exec_graph.h +++ b/oneflow/core/graph/exec_graph.h @@ -55,7 +55,7 @@ class ExecNode final : public Node { return bn_in_op2regst_; } - std::function GetMutShapePtr4BnInOpFunc() const; + std::function GetBlobDesc4BnInOpFunc() const; std::string VisualStr() const { return op_->op_name(); } diff --git a/oneflow/core/graph/loss_accumulate_comp_task_node.cpp b/oneflow/core/graph/loss_accumulate_comp_task_node.cpp index 3111000cd7..d4b6a4f09a 100644 --- a/oneflow/core/graph/loss_accumulate_comp_task_node.cpp +++ b/oneflow/core/graph/loss_accumulate_comp_task_node.cpp @@ -22,11 +22,11 @@ void LossAccCompTaskNode::BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) { mut_exec_gph().UpdateSourceAndSink(); } -void LossAccCompTaskNode::InferShapeOfBlobsInProducedRegsts(TaskGraph* gph) { +void LossAccCompTaskNode::InferBlobDescInProducedRegsts(TaskGraph* gph) { if (!chain_node()->op_vec().empty()) { auto loss_regst = GetConsumedRegstDesc("loss"); auto loss_acc_regst = GetProducedRegstDesc("loss_acc"); - loss_acc_regst->CopyShapeFrom(loss_regst.get()); + loss_acc_regst->CopyBlobDescFrom(loss_regst.get()); } } diff --git a/oneflow/core/graph/loss_accumulate_comp_task_node.h b/oneflow/core/graph/loss_accumulate_comp_task_node.h index fee4da89b0..edf8151382 100644 --- a/oneflow/core/graph/loss_accumulate_comp_task_node.h +++ b/oneflow/core/graph/loss_accumulate_comp_task_node.h @@ -13,7 +13,7 @@ class LossAccCompTaskNode final : public CompTaskNode { private: void BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) override; - void InferShapeOfBlobsInProducedRegsts(TaskGraph* gph) override; + void InferBlobDescInProducedRegsts(TaskGraph* gph) override; TaskType task_type() const override { return kLossAccCompTask; } std::unique_ptr CreateSameTypeNode() const override { return of_make_unique(); diff --git a/oneflow/core/graph/loss_record_comp_task_node.cpp b/oneflow/core/graph/loss_record_comp_task_node.cpp index 0fd6707303..f84bff7112 100644 --- a/oneflow/core/graph/loss_record_comp_task_node.cpp +++ b/oneflow/core/graph/loss_record_comp_task_node.cpp @@ -20,7 +20,6 @@ void LossRecordCompTaskNode::BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) { mut_exec_gph().UpdateSourceAndSink(); } -void LossRecordCompTaskNode::InferShapeOfBlobsInProducedRegsts(TaskGraph* gph) { -} +void LossRecordCompTaskNode::InferBlobDescInProducedRegsts(TaskGraph* gph) {} } // namespace oneflow diff --git a/oneflow/core/graph/loss_record_comp_task_node.h b/oneflow/core/graph/loss_record_comp_task_node.h index fea270d6c9..7387cf97a7 100644 --- a/oneflow/core/graph/loss_record_comp_task_node.h +++ b/oneflow/core/graph/loss_record_comp_task_node.h @@ -13,7 +13,7 @@ class LossRecordCompTaskNode final : public CompTaskNode { private: void BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) override; - void InferShapeOfBlobsInProducedRegsts(TaskGraph* gph) override; + void InferBlobDescInProducedRegsts(TaskGraph* gph) override; bool IsMeaningLess() const override { return !GetConsumedRegstDesc("loss_acc"); } diff --git a/oneflow/core/graph/model_diff_accumulate_comp_task_node.cpp b/oneflow/core/graph/model_diff_accumulate_comp_task_node.cpp index 53bd780c41..3a00d46360 100644 --- a/oneflow/core/graph/model_diff_accumulate_comp_task_node.cpp +++ b/oneflow/core/graph/model_diff_accumulate_comp_task_node.cpp @@ -35,13 +35,13 @@ void MdDiffAccCompTaskNode::BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) { mut_exec_gph().UpdateSourceAndSink(); } -void MdDiffAccCompTaskNode::InferShapeOfBlobsInProducedRegsts(TaskGraph* gph) { +void MdDiffAccCompTaskNode::InferBlobDescInProducedRegsts(TaskGraph* gph) { CHECK(IsFwNode()); if (!chain_node()->op_vec().empty()) { std::shared_ptr in_regst = GetConsumedRegstDesc("model_diff"); std::shared_ptr out_regst = GetProducedRegstDesc("model_diff_acc"); - out_regst->CopyShapeFrom(in_regst.get()); + out_regst->CopyBlobDescFrom(in_regst.get()); } } diff --git a/oneflow/core/graph/model_diff_accumulate_comp_task_node.h b/oneflow/core/graph/model_diff_accumulate_comp_task_node.h index c5af6371c5..d8c6ef34a4 100644 --- a/oneflow/core/graph/model_diff_accumulate_comp_task_node.h +++ b/oneflow/core/graph/model_diff_accumulate_comp_task_node.h @@ -19,7 +19,7 @@ class MdDiffAccCompTaskNode final : public CompTaskNode { private: void BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) override; - void InferShapeOfBlobsInProducedRegsts(TaskGraph* gph) override; + void InferBlobDescInProducedRegsts(TaskGraph* gph) override; TaskType task_type() const override { return kMdDiffAccCompTask; } std::unique_ptr CreateSameTypeNode() const override { return of_make_unique(); diff --git a/oneflow/core/graph/model_save_comp_task_node.cpp b/oneflow/core/graph/model_save_comp_task_node.cpp index 111583c8db..293cafee43 100644 --- a/oneflow/core/graph/model_save_comp_task_node.cpp +++ b/oneflow/core/graph/model_save_comp_task_node.cpp @@ -32,7 +32,7 @@ void MdSaveCompTaskNode::BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) { } } -void MdSaveCompTaskNode::InferShapeOfBlobsInProducedRegsts(TaskGraph* gph) { +void MdSaveCompTaskNode::InferBlobDescInProducedRegsts(TaskGraph* gph) { CHECK(IsFwNode()); } diff --git a/oneflow/core/graph/model_save_comp_task_node.h b/oneflow/core/graph/model_save_comp_task_node.h index 1031d6b374..d6bfd624de 100644 --- a/oneflow/core/graph/model_save_comp_task_node.h +++ b/oneflow/core/graph/model_save_comp_task_node.h @@ -22,7 +22,7 @@ class MdSaveCompTaskNode final : public CompTaskNode { private: void BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) override; - void InferShapeOfBlobsInProducedRegsts(TaskGraph* gph) override; + void InferBlobDescInProducedRegsts(TaskGraph* gph) override; bool IsMeaningLess() const override { return !GetConsumedRegstDesc("model"); } TaskType task_type() const override { return kMdSaveCompTask; } diff --git a/oneflow/core/graph/model_update_comp_task_node.cpp b/oneflow/core/graph/model_update_comp_task_node.cpp index 9dbf88915f..0220874253 100644 --- a/oneflow/core/graph/model_update_comp_task_node.cpp +++ b/oneflow/core/graph/model_update_comp_task_node.cpp @@ -34,17 +34,17 @@ void MdUpdtCompTaskNode::BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) { mut_exec_gph().UpdateSourceAndSink(); } -void MdUpdtCompTaskNode::InferShapeOfBlobsInProducedRegsts(TaskGraph* gph) { +void MdUpdtCompTaskNode::InferBlobDescInProducedRegsts(TaskGraph* gph) { CHECK(IsFwNode()); ExecNode* exec_node = exec_gph().SoleNode(); auto model_diffs_regst = GetConsumedRegstDesc("model_diffs"); - Shape packed_model_diffs_shape({model_diffs_regst->CompElemCntOfAllBlob()}); - exec_node->op()->InferShape4FwBlobs( - [&](const std::string& bn_in_op) -> Shape* { + BlobDesc packed_blob_desc = model_diffs_regst->CompPackedBlobDesc(); + exec_node->op()->InferBlobDesc4FwBlobs( + [&](const std::string& bn_in_op) -> BlobDesc* { if (bn_in_op == "model_diffs") { - return &packed_model_diffs_shape; + return &packed_blob_desc; } else { - return exec_node->GetMutShapePtr4BnInOpFunc()(bn_in_op); + return exec_node->GetBlobDesc4BnInOpFunc()(bn_in_op); } }, kDataParallel, 0, 0); diff --git a/oneflow/core/graph/model_update_comp_task_node.h b/oneflow/core/graph/model_update_comp_task_node.h index 9f68eec5b1..80667d4577 100644 --- a/oneflow/core/graph/model_update_comp_task_node.h +++ b/oneflow/core/graph/model_update_comp_task_node.h @@ -35,7 +35,7 @@ class MdUpdtCompTaskNode final : public CompTaskNode { private: void BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) override; - void InferShapeOfBlobsInProducedRegsts(TaskGraph* gph) override; + void InferBlobDescInProducedRegsts(TaskGraph* gph) override; TaskType task_type() const override { return kMdUpdtCompTask; } std::unique_ptr CreateSameTypeNode() const override { return of_make_unique(); diff --git a/oneflow/core/graph/task_graph.cpp b/oneflow/core/graph/task_graph.cpp index 00fe01ef5d..b2da096f75 100644 --- a/oneflow/core/graph/task_graph.cpp +++ b/oneflow/core/graph/task_graph.cpp @@ -33,10 +33,9 @@ void TaskGraph::BuildExecAndEnrollLbn2Regsts() { [this](TaskNode* node) { node->BuildExecAndEnrollLbn2Regsts(this); }); } -void TaskGraph::InferShapeOfBlobsInProducedRegsts() { - TopoForEachNode([this](TaskNode* node) { - node->InferShapeOfBlobsInProducedRegsts(this); - }); +void TaskGraph::InferBlobDescInProducedRegsts() { + TopoForEachNode( + [this](TaskNode* node) { node->InferBlobDescInProducedRegsts(this); }); } std::vector TaskGraph::CompTasksInChain(const ChainNode* chain) { diff --git a/oneflow/core/graph/task_graph.h b/oneflow/core/graph/task_graph.h index ac89465f59..241c18b575 100644 --- a/oneflow/core/graph/task_graph.h +++ b/oneflow/core/graph/task_graph.h @@ -22,7 +22,7 @@ class TaskGraph : public Graph { const ChainGraph* chain_gph() const { return stage_gph_->chain_gph(); } std::vector CompTasksInChain(const ChainNode*); - void InferShapeOfBlobsInProducedRegsts(); + void InferBlobDescInProducedRegsts(); const std::string& name() const { return name_; } diff --git a/oneflow/core/graph/task_node.h b/oneflow/core/graph/task_node.h index e6209ca452..9a428b48b1 100644 --- a/oneflow/core/graph/task_node.h +++ b/oneflow/core/graph/task_node.h @@ -42,7 +42,7 @@ class TaskNode : public Node { // virtual void BuildExecAndEnrollLbn2Regsts(TaskGraph*) = 0; - virtual void InferShapeOfBlobsInProducedRegsts(TaskGraph*) = 0; + virtual void InferBlobDescInProducedRegsts(TaskGraph*) = 0; #define OVERRIDE_IF_FW_BP_FOR_FUNC(func_name) \ void func_name(TaskGraph* gph) override { \ diff --git a/oneflow/core/job/compiler.cpp b/oneflow/core/job/compiler.cpp index 9133b781a0..50729e2c27 100644 --- a/oneflow/core/job/compiler.cpp +++ b/oneflow/core/job/compiler.cpp @@ -168,7 +168,7 @@ void Compiler::BuildLossGraph( void Compiler::InferShape4Regsts() { for (auto& task_gph : ordered_task_gphs_) { LOG(INFO) << "InferShape for " << task_gph->name(); - task_gph->InferShapeOfBlobsInProducedRegsts(); + task_gph->InferBlobDescInProducedRegsts(); } } diff --git a/oneflow/core/operator/boxing_op.cpp b/oneflow/core/operator/boxing_op.cpp index e4b296dd26..b0e31864d8 100644 --- a/oneflow/core/operator/boxing_op.cpp +++ b/oneflow/core/operator/boxing_op.cpp @@ -31,13 +31,13 @@ std::string BoxingOp::obn2lbn(const std::string& output_bn) const { return GetStringFromSpecialConf("lbn"); } -void BoxingOp::InferShape4FwBlobs( - std::function GetShapePtr4BnInOp, +void BoxingOp::InferBlobDesc4FwBlobs( + std::function GetBlobDesc4BnInOp, ParallelPolicy policy, int64_t parallel_id, int64_t parallel_num) const { auto boxing_conf = op_conf().boxing_conf(); auto in_box_case = boxing_conf.in_box_case(); std::vector data_tmp_blob_shape_vec = - GetShapePtr4BnInOp(input_bns().at(0))->dim_vec(); + GetBlobDesc4BnInOp(input_bns().at(0))->shape().dim_vec(); // if it is a concat-box, accumulate the dimensions on concat-axis. // otherwise only check all boxes are in the same shape. @@ -47,7 +47,8 @@ void BoxingOp::InferShape4FwBlobs( CHECK(concat_axis == 0 || concat_axis == 1); } for (size_t ib_idx = 1; ib_idx < input_bns().size(); ++ib_idx) { - auto ib_shape_vec = GetShapePtr4BnInOp(input_bns().at(ib_idx))->dim_vec(); + auto ib_shape_vec = + GetBlobDesc4BnInOp(input_bns().at(ib_idx))->shape().dim_vec(); for (size_t i = 0; i < ib_shape_vec.size(); ++i) { if (in_box_case == BoxingOpConf::kConcatBox && i == concat_axis) { data_tmp_blob_shape_vec[i] += ib_shape_vec[i]; @@ -61,7 +62,8 @@ void BoxingOp::InferShape4FwBlobs( // it is stored back if and only if this is a concat-clone box if (in_box_case == BoxingOpConf::kConcatBox && out_box_case == BoxingOpConf::kCloneBox) { - *GetShapePtr4BnInOp(SoleDtbn()) = Shape(data_tmp_blob_shape_vec); + GetBlobDesc4BnInOp(SoleDtbn())->mut_shape() = + Shape(data_tmp_blob_shape_vec); } CHECK_NE(out_box_case, BoxingOpConf::OUT_BOX_NOT_SET); if (out_box_case == BoxingOpConf::kDataSplitBox) { @@ -70,11 +72,12 @@ void BoxingOp::InferShape4FwBlobs( auto output_shape_vec = data_tmp_blob_shape_vec; for (size_t i = 0; i < out_num; ++i) { output_shape_vec[0] = splitter.At(i).size(); - *GetShapePtr4BnInOp(output_bns()[i]) = Shape(output_shape_vec); + GetBlobDesc4BnInOp(output_bns()[i])->mut_shape() = + Shape(output_shape_vec); } } else if (out_box_case == BoxingOpConf::kCloneBox) { for (auto obn : output_bns()) { - *GetShapePtr4BnInOp(obn) = Shape(data_tmp_blob_shape_vec); + GetBlobDesc4BnInOp(obn)->mut_shape() = Shape(data_tmp_blob_shape_vec); } } else { UNEXPECTED_RUN(); diff --git a/oneflow/core/operator/boxing_op.h b/oneflow/core/operator/boxing_op.h index 9b39639984..1ffb1cb707 100644 --- a/oneflow/core/operator/boxing_op.h +++ b/oneflow/core/operator/boxing_op.h @@ -14,8 +14,8 @@ class BoxingOp final : public SysOperator { void InitFromOpConf(const OperatorConf& op_conf) override; const PbMessage& GetSpecialConf() const override; - void InferShape4FwBlobs( - std::function GetShapePtr4BnInOp, + void InferBlobDesc4FwBlobs( + std::function GetBlobDesc4BnInOp, ParallelPolicy policy, int64_t parallel_id, int64_t parallel_num) const override; diff --git a/oneflow/core/operator/boxing_op_test.cpp b/oneflow/core/operator/boxing_op_test.cpp index a3ade1b7b2..1583b88b69 100644 --- a/oneflow/core/operator/boxing_op_test.cpp +++ b/oneflow/core/operator/boxing_op_test.cpp @@ -36,7 +36,7 @@ TEST(BoxingOp, box_4_10x5x6x6) { }; // do infer shape - boxing_op->InferShape4FwBlobs(fp, kModelParallel, 0, 1); + boxing_op->InferBlobDesc4FwBlobs(fp, kModelParallel, 0, 1); // test results // output_shape should be: @@ -58,7 +58,7 @@ TEST(BoxingOp, box_4_10x5x6x6) { boxing_op = ConstructOp(op_conf); // do infer shape - boxing_op->InferShape4FwBlobs(fp, kModelParallel, 0, 1); + boxing_op->InferBlobDesc4FwBlobs(fp, kModelParallel, 0, 1); // test results // output shape should be the same as input @@ -75,7 +75,7 @@ TEST(BoxingOp, box_4_10x5x6x6) { boxing_op = ConstructOp(op_conf); // do infer shape - boxing_op->InferShape4FwBlobs(fp, kModelParallel, 0, 1); + boxing_op->InferBlobDesc4FwBlobs(fp, kModelParallel, 0, 1); // data_tmp_shape is {10, 17, 6, 6}, and the 17 = 4 + 4 + 4 + 5 Shape* data_tmp_shape_ptr = bn2shape_ptr.at(boxing_op->SoleDtbn()); diff --git a/oneflow/core/operator/clone_op.cpp b/oneflow/core/operator/clone_op.cpp index f49f6abf3d..cd3ba125dc 100644 --- a/oneflow/core/operator/clone_op.cpp +++ b/oneflow/core/operator/clone_op.cpp @@ -16,12 +16,12 @@ const PbMessage& CloneOp::GetSpecialConf() const { return op_conf().clone_conf(); } -void CloneOp::InferShape4FwBlobs( - std::function GetShapePtr4BnInOp, +void CloneOp::InferBlobDesc4FwBlobs( + std::function GetBlobDesc4BnInOp, ParallelPolicy policy, int64_t parallel_id, int64_t parallel_num) const { - Shape* input_shape_ptr = GetShapePtr4BnInOp(SoleIbn()); + const BlobDesc* input_blob_desc = GetBlobDesc4BnInOp(SoleIbn()); for (std::string obn : output_bns()) { - *GetShapePtr4BnInOp(obn) = *input_shape_ptr; + *GetBlobDesc4BnInOp(obn) = *input_blob_desc; } } diff --git a/oneflow/core/operator/clone_op.h b/oneflow/core/operator/clone_op.h index 9c70a598b5..1c7c1f5b8a 100644 --- a/oneflow/core/operator/clone_op.h +++ b/oneflow/core/operator/clone_op.h @@ -15,8 +15,8 @@ class CloneOp final : public SysOperator { void InitFromOpConf(const OperatorConf& op_conf) override; const PbMessage& GetSpecialConf() const override; - void InferShape4FwBlobs( - std::function GetShapePtr4BnInOp, + void InferBlobDesc4FwBlobs( + std::function GetBlobDesc4BnInOp, ParallelPolicy policy, int64_t parallel_id, int64_t parallel_num) const override; diff --git a/oneflow/core/operator/clone_op_test.cpp b/oneflow/core/operator/clone_op_test.cpp index 457bcf5fa0..e534bbc811 100644 --- a/oneflow/core/operator/clone_op_test.cpp +++ b/oneflow/core/operator/clone_op_test.cpp @@ -18,7 +18,7 @@ TEST(CloneOp, clone_4x3_3_times) { return bn2shape_ptr.at(bn); }; - clone_op->InferShape4FwBlobs(fp, kDataParallel, 3, 10); + clone_op->InferBlobDesc4FwBlobs(fp, kDataParallel, 3, 10); Shape* input_shape_ptr = bn2shape_ptr.at(clone_op->SoleIbn()); for (std::string obn : clone_op->output_bns()) { diff --git a/oneflow/core/operator/concat_op.cpp b/oneflow/core/operator/concat_op.cpp index 5b790c8684..a8ce7ceba8 100644 --- a/oneflow/core/operator/concat_op.cpp +++ b/oneflow/core/operator/concat_op.cpp @@ -18,23 +18,26 @@ const PbMessage& ConcatOp::GetSpecialConf() const { return op_conf().concat_conf(); } -void ConcatOp::InferShape4FwBlobs( - std::function GetShapePtr4BnInOp, +void ConcatOp::InferBlobDesc4FwBlobs( + std::function GetBlobDesc4BnInOp, ParallelPolicy policy, int64_t parallel_id, int64_t parallel_num) const { - std::vector vec = GetShapePtr4BnInOp(input_bns().at(0))->dim_vec(); + std::vector vec = + GetBlobDesc4BnInOp(input_bns().at(0))->shape().dim_vec(); for (size_t ibn_idx = 1; ibn_idx < input_bns().size(); ++ibn_idx) { - Shape* ib_shape = GetShapePtr4BnInOp(input_bns().at(ibn_idx)); + const Shape& ib_shape = + GetBlobDesc4BnInOp(input_bns().at(ibn_idx))->shape(); int32_t concat_axis = op_conf().concat_conf().axis(); - for (int64_t j = 0; j < ib_shape->NumAxes(); ++j) { - if (j == concat_axis || j == concat_axis + ib_shape->NumAxes()) { - vec[j] += ib_shape->At(j); + for (int64_t j = 0; j < ib_shape.NumAxes(); ++j) { + if (j == concat_axis || j == concat_axis + ib_shape.NumAxes()) { + vec[j] += ib_shape.At(j); } else { - CHECK_EQ(vec[j], ib_shape->At(j)); + CHECK_EQ(vec[j], ib_shape.At(j)); } } } - CHECK_EQ(vec.size(), GetShapePtr4BnInOp(input_bns().at(0))->NumAxes()); - *GetShapePtr4BnInOp(SoleObn()) = Shape(vec); + CHECK_EQ(vec.size(), + GetBlobDesc4BnInOp(input_bns().at(0))->shape().NumAxes()); + GetBlobDesc4BnInOp(SoleObn())->mut_shape() = Shape(vec); } REGISTER_OP(OperatorConf::kConcatConf, ConcatOp); diff --git a/oneflow/core/operator/concat_op.h b/oneflow/core/operator/concat_op.h index f8f2b8b6f7..2f848c429f 100644 --- a/oneflow/core/operator/concat_op.h +++ b/oneflow/core/operator/concat_op.h @@ -15,8 +15,8 @@ class ConcatOp final : public UserOperator { const PbMessage& GetSpecialConf() const override; - void InferShape4FwBlobs( - std::function GetShapePtr4BnInOp, + void InferBlobDesc4FwBlobs( + std::function GetBlobDesc4BnInOp, ParallelPolicy policy, int64_t parallel_id, int64_t parallel_num) const override; diff --git a/oneflow/core/operator/concat_op_test.cpp b/oneflow/core/operator/concat_op_test.cpp index 4f74601f47..35fd2f5ea9 100644 --- a/oneflow/core/operator/concat_op_test.cpp +++ b/oneflow/core/operator/concat_op_test.cpp @@ -21,7 +21,7 @@ TEST(ConcatOp, concat_two_3x3) { return bn2shape_ptr.at(bn); }; // infershape - concat_op->InferShape4FwBlobs(fp, kDataParallel, 0, 1); + concat_op->InferBlobDesc4FwBlobs(fp, kDataParallel, 0, 1); // test Shape* output_shape_ptr = fp(concat_op->SoleObn()); ASSERT_EQ(*output_shape_ptr, Shape({3, 6})); diff --git a/oneflow/core/operator/convolution_op.cpp b/oneflow/core/operator/convolution_op.cpp index cf4fb2bb6c..3b60e4ac9d 100644 --- a/oneflow/core/operator/convolution_op.cpp +++ b/oneflow/core/operator/convolution_op.cpp @@ -22,15 +22,13 @@ const PbMessage& ConvolutionOp::GetSpecialConf() const { return op_conf().convolution_conf(); } -void ConvolutionOp::InferShape4FwBlobs( - std::function GetShapePtr4BnInOp, +void ConvolutionOp::InferBlobDesc4FwBlobs( + std::function GetBlobDesc4BnInOp, ParallelPolicy policy, int64_t parallel_id, int64_t parallel_num) const { - Shape* input_shape_ptr = GetShapePtr4BnInOp(SoleIbn()); - Shape* output_shape_ptr = GetShapePtr4BnInOp(SoleObn()); - Shape* colbuf_shape_ptr = GetShapePtr4BnInOp("col_buf"); + const Shape& input_shape = GetBlobDesc4BnInOp(SoleIbn())->shape(); auto conv_conf = op_conf().convolution_conf(); - int64_t batch_size = input_shape_ptr->At(0); - int64_t c_i = input_shape_ptr->At(1); + int64_t batch_size = input_shape.At(0); + int64_t c_i = input_shape.At(1); int32_t out_num = GetInt32FromSpecialConf("out_num"); if (policy == kModelParallel) { @@ -43,32 +41,29 @@ void ConvolutionOp::InferShape4FwBlobs( int64_t output_size = 1; std::vector output_shape_vec = {batch_size, c_o}; - int64_t h_len = (input_shape_ptr->At(2) + 2 * conv_conf.pad_h() - - conv_conf.kernel_size_h()) - / conv_conf.stride_h() - + 1; + int64_t h_len = + (input_shape.At(2) + 2 * conv_conf.pad_h() - conv_conf.kernel_size_h()) + / conv_conf.stride_h() + + 1; output_shape_vec.push_back(h_len); - int64_t w_len = (input_shape_ptr->At(3) + 2 * conv_conf.pad_w() - - conv_conf.kernel_size_w()) - / conv_conf.stride_w() - + 1; + int64_t w_len = + (input_shape.At(3) + 2 * conv_conf.pad_w() - conv_conf.kernel_size_w()) + / conv_conf.stride_w() + + 1; output_shape_vec.push_back(w_len); kernel_size *= conv_conf.kernel_size_h(); kernel_size *= conv_conf.kernel_size_w(); output_size *= h_len; output_size *= w_len; - *output_shape_ptr = Shape(output_shape_vec); - CHECK_EQ(output_shape_ptr->NumAxes(), input_shape_ptr->NumAxes()); - *colbuf_shape_ptr = Shape({batch_size, output_size, c_i * kernel_size}); - Shape* weight = GetShapePtr4BnInOp("weight"); - *weight = Shape({c_o, c_i * kernel_size}); + GetBlobDesc4BnInOp(SoleObn())->mut_shape() = Shape(output_shape_vec); + GetBlobDesc4BnInOp("col_buf")->mut_shape() = + Shape({batch_size, output_size, c_i * kernel_size}); + GetBlobDesc4BnInOp("weight")->mut_shape() = Shape({c_o, c_i * kernel_size}); if (GetBoolFromSpecialConf("has_bias_term")) { - Shape* bias = GetShapePtr4BnInOp("bias"); - Shape* biasmult_shape_ptr = GetShapePtr4BnInOp("bias_multiplier"); - *bias = Shape({c_o}); - *biasmult_shape_ptr = Shape({output_size}); + GetBlobDesc4BnInOp("bias")->mut_shape() = Shape({c_o}); + GetBlobDesc4BnInOp("bias_multiplier")->mut_shape() = Shape({output_size}); } } diff --git a/oneflow/core/operator/convolution_op.h b/oneflow/core/operator/convolution_op.h index 26b1ed1732..bb37bc7730 100644 --- a/oneflow/core/operator/convolution_op.h +++ b/oneflow/core/operator/convolution_op.h @@ -13,8 +13,8 @@ class ConvolutionOp final : public UserOperator { void InitFromOpConf(const OperatorConf& op_conf) override; const PbMessage& GetSpecialConf() const override; - void InferShape4FwBlobs( - std::function GetShapePtr4BnInOp, + void InferBlobDesc4FwBlobs( + std::function GetBlobDesc4BnInOp, ParallelPolicy policy, int64_t parallel_id, int64_t parallel_num) const override; void FixParallelDesc(ParallelDesc* pr_desc) const override { diff --git a/oneflow/core/operator/convolution_op_test.cpp b/oneflow/core/operator/convolution_op_test.cpp index 31874fbf93..3338a05f47 100644 --- a/oneflow/core/operator/convolution_op_test.cpp +++ b/oneflow/core/operator/convolution_op_test.cpp @@ -37,7 +37,7 @@ void TestDataParallelConvolutionOp() { }; // infershape - convolution_op->InferShape4FwBlobs(fp, kDataParallel, 0, 1); + convolution_op->InferBlobDesc4FwBlobs(fp, kDataParallel, 0, 1); // test Shape* output_shape_ptr = fp(convolution_op->SoleObn()); @@ -71,7 +71,7 @@ void TestModelParallelConvolutionOp() { }; // infershape - convolution_op->InferShape4FwBlobs(fp, kModelParallel, 3, 8); + convolution_op->InferBlobDesc4FwBlobs(fp, kModelParallel, 3, 8); // test Shape* output_shape_ptr = fp(convolution_op->SoleObn()); diff --git a/oneflow/core/operator/data_loader_op.cpp b/oneflow/core/operator/data_loader_op.cpp index 9185c5c15b..89bb75950b 100644 --- a/oneflow/core/operator/data_loader_op.cpp +++ b/oneflow/core/operator/data_loader_op.cpp @@ -15,8 +15,8 @@ const PbMessage& DataLoaderOp::GetSpecialConf() const { return op_conf().data_loader_conf(); } -void DataLoaderOp::InferShape4FwBlobs( - std::function GetShapePtr4BnInOp, +void DataLoaderOp::InferBlobDesc4FwBlobs( + std::function GetBlobDesc4BnInOp, ParallelPolicy policy, int64_t parallel_id, int64_t parallel_num) const { // useful vars int32_t piece_size = JobDesc::Singleton()->piece_size(); @@ -27,9 +27,9 @@ void DataLoaderOp::InferShape4FwBlobs( feature_shape.insert(feature_shape.end(), feature_shape_of_one_ins.dim_vec().begin(), feature_shape_of_one_ins.dim_vec().end()); - *GetShapePtr4BnInOp("feature") = Shape(feature_shape); + GetBlobDesc4BnInOp("feature")->mut_shape() = Shape(feature_shape); // label shape - *GetShapePtr4BnInOp("label") = Shape({piece_size}); + GetBlobDesc4BnInOp("label")->mut_shape() = Shape({piece_size}); } REGISTER_OP(OperatorConf::kDataLoaderConf, DataLoaderOp); diff --git a/oneflow/core/operator/data_loader_op.h b/oneflow/core/operator/data_loader_op.h index 9aa45a0105..f51a982436 100644 --- a/oneflow/core/operator/data_loader_op.h +++ b/oneflow/core/operator/data_loader_op.h @@ -14,8 +14,8 @@ class DataLoaderOp final : public SysOperator { void InitFromOpConf(const OperatorConf& op_conf) override; const PbMessage& GetSpecialConf() const override; - void InferShape4FwBlobs( - std::function GetShapePtr4BnInOp, + void InferBlobDesc4FwBlobs( + std::function GetBlobDesc4BnInOp, ParallelPolicy policy, int64_t parallel_id, int64_t parallel_num) const override; diff --git a/oneflow/core/operator/innerproduct_op.cpp b/oneflow/core/operator/innerproduct_op.cpp index f141501d57..575db254b7 100644 --- a/oneflow/core/operator/innerproduct_op.cpp +++ b/oneflow/core/operator/innerproduct_op.cpp @@ -21,10 +21,10 @@ const PbMessage& InnerProductOp::GetSpecialConf() const { return op_conf().innerproduct_conf(); } -void InnerProductOp::InferShape4FwBlobs( - std::function GetShapePtr4BnInOp, +void InnerProductOp::InferBlobDesc4FwBlobs( + std::function GetBlobDesc4BnInOp, ParallelPolicy policy, int64_t parallel_id, int64_t parallel_num) const { - Shape* in_shape_ptr = GetShapePtr4BnInOp(SoleIbn()); + const Shape& in_shape = GetBlobDesc4BnInOp(SoleIbn())->shape(); int32_t out_num = GetInt32FromSpecialConf("out_num"); if (policy == kModelParallel) { BalancedSplitter splitter(out_num, parallel_num); @@ -32,22 +32,20 @@ void InnerProductOp::InferShape4FwBlobs( } // output bn - Shape* out_shape_ptr = GetShapePtr4BnInOp(SoleObn()); - *out_shape_ptr = Shape({in_shape_ptr->At(0), out_num}); + GetBlobDesc4BnInOp(SoleObn())->mut_shape() = Shape({in_shape.At(0), out_num}); // model bn - Shape* weight_shape_ptr = GetShapePtr4BnInOp("weight"); - *weight_shape_ptr = Shape({out_num, in_shape_ptr->Count(1)}); + GetBlobDesc4BnInOp("weight")->mut_shape() = + Shape({out_num, in_shape.Count(1)}); if (GetBoolFromSpecialConf("has_bias_term")) { // model bn - Shape* bias_shape_ptr = GetShapePtr4BnInOp("bias"); - *bias_shape_ptr = Shape({1, out_num}); + GetBlobDesc4BnInOp("bias")->mut_shape() = Shape({1, out_num}); // model tmp bn CHECK_EQ(model_tmp_bns().size(), 1); - Shape* bias_multiplier_shape_ptr = GetShapePtr4BnInOp("bias_multiplier"); - *bias_multiplier_shape_ptr = Shape({in_shape_ptr->At(0), 1}); + GetBlobDesc4BnInOp("bias_multiplier")->mut_shape() = + Shape({in_shape.At(0), 1}); } } diff --git a/oneflow/core/operator/innerproduct_op.h b/oneflow/core/operator/innerproduct_op.h index 6a0c2a2f52..1076b5e66e 100644 --- a/oneflow/core/operator/innerproduct_op.h +++ b/oneflow/core/operator/innerproduct_op.h @@ -13,8 +13,8 @@ class InnerProductOp final : public UserOperator { void InitFromOpConf(const OperatorConf& op_conf) override; const PbMessage& GetSpecialConf() const override; - void InferShape4FwBlobs( - std::function GetShapePtr4BnInOp, + void InferBlobDesc4FwBlobs( + std::function GetBlobDesc4BnInOp, ParallelPolicy policy, int64_t parallel_id, int64_t parallel_num) const override; void FixParallelDesc(ParallelDesc* pr_desc) const override { diff --git a/oneflow/core/operator/innerproduct_op_test.cpp b/oneflow/core/operator/innerproduct_op_test.cpp index 139e899601..606f6993dd 100644 --- a/oneflow/core/operator/innerproduct_op_test.cpp +++ b/oneflow/core/operator/innerproduct_op_test.cpp @@ -29,7 +29,7 @@ void TestModelParallelInnerProductOp(bool has_bias_term) { return bn2shape_ptr.at(bn); }; - ip_op->InferShape4FwBlobs(fp, kModelParallel, 3, 10); + ip_op->InferBlobDesc4FwBlobs(fp, kModelParallel, 3, 10); BalancedSplitter splitter(40, 10); int out_num = splitter.At(3).size(); @@ -70,7 +70,7 @@ void TestDataParallelInnerProductOp(bool has_bias_term) { return bn2shape_ptr.at(bn); }; - ip_op->InferShape4FwBlobs(fp, kDataParallel, 3, 10); + ip_op->InferBlobDesc4FwBlobs(fp, kDataParallel, 3, 10); Shape* out_shape_ptr = bn2shape_ptr.at(ip_op->SoleObn()); CHECK_EQ(*out_shape_ptr, Shape({1000, 40})); diff --git a/oneflow/core/operator/model_update_op.h b/oneflow/core/operator/model_update_op.h index a5495d75fc..bf78c0d070 100644 --- a/oneflow/core/operator/model_update_op.h +++ b/oneflow/core/operator/model_update_op.h @@ -10,8 +10,8 @@ class ModelUpdtOp : public SysOperator { OF_DISALLOW_COPY_AND_MOVE(ModelUpdtOp); virtual ~ModelUpdtOp() = default; - virtual void InferShape4FwBlobs( - std::function GetShapePtr4BnInOp, + virtual void InferBlobDesc4FwBlobs( + std::function GetBlobDesc4BnInOp, ParallelPolicy policy, int64_t parallel_id, int64_t parallel_num) const override { // do nothing diff --git a/oneflow/core/operator/momentum_model_update_op.cpp b/oneflow/core/operator/momentum_model_update_op.cpp index b589c33f3e..778fe991fe 100644 --- a/oneflow/core/operator/momentum_model_update_op.cpp +++ b/oneflow/core/operator/momentum_model_update_op.cpp @@ -15,11 +15,10 @@ const PbMessage& MomentumModelUpdateOp::GetSpecialConf() const { return op_conf().momentum_mdupdt_conf(); } -void MomentumModelUpdateOp::InferShape4FwBlobs( - std::function GetShapePtr4BnInOp, +void MomentumModelUpdateOp::InferBlobDesc4FwBlobs( + std::function GetBlobDesc4BnInOp, ParallelPolicy policy, int64_t parallel_id, int64_t parallel_num) const { - Shape* input_shape_ptr = GetShapePtr4BnInOp("model_diffs"); - *GetShapePtr4BnInOp("momentum") = *input_shape_ptr; + TODO(); } REGISTER_OP(OperatorConf::kMomentumMdupdtConf, MomentumModelUpdateOp); diff --git a/oneflow/core/operator/momentum_model_update_op.h b/oneflow/core/operator/momentum_model_update_op.h index bfc8da32a1..ba119b20b7 100644 --- a/oneflow/core/operator/momentum_model_update_op.h +++ b/oneflow/core/operator/momentum_model_update_op.h @@ -13,8 +13,8 @@ class MomentumModelUpdateOp final : public ModelUpdtOp { void InitFromOpConf(const OperatorConf& op_conf) override; const PbMessage& GetSpecialConf() const override; - void InferShape4FwBlobs( - std::function GetShapePtr4BnInOp, + void InferBlobDesc4FwBlobs( + std::function GetBlobDesc4BnInOp, ParallelPolicy policy, int64_t parallel_id, int64_t parallel_num) const override; diff --git a/oneflow/core/operator/multinomial_logistic_loss_op.cpp b/oneflow/core/operator/multinomial_logistic_loss_op.cpp index 06f1e04d6a..0ba8c43c63 100644 --- a/oneflow/core/operator/multinomial_logistic_loss_op.cpp +++ b/oneflow/core/operator/multinomial_logistic_loss_op.cpp @@ -16,11 +16,11 @@ const PbMessage& MultinomialLogisticLossOp::GetSpecialConf() const { return op_conf().multinomial_logistic_loss_conf(); } -void MultinomialLogisticLossOp::InferShape4FwBlobs( - std::function GetShapePtr4BnInOp, +void MultinomialLogisticLossOp::InferBlobDesc4FwBlobs( + std::function GetBlobDesc4BnInOp, ParallelPolicy policy, int64_t parallel_id, int64_t parallel_num) const { - *GetShapePtr4BnInOp("loss") = Shape({1}); - *GetShapePtr4BnInOp("loss_buffer") = Shape({1}); + GetBlobDesc4BnInOp("loss")->mut_shape() = Shape({1}); + GetBlobDesc4BnInOp("loss_buffer")->mut_shape() = Shape({1}); } REGISTER_OP(OperatorConf::kMultinomialLogisticLossConf, diff --git a/oneflow/core/operator/multinomial_logistic_loss_op.h b/oneflow/core/operator/multinomial_logistic_loss_op.h index 6310da73df..30c5ce8b9b 100644 --- a/oneflow/core/operator/multinomial_logistic_loss_op.h +++ b/oneflow/core/operator/multinomial_logistic_loss_op.h @@ -17,8 +17,8 @@ class MultinomialLogisticLossOp final : public UserOperator { const PbMessage& GetSpecialConf() const override; bool IsLossOp() const override { return true; } - void InferShape4FwBlobs( - std::function GetShapePtr4BnInOp, + void InferBlobDesc4FwBlobs( + std::function GetBlobDesc4BnInOp, ParallelPolicy policy, int64_t parallel_id, int64_t parallel_num) const override; diff --git a/oneflow/core/operator/multinomial_logistic_loss_op_test.cpp b/oneflow/core/operator/multinomial_logistic_loss_op_test.cpp index a8fa37efeb..b7a6f678c6 100644 --- a/oneflow/core/operator/multinomial_logistic_loss_op_test.cpp +++ b/oneflow/core/operator/multinomial_logistic_loss_op_test.cpp @@ -21,7 +21,7 @@ TEST(MultinomialLogisticLossOp, test_loss_op) { return bn2shape_ptr.at(bn); }; - loss_op->InferShape4FwBlobs(fp, kDataParallel, 2, 10); + loss_op->InferBlobDesc4FwBlobs(fp, kDataParallel, 2, 10); Shape* loss_shape_ptr = bn2shape_ptr.at(loss_op->SoleObn()); Shape* loss_buffer_shape_ptr = bn2shape_ptr.at(loss_op->SoleDtbn()); diff --git a/oneflow/core/operator/operator.h b/oneflow/core/operator/operator.h index 2fac09b228..e96cb92fa0 100644 --- a/oneflow/core/operator/operator.h +++ b/oneflow/core/operator/operator.h @@ -2,13 +2,13 @@ #define ONEFLOW_CORE_OPERATOR_OPERATOR_H_ #include "oneflow/core/common/protobuf.h" -#include "oneflow/core/common/shape.h" #include "oneflow/core/common/util.h" #include "oneflow/core/job/keyword.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/job/placement.pb.h" #include "oneflow/core/operator/op_conf.pb.h" #include "oneflow/core/operator/operator.pb.h" +#include "oneflow/core/register/blob_desc.h" namespace oneflow { @@ -79,8 +79,8 @@ class Operator { // Read: shape of input_blobs // Write: shape of output_blobs, model_blobs, data_tmp_blobs, model_tmp_blobs - virtual void InferShape4FwBlobs( - std::function GetShapePtr4BnInOp, + virtual void InferBlobDesc4FwBlobs( + std::function GetBlobDesc4BnInOp, ParallelPolicy policy, int64_t parallel_id, int64_t parallel_num) const = 0; @@ -144,8 +144,8 @@ class SysOperator : public Operator { SysOperator() = default; virtual ~SysOperator() = default; - virtual void InferShape4FwBlobs( - std::function GetShapePtr4BnInOp, + virtual void InferBlobDesc4FwBlobs( + std::function GetBlobDesc4BnInOp, ParallelPolicy policy, int64_t parallel_id, int64_t parallel_num) const override { UNEXPECTED_RUN(); diff --git a/oneflow/core/operator/pooling_op.cpp b/oneflow/core/operator/pooling_op.cpp index a5a2b420dc..0325f814d9 100644 --- a/oneflow/core/operator/pooling_op.cpp +++ b/oneflow/core/operator/pooling_op.cpp @@ -15,32 +15,30 @@ const PbMessage& PoolingOp::GetSpecialConf() const { return op_conf().pooling_conf(); } -void PoolingOp::InferShape4FwBlobs( - std::function GetShapePtr4BnInOp, +void PoolingOp::InferBlobDesc4FwBlobs( + std::function GetBlobDesc4BnInOp, ParallelPolicy policy, int64_t parallel_id, int64_t parallel_num) const { - Shape* input_shape_ptr = GetShapePtr4BnInOp(SoleIbn()); - CHECK_EQ(input_shape_ptr->NumAxes(), 4); - Shape* output_shape_ptr = GetShapePtr4BnInOp(SoleObn()); + const Shape& input_shape = GetBlobDesc4BnInOp(SoleIbn())->shape(); + CHECK_EQ(input_shape.NumAxes(), 4); + BlobDesc* output_blob_desc = GetBlobDesc4BnInOp(SoleObn()); const PoolingOpConf& pooling_conf = op_conf().pooling_conf(); - std::vector output_shape_dim_vec = {input_shape_ptr->At(0), - input_shape_ptr->At(1)}; + std::vector output_shape_dim_vec = {input_shape.At(0), + input_shape.At(1)}; - output_shape_dim_vec.push_back((input_shape_ptr->At(2) - + 2 * pooling_conf.pad_h() + output_shape_dim_vec.push_back((input_shape.At(2) + 2 * pooling_conf.pad_h() - pooling_conf.kernel_size_h()) / pooling_conf.stride_h() + 1); - output_shape_dim_vec.push_back((input_shape_ptr->At(3) - + 2 * pooling_conf.pad_w() + output_shape_dim_vec.push_back((input_shape.At(3) + 2 * pooling_conf.pad_w() - pooling_conf.kernel_size_w()) / pooling_conf.stride_w() + 1); - *output_shape_ptr = Shape(output_shape_dim_vec); - Shape* data_tmp_shape_ptr = GetShapePtr4BnInOp(SoleDtbn()); - *data_tmp_shape_ptr = Shape(output_shape_dim_vec); + output_blob_desc->mut_shape() = Shape(output_shape_dim_vec); + BlobDesc* data_tmp_blob_desc = GetBlobDesc4BnInOp(SoleDtbn()); + data_tmp_blob_desc->mut_shape() = Shape(output_shape_dim_vec); } REGISTER_OP(OperatorConf::kPoolingConf, PoolingOp); diff --git a/oneflow/core/operator/pooling_op.h b/oneflow/core/operator/pooling_op.h index b6dbcfd8b8..ace71984aa 100644 --- a/oneflow/core/operator/pooling_op.h +++ b/oneflow/core/operator/pooling_op.h @@ -16,8 +16,8 @@ class PoolingOp final : public UserOperator { void InitFromOpConf(const OperatorConf& op_conf) override; const PbMessage& GetSpecialConf() const override; - void InferShape4FwBlobs( - std::function GetShapePtr4BnInOp, + void InferBlobDesc4FwBlobs( + std::function GetBlobDesc4BnInOp, ParallelPolicy policy, int64_t parallel_id, int64_t parallel_num) const override; diff --git a/oneflow/core/operator/pooling_op_test.cpp b/oneflow/core/operator/pooling_op_test.cpp index f525832b8e..9cc9abbf13 100644 --- a/oneflow/core/operator/pooling_op_test.cpp +++ b/oneflow/core/operator/pooling_op_test.cpp @@ -30,7 +30,7 @@ TEST(PoolingOp, pool_100x64x11x11) { return bn2shape_ptr.at(bn); }; // do infer shape - pooling_op->InferShape4FwBlobs(fp, kDataParallel, 0, 1); + pooling_op->InferBlobDesc4FwBlobs(fp, kDataParallel, 0, 1); // test Shape* output_shape_ptr = bn2shape_ptr.at(pooling_op->SoleObn()); Shape* data_tmp_shape_ptr = bn2shape_ptr.at(pooling_op->SoleDtbn()); diff --git a/oneflow/core/operator/record_op.h b/oneflow/core/operator/record_op.h index 06e5fe301d..d02a532706 100644 --- a/oneflow/core/operator/record_op.h +++ b/oneflow/core/operator/record_op.h @@ -15,8 +15,8 @@ class RecordOp final : public SysOperator { const PbMessage& GetSpecialConf() const override; bool IsRecordOp() const override { return true; } - void InferShape4FwBlobs( - std::function GetShapePtr4BnInOp, + void InferBlobDesc4FwBlobs( + std::function GetBlobDesc4BnInOp, ParallelPolicy policy, int64_t parallel_id, int64_t parallel_num) const override {} diff --git a/oneflow/core/operator/relu_op.cpp b/oneflow/core/operator/relu_op.cpp index ec35d5fccb..8fc5b6b200 100644 --- a/oneflow/core/operator/relu_op.cpp +++ b/oneflow/core/operator/relu_op.cpp @@ -14,12 +14,10 @@ const PbMessage& ReluOp::GetSpecialConf() const { return op_conf().relu_conf(); } -void ReluOp::InferShape4FwBlobs( - std::function GetShapePtr4BnInOp, +void ReluOp::InferBlobDesc4FwBlobs( + std::function GetBlobDesc4BnInOp, ParallelPolicy policy, int64_t parallel_id, int64_t parallel_num) const { - Shape* output_shape_ptr = GetShapePtr4BnInOp(SoleObn()); - Shape* input_shape_ptr = GetShapePtr4BnInOp(SoleIbn()); - *output_shape_ptr = *input_shape_ptr; + *GetBlobDesc4BnInOp(SoleObn()) = *GetBlobDesc4BnInOp(SoleIbn()); } REGISTER_OP(OperatorConf::kReluConf, ReluOp); diff --git a/oneflow/core/operator/relu_op.h b/oneflow/core/operator/relu_op.h index 631213372f..ee776f0816 100644 --- a/oneflow/core/operator/relu_op.h +++ b/oneflow/core/operator/relu_op.h @@ -15,8 +15,8 @@ class ReluOp final : public UserOperator { const PbMessage& GetSpecialConf() const override; bool IsElemWise() const override { return true; } - void InferShape4FwBlobs( - std::function GetShapePtr4BnInOp, + void InferBlobDesc4FwBlobs( + std::function GetBlobDesc4BnInOp, ParallelPolicy policy, int64_t parallel_id, int64_t parallel_num) const override; diff --git a/oneflow/core/operator/relu_op_test.cpp b/oneflow/core/operator/relu_op_test.cpp index 1753a0c2e9..74c12bad2a 100644 --- a/oneflow/core/operator/relu_op_test.cpp +++ b/oneflow/core/operator/relu_op_test.cpp @@ -17,7 +17,7 @@ TEST(ReluOp, relu_3x5x4) { return bn2shape_ptr.at(bn); }; // do infer shape - relu_op->InferShape4FwBlobs(fp, kDataParallel, 0, 1); + relu_op->InferBlobDesc4FwBlobs(fp, kDataParallel, 0, 1); // test Shape* input_shape_ptr = bn2shape_ptr.at(relu_op->SoleIbn()); Shape* output_shape_ptr = bn2shape_ptr.at(relu_op->SoleObn()); diff --git a/oneflow/core/operator/rmsprop_model_update_op.cpp b/oneflow/core/operator/rmsprop_model_update_op.cpp index c517b113c2..2da4a892d1 100644 --- a/oneflow/core/operator/rmsprop_model_update_op.cpp +++ b/oneflow/core/operator/rmsprop_model_update_op.cpp @@ -15,11 +15,10 @@ const PbMessage& RMSPropModelUpdateOp::GetSpecialConf() const { return op_conf().rmsprop_mdupdt_conf(); } -void RMSPropModelUpdateOp::InferShape4FwBlobs( - std::function GetShapePtr4BnInOp, +void RMSPropModelUpdateOp::InferBlobDesc4FwBlobs( + std::function GetBlobDesc4BnInOp, ParallelPolicy policy, int64_t parallel_id, int64_t parallel_num) const { - Shape* input_shape_ptr = GetShapePtr4BnInOp("model_diffs"); - *GetShapePtr4BnInOp("mean_square") = *input_shape_ptr; + TODO(); } REGISTER_OP(OperatorConf::kRmspropMdupdtConf, RMSPropModelUpdateOp); diff --git a/oneflow/core/operator/rmsprop_model_update_op.h b/oneflow/core/operator/rmsprop_model_update_op.h index c8ee7dea14..72fbbc32e5 100644 --- a/oneflow/core/operator/rmsprop_model_update_op.h +++ b/oneflow/core/operator/rmsprop_model_update_op.h @@ -13,8 +13,8 @@ class RMSPropModelUpdateOp final : public ModelUpdtOp { void InitFromOpConf(const OperatorConf& op_conf) override; const PbMessage& GetSpecialConf() const override; - void InferShape4FwBlobs( - std::function GetShapePtr4BnInOp, + void InferBlobDesc4FwBlobs( + std::function GetBlobDesc4BnInOp, ParallelPolicy policy, int64_t parallel_id, int64_t parallel_num) const override; diff --git a/oneflow/core/operator/softmax_loss_op.cpp b/oneflow/core/operator/softmax_loss_op.cpp index 71f47282f7..19369aab9d 100644 --- a/oneflow/core/operator/softmax_loss_op.cpp +++ b/oneflow/core/operator/softmax_loss_op.cpp @@ -17,16 +17,16 @@ const PbMessage& SoftmaxLossOp::GetSpecialConf() const { return op_conf().softmax_loss_conf(); } -void SoftmaxLossOp::InferShape4FwBlobs( - std::function GetShapePtr4BnInOp, +void SoftmaxLossOp::InferBlobDesc4FwBlobs( + std::function GetBlobDesc4BnInOp, ParallelPolicy policy, int64_t parallel_id, int64_t parallel_num) const { const std::vector in_dim_vec = - GetShapePtr4BnInOp("prediction")->dim_vec(); + GetBlobDesc4BnInOp("prediction")->shape().dim_vec(); CHECK_EQ(in_dim_vec.size(), 2); - CHECK_EQ(*GetShapePtr4BnInOp("label"), Shape({in_dim_vec[0]})); - *GetShapePtr4BnInOp(SoleObn()) = Shape({1}); - *GetShapePtr4BnInOp("prob") = Shape(in_dim_vec); - *GetShapePtr4BnInOp("tmp_1D") = Shape({in_dim_vec[0]}); + CHECK_EQ(GetBlobDesc4BnInOp("label")->shape(), Shape({in_dim_vec[0]})); + GetBlobDesc4BnInOp(SoleObn())->mut_shape() = Shape({1}); + GetBlobDesc4BnInOp("prob")->mut_shape() = Shape(in_dim_vec); + GetBlobDesc4BnInOp("tmp_1D")->mut_shape() = Shape({in_dim_vec[0]}); } REGISTER_OP(OperatorConf::kSoftmaxLossConf, SoftmaxLossOp); diff --git a/oneflow/core/operator/softmax_loss_op.h b/oneflow/core/operator/softmax_loss_op.h index f0b2b966ac..cca9bea9c9 100644 --- a/oneflow/core/operator/softmax_loss_op.h +++ b/oneflow/core/operator/softmax_loss_op.h @@ -15,8 +15,8 @@ class SoftmaxLossOp final : public UserOperator { const PbMessage& GetSpecialConf() const override; bool IsLossOp() const override { return true; } - void InferShape4FwBlobs( - std::function GetShapePtr4BnInOp, + void InferBlobDesc4FwBlobs( + std::function GetBlobDesc4BnInOp, ParallelPolicy policy, int64_t parallel_id, int64_t parallel_num) const override; diff --git a/oneflow/core/operator/softmax_loss_op_test.cpp b/oneflow/core/operator/softmax_loss_op_test.cpp index eb7048024f..558465e8a5 100644 --- a/oneflow/core/operator/softmax_loss_op_test.cpp +++ b/oneflow/core/operator/softmax_loss_op_test.cpp @@ -20,7 +20,7 @@ TEST(SoftmaxLossOp, softmax_loss_3x5) { return bn2shape_ptr.at(bn); }; // infershape - softmax_loss_op->InferShape4FwBlobs(fp, kDataParallel, 0, 1); + softmax_loss_op->InferBlobDesc4FwBlobs(fp, kDataParallel, 0, 1); // test ASSERT_EQ(*fp("loss"), Shape({1})); ASSERT_EQ(*fp("prob"), Shape({3, 5})); diff --git a/oneflow/core/operator/softmax_op.cpp b/oneflow/core/operator/softmax_op.cpp index 63bceb28a4..e83723bb82 100644 --- a/oneflow/core/operator/softmax_op.cpp +++ b/oneflow/core/operator/softmax_op.cpp @@ -15,13 +15,14 @@ const PbMessage& SoftmaxOp::GetSpecialConf() const { return op_conf().softmax_conf(); } -void SoftmaxOp::InferShape4FwBlobs( - std::function GetShapePtr4BnInOp, +void SoftmaxOp::InferBlobDesc4FwBlobs( + std::function GetBlobDesc4BnInOp, ParallelPolicy policy, int64_t parallel_id, int64_t parallel_num) const { - std::vector vec = GetShapePtr4BnInOp(SoleIbn())->dim_vec(); + const std::vector& vec = + GetBlobDesc4BnInOp(SoleIbn())->shape().dim_vec(); CHECK_EQ(vec.size(), 2); - *GetShapePtr4BnInOp(SoleObn()) = Shape(vec); - *GetShapePtr4BnInOp(SoleDtbn()) = Shape({vec[0]}); + GetBlobDesc4BnInOp(SoleObn())->mut_shape() = Shape(vec); + GetBlobDesc4BnInOp(SoleDtbn())->mut_shape() = Shape({vec[0]}); } REGISTER_OP(OperatorConf::kSoftmaxConf, SoftmaxOp); diff --git a/oneflow/core/operator/softmax_op.h b/oneflow/core/operator/softmax_op.h index 3aba95c97b..ea6d6bf9d0 100644 --- a/oneflow/core/operator/softmax_op.h +++ b/oneflow/core/operator/softmax_op.h @@ -14,8 +14,8 @@ class SoftmaxOp final : public UserOperator { void InitFromOpConf(const OperatorConf& op_conf) override; const PbMessage& GetSpecialConf() const override; - void InferShape4FwBlobs( - std::function GetShapePtr4BnInOp, + void InferBlobDesc4FwBlobs( + std::function GetBlobDesc4BnInOp, ParallelPolicy policy, int64_t parallel_id, int64_t parallel_num) const override; diff --git a/oneflow/core/operator/softmax_op_test.cpp b/oneflow/core/operator/softmax_op_test.cpp index 7cab6b8a8c..f0b8a3ce06 100644 --- a/oneflow/core/operator/softmax_op_test.cpp +++ b/oneflow/core/operator/softmax_op_test.cpp @@ -17,7 +17,7 @@ TEST(SoftmaxOp, softmax_3x5) { return bn2shape_ptr.at(bn); }; // infershape - softmax_op->InferShape4FwBlobs(fp, kDataParallel, 0, 1); + softmax_op->InferBlobDesc4FwBlobs(fp, kDataParallel, 0, 1); // test Shape* output_shape_ptr = fp(softmax_op->SoleObn()); Shape* tmp_shape_ptr = fp(softmax_op->SoleDtbn()); diff --git a/oneflow/core/register/blob.h b/oneflow/core/register/blob.h index 0979098cc3..f1d7b880ce 100644 --- a/oneflow/core/register/blob.h +++ b/oneflow/core/register/blob.h @@ -6,7 +6,7 @@ namespace oneflow { -class Blob { +class Blob final { public: OF_DISALLOW_COPY_AND_MOVE(Blob); Blob(void* dptr, const Shape* shape) : dptr_(dptr), shape_(shape) {} diff --git a/oneflow/core/register/blob_desc.h b/oneflow/core/register/blob_desc.h new file mode 100644 index 0000000000..2258ef870a --- /dev/null +++ b/oneflow/core/register/blob_desc.h @@ -0,0 +1,30 @@ +#ifndef ONEFLOW_CORE_REGISTER_BLOB_DESC_H_ +#define ONEFLOW_CORE_REGISTER_BLOB_DESC_H_ + +#include "oneflow/core/common/shape.h" +#include "oneflow/core/register/blob_desc.pb.h" + +namespace oneflow { + +class BlobDesc final { + public: + // OF_DISALLOW_COPY_AND_MOVE(BlobDesc); + BlobDesc() = default; + ~BlobDesc() = default; + + BlobDesc(const BlobDescProto& proto) { shape_ = Shape(proto.shape()); } + + const Shape& shape() const { return shape_; } + Shape& mut_shape() { return shape_; } + + void ToProto(BlobDescProto* proto) const { + shape_.ToProto(proto->mutable_shape()); + } + + private: + Shape shape_; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_REGISTER_BLOB_DESC_H_ diff --git a/oneflow/core/register/blob_desc.proto b/oneflow/core/register/blob_desc.proto new file mode 100644 index 0000000000..893f18ed47 --- /dev/null +++ b/oneflow/core/register/blob_desc.proto @@ -0,0 +1,8 @@ +syntax = "proto3"; +package oneflow; + +import "oneflow/core/common/shape.proto"; + +message BlobDescProto { + ShapeProto shape = 1; +} diff --git a/oneflow/core/register/register_desc.cpp b/oneflow/core/register/register_desc.cpp index 4ac2aaa80b..08030d9007 100644 --- a/oneflow/core/register/register_desc.cpp +++ b/oneflow/core/register/register_desc.cpp @@ -40,57 +40,57 @@ void RegstDesc::AddConsumer(const TaskNode* new_consumer) { } void RegstDesc::CopyLbnFrom(const RegstDesc* rhs) { - lbn2shape_.clear(); - for (const auto& pair : rhs->lbn2shape_) { + CHECK(lbn2blob_desc_.empty()); + for (const auto& pair : rhs->lbn2blob_desc_) { const std::string& lbn = pair.first; - auto shape = of_make_unique(); - CHECK(lbn2shape_.emplace(lbn, std::move(shape)).second); + CHECK(lbn2blob_desc_.emplace(lbn, of_make_unique()).second); } } -void RegstDesc::CopyShapeFrom(const RegstDesc* rhs) { - for (const auto& pair : lbn2shape_) { +void RegstDesc::CopyBlobDescFrom(const RegstDesc* rhs) { + for (const auto& pair : lbn2blob_desc_) { const std::string& lbn = pair.first; - *(lbn2shape_.at(lbn)) = rhs->GetShape(lbn); + *(lbn2blob_desc_.at(lbn)) = rhs->GetBlobDesc(lbn); } } void RegstDesc::EnrollLbn(const std::string& lbn) { - std::unique_ptr ptr(new Shape); - CHECK(lbn2shape_.emplace(lbn, std::move(ptr)).second) << lbn; + CHECK(lbn2blob_desc_.emplace(lbn, of_make_unique()).second) << lbn; } -const Shape& RegstDesc::GetShape(const std::string& lbn) const { - return *(lbn2shape_.at(lbn)); +const BlobDesc& RegstDesc::GetBlobDesc(const std::string& lbn) const { + return *(lbn2blob_desc_.at(lbn)); } -Shape* RegstDesc::GetMutShapePtr(const std::string& lbn) { - return lbn2shape_.at(lbn).get(); +BlobDesc* RegstDesc::GetMutBlobDesc(const std::string& lbn) { + return lbn2blob_desc_.at(lbn).get(); } void RegstDesc::ForEachLbn(std::function func) const { - for (const auto& p : lbn2shape_) { func(p.first); } + for (const auto& p : lbn2blob_desc_) { func(p.first); } } void RegstDesc::EraseZeroSizeBlob() { - EraseIf>( - &lbn2shape_, - [](HashMap>::iterator it) { - return it->second->elem_cnt() == 0; + EraseIf>( + &lbn2blob_desc_, + [](HashMap>::iterator it) { + return it->second->shape().elem_cnt() == 0; }); } int64_t RegstDesc::CompElemCntOfAllBlob() const { int64_t sum = 0; - for (const auto& pair : lbn2shape_) { sum += pair.second->elem_cnt(); } + for (const auto& pair : lbn2blob_desc_) { + sum += pair.second->shape().elem_cnt(); + } return sum; } std::string RegstDesc::DebugStr() const { std::stringstream ss; ss << "{"; - for (const auto& pair : lbn2shape_) { - ss << "{" << pair.first << ":" << pair.second->DebugStr() << "}"; + for (const auto& pair : lbn2blob_desc_) { + ss << "{" << pair.first << ":" << pair.second->shape().DebugStr() << "}"; } ss << "}"; return ss.str(); @@ -104,10 +104,10 @@ void RegstDesc::ToProto(RegstDescProto* ret) const { ret->add_consumer_task_id(consumer->task_id()); } } - for (const auto& pair : lbn2shape_) { - PbMapPair pb_pair(pair.first); + for (const auto& pair : lbn2blob_desc_) { + PbMapPair pb_pair(pair.first); pair.second->ToProto(&(pb_pair.second)); - ret->mutable_lbn2shape()->insert(pb_pair); + ret->mutable_lbn2blob_desc()->insert(pb_pair); } ret->set_register_num(register_num_); *(ret->mutable_mem_case()) = InferMemCase(); @@ -140,4 +140,9 @@ MemoryCase RegstDesc::InferMemCase() const { return mem_case; } +BlobDesc RegstDesc::CompPackedBlobDesc() const { + BlobDesc packed_blob_desc; + packed_blob_desc.mut_shape() = Shape({CompElemCntOfAllBlob()}); +} + } // namespace oneflow diff --git a/oneflow/core/register/register_desc.h b/oneflow/core/register/register_desc.h index c9c9038a35..d69e409485 100644 --- a/oneflow/core/register/register_desc.h +++ b/oneflow/core/register/register_desc.h @@ -1,15 +1,11 @@ #ifndef ONEFLOW_CORE_REGISTER_REGISTER_DESC_H_ #define ONEFLOW_CORE_REGISTER_REGISTER_DESC_H_ -#include "oneflow/core/common/shape.h" -#include "oneflow/core/common/util.h" +#include "oneflow/core/register/blob_desc.h" #include "oneflow/core/register/register_desc.pb.h" namespace oneflow { -// Regst : Register -// Contig : Contiguous - class TaskNode; class RegstDesc final { @@ -27,28 +23,29 @@ class RegstDesc final { const HashSet& consumers() const { return consumers_; } void AddConsumer(const TaskNode*); - // Lbn and Shape + // Lbn and BlobDesc void CopyLbnFrom(const RegstDesc*); - void CopyShapeFrom(const RegstDesc*); + void CopyBlobDescFrom(const RegstDesc*); void EnrollLbn(const std::string& lbn); - const Shape& GetShape(const std::string& lbn) const; - Shape* GetMutShapePtr(const std::string& lbn); + const BlobDesc& GetBlobDesc(const std::string& lbn) const; + BlobDesc* GetMutBlobDesc(const std::string& lbn); void ForEachLbn(std::function func) const; - size_t NumOfLbn() const { return lbn2shape_.size(); } + size_t NumOfLbn() const { return lbn2blob_desc_.size(); } // void EraseZeroSizeBlob(); - int64_t CompElemCntOfAllBlob() const; std::string DebugStr() const; void ToProto(RegstDescProto*) const; MemoryCase InferMemCase() const; + BlobDesc CompPackedBlobDesc() const; private: + int64_t CompElemCntOfAllBlob() const; int64_t regst_desc_id_; const TaskNode* producer_; HashSet consumers_; - HashMap> lbn2shape_; + HashMap> lbn2blob_desc_; int64_t register_num_; }; diff --git a/oneflow/core/register/register_desc.proto b/oneflow/core/register/register_desc.proto index 58d548e069..54b3d0e814 100644 --- a/oneflow/core/register/register_desc.proto +++ b/oneflow/core/register/register_desc.proto @@ -1,14 +1,14 @@ syntax = "proto3"; package oneflow; -import "oneflow/core/common/shape.proto"; +import "oneflow/core/register/blob_desc.proto"; import "oneflow/core/memory/memory_case.proto"; message RegstDescProto { int64 regst_desc_id = 1; int64 producer_task_id = 2; repeated int64 consumer_task_id = 3; - map lbn2shape = 4; + map lbn2blob_desc = 4; int64 register_num = 5; MemoryCase mem_case = 6; } diff --git a/oneflow/core/register/register_manager.cpp b/oneflow/core/register/register_manager.cpp index 881faa62b0..1bfc16d8a7 100644 --- a/oneflow/core/register/register_manager.cpp +++ b/oneflow/core/register/register_manager.cpp @@ -19,12 +19,12 @@ void RegstMgr::NewRegsts(const RegstDescProto& regst_desc_proto, } int64_t elem_cnt = 0; std::vector lbns; - lbns.reserve(regst_desc_proto.lbn2shape().size()); - for (const auto& pair : regst_desc_proto.lbn2shape()) { - const Shape* shape_ptr = - runtime_regst_desc->GetShapePtrFromLbn(pair.first); + lbns.reserve(regst_desc_proto.lbn2blob_desc().size()); + for (const auto& pair : regst_desc_proto.lbn2blob_desc()) { + const Shape& shape_ptr = + runtime_regst_desc->GetBlobDescFromLbn(pair.first)->shape(); lbns.push_back(pair.first); - elem_cnt += shape_ptr->elem_cnt(); + elem_cnt += shape_ptr.elem_cnt(); } std::sort(lbns.begin(), lbns.end()); std::pair> allocation = @@ -33,7 +33,8 @@ void RegstMgr::NewRegsts(const RegstDescProto& regst_desc_proto, int64_t blob_idx = 0; for (const std::string& lbn : lbns) { - const Shape* shape_ptr = runtime_regst_desc->GetShapePtrFromLbn(lbn); + const Shape* shape_ptr = + &(runtime_regst_desc->GetBlobDescFromLbn(lbn)->shape()); auto blob_ptr = of_make_unique(allocation.first + blob_idx, shape_ptr); CHECK(regst->lbn2blob_.emplace(lbn, std::move(blob_ptr)).second); diff --git a/oneflow/core/register/runtime_register_desc.cpp b/oneflow/core/register/runtime_register_desc.cpp index e3d78a817b..7d86ddce39 100644 --- a/oneflow/core/register/runtime_register_desc.cpp +++ b/oneflow/core/register/runtime_register_desc.cpp @@ -15,8 +15,9 @@ RtRegstDesc::RtRegstDesc(const RegstDescProto& regst_desc_proto) { consumers_actor_id_.push_back(IDMgr::Singleton()->ActorId4TaskId(task_id)); } - for (const auto& pair : regst_desc_proto.lbn2shape()) { - CHECK(lbn2shape_.emplace(pair.first, of_make_unique(pair.second)) + for (const auto& pair : regst_desc_proto.lbn2blob_desc()) { + CHECK(lbn2blob_desc_ + .emplace(pair.first, of_make_unique(pair.second)) .second); } mem_case_ = regst_desc_proto.mem_case(); diff --git a/oneflow/core/register/runtime_register_desc.h b/oneflow/core/register/runtime_register_desc.h index 31258039c7..14dc491a26 100644 --- a/oneflow/core/register/runtime_register_desc.h +++ b/oneflow/core/register/runtime_register_desc.h @@ -1,8 +1,8 @@ #ifndef ONEFLOW_CORE_REGISTER_RUNTIME_REGISTER_DESC_H_ #define ONEFLOW_CORE_REGISTER_RUNTIME_REGISTER_DESC_H_ -#include "oneflow/core/common/shape.h" #include "oneflow/core/memory/memory_case.pb.h" +#include "oneflow/core/register/blob_desc.h" #include "oneflow/core/register/register_desc.pb.h" namespace oneflow { @@ -23,15 +23,15 @@ class RtRegstDesc { int64_t register_num() const { return register_num_; } const MemoryCase& mem_case() const { return mem_case_; } - const Shape* GetShapePtrFromLbn(const std::string& lbn) const { - return lbn2shape_.at(lbn).get(); + const BlobDesc* GetBlobDescFromLbn(const std::string& lbn) const { + return lbn2blob_desc_.at(lbn).get(); } private: int64_t regst_desc_id_; int64_t producer_actor_id_; std::vector consumers_actor_id_; - std::unordered_map> lbn2shape_; + std::unordered_map> lbn2blob_desc_; int64_t register_num_; MemoryCase mem_case_; }; -- GitLab