提交 f6cef908 编写于 作者: W willzhang4a58

logical dag done

上级 775d3772
......@@ -26,7 +26,6 @@ class BlobDescriptor {
Shape& mutable_shape() { return shape_; }
MemoryContext& mutable_memory_context() { return memory_context_; }
FloatType& mutable_float_type() { return float_type_; }
private:
Shape shape_;
......
......@@ -18,6 +18,9 @@ using google::protobuf::io::ZeroCopyInputStream;
using google::protobuf::io::CodedInputStream;
using google::protobuf::io::ZeroCopyOutputStream;
using google::protobuf::io::CodedOutputStream;
using google::protobuf::Descriptor;
using google::protobuf::Reflection;
using google::protobuf::FieldDescriptor;
// string
void ParseProtoFromString(const std::string& str, PbMessage* proto) {
......@@ -43,4 +46,13 @@ void PrintProtoToTextFile(const PbMessage& proto,
close(fd);
}
std::string GetStringValueFromPbMessage(const PbMessage& msg,
const std::string& key) {
const Descriptor* d = msg.GetDescriptor();
const FieldDescriptor* fd = d->FindFieldByName(key);
CHECK_NOTNULL(fd);
const Reflection* r = msg.GetReflection();
return r->GetString(msg, fd);
}
} // namespace oneflow
......@@ -19,6 +19,9 @@ void ParseProtoFromTextFile(const std::string& file_path,
void PrintProtoToTextFile(const PbMessage& proto,
const std::string& file_path);
std::string GetStringValueFromPbMessage(const PbMessage& msg,
const std::string& key);
} // namespace caffe
#endif // ONEFLOW_PROTO_IO_H_
......@@ -42,4 +42,15 @@ bool Dag::DagIterator::operator != (const Dag::DagIterator& rhs) const {
}
}
void Dag::ConnectStartAndStop() {
for (DagNode* node : data_op_node_vec_) {
if (node->predecessors().empty()) {
node->AddPredecessor(&start_node_);
}
if (node->successors().empty()) {
stop_node_.AddPredecessor(node);
}
}
}
} // namespace oneflow
......@@ -92,11 +92,33 @@ class Dag {
return ret;
}
protected:
void ConnectStartAndStop();
void RegisterDataNode(std::unique_ptr<DataNode> new_node) {
data_op_node_vec_.push_back(new_node.get());
data_node_vec_.push_back(std::move(new_node));
}
void RegisterOpNode(std::unique_ptr<OpNode> new_node) {
data_op_node_vec_.push_back(new_node.get());
op_node_vec_.push_back(std::move(new_node));
}
const std::vector<std::unique_ptr<OpNode>>& op_node_vec() const {
return op_node_vec_;
}
private:
std::string dag_name_;
DagNode start_node_;
DagNode stop_node_;
// In future we can implement a Iterator to replace the data_op_node_vec_
// which is redundancy
std::vector<DagNode*> data_op_node_vec_;
std::vector<std::unique_ptr<DataNode>> data_node_vec_;
std::vector<std::unique_ptr<OpNode>> op_node_vec_;
};
} // namespace oneflow
......
......@@ -39,5 +39,36 @@ class DagNode {
};
class DataNode : public DagNode {
public:
DISALLOW_COPY_AND_MOVE(DataNode);
virtual ~DataNode() = default;
protected:
DataNode() = default;
void Init() {
DagNode::Init();
}
private:
};
class OpNode : public DagNode {
public:
DISALLOW_COPY_AND_MOVE(OpNode);
virtual ~OpNode() = default;
protected:
OpNode() = default;
void Init() {
DagNode::Init();
}
private:
};
} // namespace oneflow
#endif // ONEFLOW_DAG_DAG_NODE_H_
......@@ -2,8 +2,7 @@
#include <vector>
#include "gtest/gtest.h"
#include "common/util.h"
#include "dag/data_node.h"
#include "dag/op_node.h"
#include "dag/dag_node.h"
namespace oneflow {
......
#ifndef ONEFLOW_DAG_DATA_NODE_H_
#define ONEFLOW_DAG_DATA_NODE_H_
#include "dag/dag_node.h"
namespace oneflow {
class DataNode : public DagNode {
public:
DISALLOW_COPY_AND_MOVE(DataNode);
virtual ~DataNode() = default;
protected:
DataNode() = default;
void Init() {
DagNode::Init();
}
private:
};
} // namespace oneflow
#endif // ONEFLOW_DAG_DATA_NODE_H_
#include "dag/logical_dag.h"
#include "glog/logging.h"
#include "layer/layer_desc_factory.h"
namespace oneflow {
......@@ -8,27 +9,75 @@ void LogicalDag::Init(const std::string& dag_name,
const Strategy& strategy_conf) {
Dag::Init(dag_name);
BuildDagStruct(dl_net_conf);
FillNodeWithPlacement(strategy_conf);
FillNodeWithParallelConf(strategy_conf);
}
// BlobNameInDag = LayerName/BlobNameInLayer
// BlobNameInDagIf means Blob is the input of layer
// BlobNameInDagOf means Blob is the output of layer
static std::string BlobNameInDag2BlobNameInLayer(
const std::string& blob_name_in_dag) {
size_t slash_pos = blob_name_in_dag.find('/');
CHECK(slash_pos != std::string::npos);
return blob_name_in_dag.substr(slash_pos + 1);
}
void LogicalDag::BuildDagStruct(const DLNetConf& dl_net_conf) {
// This function only execute few times, so it is ok to declare it
std::unordered_map<std::string, LogicalDataNode*> blob_name_indag_of2ptr;
// Process Layer
for (int layer_i = 0; layer_i < dl_net_conf.layer_conf_size(); ++layer_i) {
const LayerConf& cur_layer_conf = dl_net_conf.layer_conf(layer_i);
// Construct op node
LogicalOpNode* cur_op_node = NewLogicalOpNode();
cur_op_node->mutable_layer_desc() =
LayerDescFactory::singleton().ConstructLayerDesc(cur_layer_conf);
// Construct input data node
for (const std::string& input_blob_name
// Connect input data node
for (const std::string& blob_name_in_dag_if
: cur_op_node->layer_desc().data_blob_desc_set().input_blob_names()) {
input_blob_name.find('/');
std::string blob_name_in_layer =
BlobNameInDag2BlobNameInLayer(blob_name_in_dag_if);
std::string blob_name_indag_of =
GetStringValueFromPbMessage(cur_layer_conf, blob_name_in_layer);
auto data_node_it = blob_name_indag_of2ptr.find(blob_name_indag_of);
CHECK(data_node_it != blob_name_indag_of2ptr.end());
cur_op_node->AddPredecessor(data_node_it->second);
}
// Construct and connect output data node
for (const std::string& blob_name_indag_of
: cur_op_node->layer_desc().data_blob_desc_set().output_blob_names()) {
LogicalDataNode* data_node = NewLogicalDataNode();
bool insert_success =
blob_name_indag_of2ptr.emplace(blob_name_indag_of, data_node).second;
CHECK_EQ(insert_success, true);
data_node->AddPredecessor(cur_op_node);
}
// Construct output data node
}
blob_name_indag_of2ptr.clear();
// Post Processing
ConnectStartAndStop();
}
void FillNodeWithPlacement(const Strategy& strategy_conf) {
// TODO
void LogicalDag::FillNodeWithParallelConf(const Strategy& strategy_conf) {
// This function only execute few times, so it is ok to declare it
std::unordered_map<std::string, LogicalOpNode*> layer_name2op_node;
for (const std::unique_ptr<OpNode>& op_node : op_node_vec()) {
auto logical_op_node_ptr = dynamic_cast<LogicalOpNode*> (op_node.get());
CHECK_NOTNULL(logical_op_node_ptr);
std::string layer_name = logical_op_node_ptr->layer_desc().layer_name();
bool emplace_success =
layer_name2op_node.emplace(layer_name, logical_op_node_ptr).second;
CHECK_EQ(emplace_success, true);
}
for (int gid = 0; gid < strategy_conf.placement_group_vec_size(); ++gid) {
const PlacementGroup& cur_group = strategy_conf.placement_group_vec(gid);
for (int li = 0; li < cur_group.layer_name_vec_size(); ++li) {
const std::string& layer_name = cur_group.layer_name_vec(li);
auto it = layer_name2op_node.find(layer_name);
CHECK(it != layer_name2op_node.end());
it->second->mutable_parallel_conf() = cur_group.parallel_conf();
}
}
}
} // namespace oneflow
......@@ -3,13 +3,58 @@
#include <memory>
#include "dag/dag.h"
#include "dag/logical_data_node.h"
#include "dag/logical_op_node.h"
#include "layer/base_layer_desc.h"
#include "job/dlnet_conf.pb.h"
#include "job/strategy.pb.h"
namespace oneflow {
class LogicalDataNode : public DataNode {
public:
DISALLOW_COPY_AND_MOVE(LogicalDataNode);
LogicalDataNode() = default;
~LogicalDataNode() = default;
void Init() {
DataNode::Init();
// struct style
}
private:
};
class LogicalOpNode : public OpNode {
public:
DISALLOW_COPY_AND_MOVE(LogicalOpNode);
LogicalOpNode() = default;
~LogicalOpNode() = default;
void Init() {
OpNode::Init();
// struct style
}
const BaseLayerDesc& layer_desc() const {
return *(layer_desc_.get());
}
const ParallelConf& parallel_conf() const {
return parallel_conf_;
}
std::unique_ptr<BaseLayerDesc>& mutable_layer_desc() {
return layer_desc_;
}
ParallelConf& mutable_parallel_conf() {
return parallel_conf_;
}
private:
std::unique_ptr<BaseLayerDesc> layer_desc_;
ParallelConf parallel_conf_;
};
class LogicalDag : public Dag {
public:
DISALLOW_COPY_AND_MOVE(LogicalDag);
......@@ -22,25 +67,22 @@ class LogicalDag : public Dag {
private:
void BuildDagStruct(const DLNetConf& dl_net_conf);
void FillNodeWithPlacement(const Strategy& strategy_conf);
void FillNodeWithParallelConf(const Strategy& strategy_conf);
LogicalDataNode* NewLogicalDataNode() {
std::unique_ptr<LogicalDataNode> new_node(new LogicalDataNode);
new_node->Init();
logical_data_node_vec_.push_back(std::move(new_node));
return logical_data_node_vec_.back().get();
LogicalDataNode* ret_ptr = new LogicalDataNode;
ret_ptr->Init();
RegisterDataNode(std::unique_ptr<LogicalDataNode> (ret_ptr));
return ret_ptr;
}
LogicalOpNode* NewLogicalOpNode() {
std::unique_ptr<LogicalOpNode> new_node(new LogicalOpNode);
new_node->Init();
logical_op_node_vec_.push_back(std::move(new_node));
return logical_op_node_vec_.back().get();
LogicalOpNode* ret_ptr = new LogicalOpNode;
ret_ptr->Init();
RegisterOpNode(std::unique_ptr<LogicalOpNode> (ret_ptr));
return ret_ptr;
}
std::vector<std::unique_ptr<LogicalDataNode>> logical_data_node_vec_;
std::vector<std::unique_ptr<LogicalOpNode>> logical_op_node_vec_;
};
} // namespace oneflow
......
#ifndef ONEFLOW_LOGICAL_DATA_NODE_H_
#define ONEFLOW_LOGICAL_DATA_NODE_H_
#include "dag/data_node.h"
#include "blob/blob_descriptor.h"
namespace oneflow {
class LogicalDataNode : public DataNode {
public:
DISALLOW_COPY_AND_MOVE(LogicalDataNode);
LogicalDataNode() = default;
~LogicalDataNode() = default;
void Init() {
DataNode::Init();
// struct style
}
const BlobDescriptor& blob_desc() const {
return blob_desc_;
}
BlobDescriptor& mutable_blob_desc() {
return blob_desc_;
}
private:
BlobDescriptor blob_desc_;
};
} // namespace oneflow
#endif // ONEFLOW_LOGICAL_DATA_NODE_H_
#ifndef ONEFLOW_LOGICAL_OP_NODE_H_
#define ONEFLOW_LOGICAL_OP_NODE_H_
#include "dag/op_node.h"
#include "layer/base_layer_desc.h"
#include "job/strategy.pb.h"
namespace oneflow {
class LogicalOpNode : public OpNode {
public:
DISALLOW_COPY_AND_MOVE(LogicalOpNode);
LogicalOpNode() = default;
~LogicalOpNode() = default;
void Init() {
OpNode::Init();
// struct style
}
const BaseLayerDesc& layer_desc() const {
return *(layer_desc_.get());
}
const ParallelConf& parallel_conf() const {
return parallel_conf_;
}
std::unique_ptr<BaseLayerDesc>& mutable_layer_desc() {
return layer_desc_;
}
ParallelConf& mutable_parallel_conf() {
return parallel_conf_;
}
private:
std::unique_ptr<BaseLayerDesc> layer_desc_;
ParallelConf parallel_conf_;
};
} // namespace oneflow
#endif // ONEFLOW_LOGICAL_OP_NODE_H_
#ifndef ONEFLOW_DAG_OP_NODE_H_
#define ONEFLOW_DAG_OP_NODE_H_
#include "dag/dag_node.h"
namespace oneflow {
class OpNode : public DagNode {
public:
DISALLOW_COPY_AND_MOVE(OpNode);
virtual ~OpNode() = default;
protected:
OpNode() = default;
void Init() {
DagNode::Init();
}
private:
};
} // namespace oneflow
#endif // ONEFLOW_DAG_OP_NODE_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册