From c7a2b1d74ed5dbfed43f77a7210536929436d048 Mon Sep 17 00:00:00 2001 From: willzhang4a58 Date: Mon, 8 Jan 2018 18:01:00 +0800 Subject: [PATCH] nonrecurrent->normal --- oneflow/core/actor/backward_compute_actor.cpp | 2 +- oneflow/core/actor/forward_compute_actor.cpp | 2 +- oneflow/core/comm_network/rdma/connection.h | 6 ++--- oneflow/core/graph/chain_node.cpp | 8 +++--- .../nonrecurrent_forward_compute_task_node.h | 26 ------------------- ... => normal_backward_compute_task_node.cpp} | 11 ++++---- ....h => normal_backward_compute_task_node.h} | 18 ++++++------- ...p => normal_forward_compute_task_node.cpp} | 8 +++--- .../graph/normal_forward_compute_task_node.h | 24 +++++++++++++++++ oneflow/core/graph/task_node.cpp | 6 ++--- oneflow/core/job/compiler.cpp | 4 +-- oneflow/core/job/task.proto | 4 +-- 12 files changed, 56 insertions(+), 63 deletions(-) delete mode 100644 oneflow/core/graph/nonrecurrent_forward_compute_task_node.h rename oneflow/core/graph/{nonrecurrent_backward_compute_task_node.cpp => normal_backward_compute_task_node.cpp} (90%) rename oneflow/core/graph/{nonrecurrent_backward_compute_task_node.h => normal_backward_compute_task_node.h} (62%) rename oneflow/core/graph/{nonrecurrent_forward_compute_task_node.cpp => normal_forward_compute_task_node.cpp} (82%) create mode 100644 oneflow/core/graph/normal_forward_compute_task_node.h diff --git a/oneflow/core/actor/backward_compute_actor.cpp b/oneflow/core/actor/backward_compute_actor.cpp index ce3acba50e..922f39d7e7 100644 --- a/oneflow/core/actor/backward_compute_actor.cpp +++ b/oneflow/core/actor/backward_compute_actor.cpp @@ -124,6 +124,6 @@ void BackwardCompActor::Act() { } } -REGISTER_ACTOR(TaskType::kNonRecurrentBackward, BackwardCompActor); +REGISTER_ACTOR(TaskType::kNormalBackward, BackwardCompActor); } // namespace oneflow diff --git a/oneflow/core/actor/forward_compute_actor.cpp b/oneflow/core/actor/forward_compute_actor.cpp index d851031f24..6d284bc436 100644 --- a/oneflow/core/actor/forward_compute_actor.cpp +++ b/oneflow/core/actor/forward_compute_actor.cpp @@ -150,7 +150,7 @@ void ForwardCompActor::TryAsyncReturnModelTmpRegst() { } } -REGISTER_ACTOR(TaskType::kNonRecurrentForward, ForwardCompActor); +REGISTER_ACTOR(TaskType::kNormalForward, ForwardCompActor); REGISTER_ACTOR(TaskType::kLoss, ForwardCompActor); } // namespace oneflow diff --git a/oneflow/core/comm_network/rdma/connection.h b/oneflow/core/comm_network/rdma/connection.h index a03d51c45e..6eb4760e68 100644 --- a/oneflow/core/comm_network/rdma/connection.h +++ b/oneflow/core/comm_network/rdma/connection.h @@ -1,5 +1,5 @@ -#ifndef ONEFLOW_CORE_COMM_NETWORK_RDMA_CONNECTION_H -#define ONEFLOW_CORE_COMM_NETWORK_RDMA_CONNECTION_H +#ifndef ONEFLOW_CORE_COMM_NETWORK_RDMA_CONNECTION_H_ +#define ONEFLOW_CORE_COMM_NETWORK_RDMA_CONNECTION_H_ #ifdef WITH_RDMA @@ -49,4 +49,4 @@ class Connection { #endif // WITH_RDMA -#endif // ONEFLOW_CORE_COMM_NETWORK_RDMA_CONNECTION_H +#endif // ONEFLOW_CORE_COMM_NETWORK_RDMA_CONNECTION_H_ diff --git a/oneflow/core/graph/chain_node.cpp b/oneflow/core/graph/chain_node.cpp index 6157240b6f..b2b204b563 100644 --- a/oneflow/core/graph/chain_node.cpp +++ b/oneflow/core/graph/chain_node.cpp @@ -1,8 +1,8 @@ #include "oneflow/core/graph/chain_node.h" #include "oneflow/core/graph/recurrent_backward_compute_task_node.h" -#include "oneflow/core/graph/nonrecurrent_backward_compute_task_node.h" +#include "oneflow/core/graph/normal_backward_compute_task_node.h" #include "oneflow/core/graph/recurrent_forward_compute_task_node.h" -#include "oneflow/core/graph/nonrecurrent_forward_compute_task_node.h" +#include "oneflow/core/graph/normal_forward_compute_task_node.h" #include "oneflow/core/graph/loss_accumulate_compute_task_node.h" #include "oneflow/core/graph/loss_compute_task_node.h" #include "oneflow/core/graph/loss_print_compute_task_node.h" @@ -230,7 +230,7 @@ CompTaskNode* ForwardChainNode::NewCompTaskNode() const { if (HasSoleRecurrentOp()) { return new RecurrentForwardCompTaskNode; } else { - return new NonRecurrentForwardCompTaskNode; + return new NormalForwardCompTaskNode; } } @@ -290,7 +290,7 @@ CompTaskNode* BackwardChainNode::NewCompTaskNode() const { if (HasSoleRecurrentOp()) { return new RecurrentBackwardCompTaskNode; } else { - return new NonRecurrentBackwardCompTaskNode; + return new NormalBackwardCompTaskNode; } } diff --git a/oneflow/core/graph/nonrecurrent_forward_compute_task_node.h b/oneflow/core/graph/nonrecurrent_forward_compute_task_node.h deleted file mode 100644 index 0e35ae1da5..0000000000 --- a/oneflow/core/graph/nonrecurrent_forward_compute_task_node.h +++ /dev/null @@ -1,26 +0,0 @@ -#ifndef ONEFLOW_CORE_GRAPH_NONRECURRENT_FORWARD_COMPUTE_TASK_NODE_H_ -#define ONEFLOW_CORE_GRAPH_NONRECURRENT_FORWARD_COMPUTE_TASK_NODE_H_ - -#include "oneflow/core/graph/forward_compute_task_node.h" - -namespace oneflow { - -class NonRecurrentForwardCompTaskNode final : public ForwardCompTaskNode { - public: - OF_DISALLOW_COPY_AND_MOVE(NonRecurrentForwardCompTaskNode); - NonRecurrentForwardCompTaskNode() = default; - ~NonRecurrentForwardCompTaskNode() = default; - - TaskType GetTaskType() const override { - return TaskType::kNonRecurrentForward; - } - bool IsReadyForBuild() override; - - private: - void VirtualConsumeInRegst(TaskEdge* edge) override; - void BuildExecGphStructAndBindInRegst() override; -}; - -} // namespace oneflow - -#endif // ONEFLOW_CORE_GRAPH_NONRECURRENT_FORWARD_COMPUTE_TASK_NODE_H_ diff --git a/oneflow/core/graph/nonrecurrent_backward_compute_task_node.cpp b/oneflow/core/graph/normal_backward_compute_task_node.cpp similarity index 90% rename from oneflow/core/graph/nonrecurrent_backward_compute_task_node.cpp rename to oneflow/core/graph/normal_backward_compute_task_node.cpp index 6dbde56aaf..386d37a0ec 100644 --- a/oneflow/core/graph/nonrecurrent_backward_compute_task_node.cpp +++ b/oneflow/core/graph/normal_backward_compute_task_node.cpp @@ -1,10 +1,9 @@ -#include "oneflow/core/graph/nonrecurrent_backward_compute_task_node.h" +#include "oneflow/core/graph/normal_backward_compute_task_node.h" #include "oneflow/core/graph/chain_node.h" namespace oneflow { -void NonRecurrentBackwardCompTaskNode:: - VirtualBuildExecGphAndBindOutDiffRegst() { +void NormalBackwardCompTaskNode::VirtualBuildExecGphAndBindOutDiffRegst() { HashMap> lbn2producer; for (std::shared_ptr op : chain_node()->op_vec()) { ExecNode* cur_node = mut_exec_gph().NewNode(); @@ -34,7 +33,7 @@ void NonRecurrentBackwardCompTaskNode:: }); } -void NonRecurrentBackwardCompTaskNode::VirtualBuildActivationDiffRegst() { +void NormalBackwardCompTaskNode::VirtualBuildActivationDiffRegst() { std::shared_ptr activation_regst = GetConsumedRegst("activation"); auto activation_diff_regst = GetProducedRegst("activation_diff"); mut_exec_gph().ForEachEdge([&](ExecEdge* edge) { @@ -57,7 +56,7 @@ void NonRecurrentBackwardCompTaskNode::VirtualBuildActivationDiffRegst() { }); } -void NonRecurrentBackwardCompTaskNode::VirtualBuildInDiffRegst() { +void NormalBackwardCompTaskNode::VirtualBuildInDiffRegst() { std::shared_ptr in_diff_regst = GetProducedRegst("in_diff"); std::shared_ptr in_regst = GetConsumedRegst("in"); mut_exec_gph().ForEachNode([&](ExecNode* cur_node) { @@ -77,7 +76,7 @@ void NonRecurrentBackwardCompTaskNode::VirtualBuildInDiffRegst() { }); } -void NonRecurrentBackwardCompTaskNode::VirtualConsumeInRegst() { +void NormalBackwardCompTaskNode::VirtualConsumeInRegst() { TaskNode* fw_node = GetRelatedFwTaskNode(); for (TaskEdge* edge : fw_node->in_edges()) { TaskNode* pred_fw_node = edge->src_node(); diff --git a/oneflow/core/graph/nonrecurrent_backward_compute_task_node.h b/oneflow/core/graph/normal_backward_compute_task_node.h similarity index 62% rename from oneflow/core/graph/nonrecurrent_backward_compute_task_node.h rename to oneflow/core/graph/normal_backward_compute_task_node.h index 2cab3b7062..29d1687ef6 100644 --- a/oneflow/core/graph/nonrecurrent_backward_compute_task_node.h +++ b/oneflow/core/graph/normal_backward_compute_task_node.h @@ -1,19 +1,17 @@ -#ifndef ONEFLOW_CORE_GRAPH_NONRECURRENT_BACKWARD_COMPUTE_TASK_NODE_H_ -#define ONEFLOW_CORE_GRAPH_NONRECURRENT_BACKWARD_COMPUTE_TASK_NODE_H_ +#ifndef ONEFLOW_CORE_GRAPH_NORMAL_BACKWARD_COMPUTE_TASK_NODE_H_ +#define ONEFLOW_CORE_GRAPH_NORMAL_BACKWARD_COMPUTE_TASK_NODE_H_ #include "oneflow/core/graph/backward_compute_task_node.h" namespace oneflow { -class NonRecurrentBackwardCompTaskNode final : public BackwardCompTaskNode { +class NormalBackwardCompTaskNode final : public BackwardCompTaskNode { public: - OF_DISALLOW_COPY_AND_MOVE(NonRecurrentBackwardCompTaskNode); - NonRecurrentBackwardCompTaskNode() = default; - ~NonRecurrentBackwardCompTaskNode() = default; + OF_DISALLOW_COPY_AND_MOVE(NormalBackwardCompTaskNode); + NormalBackwardCompTaskNode() = default; + ~NormalBackwardCompTaskNode() = default; - TaskType GetTaskType() const override { - return TaskType::kNonRecurrentBackward; - } + TaskType GetTaskType() const override { return TaskType::kNormalBackward; } private: void VirtualBuildExecGphAndBindOutDiffRegst() override; @@ -38,4 +36,4 @@ class NonRecurrentBackwardCompTaskNode final : public BackwardCompTaskNode { } // namespace oneflow -#endif // ONEFLOW_CORE_GRAPH_NONRECURRENT_BACKWARD_COMPUTE_TASK_NODE_H_ +#endif // ONEFLOW_CORE_GRAPH_NORMAL_BACKWARD_COMPUTE_TASK_NODE_H_ diff --git a/oneflow/core/graph/nonrecurrent_forward_compute_task_node.cpp b/oneflow/core/graph/normal_forward_compute_task_node.cpp similarity index 82% rename from oneflow/core/graph/nonrecurrent_forward_compute_task_node.cpp rename to oneflow/core/graph/normal_forward_compute_task_node.cpp index 73ca947cd0..fed4fa39c1 100644 --- a/oneflow/core/graph/nonrecurrent_forward_compute_task_node.cpp +++ b/oneflow/core/graph/normal_forward_compute_task_node.cpp @@ -1,14 +1,14 @@ #include "oneflow/core/graph/forward_compute_task_node.h" -#include "oneflow/core/graph/nonrecurrent_forward_compute_task_node.h" +#include "oneflow/core/graph/normal_forward_compute_task_node.h" #include "oneflow/core/graph/chain_node.h" namespace oneflow { -void NonRecurrentForwardCompTaskNode::VirtualConsumeInRegst(TaskEdge* edge) { +void NormalForwardCompTaskNode::VirtualConsumeInRegst(TaskEdge* edge) { ConsumeRegst("in", edge->GetSoleRegst()); } -void NonRecurrentForwardCompTaskNode::BuildExecGphStructAndBindInRegst() { +void NormalForwardCompTaskNode::BuildExecGphStructAndBindInRegst() { HashMap> lbn2producer; for (std::shared_ptr op : chain_node()->op_vec()) { ExecNode* cur_node = mut_exec_gph().NewNode(); @@ -36,7 +36,7 @@ void NonRecurrentForwardCompTaskNode::BuildExecGphStructAndBindInRegst() { }); } -bool NonRecurrentForwardCompTaskNode::IsReadyForBuild() { +bool NormalForwardCompTaskNode::IsReadyForBuild() { return GetConsumedRegst("in")->IsLocked(); } diff --git a/oneflow/core/graph/normal_forward_compute_task_node.h b/oneflow/core/graph/normal_forward_compute_task_node.h new file mode 100644 index 0000000000..30f770ffe3 --- /dev/null +++ b/oneflow/core/graph/normal_forward_compute_task_node.h @@ -0,0 +1,24 @@ +#ifndef ONEFLOW_CORE_GRAPH_NORMAL_FORWARD_COMPUTE_TASK_NODE_H_ +#define ONEFLOW_CORE_GRAPH_NORMAL_FORWARD_COMPUTE_TASK_NODE_H_ + +#include "oneflow/core/graph/forward_compute_task_node.h" + +namespace oneflow { + +class NormalForwardCompTaskNode final : public ForwardCompTaskNode { + public: + OF_DISALLOW_COPY_AND_MOVE(NormalForwardCompTaskNode); + NormalForwardCompTaskNode() = default; + ~NormalForwardCompTaskNode() = default; + + TaskType GetTaskType() const override { return TaskType::kNormalForward; } + bool IsReadyForBuild() override; + + private: + void VirtualConsumeInRegst(TaskEdge* edge) override; + void BuildExecGphStructAndBindInRegst() override; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_GRAPH_NORMAL_FORWARD_COMPUTE_TASK_NODE_H_ diff --git a/oneflow/core/graph/task_node.cpp b/oneflow/core/graph/task_node.cpp index 78afcf10df..61d3cada38 100644 --- a/oneflow/core/graph/task_node.cpp +++ b/oneflow/core/graph/task_node.cpp @@ -3,13 +3,11 @@ namespace oneflow { bool IsForwardTaskType(TaskType tt) { - return tt == TaskType::kNonRecurrentForward - || tt == TaskType::kRecurrentForward; + return tt == TaskType::kNormalForward || tt == TaskType::kRecurrentForward; } bool IsBackwardTaskType(TaskType tt) { - return tt == TaskType::kNonRecurrentBackward - || tt == TaskType::kRecurrentBackward; + return tt == TaskType::kNormalBackward || tt == TaskType::kRecurrentBackward; } TaskNode::TaskNode() : machine_id_(-1), thrd_id_(-1), task_id_(-1) {} diff --git a/oneflow/core/job/compiler.cpp b/oneflow/core/job/compiler.cpp index de2a462310..983bf5ab2c 100644 --- a/oneflow/core/job/compiler.cpp +++ b/oneflow/core/job/compiler.cpp @@ -6,9 +6,9 @@ namespace { std::map task_type2color = { {kInvalid, "0"}, - {kNonRecurrentForward, "2"}, + {kNormalForward, "2"}, {kRecurrentForward, "2"}, - {kNonRecurrentBackward, "3"}, + {kNormalBackward, "3"}, {kRecurrentBackward, "3"}, {kSource, "1"}, {kLoss, "4"}, diff --git a/oneflow/core/job/task.proto b/oneflow/core/job/task.proto index 068ff3978b..98b8af46f8 100644 --- a/oneflow/core/job/task.proto +++ b/oneflow/core/job/task.proto @@ -7,9 +7,9 @@ import "oneflow/core/job/placement.proto"; enum TaskType { kInvalid = 0; - kNonRecurrentForward = 1; + kNormalForward = 1; kRecurrentForward = 2; - kNonRecurrentBackward = 3; + kNormalBackward = 3; kRecurrentBackward = 4; kSource = 5; kLoss = 6; -- GitLab