提交 c7a2b1d7 编写于 作者: W willzhang4a58

nonrecurrent->normal

上级 0833ce24
......@@ -124,6 +124,6 @@ void BackwardCompActor::Act() {
}
}
REGISTER_ACTOR(TaskType::kNonRecurrentBackward, BackwardCompActor);
REGISTER_ACTOR(TaskType::kNormalBackward, BackwardCompActor);
} // namespace oneflow
......@@ -150,7 +150,7 @@ void ForwardCompActor::TryAsyncReturnModelTmpRegst() {
}
}
REGISTER_ACTOR(TaskType::kNonRecurrentForward, ForwardCompActor);
REGISTER_ACTOR(TaskType::kNormalForward, ForwardCompActor);
REGISTER_ACTOR(TaskType::kLoss, ForwardCompActor);
} // namespace oneflow
#ifndef ONEFLOW_CORE_COMM_NETWORK_RDMA_CONNECTION_H
#define ONEFLOW_CORE_COMM_NETWORK_RDMA_CONNECTION_H
#ifndef ONEFLOW_CORE_COMM_NETWORK_RDMA_CONNECTION_H_
#define ONEFLOW_CORE_COMM_NETWORK_RDMA_CONNECTION_H_
#ifdef WITH_RDMA
......@@ -49,4 +49,4 @@ class Connection {
#endif // WITH_RDMA
#endif // ONEFLOW_CORE_COMM_NETWORK_RDMA_CONNECTION_H
#endif // ONEFLOW_CORE_COMM_NETWORK_RDMA_CONNECTION_H_
#include "oneflow/core/graph/chain_node.h"
#include "oneflow/core/graph/recurrent_backward_compute_task_node.h"
#include "oneflow/core/graph/nonrecurrent_backward_compute_task_node.h"
#include "oneflow/core/graph/normal_backward_compute_task_node.h"
#include "oneflow/core/graph/recurrent_forward_compute_task_node.h"
#include "oneflow/core/graph/nonrecurrent_forward_compute_task_node.h"
#include "oneflow/core/graph/normal_forward_compute_task_node.h"
#include "oneflow/core/graph/loss_accumulate_compute_task_node.h"
#include "oneflow/core/graph/loss_compute_task_node.h"
#include "oneflow/core/graph/loss_print_compute_task_node.h"
......@@ -230,7 +230,7 @@ CompTaskNode* ForwardChainNode::NewCompTaskNode() const {
if (HasSoleRecurrentOp()) {
return new RecurrentForwardCompTaskNode;
} else {
return new NonRecurrentForwardCompTaskNode;
return new NormalForwardCompTaskNode;
}
}
......@@ -290,7 +290,7 @@ CompTaskNode* BackwardChainNode::NewCompTaskNode() const {
if (HasSoleRecurrentOp()) {
return new RecurrentBackwardCompTaskNode;
} else {
return new NonRecurrentBackwardCompTaskNode;
return new NormalBackwardCompTaskNode;
}
}
......
#include "oneflow/core/graph/nonrecurrent_backward_compute_task_node.h"
#include "oneflow/core/graph/normal_backward_compute_task_node.h"
#include "oneflow/core/graph/chain_node.h"
namespace oneflow {
void NonRecurrentBackwardCompTaskNode::
VirtualBuildExecGphAndBindOutDiffRegst() {
void NormalBackwardCompTaskNode::VirtualBuildExecGphAndBindOutDiffRegst() {
HashMap<std::string, std::pair<ExecNode*, std::string>> lbn2producer;
for (std::shared_ptr<const Operator> op : chain_node()->op_vec()) {
ExecNode* cur_node = mut_exec_gph().NewNode();
......@@ -34,7 +33,7 @@ void NonRecurrentBackwardCompTaskNode::
});
}
void NonRecurrentBackwardCompTaskNode::VirtualBuildActivationDiffRegst() {
void NormalBackwardCompTaskNode::VirtualBuildActivationDiffRegst() {
std::shared_ptr<RegstDesc> activation_regst = GetConsumedRegst("activation");
auto activation_diff_regst = GetProducedRegst("activation_diff");
mut_exec_gph().ForEachEdge([&](ExecEdge* edge) {
......@@ -57,7 +56,7 @@ void NonRecurrentBackwardCompTaskNode::VirtualBuildActivationDiffRegst() {
});
}
void NonRecurrentBackwardCompTaskNode::VirtualBuildInDiffRegst() {
void NormalBackwardCompTaskNode::VirtualBuildInDiffRegst() {
std::shared_ptr<RegstDesc> in_diff_regst = GetProducedRegst("in_diff");
std::shared_ptr<RegstDesc> in_regst = GetConsumedRegst("in");
mut_exec_gph().ForEachNode([&](ExecNode* cur_node) {
......@@ -77,7 +76,7 @@ void NonRecurrentBackwardCompTaskNode::VirtualBuildInDiffRegst() {
});
}
void NonRecurrentBackwardCompTaskNode::VirtualConsumeInRegst() {
void NormalBackwardCompTaskNode::VirtualConsumeInRegst() {
TaskNode* fw_node = GetRelatedFwTaskNode();
for (TaskEdge* edge : fw_node->in_edges()) {
TaskNode* pred_fw_node = edge->src_node();
......
#ifndef ONEFLOW_CORE_GRAPH_NONRECURRENT_BACKWARD_COMPUTE_TASK_NODE_H_
#define ONEFLOW_CORE_GRAPH_NONRECURRENT_BACKWARD_COMPUTE_TASK_NODE_H_
#ifndef ONEFLOW_CORE_GRAPH_NORMAL_BACKWARD_COMPUTE_TASK_NODE_H_
#define ONEFLOW_CORE_GRAPH_NORMAL_BACKWARD_COMPUTE_TASK_NODE_H_
#include "oneflow/core/graph/backward_compute_task_node.h"
namespace oneflow {
class NonRecurrentBackwardCompTaskNode final : public BackwardCompTaskNode {
class NormalBackwardCompTaskNode final : public BackwardCompTaskNode {
public:
OF_DISALLOW_COPY_AND_MOVE(NonRecurrentBackwardCompTaskNode);
NonRecurrentBackwardCompTaskNode() = default;
~NonRecurrentBackwardCompTaskNode() = default;
OF_DISALLOW_COPY_AND_MOVE(NormalBackwardCompTaskNode);
NormalBackwardCompTaskNode() = default;
~NormalBackwardCompTaskNode() = default;
TaskType GetTaskType() const override {
return TaskType::kNonRecurrentBackward;
}
TaskType GetTaskType() const override { return TaskType::kNormalBackward; }
private:
void VirtualBuildExecGphAndBindOutDiffRegst() override;
......@@ -38,4 +36,4 @@ class NonRecurrentBackwardCompTaskNode final : public BackwardCompTaskNode {
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_NONRECURRENT_BACKWARD_COMPUTE_TASK_NODE_H_
#endif // ONEFLOW_CORE_GRAPH_NORMAL_BACKWARD_COMPUTE_TASK_NODE_H_
#include "oneflow/core/graph/forward_compute_task_node.h"
#include "oneflow/core/graph/nonrecurrent_forward_compute_task_node.h"
#include "oneflow/core/graph/normal_forward_compute_task_node.h"
#include "oneflow/core/graph/chain_node.h"
namespace oneflow {
void NonRecurrentForwardCompTaskNode::VirtualConsumeInRegst(TaskEdge* edge) {
void NormalForwardCompTaskNode::VirtualConsumeInRegst(TaskEdge* edge) {
ConsumeRegst("in", edge->GetSoleRegst());
}
void NonRecurrentForwardCompTaskNode::BuildExecGphStructAndBindInRegst() {
void NormalForwardCompTaskNode::BuildExecGphStructAndBindInRegst() {
HashMap<std::string, std::pair<ExecNode*, std::string>> lbn2producer;
for (std::shared_ptr<const Operator> op : chain_node()->op_vec()) {
ExecNode* cur_node = mut_exec_gph().NewNode();
......@@ -36,7 +36,7 @@ void NonRecurrentForwardCompTaskNode::BuildExecGphStructAndBindInRegst() {
});
}
bool NonRecurrentForwardCompTaskNode::IsReadyForBuild() {
bool NormalForwardCompTaskNode::IsReadyForBuild() {
return GetConsumedRegst("in")->IsLocked();
}
......
#ifndef ONEFLOW_CORE_GRAPH_NONRECURRENT_FORWARD_COMPUTE_TASK_NODE_H_
#define ONEFLOW_CORE_GRAPH_NONRECURRENT_FORWARD_COMPUTE_TASK_NODE_H_
#ifndef ONEFLOW_CORE_GRAPH_NORMAL_FORWARD_COMPUTE_TASK_NODE_H_
#define ONEFLOW_CORE_GRAPH_NORMAL_FORWARD_COMPUTE_TASK_NODE_H_
#include "oneflow/core/graph/forward_compute_task_node.h"
namespace oneflow {
class NonRecurrentForwardCompTaskNode final : public ForwardCompTaskNode {
class NormalForwardCompTaskNode final : public ForwardCompTaskNode {
public:
OF_DISALLOW_COPY_AND_MOVE(NonRecurrentForwardCompTaskNode);
NonRecurrentForwardCompTaskNode() = default;
~NonRecurrentForwardCompTaskNode() = default;
OF_DISALLOW_COPY_AND_MOVE(NormalForwardCompTaskNode);
NormalForwardCompTaskNode() = default;
~NormalForwardCompTaskNode() = default;
TaskType GetTaskType() const override {
return TaskType::kNonRecurrentForward;
}
TaskType GetTaskType() const override { return TaskType::kNormalForward; }
bool IsReadyForBuild() override;
private:
......@@ -23,4 +21,4 @@ class NonRecurrentForwardCompTaskNode final : public ForwardCompTaskNode {
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_NONRECURRENT_FORWARD_COMPUTE_TASK_NODE_H_
#endif // ONEFLOW_CORE_GRAPH_NORMAL_FORWARD_COMPUTE_TASK_NODE_H_
......@@ -3,13 +3,11 @@
namespace oneflow {
bool IsForwardTaskType(TaskType tt) {
return tt == TaskType::kNonRecurrentForward
|| tt == TaskType::kRecurrentForward;
return tt == TaskType::kNormalForward || tt == TaskType::kRecurrentForward;
}
bool IsBackwardTaskType(TaskType tt) {
return tt == TaskType::kNonRecurrentBackward
|| tt == TaskType::kRecurrentBackward;
return tt == TaskType::kNormalBackward || tt == TaskType::kRecurrentBackward;
}
TaskNode::TaskNode() : machine_id_(-1), thrd_id_(-1), task_id_(-1) {}
......
......@@ -6,9 +6,9 @@ namespace {
std::map<TaskType, std::string> task_type2color = {
{kInvalid, "0"},
{kNonRecurrentForward, "2"},
{kNormalForward, "2"},
{kRecurrentForward, "2"},
{kNonRecurrentBackward, "3"},
{kNormalBackward, "3"},
{kRecurrentBackward, "3"},
{kSource, "1"},
{kLoss, "4"},
......
......@@ -7,9 +7,9 @@ import "oneflow/core/job/placement.proto";
enum TaskType {
kInvalid = 0;
kNonRecurrentForward = 1;
kNormalForward = 1;
kRecurrentForward = 2;
kNonRecurrentBackward = 3;
kNormalBackward = 3;
kRecurrentBackward = 4;
kSource = 5;
kLoss = 6;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册