提交 fc4900c6 编写于 作者: W willzhang4a58

remove empty register

上级 1c9894a7
......@@ -38,11 +38,11 @@ void BoxingTaskNode::FwBuildExecAndEnrollLbn2Regsts(TaskGraph* gph) {
void BoxingTaskNode::EnrollAllRegstAndBindRelatedEdge() {
for (TaskEdge* edge : out_edges()) {
std::string name = "boxing_out_" + edge->edge_id_str();
auto regst_desc = RegstDescMgr::Singleton().CreateRegisterDesc();
auto regst_desc = of_make_unique<RegstDesc> ();
BindProducedRegstAndOutEdge(regst_desc.get(), edge);
EnrollProducedRegstDesc(name, std::move(regst_desc));
}
auto regst_desc = RegstDescMgr::Singleton().CreateRegisterDesc();
auto regst_desc = of_make_unique<RegstDesc> ();
EnrollProducedRegstDesc("middle", std::move(regst_desc));
}
......
......@@ -235,6 +235,15 @@ std::string ChainNode::ConcatedOpsName() const {
}
}
bool ChainNode::HasOpWithModelOrModelTmpBlob() const {
for (std::shared_ptr<const Operator> op : op_vec_) {
if (!op->model_bns().empty() || !op->model_tmp_bns().empty()) {
return true;
}
}
return false;
}
ChainGraph::ChainGraph(const LogicalGraph* logical_gph,
const std::string& dot_filepath) {
LOG(INFO) << "Build ChainGraph...";
......
......@@ -55,6 +55,8 @@ class ChainNode final : public Node<ChainNode, ChainEdge> {
}
std::string VisualStr() const { return ConcatedOpsName(); }
bool HasOpWithModelOrModelTmpBlob() const;
private:
std::vector<std::shared_ptr<const Operator>> op_vec_;
......@@ -64,6 +66,7 @@ class ChainNode final : public Node<ChainNode, ChainEdge> {
};
class ChainEdge final : public Edge<ChainNode, ChainEdge> {
public:
OF_DISALLOW_COPY_AND_MOVE(ChainEdge);
......
......@@ -24,16 +24,16 @@ void CompTaskNode::DataFwBuildExecAndEnrollLbn2Regsts(TaskGraph*) {
mut_exec_gph().UpdateSourceAndSink();
// out regst
if (!out_edges().empty()) {
auto out_regst = RegstDescMgr::Singleton().CreateRegisterDesc();
auto out_regst = of_make_unique<RegstDesc> ();
BindProducedRegstAndOutEdge(out_regst.get(), SoleOutEdge());
EnrollProducedRegstDesc("out", std::move(out_regst));
}
// the other produced regsts
auto activation_regst = RegstDescMgr::Singleton().CreateRegisterDesc();
auto data_tmp_regst = RegstDescMgr::Singleton().CreateRegisterDesc();
auto model_tmp_regst = RegstDescMgr::Singleton().CreateRegisterDesc();
auto model_regst = RegstDescMgr::Singleton().CreateRegisterDesc();
auto log_regst = RegstDescMgr::Singleton().CreateRegisterDesc();
auto activation_regst = of_make_unique<RegstDesc> ();
auto data_tmp_regst = of_make_unique<RegstDesc> ();
auto model_tmp_regst = of_make_unique<RegstDesc> ();
auto model_regst = of_make_unique<RegstDesc> ();
auto log_regst = of_make_unique<RegstDesc> ();
// EnrollProducedRegstDesc
EnrollProducedRegstDesc("activation", std::move(activation_regst));
EnrollProducedRegstDesc("data_tmp", std::move(data_tmp_regst));
......@@ -98,7 +98,7 @@ void CompTaskNode::MdLoadFwBuildExecAndEnrollLbn2Regsts(TaskGraph* gph) {
exec_node->mut_op() = chain_node()->SoleOp();
mut_exec_gph().UpdateSourceAndSink();
auto model_regst = RegstDescMgr::Singleton().CreateRegisterDesc();
auto model_regst = of_make_unique<RegstDesc> ();
exec_node->BindBnInOpAndRegst(exec_node->op()->SoleObn(), model_regst.get());
BindProducedRegstAndOutEdge(model_regst.get(), SoleOutEdge());
CompTaskNode* update_0 = md_load_gph->parallel_id2updt_task().at(0);
......@@ -271,9 +271,9 @@ void CompTaskNode::BpBuildExecAndEnrollLbn2Regsts(TaskGraph*) {
HashMap<ExecEdge*, const ExecEdge*> bp_edge2fw_edge;
BpBuildExecGraph(fw_gph, &fw_node2bp_node, &bp_edge2fw_edge);
// Produced registers
auto in_diff_regst = RegstDescMgr::Singleton().CreateRegisterDesc();
auto model_diff_regst = RegstDescMgr::Singleton().CreateRegisterDesc();
auto activation_diff_regst = RegstDescMgr::Singleton().CreateRegisterDesc();
auto in_diff_regst = of_make_unique<RegstDesc> ();
auto model_diff_regst = of_make_unique<RegstDesc> ();
auto activation_diff_regst = of_make_unique<RegstDesc> ();
// Bind out edge
if (!out_edges().empty()) {
BindProducedRegstAndOutEdge(in_diff_regst.get(), SoleOutEdge());
......@@ -292,12 +292,14 @@ void CompTaskNode::BpInferShapeOfBlobsInProducedRegsts(TaskGraph*) {
RegstDesc* in_regst = GetRelatedRegst(GetFwNode()->SoleInEdge());
in_diff_regst->CopyShapeFrom(in_regst);
// model_diff_regst
RegstDesc* model_diff_regst = GetProducedRegstDesc("model_diff");
model_diff_regst->CopyShapeFrom(GetFwNode()->exec_gph().RelatedModelRegst());
if (RegstDesc* md_diff_regst = GetProducedRegstDesc("model_diff")) {
md_diff_regst->CopyShapeFrom(GetFwNode()->exec_gph().RelatedModelRegst());
}
// activation_diff_regst
RegstDesc* activation_diff_regst = GetProducedRegstDesc("activation_diff");
RegstDesc* activation_regst = GetFwNode()->GetProducedRegstDesc("activation");
activation_diff_regst->CopyShapeFrom(activation_regst);
if (RegstDesc* acti_diff_regst = GetProducedRegstDesc("activation_diff")) {
RegstDesc* acti_regst = GetFwNode()->GetProducedRegstDesc("activation");
acti_diff_regst->CopyShapeFrom(acti_regst);
}
}
void CompTaskNode::BpBuildExecGraph(
......
......@@ -6,7 +6,7 @@
namespace oneflow {
void CopyTaskNode::BuildExecAndEnrollLbn2Regsts(TaskGraph*){
auto out_regst = RegstDescMgr::Singleton().CreateRegisterDesc();
auto out_regst = of_make_unique<RegstDesc> ();
BindProducedRegstAndOutEdge(out_regst.get(), SoleOutEdge());
RegstDesc* in_regst = GetRelatedRegst(SoleInEdge());
out_regst->CopyLbnFrom(in_regst);
......
......@@ -14,7 +14,7 @@ void ExecNode::ToProto(ExecNodeProto* ret) const {
bn_regst.first, bn_regst.second->regst_desc_id()});
}
for (ExecEdge* edge: in_edges()) {
ret->add_predecessor_ids(edge->src_node()->node_id());
ret->add_predecessor_id(edge->src_node()->node_id());
}
}
......@@ -29,7 +29,7 @@ RegstDesc* ExecGraph::RelatedModelRegst() const {
void ExecGraph::ToProto(ExecGraphProto* ret) const {
for (const std::unique_ptr<ExecNode>& node: nodes()) {
node->ToProto(ret->add_exec_nodes());
node->ToProto(ret->add_exec_node());
}
}
......
......@@ -5,9 +5,9 @@ message ExecNodeProto {
uint64 id = 1;
string op_name = 2;
map<string, uint64> bn_in_op2regst_desc_id = 3;
repeated uint64 predecessor_ids = 4;
repeated uint64 predecessor_id = 4;
}
message ExecGraphProto {
repeated ExecNodeProto exec_nodes = 1;
repeated ExecNodeProto exec_node = 1;
}
......@@ -23,8 +23,8 @@ void LogicalGraph::NaiveBuildGraphStruct(
HashMap<LogicalEdge*, std::string>* edge2ibn) {
HashMap<std::string, LogicalNode*> lbn2producer;
// Process Op
for (int op_i = 0; op_i < dl_net_conf.op_conf_size(); ++op_i) {
const OperatorConf& cur_op_conf = dl_net_conf.op_conf(op_i);
for (int op_i = 0; op_i < dl_net_conf.op_size(); ++op_i) {
const OperatorConf& cur_op_conf = dl_net_conf.op(op_i);
// Construct cur node
LogicalNode* cur_node = NewNode();
cur_node->mut_op() = OpMgr::Singleton().ConstructOp(cur_op_conf);
......
#include "graph/task_graph_manager.h"
namespace oneflow {
void TaskGraphMgr::BuildGraphs() {
ordered_task_gphs_.clear();
// data graph
LOG(INFO) << "Build DataTaskGraph...";
auto data_task_gph = new DataTaskGraph(
"data",
JobDesc::Singleton().train_dlnet_conf(),
JobDesc::Singleton().strategy(),
true);
ordered_task_gphs_.emplace_back(data_task_gph);
// construct data_chain2sorted_bp_comp_tasks
HashMap<const ChainNode*, std::vector<CompTaskNode*>>
data_chain2sorted_bp_comp_tasks;
for (const auto& node : data_task_gph->nodes()) {
auto bp_node = dynamic_cast<CompTaskNode*>(node.get());
if (bp_node == nullptr || bp_node->IsFwNode()) { continue; }
data_chain2sorted_bp_comp_tasks[bp_node->chain_node()].push_back(bp_node);
}
for (auto& pair : data_chain2sorted_bp_comp_tasks) {
SortByParallelId(&(pair.second));
}
// model graph
for (const auto& pair : data_chain2sorted_bp_comp_tasks) {
std::string chain_tag = pair.first->op_vec().front()->op_name();
str_replace(&chain_tag, '/', '_');
const std::string dot_path_prefix = DotDir() + "/model/" + chain_tag + "_";
ParallelPolicy policy = pair.first->parallel_desc()->policy();
// model update
LOG(INFO) << "Build MdUpdtTaskGraph... for " << chain_tag;
auto updt_gph = new MdUpdtTaskGraph(
"md_updt_" + chain_tag,
pair.first, pair.second, dot_path_prefix + "model_update_");
ChainNode* updt_chain = updt_gph->chain_gph()->SoleSinkNode();
auto sorted_updt_tasks = updt_gph->SortedCompTasksInChain(updt_chain);
HashMap<uint64_t, CompTaskNode*> parallel_id2updt_task;
for (CompTaskNode* update_task : sorted_updt_tasks) {
CHECK(parallel_id2updt_task.emplace(
update_task->parallel_id(), update_task).second);
}
// model load save
LOG(INFO) << "Build MdLoadTaskGraph... for " << chain_tag;
auto load_gph = new MdLoadTaskGraph(
"md_load_" + chain_tag,
updt_chain, parallel_id2updt_task, policy,
dot_path_prefix + "model_load_");
LOG(INFO) << "Build MdSaveTaskGraph... for " << chain_tag;
auto save_gph = new MdSaveTaskGraph(
"md_save_" + chain_tag,
updt_chain, parallel_id2updt_task, policy,
dot_path_prefix + "model_save_");
ordered_task_gphs_.emplace_back(updt_gph);
ordered_task_gphs_.emplace_back(load_gph);
ordered_task_gphs_.emplace_back(save_gph);
}
// all exec_graph 2 dot
for (const auto& task_gph : ordered_task_gphs_) {
for (const auto& task_node : task_gph->nodes()) {
std::string file_path = DotDir() + "/exec/";
file_path += task_node->node_id_str() + ".dot";
task_node->exec_gph().ToDotFile(file_path);
}
}
}
void TaskGraphMgr::InferShape4Regsts() {
for (auto& task_gph : ordered_task_gphs_) {
LOG(INFO) << "InferShape... for " << task_gph->name();
task_gph->InferShapeOfBlobsInProducedRegsts();
}
}
void TaskGraphMgr::AllTaskNodesToProto(PbRpf<TaskProto>* ret) {
ret->Clear();
for (const auto& task_gph : ordered_task_gphs_) {
for (const auto& task_node : task_gph->nodes()) {
task_node->ToProto(ret->Add());
}
}
}
} // namespace oneflow
#ifndef ONEFLOW_GRAPH_TASK_GRAPH_MANAGER_H_
#define ONEFLOW_GRAPH_TASK_GRAPH_MANAGER_H_
#include "job/job_desc.h"
#include "graph/data_task_graph.h"
#include "graph/model_load_task_graph.h"
#include "graph/model_save_task_graph.h"
#include "graph/model_update_task_graph.h"
namespace oneflow {
class TaskGraphMgr {
public:
OF_DISALLOW_COPY_AND_MOVE(TaskGraphMgr);
~TaskGraphMgr() = default;
static TaskGraphMgr& Singleton() {
static TaskGraphMgr obj;
return obj;
}
void BuildGraphs();
void InferShape4Regsts();
void AllTaskNodesToProto(PbRpf<TaskProto>*);
private:
TaskGraphMgr() = default;
std::vector<std::unique_ptr<TaskGraph>> ordered_task_gphs_;
};
} // namespace oneflow
#endif // ONEFLOW_GRAPH_TASK_GRAPH_MANAGER_H_
......@@ -41,7 +41,12 @@ std::unique_ptr<TaskNode> TaskNode::BuildAndConnectBpNode() {
}
RegstDesc* TaskNode::GetProducedRegstDesc(const std::string& regst_desc_name) {
return produced_regst_descs_.at(regst_desc_name).get();
auto it = produced_regst_descs_.find(regst_desc_name);
if (it == produced_regst_descs_.end()) {
return nullptr;
} else {
return it->second.get();
}
}
void TaskNode::TakeOverRegstDesc(TaskNode* rhs,
......@@ -57,6 +62,18 @@ void TaskNode::TakeOverRegstDesc(TaskNode* rhs,
std::move(this_regst)).second);
}
void TaskNode::RemoveRegstsWithoutBlob() {
for (auto it = produced_regst_descs_.begin();
it != produced_regst_descs_.end();) {
if (it->second->lbn2shape().empty()) {
auto cur_it = it++;
produced_regst_descs_.erase(cur_it);
} else {
++it;
}
}
}
const TaskEdge* TaskNode::GetOutEdge4ProducedRegst(RegstDesc* regst) const {
return produced_regst2out_edge_.at(regst);
}
......@@ -94,20 +111,7 @@ void TaskNode::ToProto(TaskProto* ret) const {
ret->set_is_forward(is_fw_node_);
exec_gph_.ToProto(ret->mutable_exec_graph());
for (const auto& pair : produced_regst_descs_) {
ret->mutable_produced_regst_desc_ids()->Add(
pair.second->regst_desc_id());
}
// subscribed_regsts
std::unordered_set<RegstDesc*> subscribed_regsts;
for (const auto& exec_node : exec_gph().nodes()) {
for (const auto& pair : exec_node->bn_in_op2regst()) {
RegstDesc* related_regst = pair.second;
if (related_regst->GetProducer() == this) { continue; }
subscribed_regsts.insert(related_regst);
}
}
for (RegstDesc* regst : subscribed_regsts) {
ret->mutable_subscribed_regst_desc_ids()->Add(regst->regst_desc_id());
pair.second->ToProto(ret->mutable_produced_regst_desc()->Add());
}
}
......
......@@ -4,7 +4,6 @@
#include "task/task.pb.h"
#include "graph/stage_graph.h"
#include "graph/exec_graph.h"
#include "register/register_desc_manager.h"
namespace oneflow {
......@@ -55,6 +54,7 @@ class TaskNode : public Node<TaskNode, TaskEdge> {
//
RegstDesc* GetProducedRegstDesc(const std::string& regst_desc_name);
void TakeOverRegstDesc(TaskNode* rhs, const std::string& regst_desc_name);
void RemoveRegstsWithoutBlob();
//
const TaskEdge* GetOutEdge4ProducedRegst(RegstDesc*) const;
......
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "job/id_manager.h"
#include "graph/task_graph_manager.h"
#include "common/proto_io.h"
#include "graph/model_load_task_graph.h"
#include "graph/model_save_task_graph.h"
#include "graph/model_update_task_graph.h"
#include "graph/data_task_graph.h"
#include "job/job_conf.pb.h"
#include "job/ofelf.pb.h"
DEFINE_string(job_conf_filepath, "", "");
DEFINE_string(elf_filepath, "", "");
namespace oneflow {
class Compiler final {
......@@ -20,35 +21,124 @@ class Compiler final {
return obj;
}
void Compile(const JobConf& job_conf) {
JobDesc::Singleton().InitFromJobConf(job_conf);
IDMgr::Singleton().InitFromResource(JobDesc::Singleton().resource());
TaskGraphMgr::Singleton().BuildGraphs();
JobDesc::Singleton().set_piece_size(50); // TODO: set appropriate piece_size
TaskGraphMgr::Singleton().InferShape4Regsts();
// To Proto
OfElf elf;
TaskGraphMgr::Singleton().AllTaskNodesToProto(elf.mutable_tasks());
RegstDescMgr::Singleton().AllRegstsToProto(elf.mutable_regst_descs());
OpMgr::Singleton().AllOpToProto(elf.mutable_operators());
JobDesc::Singleton().ToProto(elf.mutable_job_desc());
PrintProtoToTextFile(elf, FLAGS_elf_filepath);
}
void Compile(const JobConf& job_conf, const std::string& elf_filepath);
private:
Compiler() = default;
void RunFunc4EachTaskNode(std::function<void(TaskNode*)> func);
void BuildGraphs();
void RemoveRegstsWithoutBlob();
void InferShape4Regsts();
std::vector<std::unique_ptr<TaskGraph>> ordered_task_gphs_;
};
void Compiler::RunFunc4EachTaskNode(std::function<void(TaskNode*)> func) {
for (const auto& task_gph : ordered_task_gphs_) {
for (const auto& task_node : task_gph->nodes()) {
func(task_node.get());
}
}
}
// TODO: inference "piece_size" and "register_num for each register_desc"
void Compiler::Compile(const JobConf& job_conf,
const std::string& elf_filepath) {
JobDesc::Singleton().InitFromJobConf(job_conf);
IDMgr::Singleton().InitFromResource(JobDesc::Singleton().resource());
BuildGraphs();
RunFunc4EachTaskNode([](TaskNode* node) { node->RemoveRegstsWithoutBlob(); });
InferShape4Regsts();
OfElf elf;
RunFunc4EachTaskNode([&elf](TaskNode* node) {
node->ToProto(elf.mutable_task()->Add());
});
OpMgr::Singleton().AllOpToProto(elf.mutable_op());
JobDesc::Singleton().ToProto(elf.mutable_job_desc());
PrintProtoToTextFile(elf, elf_filepath);
}
void Compiler::BuildGraphs() {
ordered_task_gphs_.clear();
// data graph
LOG(INFO) << "Build DataTaskGraph...";
auto data_task_gph = new DataTaskGraph(
"data",
JobDesc::Singleton().train_dlnet_conf(),
JobDesc::Singleton().strategy(),
true);
ordered_task_gphs_.emplace_back(data_task_gph);
// construct data_chain2sorted_bp_comp_tasks
HashMap<const ChainNode*, std::vector<CompTaskNode*>>
data_chain2sorted_bp_comp_tasks;
for (const auto& node : data_task_gph->nodes()) {
auto bp_node = dynamic_cast<CompTaskNode*>(node.get());
if (bp_node == nullptr || bp_node->IsFwNode()) { continue; }
data_chain2sorted_bp_comp_tasks[bp_node->chain_node()].push_back(bp_node);
}
for (auto& pair : data_chain2sorted_bp_comp_tasks) {
SortByParallelId(&(pair.second));
}
// model graph
for (const auto& pair : data_chain2sorted_bp_comp_tasks) {
if (pair.first->HasOpWithModelOrModelTmpBlob() == false) { continue; }
std::string chain_tag = pair.first->op_vec().front()->op_name();
str_replace(&chain_tag, '/', '_');
const std::string dot_path_prefix = DotDir() + "/model/" + chain_tag + "_";
ParallelPolicy policy = pair.first->parallel_desc()->policy();
LOG(INFO) << "Build MdUpdtTaskGraph... for " << chain_tag;
auto updt_gph = new MdUpdtTaskGraph(
"md_updt_" + chain_tag,
pair.first, pair.second, dot_path_prefix + "model_update_");
ChainNode* updt_chain = updt_gph->chain_gph()->SoleSinkNode();
auto sorted_updt_tasks = updt_gph->SortedCompTasksInChain(updt_chain);
HashMap<uint64_t, CompTaskNode*> parallel_id2updt_task;
for (CompTaskNode* update_task : sorted_updt_tasks) {
CHECK(parallel_id2updt_task.emplace(
update_task->parallel_id(), update_task).second);
}
LOG(INFO) << "Build MdLoadTaskGraph... for " << chain_tag;
auto load_gph = new MdLoadTaskGraph(
"md_load_" + chain_tag,
updt_chain, parallel_id2updt_task, policy,
dot_path_prefix + "model_load_");
LOG(INFO) << "Build MdSaveTaskGraph... for " << chain_tag;
auto save_gph = new MdSaveTaskGraph(
"md_save_" + chain_tag,
updt_chain, parallel_id2updt_task, policy,
dot_path_prefix + "model_save_");
ordered_task_gphs_.emplace_back(updt_gph);
ordered_task_gphs_.emplace_back(load_gph);
ordered_task_gphs_.emplace_back(save_gph);
}
// all exec_graph 2 dot
RunFunc4EachTaskNode([](TaskNode* node) {
std::string file_path = DotDir() + "/exec/" + node->node_id_str() + ".dot";
node->exec_gph().ToDotFile(file_path);
});
}
void Compiler::InferShape4Regsts() {
for (auto& task_gph : ordered_task_gphs_) {
LOG(INFO) << "InferShape... for " << task_gph->name();
task_gph->InferShapeOfBlobsInProducedRegsts();
}
}
} // namespace oneflow
DEFINE_string(job_conf_filepath, "", "");
DEFINE_string(elf_filepath, "", "");
int main(int argc, char** argv) {
google::InitGoogleLogging(argv[0]);
google::ParseCommandLineFlags(&argc, &argv, true);
LOG(INFO) << "Compiler Starting Up...";
oneflow::JobConf job_conf;
oneflow::ParseProtoFromTextFile(FLAGS_job_conf_filepath, &job_conf);
oneflow::Compiler::Singleton().Compile(job_conf);
oneflow::Compiler::Singleton().Compile(job_conf, FLAGS_elf_filepath);
LOG(INFO) << "Compiler Shutting Down...";
return 0;
}
......@@ -5,5 +5,5 @@ import "operator/op_conf.proto";
message DLNetConf {
string name = 1;
repeated OperatorConf op_conf = 100;
repeated OperatorConf op = 100;
}
......@@ -8,4 +8,5 @@ message JobConf {
string model_load_machine = 4;
string model_save_machine = 5;
uint32 batch_size = 6;
uint32 piece_size = 7;
}
......@@ -11,7 +11,7 @@ void JobDesc::InitFromJobConf(const JobConf& conf) {
md_load_machine_ = conf.model_load_machine();
md_save_machine_ = conf.model_save_machine();
batch_size_ = conf.batch_size();
piece_size_ = 0; // TODO
piece_size_ = conf.piece_size();
}
void JobDesc::InitFromProto(const JobDescProto& proto) {
......
......@@ -7,8 +7,7 @@ import "task/task.proto";
import "job/job_desc.proto";
message OfElf {
repeated TaskProto tasks = 1;
repeated RegstDescProto regst_descs = 2;
repeated OperatorProto operators = 3;
repeated TaskProto task = 1;
repeated OperatorProto op = 3;
JobDescProto job_desc = 4;
}
......@@ -5,27 +5,27 @@ namespace oneflow {
void Operator::InitFromProto(const OperatorProto& op_proto) {
op_conf_ = op_proto.op_conf();
bn_in_op2lbn_ = PbMap2HashMap(op_proto.bn_in_op2lbn());
data_tmp_bns_ = PbVec2StdVec(op_proto.data_tmp_bns());
input_bns_ = PbVec2StdVec(op_proto.input_bns());
input_diff_bns_ = PbVec2StdVec(op_proto.input_diff_bns());
output_bns_ = PbVec2StdVec(op_proto.output_bns());
output_diff_bns_ = PbVec2StdVec(op_proto.output_diff_bns());
model_bns_ = PbVec2StdVec(op_proto.model_bns());
model_diff_bns_ = PbVec2StdVec(op_proto.model_diff_bns());
model_tmp_bns_ = PbVec2StdVec(op_proto.model_tmp_bns());
data_tmp_bns_ = PbVec2StdVec(op_proto.data_tmp_bn());
input_bns_ = PbVec2StdVec(op_proto.input_bn());
input_diff_bns_ = PbVec2StdVec(op_proto.input_diff_bn());
output_bns_ = PbVec2StdVec(op_proto.output_bn());
output_diff_bns_ = PbVec2StdVec(op_proto.output_diff_bn());
model_bns_ = PbVec2StdVec(op_proto.model_bn());
model_diff_bns_ = PbVec2StdVec(op_proto.model_diff_bn());
model_tmp_bns_ = PbVec2StdVec(op_proto.model_tmp_bn());
}
void Operator::ToProto(OperatorProto* ret) const {
*(ret->mutable_op_conf()) = op_conf_;
*(ret->mutable_bn_in_op2lbn()) = HashMap2PbMap(bn_in_op2lbn_);
*(ret->mutable_data_tmp_bns()) = StdVec2PbVec(data_tmp_bns_);
*(ret->mutable_input_bns()) = StdVec2PbVec(input_bns_);
*(ret->mutable_input_diff_bns()) = StdVec2PbVec(input_diff_bns_);
*(ret->mutable_output_bns()) = StdVec2PbVec(output_bns_);
*(ret->mutable_output_diff_bns()) = StdVec2PbVec(output_diff_bns_);
*(ret->mutable_model_bns()) = StdVec2PbVec(model_bns_);
*(ret->mutable_model_diff_bns()) = StdVec2PbVec(model_diff_bns_);
*(ret->mutable_model_tmp_bns()) = StdVec2PbVec(model_tmp_bns_);
*(ret->mutable_data_tmp_bn()) = StdVec2PbVec(data_tmp_bns_);
*(ret->mutable_input_bn()) = StdVec2PbVec(input_bns_);
*(ret->mutable_input_diff_bn()) = StdVec2PbVec(input_diff_bns_);
*(ret->mutable_output_bn()) = StdVec2PbVec(output_bns_);
*(ret->mutable_output_diff_bn()) = StdVec2PbVec(output_diff_bns_);
*(ret->mutable_model_bn()) = StdVec2PbVec(model_bns_);
*(ret->mutable_model_diff_bn()) = StdVec2PbVec(model_diff_bns_);
*(ret->mutable_model_tmp_bn()) = StdVec2PbVec(model_tmp_bns_);
}
const std::string& Operator::Lbn4BnInOp(const std::string& bn_in_op) const {
......
......@@ -7,13 +7,13 @@ message OperatorProto {
OperatorConf op_conf = 1;
map<string, string> bn_in_op2lbn = 3;
repeated string data_tmp_bns = 4;
repeated string input_bns = 5;
repeated string input_diff_bns = 6;
repeated string output_bns = 7;
repeated string output_diff_bns = 8;
repeated string data_tmp_bn = 4;
repeated string input_bn = 5;
repeated string input_diff_bn = 6;
repeated string output_bn = 7;
repeated string output_diff_bn = 8;
repeated string model_bns = 9;
repeated string model_diff_bns = 10;
repeated string model_tmp_bns = 11;
repeated string model_bn = 9;
repeated string model_diff_bn = 10;
repeated string model_tmp_bn = 11;
}
......@@ -7,7 +7,7 @@ namespace oneflow {
RegstDesc::RegstDesc() {
producer_ = nullptr;
register_num_ = 0; // TODO
register_num_ = 5; // TODO
}
void RegstDesc::CopyLbnFrom(const RegstDesc* rhs) {
......
#ifndef ONEFLOW_REGISTER_REGISTER_DESC_MANAGER_H_
#define ONEFLOW_REGISTER_REGISTER_DESC_MANAGER_H_
#include "register/register_desc.h"
namespace oneflow {
class RegstDescMgr final {
public:
OF_DISALLOW_COPY_AND_MOVE(RegstDescMgr);
~RegstDescMgr() = default;
static RegstDescMgr& Singleton() {
static RegstDescMgr obj;
return obj;
}
std::unique_ptr<RegstDesc> CreateRegisterDesc() {
auto ret = of_make_unique<RegstDesc> ();
regst_descs_.push_back(ret.get());
return ret;
}
void AllRegstsToProto(PbRpf<RegstDescProto>* ret) {
ret->Clear();
for (RegstDesc* regst : regst_descs_) {
regst->ToProto(ret->Add());
}
}
private:
RegstDescMgr() = default;
std::list<RegstDesc*> regst_descs_;
};
} // namespace oneflow
#endif // ONEFLOW_REGISTER_REGISTER_DESC_MANAGER_H_
......@@ -2,6 +2,7 @@ syntax = "proto3";
package oneflow;
import "graph/exec_graph.proto";
import "register/register_desc.proto";
enum TaskType {
HostCompTask = 0;
......@@ -19,8 +20,7 @@ message TaskProto {
uint64 thrd_local_id = 4;
bool is_forward = 5;
ExecGraphProto exec_graph = 6;
repeated uint64 produced_regst_desc_ids = 7;
repeated uint64 subscribed_regst_desc_ids = 8;
repeated RegstDescProto produced_regst_desc = 7;
// for CompTask
uint64 parallel_id = 1000;
}
name: "GoogleNet"
op_conf {
op {
name: "mnist"
data_loader_conf {
feature: "feature"
......@@ -13,7 +13,7 @@ op_conf {
}
}
op_conf {
op {
name: "conv1"
convolution_conf {
in: "mnist/feature"
......@@ -28,7 +28,7 @@ op_conf {
}
}
op_conf {
op {
name: "pool1"
pooling_conf {
in: "conv1/out"
......@@ -43,7 +43,7 @@ op_conf {
}
}
op_conf {
op {
name: "conv2_1x1"
convolution_conf {
in: "pool1/out"
......@@ -58,7 +58,7 @@ op_conf {
}
}
op_conf {
op {
name: "conv2_1x3"
convolution_conf {
in: "pool1/out"
......@@ -73,7 +73,7 @@ op_conf {
}
}
op_conf {
op {
name: "conv2_1x5"
convolution_conf {
in: "pool1/out"
......@@ -88,7 +88,7 @@ op_conf {
}
}
op_conf {
op {
name: "conv2_3x3"
convolution_conf {
in: "conv2_1x3/out"
......@@ -103,7 +103,7 @@ op_conf {
}
}
op_conf {
op {
name: "conv2_5x5"
convolution_conf {
in: "conv2_1x5/out"
......@@ -118,7 +118,7 @@ op_conf {
}
}
op_conf {
op {
name: "concat2"
concat_conf {
in: "conv2_1x1/out"
......@@ -129,7 +129,7 @@ op_conf {
}
}
op_conf {
op {
name: "ip1"
innerproduct_conf {
in: "concat2/out"
......@@ -138,7 +138,7 @@ op_conf {
}
}
op_conf {
op {
name: "relu1"
relu_conf {
in: "ip1/out"
......@@ -146,7 +146,7 @@ op_conf {
}
}
op_conf {
op {
name: "ip2"
innerproduct_conf {
in: "relu1/out"
......@@ -155,7 +155,7 @@ op_conf {
}
}
op_conf {
op {
name: "softmax1"
softmax_conf {
in: "ip2/out"
......@@ -163,7 +163,7 @@ op_conf {
}
}
op_conf {
op {
name: "loss1"
multinomial_logistic_loss_conf {
prediction: "softmax1/out"
......@@ -172,7 +172,7 @@ op_conf {
}
}
op_conf {
op {
name: "pool2"
pooling_conf {
in: "concat2/out"
......@@ -187,7 +187,7 @@ op_conf {
}
}
op_conf {
op {
name: "conv3_1x1"
convolution_conf {
in: "pool2/out"
......@@ -202,7 +202,7 @@ op_conf {
}
}
op_conf {
op {
name: "conv3_1x3"
convolution_conf {
in: "pool2/out"
......@@ -217,7 +217,7 @@ op_conf {
}
}
op_conf {
op {
name: "conv3_1x5"
convolution_conf {
in: "pool2/out"
......@@ -232,7 +232,7 @@ op_conf {
}
}
op_conf {
op {
name: "conv3_3x3"
convolution_conf {
in: "conv3_1x3/out"
......@@ -247,7 +247,7 @@ op_conf {
}
}
op_conf {
op {
name: "conv3_5x5"
convolution_conf {
in: "conv3_1x5/out"
......@@ -262,7 +262,7 @@ op_conf {
}
}
op_conf {
op {
name: "concat3"
concat_conf {
in: "conv3_1x1/out"
......@@ -273,7 +273,7 @@ op_conf {
}
}
op_conf {
op {
name: "ip3"
innerproduct_conf {
in: "concat3/out"
......@@ -282,7 +282,7 @@ op_conf {
}
}
op_conf {
op {
name: "relu3"
relu_conf {
in: "ip3/out"
......@@ -290,7 +290,7 @@ op_conf {
}
}
op_conf {
op {
name: "ip4"
innerproduct_conf {
in: "relu3/out"
......@@ -299,7 +299,7 @@ op_conf {
}
}
op_conf {
op {
name: "softmax2"
softmax_conf {
in: "ip4/out"
......@@ -307,7 +307,7 @@ op_conf {
}
}
op_conf {
op {
name: "loss2"
multinomial_logistic_loss_conf {
prediction: "softmax2/out"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册