提交 c7a2b1d7 编写于 作者: W willzhang4a58

nonrecurrent->normal

上级 0833ce24
...@@ -124,6 +124,6 @@ void BackwardCompActor::Act() { ...@@ -124,6 +124,6 @@ void BackwardCompActor::Act() {
} }
} }
REGISTER_ACTOR(TaskType::kNonRecurrentBackward, BackwardCompActor); REGISTER_ACTOR(TaskType::kNormalBackward, BackwardCompActor);
} // namespace oneflow } // namespace oneflow
...@@ -150,7 +150,7 @@ void ForwardCompActor::TryAsyncReturnModelTmpRegst() { ...@@ -150,7 +150,7 @@ void ForwardCompActor::TryAsyncReturnModelTmpRegst() {
} }
} }
REGISTER_ACTOR(TaskType::kNonRecurrentForward, ForwardCompActor); REGISTER_ACTOR(TaskType::kNormalForward, ForwardCompActor);
REGISTER_ACTOR(TaskType::kLoss, ForwardCompActor); REGISTER_ACTOR(TaskType::kLoss, ForwardCompActor);
} // namespace oneflow } // namespace oneflow
#ifndef ONEFLOW_CORE_COMM_NETWORK_RDMA_CONNECTION_H #ifndef ONEFLOW_CORE_COMM_NETWORK_RDMA_CONNECTION_H_
#define ONEFLOW_CORE_COMM_NETWORK_RDMA_CONNECTION_H #define ONEFLOW_CORE_COMM_NETWORK_RDMA_CONNECTION_H_
#ifdef WITH_RDMA #ifdef WITH_RDMA
...@@ -49,4 +49,4 @@ class Connection { ...@@ -49,4 +49,4 @@ class Connection {
#endif // WITH_RDMA #endif // WITH_RDMA
#endif // ONEFLOW_CORE_COMM_NETWORK_RDMA_CONNECTION_H #endif // ONEFLOW_CORE_COMM_NETWORK_RDMA_CONNECTION_H_
#include "oneflow/core/graph/chain_node.h" #include "oneflow/core/graph/chain_node.h"
#include "oneflow/core/graph/recurrent_backward_compute_task_node.h" #include "oneflow/core/graph/recurrent_backward_compute_task_node.h"
#include "oneflow/core/graph/nonrecurrent_backward_compute_task_node.h" #include "oneflow/core/graph/normal_backward_compute_task_node.h"
#include "oneflow/core/graph/recurrent_forward_compute_task_node.h" #include "oneflow/core/graph/recurrent_forward_compute_task_node.h"
#include "oneflow/core/graph/nonrecurrent_forward_compute_task_node.h" #include "oneflow/core/graph/normal_forward_compute_task_node.h"
#include "oneflow/core/graph/loss_accumulate_compute_task_node.h" #include "oneflow/core/graph/loss_accumulate_compute_task_node.h"
#include "oneflow/core/graph/loss_compute_task_node.h" #include "oneflow/core/graph/loss_compute_task_node.h"
#include "oneflow/core/graph/loss_print_compute_task_node.h" #include "oneflow/core/graph/loss_print_compute_task_node.h"
...@@ -230,7 +230,7 @@ CompTaskNode* ForwardChainNode::NewCompTaskNode() const { ...@@ -230,7 +230,7 @@ CompTaskNode* ForwardChainNode::NewCompTaskNode() const {
if (HasSoleRecurrentOp()) { if (HasSoleRecurrentOp()) {
return new RecurrentForwardCompTaskNode; return new RecurrentForwardCompTaskNode;
} else { } else {
return new NonRecurrentForwardCompTaskNode; return new NormalForwardCompTaskNode;
} }
} }
...@@ -290,7 +290,7 @@ CompTaskNode* BackwardChainNode::NewCompTaskNode() const { ...@@ -290,7 +290,7 @@ CompTaskNode* BackwardChainNode::NewCompTaskNode() const {
if (HasSoleRecurrentOp()) { if (HasSoleRecurrentOp()) {
return new RecurrentBackwardCompTaskNode; return new RecurrentBackwardCompTaskNode;
} else { } else {
return new NonRecurrentBackwardCompTaskNode; return new NormalBackwardCompTaskNode;
} }
} }
......
#include "oneflow/core/graph/nonrecurrent_backward_compute_task_node.h" #include "oneflow/core/graph/normal_backward_compute_task_node.h"
#include "oneflow/core/graph/chain_node.h" #include "oneflow/core/graph/chain_node.h"
namespace oneflow { namespace oneflow {
void NonRecurrentBackwardCompTaskNode:: void NormalBackwardCompTaskNode::VirtualBuildExecGphAndBindOutDiffRegst() {
VirtualBuildExecGphAndBindOutDiffRegst() {
HashMap<std::string, std::pair<ExecNode*, std::string>> lbn2producer; HashMap<std::string, std::pair<ExecNode*, std::string>> lbn2producer;
for (std::shared_ptr<const Operator> op : chain_node()->op_vec()) { for (std::shared_ptr<const Operator> op : chain_node()->op_vec()) {
ExecNode* cur_node = mut_exec_gph().NewNode(); ExecNode* cur_node = mut_exec_gph().NewNode();
...@@ -34,7 +33,7 @@ void NonRecurrentBackwardCompTaskNode:: ...@@ -34,7 +33,7 @@ void NonRecurrentBackwardCompTaskNode::
}); });
} }
void NonRecurrentBackwardCompTaskNode::VirtualBuildActivationDiffRegst() { void NormalBackwardCompTaskNode::VirtualBuildActivationDiffRegst() {
std::shared_ptr<RegstDesc> activation_regst = GetConsumedRegst("activation"); std::shared_ptr<RegstDesc> activation_regst = GetConsumedRegst("activation");
auto activation_diff_regst = GetProducedRegst("activation_diff"); auto activation_diff_regst = GetProducedRegst("activation_diff");
mut_exec_gph().ForEachEdge([&](ExecEdge* edge) { mut_exec_gph().ForEachEdge([&](ExecEdge* edge) {
...@@ -57,7 +56,7 @@ void NonRecurrentBackwardCompTaskNode::VirtualBuildActivationDiffRegst() { ...@@ -57,7 +56,7 @@ void NonRecurrentBackwardCompTaskNode::VirtualBuildActivationDiffRegst() {
}); });
} }
void NonRecurrentBackwardCompTaskNode::VirtualBuildInDiffRegst() { void NormalBackwardCompTaskNode::VirtualBuildInDiffRegst() {
std::shared_ptr<RegstDesc> in_diff_regst = GetProducedRegst("in_diff"); std::shared_ptr<RegstDesc> in_diff_regst = GetProducedRegst("in_diff");
std::shared_ptr<RegstDesc> in_regst = GetConsumedRegst("in"); std::shared_ptr<RegstDesc> in_regst = GetConsumedRegst("in");
mut_exec_gph().ForEachNode([&](ExecNode* cur_node) { mut_exec_gph().ForEachNode([&](ExecNode* cur_node) {
...@@ -77,7 +76,7 @@ void NonRecurrentBackwardCompTaskNode::VirtualBuildInDiffRegst() { ...@@ -77,7 +76,7 @@ void NonRecurrentBackwardCompTaskNode::VirtualBuildInDiffRegst() {
}); });
} }
void NonRecurrentBackwardCompTaskNode::VirtualConsumeInRegst() { void NormalBackwardCompTaskNode::VirtualConsumeInRegst() {
TaskNode* fw_node = GetRelatedFwTaskNode(); TaskNode* fw_node = GetRelatedFwTaskNode();
for (TaskEdge* edge : fw_node->in_edges()) { for (TaskEdge* edge : fw_node->in_edges()) {
TaskNode* pred_fw_node = edge->src_node(); TaskNode* pred_fw_node = edge->src_node();
......
#ifndef ONEFLOW_CORE_GRAPH_NONRECURRENT_BACKWARD_COMPUTE_TASK_NODE_H_ #ifndef ONEFLOW_CORE_GRAPH_NORMAL_BACKWARD_COMPUTE_TASK_NODE_H_
#define ONEFLOW_CORE_GRAPH_NONRECURRENT_BACKWARD_COMPUTE_TASK_NODE_H_ #define ONEFLOW_CORE_GRAPH_NORMAL_BACKWARD_COMPUTE_TASK_NODE_H_
#include "oneflow/core/graph/backward_compute_task_node.h" #include "oneflow/core/graph/backward_compute_task_node.h"
namespace oneflow { namespace oneflow {
class NonRecurrentBackwardCompTaskNode final : public BackwardCompTaskNode { class NormalBackwardCompTaskNode final : public BackwardCompTaskNode {
public: public:
OF_DISALLOW_COPY_AND_MOVE(NonRecurrentBackwardCompTaskNode); OF_DISALLOW_COPY_AND_MOVE(NormalBackwardCompTaskNode);
NonRecurrentBackwardCompTaskNode() = default; NormalBackwardCompTaskNode() = default;
~NonRecurrentBackwardCompTaskNode() = default; ~NormalBackwardCompTaskNode() = default;
TaskType GetTaskType() const override { TaskType GetTaskType() const override { return TaskType::kNormalBackward; }
return TaskType::kNonRecurrentBackward;
}
private: private:
void VirtualBuildExecGphAndBindOutDiffRegst() override; void VirtualBuildExecGphAndBindOutDiffRegst() override;
...@@ -38,4 +36,4 @@ class NonRecurrentBackwardCompTaskNode final : public BackwardCompTaskNode { ...@@ -38,4 +36,4 @@ class NonRecurrentBackwardCompTaskNode final : public BackwardCompTaskNode {
} // namespace oneflow } // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_NONRECURRENT_BACKWARD_COMPUTE_TASK_NODE_H_ #endif // ONEFLOW_CORE_GRAPH_NORMAL_BACKWARD_COMPUTE_TASK_NODE_H_
#include "oneflow/core/graph/forward_compute_task_node.h" #include "oneflow/core/graph/forward_compute_task_node.h"
#include "oneflow/core/graph/nonrecurrent_forward_compute_task_node.h" #include "oneflow/core/graph/normal_forward_compute_task_node.h"
#include "oneflow/core/graph/chain_node.h" #include "oneflow/core/graph/chain_node.h"
namespace oneflow { namespace oneflow {
void NonRecurrentForwardCompTaskNode::VirtualConsumeInRegst(TaskEdge* edge) { void NormalForwardCompTaskNode::VirtualConsumeInRegst(TaskEdge* edge) {
ConsumeRegst("in", edge->GetSoleRegst()); ConsumeRegst("in", edge->GetSoleRegst());
} }
void NonRecurrentForwardCompTaskNode::BuildExecGphStructAndBindInRegst() { void NormalForwardCompTaskNode::BuildExecGphStructAndBindInRegst() {
HashMap<std::string, std::pair<ExecNode*, std::string>> lbn2producer; HashMap<std::string, std::pair<ExecNode*, std::string>> lbn2producer;
for (std::shared_ptr<const Operator> op : chain_node()->op_vec()) { for (std::shared_ptr<const Operator> op : chain_node()->op_vec()) {
ExecNode* cur_node = mut_exec_gph().NewNode(); ExecNode* cur_node = mut_exec_gph().NewNode();
...@@ -36,7 +36,7 @@ void NonRecurrentForwardCompTaskNode::BuildExecGphStructAndBindInRegst() { ...@@ -36,7 +36,7 @@ void NonRecurrentForwardCompTaskNode::BuildExecGphStructAndBindInRegst() {
}); });
} }
bool NonRecurrentForwardCompTaskNode::IsReadyForBuild() { bool NormalForwardCompTaskNode::IsReadyForBuild() {
return GetConsumedRegst("in")->IsLocked(); return GetConsumedRegst("in")->IsLocked();
} }
......
#ifndef ONEFLOW_CORE_GRAPH_NONRECURRENT_FORWARD_COMPUTE_TASK_NODE_H_ #ifndef ONEFLOW_CORE_GRAPH_NORMAL_FORWARD_COMPUTE_TASK_NODE_H_
#define ONEFLOW_CORE_GRAPH_NONRECURRENT_FORWARD_COMPUTE_TASK_NODE_H_ #define ONEFLOW_CORE_GRAPH_NORMAL_FORWARD_COMPUTE_TASK_NODE_H_
#include "oneflow/core/graph/forward_compute_task_node.h" #include "oneflow/core/graph/forward_compute_task_node.h"
namespace oneflow { namespace oneflow {
class NonRecurrentForwardCompTaskNode final : public ForwardCompTaskNode { class NormalForwardCompTaskNode final : public ForwardCompTaskNode {
public: public:
OF_DISALLOW_COPY_AND_MOVE(NonRecurrentForwardCompTaskNode); OF_DISALLOW_COPY_AND_MOVE(NormalForwardCompTaskNode);
NonRecurrentForwardCompTaskNode() = default; NormalForwardCompTaskNode() = default;
~NonRecurrentForwardCompTaskNode() = default; ~NormalForwardCompTaskNode() = default;
TaskType GetTaskType() const override { TaskType GetTaskType() const override { return TaskType::kNormalForward; }
return TaskType::kNonRecurrentForward;
}
bool IsReadyForBuild() override; bool IsReadyForBuild() override;
private: private:
...@@ -23,4 +21,4 @@ class NonRecurrentForwardCompTaskNode final : public ForwardCompTaskNode { ...@@ -23,4 +21,4 @@ class NonRecurrentForwardCompTaskNode final : public ForwardCompTaskNode {
} // namespace oneflow } // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_NONRECURRENT_FORWARD_COMPUTE_TASK_NODE_H_ #endif // ONEFLOW_CORE_GRAPH_NORMAL_FORWARD_COMPUTE_TASK_NODE_H_
...@@ -3,13 +3,11 @@ ...@@ -3,13 +3,11 @@
namespace oneflow { namespace oneflow {
bool IsForwardTaskType(TaskType tt) { bool IsForwardTaskType(TaskType tt) {
return tt == TaskType::kNonRecurrentForward return tt == TaskType::kNormalForward || tt == TaskType::kRecurrentForward;
|| tt == TaskType::kRecurrentForward;
} }
bool IsBackwardTaskType(TaskType tt) { bool IsBackwardTaskType(TaskType tt) {
return tt == TaskType::kNonRecurrentBackward return tt == TaskType::kNormalBackward || tt == TaskType::kRecurrentBackward;
|| tt == TaskType::kRecurrentBackward;
} }
TaskNode::TaskNode() : machine_id_(-1), thrd_id_(-1), task_id_(-1) {} TaskNode::TaskNode() : machine_id_(-1), thrd_id_(-1), task_id_(-1) {}
......
...@@ -6,9 +6,9 @@ namespace { ...@@ -6,9 +6,9 @@ namespace {
std::map<TaskType, std::string> task_type2color = { std::map<TaskType, std::string> task_type2color = {
{kInvalid, "0"}, {kInvalid, "0"},
{kNonRecurrentForward, "2"}, {kNormalForward, "2"},
{kRecurrentForward, "2"}, {kRecurrentForward, "2"},
{kNonRecurrentBackward, "3"}, {kNormalBackward, "3"},
{kRecurrentBackward, "3"}, {kRecurrentBackward, "3"},
{kSource, "1"}, {kSource, "1"},
{kLoss, "4"}, {kLoss, "4"},
......
...@@ -7,9 +7,9 @@ import "oneflow/core/job/placement.proto"; ...@@ -7,9 +7,9 @@ import "oneflow/core/job/placement.proto";
enum TaskType { enum TaskType {
kInvalid = 0; kInvalid = 0;
kNonRecurrentForward = 1; kNormalForward = 1;
kRecurrentForward = 2; kRecurrentForward = 2;
kNonRecurrentBackward = 3; kNormalBackward = 3;
kRecurrentBackward = 4; kRecurrentBackward = 4;
kSource = 5; kSource = 5;
kLoss = 6; kLoss = 6;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册