diff --git a/oneflow/blob/blob_descriptor.h b/oneflow/blob/blob_descriptor.h index 2ca91830f472aa3e1b43293e0ab3ae78df02378a..e6e5011a7f72afb487f6b83c582961f25a52e125 100644 --- a/oneflow/blob/blob_descriptor.h +++ b/oneflow/blob/blob_descriptor.h @@ -26,7 +26,6 @@ class BlobDescriptor { Shape& mutable_shape() { return shape_; } MemoryContext& mutable_memory_context() { return memory_context_; } FloatType& mutable_float_type() { return float_type_; } - private: Shape shape_; diff --git a/oneflow/common/proto_io.cpp b/oneflow/common/proto_io.cpp index ab1b781468bdf73cf9058dea824b31605bb94c94..ac9e400789d9a3dfde8fa12bbc912823220cee80 100644 --- a/oneflow/common/proto_io.cpp +++ b/oneflow/common/proto_io.cpp @@ -18,6 +18,9 @@ using google::protobuf::io::ZeroCopyInputStream; using google::protobuf::io::CodedInputStream; using google::protobuf::io::ZeroCopyOutputStream; using google::protobuf::io::CodedOutputStream; +using google::protobuf::Descriptor; +using google::protobuf::Reflection; +using google::protobuf::FieldDescriptor; // string void ParseProtoFromString(const std::string& str, PbMessage* proto) { @@ -43,4 +46,13 @@ void PrintProtoToTextFile(const PbMessage& proto, close(fd); } +std::string GetStringValueFromPbMessage(const PbMessage& msg, + const std::string& key) { + const Descriptor* d = msg.GetDescriptor(); + const FieldDescriptor* fd = d->FindFieldByName(key); + CHECK_NOTNULL(fd); + const Reflection* r = msg.GetReflection(); + return r->GetString(msg, fd); +} + } // namespace oneflow diff --git a/oneflow/common/proto_io.h b/oneflow/common/proto_io.h index 0a5e1d6333eca64e505366bc1e66e7032a923294..328d5445a48b7aa4dd56acd01559b6545b5c6e55 100644 --- a/oneflow/common/proto_io.h +++ b/oneflow/common/proto_io.h @@ -19,6 +19,9 @@ void ParseProtoFromTextFile(const std::string& file_path, void PrintProtoToTextFile(const PbMessage& proto, const std::string& file_path); +std::string GetStringValueFromPbMessage(const PbMessage& msg, + const std::string& key); + } // namespace caffe #endif // ONEFLOW_PROTO_IO_H_ diff --git a/oneflow/dag/dag.cpp b/oneflow/dag/dag.cpp index f0bc5835a6ebfc8a004aae64d3be0c7c42e6a036..34fcf1faa096b942bc0c41b8b289d1b86df51a95 100644 --- a/oneflow/dag/dag.cpp +++ b/oneflow/dag/dag.cpp @@ -42,4 +42,15 @@ bool Dag::DagIterator::operator != (const Dag::DagIterator& rhs) const { } } +void Dag::ConnectStartAndStop() { + for (DagNode* node : data_op_node_vec_) { + if (node->predecessors().empty()) { + node->AddPredecessor(&start_node_); + } + if (node->successors().empty()) { + stop_node_.AddPredecessor(node); + } + } +} + } // namespace oneflow diff --git a/oneflow/dag/dag.h b/oneflow/dag/dag.h index acf2c39a90a39419e62b9ca1566d4976dbc435a7..8cd264de77f9ae86c8aacb660298353c14305bff 100644 --- a/oneflow/dag/dag.h +++ b/oneflow/dag/dag.h @@ -92,11 +92,33 @@ class Dag { return ret; } + protected: + void ConnectStartAndStop(); + + void RegisterDataNode(std::unique_ptr new_node) { + data_op_node_vec_.push_back(new_node.get()); + data_node_vec_.push_back(std::move(new_node)); + } + void RegisterOpNode(std::unique_ptr new_node) { + data_op_node_vec_.push_back(new_node.get()); + op_node_vec_.push_back(std::move(new_node)); + } + + const std::vector>& op_node_vec() const { + return op_node_vec_; + } + private: std::string dag_name_; DagNode start_node_; DagNode stop_node_; + // In future we can implement a Iterator to replace the data_op_node_vec_ + // which is redundancy + std::vector data_op_node_vec_; + std::vector> data_node_vec_; + std::vector> op_node_vec_; + }; } // namespace oneflow diff --git a/oneflow/dag/dag_node.h b/oneflow/dag/dag_node.h index 206a8072c0cd526ce72b3c94ea0b13387ad1abf1..be3e4d2534435c09428031785f0c6355c9cfced1 100644 --- a/oneflow/dag/dag_node.h +++ b/oneflow/dag/dag_node.h @@ -39,5 +39,36 @@ class DagNode { }; +class DataNode : public DagNode { + public: + DISALLOW_COPY_AND_MOVE(DataNode); + virtual ~DataNode() = default; + + protected: + DataNode() = default; + void Init() { + DagNode::Init(); + } + + private: + +}; + +class OpNode : public DagNode { + public: + DISALLOW_COPY_AND_MOVE(OpNode); + + virtual ~OpNode() = default; + + protected: + OpNode() = default; + void Init() { + DagNode::Init(); + } + + private: + +}; + } // namespace oneflow #endif // ONEFLOW_DAG_DAG_NODE_H_ diff --git a/oneflow/dag/dag_node_test.cpp b/oneflow/dag/dag_node_test.cpp index 0d97b197443c4269f1a0c2f28808313503463040..eb82f9d6d43994bde3fc821b11ee32e1fdcd249f 100644 --- a/oneflow/dag/dag_node_test.cpp +++ b/oneflow/dag/dag_node_test.cpp @@ -2,8 +2,7 @@ #include #include "gtest/gtest.h" #include "common/util.h" -#include "dag/data_node.h" -#include "dag/op_node.h" +#include "dag/dag_node.h" namespace oneflow { diff --git a/oneflow/dag/data_node.h b/oneflow/dag/data_node.h deleted file mode 100644 index ba75a80ef1f711b7cf5edebb514ed02353f0d4d1..0000000000000000000000000000000000000000 --- a/oneflow/dag/data_node.h +++ /dev/null @@ -1,25 +0,0 @@ -#ifndef ONEFLOW_DAG_DATA_NODE_H_ -#define ONEFLOW_DAG_DATA_NODE_H_ - -#include "dag/dag_node.h" - -namespace oneflow { - -class DataNode : public DagNode { - public: - DISALLOW_COPY_AND_MOVE(DataNode); - virtual ~DataNode() = default; - - protected: - DataNode() = default; - void Init() { - DagNode::Init(); - } - - private: - -}; - -} // namespace oneflow - -#endif // ONEFLOW_DAG_DATA_NODE_H_ diff --git a/oneflow/dag/logical_dag.cpp b/oneflow/dag/logical_dag.cpp index e9ff6478e13cdf7ba060bac4ec85d7539b01972f..3678f5f4d25a8c87aaa264e6b84db79f3d7205e8 100644 --- a/oneflow/dag/logical_dag.cpp +++ b/oneflow/dag/logical_dag.cpp @@ -1,4 +1,5 @@ #include "dag/logical_dag.h" +#include "glog/logging.h" #include "layer/layer_desc_factory.h" namespace oneflow { @@ -8,27 +9,75 @@ void LogicalDag::Init(const std::string& dag_name, const Strategy& strategy_conf) { Dag::Init(dag_name); BuildDagStruct(dl_net_conf); - FillNodeWithPlacement(strategy_conf); + FillNodeWithParallelConf(strategy_conf); +} + +// BlobNameInDag = LayerName/BlobNameInLayer +// BlobNameInDagIf means Blob is the input of layer +// BlobNameInDagOf means Blob is the output of layer +static std::string BlobNameInDag2BlobNameInLayer( + const std::string& blob_name_in_dag) { + size_t slash_pos = blob_name_in_dag.find('/'); + CHECK(slash_pos != std::string::npos); + return blob_name_in_dag.substr(slash_pos + 1); } void LogicalDag::BuildDagStruct(const DLNetConf& dl_net_conf) { + // This function only execute few times, so it is ok to declare it + std::unordered_map blob_name_indag_of2ptr; + // Process Layer for (int layer_i = 0; layer_i < dl_net_conf.layer_conf_size(); ++layer_i) { const LayerConf& cur_layer_conf = dl_net_conf.layer_conf(layer_i); // Construct op node LogicalOpNode* cur_op_node = NewLogicalOpNode(); cur_op_node->mutable_layer_desc() = LayerDescFactory::singleton().ConstructLayerDesc(cur_layer_conf); - // Construct input data node - for (const std::string& input_blob_name + // Connect input data node + for (const std::string& blob_name_in_dag_if : cur_op_node->layer_desc().data_blob_desc_set().input_blob_names()) { - input_blob_name.find('/'); + std::string blob_name_in_layer = + BlobNameInDag2BlobNameInLayer(blob_name_in_dag_if); + std::string blob_name_indag_of = + GetStringValueFromPbMessage(cur_layer_conf, blob_name_in_layer); + auto data_node_it = blob_name_indag_of2ptr.find(blob_name_indag_of); + CHECK(data_node_it != blob_name_indag_of2ptr.end()); + cur_op_node->AddPredecessor(data_node_it->second); + } + // Construct and connect output data node + for (const std::string& blob_name_indag_of + : cur_op_node->layer_desc().data_blob_desc_set().output_blob_names()) { + LogicalDataNode* data_node = NewLogicalDataNode(); + bool insert_success = + blob_name_indag_of2ptr.emplace(blob_name_indag_of, data_node).second; + CHECK_EQ(insert_success, true); + data_node->AddPredecessor(cur_op_node); } - // Construct output data node } + blob_name_indag_of2ptr.clear(); + // Post Processing + ConnectStartAndStop(); } -void FillNodeWithPlacement(const Strategy& strategy_conf) { - // TODO +void LogicalDag::FillNodeWithParallelConf(const Strategy& strategy_conf) { + // This function only execute few times, so it is ok to declare it + std::unordered_map layer_name2op_node; + for (const std::unique_ptr& op_node : op_node_vec()) { + auto logical_op_node_ptr = dynamic_cast (op_node.get()); + CHECK_NOTNULL(logical_op_node_ptr); + std::string layer_name = logical_op_node_ptr->layer_desc().layer_name(); + bool emplace_success = + layer_name2op_node.emplace(layer_name, logical_op_node_ptr).second; + CHECK_EQ(emplace_success, true); + } + for (int gid = 0; gid < strategy_conf.placement_group_vec_size(); ++gid) { + const PlacementGroup& cur_group = strategy_conf.placement_group_vec(gid); + for (int li = 0; li < cur_group.layer_name_vec_size(); ++li) { + const std::string& layer_name = cur_group.layer_name_vec(li); + auto it = layer_name2op_node.find(layer_name); + CHECK(it != layer_name2op_node.end()); + it->second->mutable_parallel_conf() = cur_group.parallel_conf(); + } + } } } // namespace oneflow diff --git a/oneflow/dag/logical_dag.h b/oneflow/dag/logical_dag.h index eb9a742896bb2326d8b4d64b36c6f3c9b786d871..6326ca4b9f5b33d9faa7efffbd41aa0b092521b2 100644 --- a/oneflow/dag/logical_dag.h +++ b/oneflow/dag/logical_dag.h @@ -3,13 +3,58 @@ #include #include "dag/dag.h" -#include "dag/logical_data_node.h" -#include "dag/logical_op_node.h" +#include "layer/base_layer_desc.h" #include "job/dlnet_conf.pb.h" #include "job/strategy.pb.h" namespace oneflow { +class LogicalDataNode : public DataNode { + public: + DISALLOW_COPY_AND_MOVE(LogicalDataNode); + LogicalDataNode() = default; + ~LogicalDataNode() = default; + + void Init() { + DataNode::Init(); + // struct style + } + + private: + +}; + +class LogicalOpNode : public OpNode { + public: + DISALLOW_COPY_AND_MOVE(LogicalOpNode); + LogicalOpNode() = default; + ~LogicalOpNode() = default; + + void Init() { + OpNode::Init(); + // struct style + } + + const BaseLayerDesc& layer_desc() const { + return *(layer_desc_.get()); + } + const ParallelConf& parallel_conf() const { + return parallel_conf_; + } + + std::unique_ptr& mutable_layer_desc() { + return layer_desc_; + } + ParallelConf& mutable_parallel_conf() { + return parallel_conf_; + } + + private: + std::unique_ptr layer_desc_; + ParallelConf parallel_conf_; + +}; + class LogicalDag : public Dag { public: DISALLOW_COPY_AND_MOVE(LogicalDag); @@ -22,25 +67,22 @@ class LogicalDag : public Dag { private: void BuildDagStruct(const DLNetConf& dl_net_conf); - void FillNodeWithPlacement(const Strategy& strategy_conf); + void FillNodeWithParallelConf(const Strategy& strategy_conf); LogicalDataNode* NewLogicalDataNode() { - std::unique_ptr new_node(new LogicalDataNode); - new_node->Init(); - logical_data_node_vec_.push_back(std::move(new_node)); - return logical_data_node_vec_.back().get(); + LogicalDataNode* ret_ptr = new LogicalDataNode; + ret_ptr->Init(); + RegisterDataNode(std::unique_ptr (ret_ptr)); + return ret_ptr; } LogicalOpNode* NewLogicalOpNode() { - std::unique_ptr new_node(new LogicalOpNode); - new_node->Init(); - logical_op_node_vec_.push_back(std::move(new_node)); - return logical_op_node_vec_.back().get(); + LogicalOpNode* ret_ptr = new LogicalOpNode; + ret_ptr->Init(); + RegisterOpNode(std::unique_ptr (ret_ptr)); + return ret_ptr; } - std::vector> logical_data_node_vec_; - std::vector> logical_op_node_vec_; - }; } // namespace oneflow diff --git a/oneflow/dag/logical_data_node.h b/oneflow/dag/logical_data_node.h deleted file mode 100644 index 382008457038aedd06cb9673622b3a90a201bf00..0000000000000000000000000000000000000000 --- a/oneflow/dag/logical_data_node.h +++ /dev/null @@ -1,34 +0,0 @@ -#ifndef ONEFLOW_LOGICAL_DATA_NODE_H_ -#define ONEFLOW_LOGICAL_DATA_NODE_H_ - -#include "dag/data_node.h" -#include "blob/blob_descriptor.h" - -namespace oneflow { - -class LogicalDataNode : public DataNode { - public: - DISALLOW_COPY_AND_MOVE(LogicalDataNode); - LogicalDataNode() = default; - ~LogicalDataNode() = default; - - void Init() { - DataNode::Init(); - // struct style - } - - const BlobDescriptor& blob_desc() const { - return blob_desc_; - } - BlobDescriptor& mutable_blob_desc() { - return blob_desc_; - } - - private: - BlobDescriptor blob_desc_; - -}; - -} // namespace oneflow - -#endif // ONEFLOW_LOGICAL_DATA_NODE_H_ diff --git a/oneflow/dag/logical_op_node.h b/oneflow/dag/logical_op_node.h deleted file mode 100644 index 30f8607226373b06e68bb27a98a939aa02cfefdb..0000000000000000000000000000000000000000 --- a/oneflow/dag/logical_op_node.h +++ /dev/null @@ -1,43 +0,0 @@ -#ifndef ONEFLOW_LOGICAL_OP_NODE_H_ -#define ONEFLOW_LOGICAL_OP_NODE_H_ - -#include "dag/op_node.h" -#include "layer/base_layer_desc.h" -#include "job/strategy.pb.h" - -namespace oneflow { - -class LogicalOpNode : public OpNode { - public: - DISALLOW_COPY_AND_MOVE(LogicalOpNode); - LogicalOpNode() = default; - ~LogicalOpNode() = default; - - void Init() { - OpNode::Init(); - // struct style - } - - const BaseLayerDesc& layer_desc() const { - return *(layer_desc_.get()); - } - const ParallelConf& parallel_conf() const { - return parallel_conf_; - } - - std::unique_ptr& mutable_layer_desc() { - return layer_desc_; - } - ParallelConf& mutable_parallel_conf() { - return parallel_conf_; - } - - private: - std::unique_ptr layer_desc_; - ParallelConf parallel_conf_; - -}; - -} // namespace oneflow - -#endif // ONEFLOW_LOGICAL_OP_NODE_H_ diff --git a/oneflow/dag/op_node.h b/oneflow/dag/op_node.h deleted file mode 100644 index 39b8f27a52b99416b4f9c91c2e14d4cfe9dd742f..0000000000000000000000000000000000000000 --- a/oneflow/dag/op_node.h +++ /dev/null @@ -1,26 +0,0 @@ -#ifndef ONEFLOW_DAG_OP_NODE_H_ -#define ONEFLOW_DAG_OP_NODE_H_ - -#include "dag/dag_node.h" - -namespace oneflow { - -class OpNode : public DagNode { - public: - DISALLOW_COPY_AND_MOVE(OpNode); - - virtual ~OpNode() = default; - - protected: - OpNode() = default; - void Init() { - DagNode::Init(); - } - - private: - -}; - -} // namespace oneflow - -#endif // ONEFLOW_DAG_OP_NODE_H_