diff --git a/oneflow/core/actor/backward_compute_actor.cpp b/oneflow/core/actor/backward_compute_actor.cpp index ce3acba50e6e9f1f0b25b6e765decce822a3e32e..922f39d7e748adbe01b0a3c9fefe61447ff232f5 100644 --- a/oneflow/core/actor/backward_compute_actor.cpp +++ b/oneflow/core/actor/backward_compute_actor.cpp @@ -124,6 +124,6 @@ void BackwardCompActor::Act() { } } -REGISTER_ACTOR(TaskType::kNonRecurrentBackward, BackwardCompActor); +REGISTER_ACTOR(TaskType::kNormalBackward, BackwardCompActor); } // namespace oneflow diff --git a/oneflow/core/actor/forward_compute_actor.cpp b/oneflow/core/actor/forward_compute_actor.cpp index d851031f247b9e8924468dc4caedc2b2753524e2..6d284bc436161b5c88ac6f6391c5c0c70247ab52 100644 --- a/oneflow/core/actor/forward_compute_actor.cpp +++ b/oneflow/core/actor/forward_compute_actor.cpp @@ -150,7 +150,7 @@ void ForwardCompActor::TryAsyncReturnModelTmpRegst() { } } -REGISTER_ACTOR(TaskType::kNonRecurrentForward, ForwardCompActor); +REGISTER_ACTOR(TaskType::kNormalForward, ForwardCompActor); REGISTER_ACTOR(TaskType::kLoss, ForwardCompActor); } // namespace oneflow diff --git a/oneflow/core/comm_network/rdma/connection.h b/oneflow/core/comm_network/rdma/connection.h index a03d51c45e0d195ec88495f65d1de151020d13cd..6eb4760e68c7fadb13c16e491d3db8a49db56566 100644 --- a/oneflow/core/comm_network/rdma/connection.h +++ b/oneflow/core/comm_network/rdma/connection.h @@ -1,5 +1,5 @@ -#ifndef ONEFLOW_CORE_COMM_NETWORK_RDMA_CONNECTION_H -#define ONEFLOW_CORE_COMM_NETWORK_RDMA_CONNECTION_H +#ifndef ONEFLOW_CORE_COMM_NETWORK_RDMA_CONNECTION_H_ +#define ONEFLOW_CORE_COMM_NETWORK_RDMA_CONNECTION_H_ #ifdef WITH_RDMA @@ -49,4 +49,4 @@ class Connection { #endif // WITH_RDMA -#endif // ONEFLOW_CORE_COMM_NETWORK_RDMA_CONNECTION_H +#endif // ONEFLOW_CORE_COMM_NETWORK_RDMA_CONNECTION_H_ diff --git a/oneflow/core/graph/chain_node.cpp b/oneflow/core/graph/chain_node.cpp index 6157240b6ff08a2c80d339c9c1719e90d8c93783..b2b204b563b853e97099d57c6000163b67be305c 100644 --- a/oneflow/core/graph/chain_node.cpp +++ b/oneflow/core/graph/chain_node.cpp @@ -1,8 +1,8 @@ #include "oneflow/core/graph/chain_node.h" #include "oneflow/core/graph/recurrent_backward_compute_task_node.h" -#include "oneflow/core/graph/nonrecurrent_backward_compute_task_node.h" +#include "oneflow/core/graph/normal_backward_compute_task_node.h" #include "oneflow/core/graph/recurrent_forward_compute_task_node.h" -#include "oneflow/core/graph/nonrecurrent_forward_compute_task_node.h" +#include "oneflow/core/graph/normal_forward_compute_task_node.h" #include "oneflow/core/graph/loss_accumulate_compute_task_node.h" #include "oneflow/core/graph/loss_compute_task_node.h" #include "oneflow/core/graph/loss_print_compute_task_node.h" @@ -230,7 +230,7 @@ CompTaskNode* ForwardChainNode::NewCompTaskNode() const { if (HasSoleRecurrentOp()) { return new RecurrentForwardCompTaskNode; } else { - return new NonRecurrentForwardCompTaskNode; + return new NormalForwardCompTaskNode; } } @@ -290,7 +290,7 @@ CompTaskNode* BackwardChainNode::NewCompTaskNode() const { if (HasSoleRecurrentOp()) { return new RecurrentBackwardCompTaskNode; } else { - return new NonRecurrentBackwardCompTaskNode; + return new NormalBackwardCompTaskNode; } } diff --git a/oneflow/core/graph/nonrecurrent_forward_compute_task_node.h b/oneflow/core/graph/nonrecurrent_forward_compute_task_node.h deleted file mode 100644 index 0e35ae1da5320e38a8d04b9be7765c515ba2029c..0000000000000000000000000000000000000000 --- a/oneflow/core/graph/nonrecurrent_forward_compute_task_node.h +++ /dev/null @@ -1,26 +0,0 @@ -#ifndef ONEFLOW_CORE_GRAPH_NONRECURRENT_FORWARD_COMPUTE_TASK_NODE_H_ -#define ONEFLOW_CORE_GRAPH_NONRECURRENT_FORWARD_COMPUTE_TASK_NODE_H_ - -#include "oneflow/core/graph/forward_compute_task_node.h" - -namespace oneflow { - -class NonRecurrentForwardCompTaskNode final : public ForwardCompTaskNode { - public: - OF_DISALLOW_COPY_AND_MOVE(NonRecurrentForwardCompTaskNode); - NonRecurrentForwardCompTaskNode() = default; - ~NonRecurrentForwardCompTaskNode() = default; - - TaskType GetTaskType() const override { - return TaskType::kNonRecurrentForward; - } - bool IsReadyForBuild() override; - - private: - void VirtualConsumeInRegst(TaskEdge* edge) override; - void BuildExecGphStructAndBindInRegst() override; -}; - -} // namespace oneflow - -#endif // ONEFLOW_CORE_GRAPH_NONRECURRENT_FORWARD_COMPUTE_TASK_NODE_H_ diff --git a/oneflow/core/graph/nonrecurrent_backward_compute_task_node.cpp b/oneflow/core/graph/normal_backward_compute_task_node.cpp similarity index 90% rename from oneflow/core/graph/nonrecurrent_backward_compute_task_node.cpp rename to oneflow/core/graph/normal_backward_compute_task_node.cpp index 6dbde56aafe0fb617a6c2d68c94f96c53964b7d7..386d37a0eca4289f668aa323a8af772cd9d14c4b 100644 --- a/oneflow/core/graph/nonrecurrent_backward_compute_task_node.cpp +++ b/oneflow/core/graph/normal_backward_compute_task_node.cpp @@ -1,10 +1,9 @@ -#include "oneflow/core/graph/nonrecurrent_backward_compute_task_node.h" +#include "oneflow/core/graph/normal_backward_compute_task_node.h" #include "oneflow/core/graph/chain_node.h" namespace oneflow { -void NonRecurrentBackwardCompTaskNode:: - VirtualBuildExecGphAndBindOutDiffRegst() { +void NormalBackwardCompTaskNode::VirtualBuildExecGphAndBindOutDiffRegst() { HashMap> lbn2producer; for (std::shared_ptr op : chain_node()->op_vec()) { ExecNode* cur_node = mut_exec_gph().NewNode(); @@ -34,7 +33,7 @@ void NonRecurrentBackwardCompTaskNode:: }); } -void NonRecurrentBackwardCompTaskNode::VirtualBuildActivationDiffRegst() { +void NormalBackwardCompTaskNode::VirtualBuildActivationDiffRegst() { std::shared_ptr activation_regst = GetConsumedRegst("activation"); auto activation_diff_regst = GetProducedRegst("activation_diff"); mut_exec_gph().ForEachEdge([&](ExecEdge* edge) { @@ -57,7 +56,7 @@ void NonRecurrentBackwardCompTaskNode::VirtualBuildActivationDiffRegst() { }); } -void NonRecurrentBackwardCompTaskNode::VirtualBuildInDiffRegst() { +void NormalBackwardCompTaskNode::VirtualBuildInDiffRegst() { std::shared_ptr in_diff_regst = GetProducedRegst("in_diff"); std::shared_ptr in_regst = GetConsumedRegst("in"); mut_exec_gph().ForEachNode([&](ExecNode* cur_node) { @@ -77,7 +76,7 @@ void NonRecurrentBackwardCompTaskNode::VirtualBuildInDiffRegst() { }); } -void NonRecurrentBackwardCompTaskNode::VirtualConsumeInRegst() { +void NormalBackwardCompTaskNode::VirtualConsumeInRegst() { TaskNode* fw_node = GetRelatedFwTaskNode(); for (TaskEdge* edge : fw_node->in_edges()) { TaskNode* pred_fw_node = edge->src_node(); diff --git a/oneflow/core/graph/nonrecurrent_backward_compute_task_node.h b/oneflow/core/graph/normal_backward_compute_task_node.h similarity index 62% rename from oneflow/core/graph/nonrecurrent_backward_compute_task_node.h rename to oneflow/core/graph/normal_backward_compute_task_node.h index 2cab3b7062de0e8bf5710bced3208e01cd1ae779..29d1687ef64de67105bbe7484bdee6a925a1f8c2 100644 --- a/oneflow/core/graph/nonrecurrent_backward_compute_task_node.h +++ b/oneflow/core/graph/normal_backward_compute_task_node.h @@ -1,19 +1,17 @@ -#ifndef ONEFLOW_CORE_GRAPH_NONRECURRENT_BACKWARD_COMPUTE_TASK_NODE_H_ -#define ONEFLOW_CORE_GRAPH_NONRECURRENT_BACKWARD_COMPUTE_TASK_NODE_H_ +#ifndef ONEFLOW_CORE_GRAPH_NORMAL_BACKWARD_COMPUTE_TASK_NODE_H_ +#define ONEFLOW_CORE_GRAPH_NORMAL_BACKWARD_COMPUTE_TASK_NODE_H_ #include "oneflow/core/graph/backward_compute_task_node.h" namespace oneflow { -class NonRecurrentBackwardCompTaskNode final : public BackwardCompTaskNode { +class NormalBackwardCompTaskNode final : public BackwardCompTaskNode { public: - OF_DISALLOW_COPY_AND_MOVE(NonRecurrentBackwardCompTaskNode); - NonRecurrentBackwardCompTaskNode() = default; - ~NonRecurrentBackwardCompTaskNode() = default; + OF_DISALLOW_COPY_AND_MOVE(NormalBackwardCompTaskNode); + NormalBackwardCompTaskNode() = default; + ~NormalBackwardCompTaskNode() = default; - TaskType GetTaskType() const override { - return TaskType::kNonRecurrentBackward; - } + TaskType GetTaskType() const override { return TaskType::kNormalBackward; } private: void VirtualBuildExecGphAndBindOutDiffRegst() override; @@ -38,4 +36,4 @@ class NonRecurrentBackwardCompTaskNode final : public BackwardCompTaskNode { } // namespace oneflow -#endif // ONEFLOW_CORE_GRAPH_NONRECURRENT_BACKWARD_COMPUTE_TASK_NODE_H_ +#endif // ONEFLOW_CORE_GRAPH_NORMAL_BACKWARD_COMPUTE_TASK_NODE_H_ diff --git a/oneflow/core/graph/nonrecurrent_forward_compute_task_node.cpp b/oneflow/core/graph/normal_forward_compute_task_node.cpp similarity index 82% rename from oneflow/core/graph/nonrecurrent_forward_compute_task_node.cpp rename to oneflow/core/graph/normal_forward_compute_task_node.cpp index 73ca947cd013826723f890f1b38730e117ebaafb..fed4fa39c162daadf31bb81bb9062bd7a3033a52 100644 --- a/oneflow/core/graph/nonrecurrent_forward_compute_task_node.cpp +++ b/oneflow/core/graph/normal_forward_compute_task_node.cpp @@ -1,14 +1,14 @@ #include "oneflow/core/graph/forward_compute_task_node.h" -#include "oneflow/core/graph/nonrecurrent_forward_compute_task_node.h" +#include "oneflow/core/graph/normal_forward_compute_task_node.h" #include "oneflow/core/graph/chain_node.h" namespace oneflow { -void NonRecurrentForwardCompTaskNode::VirtualConsumeInRegst(TaskEdge* edge) { +void NormalForwardCompTaskNode::VirtualConsumeInRegst(TaskEdge* edge) { ConsumeRegst("in", edge->GetSoleRegst()); } -void NonRecurrentForwardCompTaskNode::BuildExecGphStructAndBindInRegst() { +void NormalForwardCompTaskNode::BuildExecGphStructAndBindInRegst() { HashMap> lbn2producer; for (std::shared_ptr op : chain_node()->op_vec()) { ExecNode* cur_node = mut_exec_gph().NewNode(); @@ -36,7 +36,7 @@ void NonRecurrentForwardCompTaskNode::BuildExecGphStructAndBindInRegst() { }); } -bool NonRecurrentForwardCompTaskNode::IsReadyForBuild() { +bool NormalForwardCompTaskNode::IsReadyForBuild() { return GetConsumedRegst("in")->IsLocked(); } diff --git a/oneflow/core/graph/normal_forward_compute_task_node.h b/oneflow/core/graph/normal_forward_compute_task_node.h new file mode 100644 index 0000000000000000000000000000000000000000..30f770ffe3cd60b322ef77683c70f684704d3202 --- /dev/null +++ b/oneflow/core/graph/normal_forward_compute_task_node.h @@ -0,0 +1,24 @@ +#ifndef ONEFLOW_CORE_GRAPH_NORMAL_FORWARD_COMPUTE_TASK_NODE_H_ +#define ONEFLOW_CORE_GRAPH_NORMAL_FORWARD_COMPUTE_TASK_NODE_H_ + +#include "oneflow/core/graph/forward_compute_task_node.h" + +namespace oneflow { + +class NormalForwardCompTaskNode final : public ForwardCompTaskNode { + public: + OF_DISALLOW_COPY_AND_MOVE(NormalForwardCompTaskNode); + NormalForwardCompTaskNode() = default; + ~NormalForwardCompTaskNode() = default; + + TaskType GetTaskType() const override { return TaskType::kNormalForward; } + bool IsReadyForBuild() override; + + private: + void VirtualConsumeInRegst(TaskEdge* edge) override; + void BuildExecGphStructAndBindInRegst() override; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_GRAPH_NORMAL_FORWARD_COMPUTE_TASK_NODE_H_ diff --git a/oneflow/core/graph/task_node.cpp b/oneflow/core/graph/task_node.cpp index 78afcf10df519729a23910a8169d4ea754c529f0..61d3cada38d600b8edc1b9e76320c7ece63d8739 100644 --- a/oneflow/core/graph/task_node.cpp +++ b/oneflow/core/graph/task_node.cpp @@ -3,13 +3,11 @@ namespace oneflow { bool IsForwardTaskType(TaskType tt) { - return tt == TaskType::kNonRecurrentForward - || tt == TaskType::kRecurrentForward; + return tt == TaskType::kNormalForward || tt == TaskType::kRecurrentForward; } bool IsBackwardTaskType(TaskType tt) { - return tt == TaskType::kNonRecurrentBackward - || tt == TaskType::kRecurrentBackward; + return tt == TaskType::kNormalBackward || tt == TaskType::kRecurrentBackward; } TaskNode::TaskNode() : machine_id_(-1), thrd_id_(-1), task_id_(-1) {} diff --git a/oneflow/core/job/compiler.cpp b/oneflow/core/job/compiler.cpp index de2a4623100eba63403285d1ad4a81dd03dcef5e..983bf5ab2c8f8c025c4bbe6626e6cf33f7e0f473 100644 --- a/oneflow/core/job/compiler.cpp +++ b/oneflow/core/job/compiler.cpp @@ -6,9 +6,9 @@ namespace { std::map task_type2color = { {kInvalid, "0"}, - {kNonRecurrentForward, "2"}, + {kNormalForward, "2"}, {kRecurrentForward, "2"}, - {kNonRecurrentBackward, "3"}, + {kNormalBackward, "3"}, {kRecurrentBackward, "3"}, {kSource, "1"}, {kLoss, "4"}, diff --git a/oneflow/core/job/task.proto b/oneflow/core/job/task.proto index 068ff3978be6765168ac5783a7df541f1992dfea..98b8af46f8ee92541c81a991639b1950dee652f3 100644 --- a/oneflow/core/job/task.proto +++ b/oneflow/core/job/task.proto @@ -7,9 +7,9 @@ import "oneflow/core/job/placement.proto"; enum TaskType { kInvalid = 0; - kNonRecurrentForward = 1; + kNormalForward = 1; kRecurrentForward = 2; - kNonRecurrentBackward = 3; + kNormalBackward = 3; kRecurrentBackward = 4; kSource = 5; kLoss = 6;