提交 775d3772 编写于 作者: W willzhang4a58

temp save

上级 19464a6d
......@@ -13,12 +13,8 @@ class BlobDescriptor {
BlobDescriptor() = default;
~BlobDescriptor() = default;
void init(const Shape& rhs_shape,
const MemoryContext& rhs_memory_context,
FloatType rhs_float_type) {
shape_ = rhs_shape;
memory_context_ = rhs_memory_context;
float_type_ = rhs_float_type;
void Init() {
// struct style
}
const Shape& shape() const { return shape_; }
......@@ -26,6 +22,11 @@ class BlobDescriptor {
size_t byte_size() const {
return shape_.elem_cnt() * GetFloatByteSize(float_type_);
}
Shape& mutable_shape() { return shape_; }
MemoryContext& mutable_memory_context() { return memory_context_; }
FloatType& mutable_float_type() { return float_type_; }
private:
Shape shape_;
......
#include "dag/logical_dag.h"
#include "layer/layer_desc_factory.h"
namespace oneflow {
......@@ -10,10 +11,19 @@ void LogicalDag::Init(const std::string& dag_name,
FillNodeWithPlacement(strategy_conf);
}
void BuildDagStruct(const DLNetConf& dl_net_conf) {
void LogicalDag::BuildDagStruct(const DLNetConf& dl_net_conf) {
for (int layer_i = 0; layer_i < dl_net_conf.layer_conf_size(); ++layer_i) {
const LayerConf& cur_layer_conf = dl_net_conf.layer_conf(layer_i);
// TODO
// Construct op node
LogicalOpNode* cur_op_node = NewLogicalOpNode();
cur_op_node->mutable_layer_desc() =
LayerDescFactory::singleton().ConstructLayerDesc(cur_layer_conf);
// Construct input data node
for (const std::string& input_blob_name
: cur_op_node->layer_desc().data_blob_desc_set().input_blob_names()) {
input_blob_name.find('/');
}
// Construct output data node
}
}
......
#ifndef ONEFLOW_DAG_LOGICAL_DAG_H
#define ONEFLOW_DAG_LOGICAL_DAG_H
#include <memory>
#include "dag/dag.h"
#include "dag/logical_data_node.h"
#include "dag/logical_op_node.h"
......@@ -23,6 +24,23 @@ class LogicalDag : public Dag {
void BuildDagStruct(const DLNetConf& dl_net_conf);
void FillNodeWithPlacement(const Strategy& strategy_conf);
LogicalDataNode* NewLogicalDataNode() {
std::unique_ptr<LogicalDataNode> new_node(new LogicalDataNode);
new_node->Init();
logical_data_node_vec_.push_back(std::move(new_node));
return logical_data_node_vec_.back().get();
}
LogicalOpNode* NewLogicalOpNode() {
std::unique_ptr<LogicalOpNode> new_node(new LogicalOpNode);
new_node->Init();
logical_op_node_vec_.push_back(std::move(new_node));
return logical_op_node_vec_.back().get();
}
std::vector<std::unique_ptr<LogicalDataNode>> logical_data_node_vec_;
std::vector<std::unique_ptr<LogicalOpNode>> logical_op_node_vec_;
};
} // namespace oneflow
......
......@@ -2,6 +2,7 @@
#define ONEFLOW_LOGICAL_DATA_NODE_H_
#include "dag/data_node.h"
#include "blob/blob_descriptor.h"
namespace oneflow {
......@@ -13,9 +14,19 @@ class LogicalDataNode : public DataNode {
void Init() {
DataNode::Init();
// struct style
}
const BlobDescriptor& blob_desc() const {
return blob_desc_;
}
BlobDescriptor& mutable_blob_desc() {
return blob_desc_;
}
private:
BlobDescriptor blob_desc_;
};
} // namespace oneflow
......
......@@ -2,6 +2,8 @@
#define ONEFLOW_LOGICAL_OP_NODE_H_
#include "dag/op_node.h"
#include "layer/base_layer_desc.h"
#include "job/strategy.pb.h"
namespace oneflow {
......@@ -13,9 +15,27 @@ class LogicalOpNode : public OpNode {
void Init() {
OpNode::Init();
// struct style
}
const BaseLayerDesc& layer_desc() const {
return *(layer_desc_.get());
}
const ParallelConf& parallel_conf() const {
return parallel_conf_;
}
std::unique_ptr<BaseLayerDesc>& mutable_layer_desc() {
return layer_desc_;
}
ParallelConf& mutable_parallel_conf() {
return parallel_conf_;
}
private:
std::unique_ptr<BaseLayerDesc> layer_desc_;
ParallelConf parallel_conf_;
};
} // namespace oneflow
......
......@@ -11,30 +11,30 @@
namespace oneflow {
// It is ugly now, maybe we can find one more elegant implemention ?
std::shared_ptr<BaseLayerDesc> LayerDescFactory::ConstructLayerDesc(
std::unique_ptr<BaseLayerDesc> LayerDescFactory::ConstructLayerDesc(
const LayerConf& layer_conf) const {
std::shared_ptr<BaseLayerDesc> ret;
std::unique_ptr<BaseLayerDesc> ret;
switch (layer_conf.specified_type_case()) {
case LayerConf::kConvolutionLayerConf: {
ret = std::make_shared<ConvolutionLayerDesc> ();
ret.reset(new ConvolutionLayerDesc);
}
case LayerConf::kInnerProductLayerConf: {
ret = std::make_shared<InnerProductLayerDesc> ();
ret.reset(new InnerProductLayerDesc);
}
case LayerConf::kLoaderLayerConf: {
ret = std::make_shared<LoaderLayerDesc> ();
ret.reset(new LoaderLayerDesc);
}
case LayerConf::kPoolingLayerConf: {
ret = std::make_shared<PoolingLayerDesc> ();
ret.reset(new PoolingLayerDesc);
}
case LayerConf::kReluLayerConf: {
ret = std::make_shared<ReluLayerDesc> ();
ret.reset(new ReluLayerDesc);
}
case LayerConf::kSoftmaxLayerConf: {
ret = std::make_shared<SoftmaxLayerDesc> ();
ret.reset(new SoftmaxLayerDesc);
}
case LayerConf::kMultinomialLogisticLossLayerConf: {
ret = std::make_shared<MultinomialLogisticLossLayerDesc> ();
ret.reset(new MultinomialLogisticLossLayerDesc);
}
default: {
LOG(FATAL) << "unknow layer";
......
......@@ -15,7 +15,7 @@ class LayerDescFactory {
return obj;
}
std::shared_ptr<BaseLayerDesc> ConstructLayerDesc(const LayerConf&) const;
std::unique_ptr<BaseLayerDesc> ConstructLayerDesc(const LayerConf&) const;
private:
LayerDescFactory() = default;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册