未验证 提交 330cf3b8 编写于 作者: C cheng cheng 提交者: GitHub

Remove AreaId (#4283)

* Remove AreaId

* refine check for scope symbol id

* refine logical node macro

* rollback error change in group_boxing_by_dst_parallel
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 83f08932
......@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/graph/logical_node.h"
#include "oneflow/core/graph/accumulate_compute_task_node.h"
namespace oneflow {
......
......@@ -73,7 +73,7 @@ void NcclInitCollectiveNode(CollectiveBoxingGenericTaskNode* node,
CHECK_NOTNULL(stream_index_generator);
auto stream_index = stream_index_generator->GenerateNcclStreamIndex();
const int64_t thrd_id = SerializeStreamIdToInt64(StreamId{device_id, stream_index});
node->Init(machine_id, thrd_id, NewAreaId(), op_conf);
node->Init(machine_id, thrd_id, op_conf);
}
int64_t FindRootParallelId(const ParallelDesc& multi_device, const ParallelDesc& sole_device) {
......@@ -402,8 +402,8 @@ class NcclCollectiveBoxingAll2AllSubTskGphBuilder final : public SubTskGphBuilde
TaskNode* in_node = sorted_in_tasks.at(i);
CollectiveBoxingPackTaskNode* pack_node =
ctx->task_graph()->NewNode<CollectiveBoxingPackTaskNode>();
pack_node->Init(machine_id, thrd_id, NewAreaId(), lbi, logical_blob_desc.shape(),
in_sbp_parallel, out_sbp_parallel, in_parallel_desc.parallel_num());
pack_node->Init(machine_id, thrd_id, lbi, logical_blob_desc.shape(), in_sbp_parallel,
out_sbp_parallel, in_parallel_desc.parallel_num());
Connect<TaskNode>(in_node, ctx->task_graph()->NewEdge(), pack_node);
auto* collective_node = ctx->task_graph()->NewNode<CollectiveBoxingGenericTaskNode>();
......@@ -413,8 +413,8 @@ class NcclCollectiveBoxingAll2AllSubTskGphBuilder final : public SubTskGphBuilde
CollectiveBoxingUnpackTaskNode* unpack_node =
ctx->task_graph()->NewNode<CollectiveBoxingUnpackTaskNode>();
unpack_node->Init(machine_id, thrd_id, NewAreaId(), lbi, logical_blob_desc.shape(),
in_sbp_parallel, out_sbp_parallel, in_parallel_desc.parallel_num());
unpack_node->Init(machine_id, thrd_id, lbi, logical_blob_desc.shape(), in_sbp_parallel,
out_sbp_parallel, in_parallel_desc.parallel_num());
Connect<TaskNode>(collective_node, ctx->task_graph()->NewEdge(), unpack_node);
sorted_out_tasks->push_back(unpack_node);
}
......
......@@ -75,7 +75,7 @@ Maybe<SubTskGphBuilderStatus> NaiveB2PSubTskGphBuilder::Build(
UNIMPLEMENTED();
}
auto* zeros_node = ctx->task_graph()->NewNode<BoxingZerosTaskNode>();
zeros_node->Init(out_machine_id, thrd_id, NewAreaId(), lbi, logical_blob_desc.shape(),
zeros_node->Init(out_machine_id, thrd_id, lbi, logical_blob_desc.shape(),
logical_blob_desc.data_type(), time_shape);
nearest_in_node->BuildCtrlRegstDesc(zeros_node);
Connect<TaskNode>(nearest_in_node, ctx->task_graph()->NewEdge(), zeros_node);
......
......@@ -18,12 +18,10 @@ limitations under the License.
namespace oneflow {
void BoxingIdentityTaskNode::Init(int64_t machine_id, int64_t thrd_id, int64_t area_id,
const LogicalBlobId& lbi) {
void BoxingIdentityTaskNode::Init(int64_t machine_id, int64_t thrd_id, const LogicalBlobId& lbi) {
lbi_ = lbi;
set_machine_id(machine_id);
set_thrd_id(thrd_id);
set_area_id(area_id);
}
void BoxingIdentityTaskNode::ProduceAllRegstsAndBindEdges() {
......
......@@ -25,7 +25,7 @@ class BoxingIdentityTaskNode : public TaskNode {
BoxingIdentityTaskNode() = default;
~BoxingIdentityTaskNode() override = default;
void Init(int64_t machine_id, int64_t thrd_id, int64_t area_id, const LogicalBlobId& lbi);
void Init(int64_t machine_id, int64_t thrd_id, const LogicalBlobId& lbi);
TaskType GetTaskType() const override { return TaskType::kBoxingIdentity; }
private:
......
......@@ -18,13 +18,11 @@ limitations under the License.
namespace oneflow {
void BoxingZerosTaskNode::Init(int64_t machine_id, int64_t thrd_id, int64_t area_id,
const LogicalBlobId& lbi, const Shape& shape, DataType data_type,
const Shape& time_shape) {
void BoxingZerosTaskNode::Init(int64_t machine_id, int64_t thrd_id, const LogicalBlobId& lbi,
const Shape& shape, DataType data_type, const Shape& time_shape) {
lbi_ = lbi;
set_machine_id(machine_id);
set_thrd_id(thrd_id);
set_area_id(area_id);
shape_ = shape;
data_type_ = data_type;
time_shape_ = time_shape;
......
......@@ -26,8 +26,8 @@ class BoxingZerosTaskNode : public TaskNode {
BoxingZerosTaskNode() = default;
~BoxingZerosTaskNode() override = default;
void Init(int64_t machine_id, int64_t thrd_id, int64_t area_id, const LogicalBlobId& lbi,
const Shape& shape, DataType data_type, const Shape& time_shape);
void Init(int64_t machine_id, int64_t thrd_id, const LogicalBlobId& lbi, const Shape& shape,
DataType data_type, const Shape& time_shape);
TaskType GetTaskType() const override { return TaskType::kBoxingZeros; }
private:
......
......@@ -134,8 +134,6 @@ void ChainActSubGraph::InitNodes(
const TaskProto& task_proto = GetTaskProto(act_event->actor_id());
int64_t chain_id = task_proto.task_set_info().chain_id();
std::pair<int64_t, int64_t> chain_act_id_pair(chain_id, act_id);
// kMdUpdtArea regst num will always be 1
if (task_proto.task_set_info().area_id() == kMdUpdtArea) { continue; }
chain_id_with_act_id2act_events[chain_act_id_pair].push_back(std::move(act_event));
}
for (auto& pair : chain_id_with_act_id2act_events) {
......
......@@ -18,7 +18,7 @@ limitations under the License.
namespace oneflow {
void CollectiveBoxingPackTaskNode::Init(int64_t machine_id, int64_t thrd_id, int64_t area_id,
void CollectiveBoxingPackTaskNode::Init(int64_t machine_id, int64_t thrd_id,
const LogicalBlobId& lbi, const Shape& logical_shape,
const SbpParallel& src_sbp_parallel,
const SbpParallel& dst_sbp_parallel,
......@@ -26,7 +26,6 @@ void CollectiveBoxingPackTaskNode::Init(int64_t machine_id, int64_t thrd_id, int
lbi_ = lbi;
set_machine_id(machine_id);
set_thrd_id(thrd_id);
set_area_id(area_id);
logical_shape_ = logical_shape;
parallel_num_ = parallel_num;
src_sbp_parallel_ = src_sbp_parallel;
......
......@@ -25,7 +25,7 @@ class CollectiveBoxingPackTaskNode : public TaskNode {
CollectiveBoxingPackTaskNode() = default;
~CollectiveBoxingPackTaskNode() override = default;
void Init(int64_t machine_id, int64_t thrd_id, int64_t area_id, const LogicalBlobId& lbi,
void Init(int64_t machine_id, int64_t thrd_id, const LogicalBlobId& lbi,
const Shape& logical_shape, const SbpParallel& src_sbp_parallel,
const SbpParallel& dst_sbp_parallel, const int64_t parallel_num);
TaskType GetTaskType() const override { return TaskType::kCollectiveBoxingPack; }
......
......@@ -18,9 +18,8 @@ limitations under the License.
namespace oneflow {
void CollectiveBoxingGenericTaskNode::Init(int64_t machine_id, int64_t thrd_id, int64_t area_id,
void CollectiveBoxingGenericTaskNode::Init(int64_t machine_id, int64_t thrd_id,
const OperatorConf& op_conf) {
set_area_id(area_id);
set_machine_id(machine_id);
set_thrd_id(thrd_id);
op_conf_ = op_conf;
......
......@@ -26,7 +26,7 @@ class CollectiveBoxingGenericTaskNode : public TaskNode {
CollectiveBoxingGenericTaskNode() = default;
~CollectiveBoxingGenericTaskNode() override = default;
void Init(int64_t machine_id, int64_t thrd_id, int64_t area_id, const OperatorConf& op_conf);
void Init(int64_t machine_id, int64_t thrd_id, const OperatorConf& op_conf);
private:
void BuildExecGphAndRegst() override;
......
......@@ -18,7 +18,7 @@ limitations under the License.
namespace oneflow {
void CollectiveBoxingUnpackTaskNode::Init(int64_t machine_id, int64_t thrd_id, int64_t area_id,
void CollectiveBoxingUnpackTaskNode::Init(int64_t machine_id, int64_t thrd_id,
const LogicalBlobId& lbi, const Shape& logical_shape,
const SbpParallel& src_sbp_parallel,
const SbpParallel& dst_sbp_parallel,
......@@ -26,7 +26,6 @@ void CollectiveBoxingUnpackTaskNode::Init(int64_t machine_id, int64_t thrd_id, i
lbi_ = lbi;
set_machine_id(machine_id);
set_thrd_id(thrd_id);
set_area_id(area_id);
logical_shape_ = logical_shape;
parallel_num_ = parallel_num;
src_sbp_parallel_ = src_sbp_parallel;
......
......@@ -26,7 +26,7 @@ class CollectiveBoxingUnpackTaskNode : public TaskNode {
CollectiveBoxingUnpackTaskNode() = default;
~CollectiveBoxingUnpackTaskNode() override = default;
void Init(int64_t machine_id, int64_t thrd_id, int64_t area_id, const LogicalBlobId& lbi,
void Init(int64_t machine_id, int64_t thrd_id, const LogicalBlobId& lbi,
const Shape& logical_shape, const SbpParallel& src_sbp_parallel,
const SbpParallel& dst_sbp_parallel, const int64_t parallel_num);
......
......@@ -18,20 +18,6 @@ limitations under the License.
#include "oneflow/core/graph/compute_task_node.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/graph/wait_and_send_ids_compute_task_node.h"
#include "oneflow/core/graph/foreign_input_compute_task_node.h"
#include "oneflow/core/graph/foreign_output_compute_task_node.h"
#include "oneflow/core/graph/callback_notify_compute_task_node.h"
#include "oneflow/core/graph/reentrant_lock_compute_task_node.h"
#include "oneflow/core/graph/src_subset_tick_compute_task_node.h"
#include "oneflow/core/graph/dst_subset_tick_compute_task_node.h"
#include "oneflow/core/graph/source_tick_compute_task_node.h"
#include "oneflow/core/graph/tick_compute_task_node.h"
#include "oneflow/core/graph/device_tick_compute_task_node.h"
#include "oneflow/core/graph/acc_tick_compute_task_node.h"
#include "oneflow/core/graph/case_compute_task_node.h"
#include "oneflow/core/graph/esac_compute_task_node.h"
#include "oneflow/core/graph/decode_h2d_compute_task_node.h"
namespace oneflow {
......@@ -71,9 +57,6 @@ class LogicalNode : public Node<LogicalNode, LogicalEdge> {
std::string VisualStr() const;
void GenSortedCompTaskNodes(std::function<void(CompTaskNode*)>) const;
// other
virtual int64_t GetAreaId() const = 0;
protected:
LogicalNode() {}
virtual CompTaskNode* NewCompTaskNode() const = 0;
......@@ -119,111 +102,39 @@ class LogicalEdge final : public Edge<LogicalNode, LogicalEdge> {
BldSubTskGphMthd GetMthdForBldSubTskGph(const LogicalNode* src, const LogicalNode* dst);
#define OVERRIDE_PURE_VIRTUAL_METHOD() \
std::string TypeName() const override; \
CompTaskNode* NewCompTaskNode() const override; \
int64_t GetAreaId() const override;
#define LOGICAL_NODE_BOILERPLATE(class_name) \
OF_DISALLOW_COPY_AND_MOVE(class_name); \
class_name() = default; \
~class_name() = default; \
OVERRIDE_PURE_VIRTUAL_METHOD();
class ForwardLogicalNode : public LogicalNode {
public:
OF_DISALLOW_COPY_AND_MOVE(ForwardLogicalNode);
ForwardLogicalNode() = default;
virtual ~ForwardLogicalNode() = default;
};
class NormalForwardLogicalNode final : public ForwardLogicalNode {
public:
LOGICAL_NODE_BOILERPLATE(NormalForwardLogicalNode);
private:
};
int64_t NewAreaId();
#define LOGICAL_NODE_WITH_NEW_AREA_ID_BOILERPLATE(name) \
public: \
OF_DISALLOW_COPY_AND_MOVE(name##LogicalNode); \
name##LogicalNode() { area_id_ = NewAreaId(); } \
~name##LogicalNode() = default; \
\
std::string TypeName() const override { return #name; } \
CompTaskNode* NewCompTaskNode() const override { return new name##CompTaskNode; } \
int64_t GetAreaId() const override { return area_id_; } \
\
private: \
int64_t area_id_;
#define DECLARE_DERIVED_FORWARD_LOGICAL_NODE_WITH_NEW_AREA_ID(name) \
class name##LogicalNode final : public ForwardLogicalNode { \
LOGICAL_NODE_WITH_NEW_AREA_ID_BOILERPLATE(name) \
\
private: \
}
#define DECLARE_NAIVE_LOGICAL_NODE(name) \
class name final : public LogicalNode { \
public: \
LOGICAL_NODE_BOILERPLATE(name); \
}
DECLARE_NAIVE_LOGICAL_NODE(DecodeRandomLogicalNode);
DECLARE_NAIVE_LOGICAL_NODE(DistributeConcatLogicalNode);
DECLARE_NAIVE_LOGICAL_NODE(DistributeSplitLogicalNode);
DECLARE_NAIVE_LOGICAL_NODE(PrintLogicalNode);
DECLARE_DERIVED_FORWARD_LOGICAL_NODE_WITH_NEW_AREA_ID(WaitAndSendIds);
DECLARE_DERIVED_FORWARD_LOGICAL_NODE_WITH_NEW_AREA_ID(ForeignInput);
DECLARE_DERIVED_FORWARD_LOGICAL_NODE_WITH_NEW_AREA_ID(ForeignOutput);
DECLARE_DERIVED_FORWARD_LOGICAL_NODE_WITH_NEW_AREA_ID(CallbackNotify);
DECLARE_DERIVED_FORWARD_LOGICAL_NODE_WITH_NEW_AREA_ID(ReentrantLock);
DECLARE_DERIVED_FORWARD_LOGICAL_NODE_WITH_NEW_AREA_ID(SrcSubsetTick);
DECLARE_DERIVED_FORWARD_LOGICAL_NODE_WITH_NEW_AREA_ID(DstSubsetTick);
DECLARE_DERIVED_FORWARD_LOGICAL_NODE_WITH_NEW_AREA_ID(SourceTick);
DECLARE_DERIVED_FORWARD_LOGICAL_NODE_WITH_NEW_AREA_ID(AccTick);
DECLARE_DERIVED_FORWARD_LOGICAL_NODE_WITH_NEW_AREA_ID(Tick);
DECLARE_DERIVED_FORWARD_LOGICAL_NODE_WITH_NEW_AREA_ID(DeviceTick);
DECLARE_DERIVED_FORWARD_LOGICAL_NODE_WITH_NEW_AREA_ID(Case);
DECLARE_DERIVED_FORWARD_LOGICAL_NODE_WITH_NEW_AREA_ID(Esac);
DECLARE_DERIVED_FORWARD_LOGICAL_NODE_WITH_NEW_AREA_ID(DecodeH2D);
class UserOpAreaIdCreator {
public:
virtual ~UserOpAreaIdCreator() = default;
virtual int64_t GetAreaId() = 0;
};
class FixedUserOpAreaIdCreator : public UserOpAreaIdCreator {
public:
explicit FixedUserOpAreaIdCreator(int64_t area_id) : area_id_(area_id) {}
~FixedUserOpAreaIdCreator() override = default;
int64_t GetAreaId() override { return area_id_; }
private:
int64_t area_id_;
};
class IndependentUserOpAreaIdCreator : public UserOpAreaIdCreator {
public:
IndependentUserOpAreaIdCreator() = default;
~IndependentUserOpAreaIdCreator() override = default;
int64_t GetAreaId() override { return NewAreaId(); }
};
#define REGISTER_USER_OP_AREA_ID(op_type_name, area_id) \
REGISTER_CLASS_CREATOR(std::string, op_type_name, UserOpAreaIdCreator, \
([] { return new FixedUserOpAreaIdCreator(area_id); }));
#define REGISTER_USER_OP_INDEPENDENT_AREA_ID(op_type_name) \
REGISTER_CLASS_CREATOR(std::string, op_type_name, UserOpAreaIdCreator, \
([] { return new IndependentUserOpAreaIdCreator(); }));
#define DECLARE_LOGICAL_NODE(name) \
class name##LogicalNode final : public LogicalNode { \
public: \
OF_DISALLOW_COPY_AND_MOVE(name##LogicalNode); \
name##LogicalNode() = default; \
~name##LogicalNode() = default; \
std::string TypeName() const override; \
CompTaskNode* NewCompTaskNode() const override; \
};
DECLARE_LOGICAL_NODE(NormalForward);
#define LOGICAL_TYPE_SEQ \
OF_PP_MAKE_TUPLE_SEQ(WaitAndSendIds) \
OF_PP_MAKE_TUPLE_SEQ(ForeignInput) \
OF_PP_MAKE_TUPLE_SEQ(ForeignOutput) \
OF_PP_MAKE_TUPLE_SEQ(CallbackNotify) \
OF_PP_MAKE_TUPLE_SEQ(ReentrantLock) \
OF_PP_MAKE_TUPLE_SEQ(SrcSubsetTick) \
OF_PP_MAKE_TUPLE_SEQ(DstSubsetTick) \
OF_PP_MAKE_TUPLE_SEQ(SourceTick) \
OF_PP_MAKE_TUPLE_SEQ(AccTick) \
OF_PP_MAKE_TUPLE_SEQ(Tick) \
OF_PP_MAKE_TUPLE_SEQ(DeviceTick) \
OF_PP_MAKE_TUPLE_SEQ(Case) \
OF_PP_MAKE_TUPLE_SEQ(Esac) \
OF_PP_MAKE_TUPLE_SEQ(DecodeH2D) \
OF_PP_MAKE_TUPLE_SEQ(DistributeConcat) \
OF_PP_MAKE_TUPLE_SEQ(DistributeSplit) \
OF_PP_MAKE_TUPLE_SEQ(DecodeRandom) \
OF_PP_MAKE_TUPLE_SEQ(Print)
OF_PP_FOR_EACH_TUPLE(DECLARE_LOGICAL_NODE, LOGICAL_TYPE_SEQ);
class UserOpCompTaskNodeCreator {
public:
......
......@@ -38,7 +38,6 @@ class PlanTaskNode final : public Node<PlanTaskNode, PlanTaskEdge> {
const TaskProto* task_proto() const { return task_proto_; }
int64_t task_id() const { return task_proto_->task_id(); }
int64_t area_id() const { return task_proto_->task_set_info().area_id(); }
int64_t chain_id() const;
int64_t order_in_graph() const { return task_proto_->task_set_info().order_in_graph(); }
......
......@@ -28,7 +28,6 @@ void SliceBoxingTaskNode::Init(const LogicalBlobId& lbi, const TensorSliceView&
mem_zone_id_ = mem_zone_id;
set_machine_id(machine_id);
set_thrd_id(thrd_id);
set_area_id(kMdUpdtArea);
}
void SliceBoxingTaskNode::Init(const LogicalBlobId& lbi, const TensorSliceView& out_slice,
......
......@@ -32,6 +32,9 @@ limitations under the License.
#include "oneflow/core/graph/boxing/one_to_one_sub_task_graph_builder.h"
#include "oneflow/core/graph/boxing/sub_task_graph_builder_util.h"
#include "oneflow/core/graph/boxing_identity_task_node.h"
#include "oneflow/core/job/scope.h"
#include "oneflow/core/vm/symbol_storage.h"
#include "oneflow/core/job_rewriter/calculation_pass.h"
namespace oneflow {
......@@ -54,9 +57,34 @@ bool IsConnectToTickOp(const TaskNode* node) {
return false;
}
bool IsOptimizerPassOp(const Operator* op) {
// NOTE(chengcheng): use scope::calculation_pass_name instead of area_id to not merge optimizer
// ops with fw/bw ops
if (!op->op_conf().has_scope_symbol_id()) {
// NOTE(chengcheng): Some system op insert to OpGraph may not set scope_symbol_id, it MUST NOT
// optimizer subgraph ops.
return false;
}
int64_t scope_symbol_id = op->op_conf().scope_symbol_id();
CHECK(Global<symbol::Storage<Scope>>::Get()->Has(scope_symbol_id))
<< " Error! op : \n " << op->op_conf().DebugString()
<< " has error scope_symbol_id = " << scope_symbol_id
<< " which cannot find in Global<symbol::Storage<Scope>>::Get()\n";
const Scope& scope = Global<symbol::Storage<Scope>>::Get()->Get(scope_symbol_id);
return scope.scope_proto().calculation_pass_name() == kOptimizerPass;
}
bool IsSpecialOpNotConsiderMergeInChain(const Operator* op) {
const OperatorConf& op_conf = op->op_conf();
if (op_conf.has_variable_conf() || op_conf.has_tick_conf() || op_conf.has_device_tick_conf()) {
if (op_conf.has_variable_conf() || op_conf.has_tick_conf() || op_conf.has_device_tick_conf()
|| op_conf.has_src_subset_tick_conf() || op_conf.has_dst_subset_tick_conf()
|| op_conf.has_source_tick_conf() || op_conf.has_sink_tick_conf()
|| op_conf.has_acc_tick_conf()) {
return true;
}
// NOTE(chengcheng): ONLY nccl_use_compute_stream = false will exclude optimizer pass ops
if (!Global<ResourceDesc, ForSession>::Get()->nccl_use_compute_stream()
&& IsOptimizerPassOp(op)) {
return true;
}
return false;
......@@ -97,10 +125,8 @@ void TraverseConnectedSubGraphMergeInThisChain(TaskNode* this_node, const int64_
cur_node->set_chain_id(this_chain_id);
cur_node->ForEachNodeOnInOutEdge([&](TaskNode* next_node) {
// NOTE(chengcheng): use area_id to not merge optimizer ops with fw/bw ops
if (visited_nodes.find(next_node) == visited_nodes.end() && CanBeMergedInChain(next_node)
&& this_node->GlobalWorkStreamId() == next_node->GlobalWorkStreamId()
&& this_node->area_id() == next_node->area_id()) {
&& this_node->GlobalWorkStreamId() == next_node->GlobalWorkStreamId()) {
if (next_node->chain_id() == -1) {
queued_nodes.push(next_node);
visited_nodes.insert(next_node);
......@@ -253,7 +279,6 @@ TaskGraph::TaskGraph(std::unique_ptr<const LogicalGraph>&& logical_gph) {
logical_node->GenSortedCompTaskNodes([&](CompTaskNode* comp_task_node) {
AddAllocatedNode(comp_task_node);
logical2sorted_comp_tasks[logical_node].push_back(comp_task_node);
comp_task_node->set_area_id(logical_node->GetAreaId());
});
});
......@@ -263,7 +288,6 @@ TaskGraph::TaskGraph(std::unique_ptr<const LogicalGraph>&& logical_gph) {
(this->*method)(logical_edge->src_node(), logical_edge->dst_node(),
logical2sorted_comp_tasks.at(logical_edge->src_node()),
logical2sorted_comp_tasks.at(logical_edge->dst_node()), MutBufTask);
SetAreaIdForNewNodes(logical_edge->src_node(), logical_edge->dst_node());
});
logical_gph_->ForEachNecessaryCtrlEdge(
[&](const LogicalNode* src, const LogicalNode* dst, int64_t ctrl_regst_num) {
......@@ -501,19 +525,6 @@ void TaskGraph::EnableInplaceMemSharing(
});
}
void TaskGraph::SetAreaIdForNewNodes(const LogicalNode* src_logical,
const LogicalNode* dst_logical) {
CHECK(src_logical != nullptr && dst_logical != nullptr);
ForEachNode([&](TaskNode* node) {
if (node->area_id() != static_cast<int64_t>(kInvalidArea)) return;
if (src_logical->GetAreaId() == dst_logical->GetAreaId()) {
node->set_area_id(src_logical->GetAreaId());
} else {
node->set_area_id(static_cast<int64_t>(kBoundaryArea));
}
});
}
#define DEFINE_BLD_SUB_TASK_GRAPH_METHOD(method_name) \
void TaskGraph::method_name BLD_SUB_TSK_GPH_MTHD_ARGS()
......@@ -526,7 +537,7 @@ DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBoxing) {
} else {
for (CompTaskNode* src_node : sorted_src_comp_tasks) {
auto* identity_node = NewNode<BoxingIdentityTaskNode>();
identity_node->Init(src_node->machine_id(), src_node->thrd_id(), src_node->area_id(), lbi);
identity_node->Init(src_node->machine_id(), src_node->thrd_id(), lbi);
Connect<TaskNode>(src_node, NewEdge(), identity_node);
in_nodes.push_back(identity_node);
}
......
......@@ -83,7 +83,6 @@ class TaskGraph final : public Graph<TaskNode, TaskEdge> {
void ConnectCtrlEdges(const std::vector<CompTaskNode*>& src_task_nodes,
const std::vector<CompTaskNode*>& dst_task_nodes, int64_t ctrl_regst_num);
void SetAreaIdForNewNodes(const LogicalNode* src_logical, const LogicalNode* dst_logical);
void SetOrderInGraphForEachNode();
void MergeChain();
void BuildCtrlRegstDescInSameChain();
......
......@@ -41,12 +41,7 @@ void ForEachDataEdge(const std::unordered_set<TaskEdge*>& edges,
} // namespace
TaskNode::TaskNode()
: machine_id_(-1),
thrd_id_(-1),
task_id_(-1),
area_id_(0),
chain_id_(-1),
order_in_graph_(-1) {}
: machine_id_(-1), thrd_id_(-1), task_id_(-1), chain_id_(-1), order_in_graph_(-1) {}
std::shared_ptr<RegstDesc> TaskNode::GetProducedRegst(const std::string& name) {
auto produced_regsts_it = produced_regsts_.find(name);
......@@ -86,11 +81,6 @@ void TaskNode::set_thrd_id(int64_t val) {
if (machine_id_ != -1) { UpdateTaskId(); }
}
void TaskNode::set_area_id(int64_t val) {
CHECK_EQ(area_id_, 0);
area_id_ = val;
}
void TaskNode::set_chain_id(int64_t val) {
CHECK_EQ(chain_id_, -1);
chain_id_ = val;
......@@ -168,7 +158,6 @@ void TaskNode::Build() {
if (consumed_regsts_.size()) { CHECK(IsReadyForBuild()); }
BuildExecGphAndRegst();
LockRegsts();
FixRegisterNumRange();
}
void TaskNode::EraseZeroSizeProducedBlob() {
......@@ -222,7 +211,6 @@ void TaskNode::ToProto(TaskProto* task_proto) {
task_proto->set_thrd_id(thrd_id_);
task_proto->set_task_id(task_id_);
task_proto->set_job_id(GlobalJobDesc().job_id());
task_proto->mutable_task_set_info()->set_area_id(area_id_);
task_proto->mutable_task_set_info()->set_chain_id(chain_id_);
task_proto->mutable_task_set_info()->set_order_in_graph(order_in_graph_);
exec_gph_.ToExecSequence(parallel_ctx(), task_proto->mutable_exec_sequence());
......@@ -368,23 +356,6 @@ void TaskNode::LockRegsts() {
for (auto& pair : produced_regsts_) { pair.second->Lock(); }
}
void TaskNode::FixRegisterNumRange() {
for (auto& pair : produced_regsts_) {
RegstDesc* produced_regst = pair.second.get();
bool in_same_stream = true;
for (const TaskNode* consumer : produced_regst->consumers()) {
if (consumer->GlobalWorkStreamId() != GlobalWorkStreamId()) {
in_same_stream = false;
break;
}
}
if (in_same_stream == false && area_id_ != static_cast<int64_t>(kMdUpdtArea)
&& GetTaskType() == TaskType::kCopyHd) { // TODO: delete this hack
if (produced_regst->max_register_num() >= 2) { produced_regst->UpdtMinRegstNumIfNeed(2); }
}
}
}
void TaskNode::UpdateTaskId() {
CHECK_NE(machine_id_, -1);
CHECK_NE(thrd_id_, -1);
......
......@@ -52,7 +52,6 @@ class TaskNode : public Node<TaskNode, TaskEdge> {
int64_t machine_id() const { return machine_id_; }
int64_t thrd_id() const { return thrd_id_; }
int64_t task_id() const { return task_id_; }
int64_t area_id() const { return area_id_; }
int64_t chain_id() const { return chain_id_; }
int64_t order_in_graph() const { return order_in_graph_; }
const ExecGraph& exec_gph() const { return exec_gph_; }
......@@ -69,12 +68,10 @@ class TaskNode : public Node<TaskNode, TaskEdge> {
virtual const ParallelContext* parallel_ctx() const { return nullptr; }
int64_t GlobalWorkStreamId() const;
int64_t GpuPhyId() const { return Global<IDMgr>::Get()->GetGpuPhyIdFromThrdId(thrd_id_); }
virtual int64_t AreaId4ChainMerge() const { return area_id(); }
// Setters
void set_machine_id(int64_t val);
void set_thrd_id(int64_t val);
void set_area_id(int64_t val);
void set_chain_id(int64_t val);
void set_order_in_graph(int64_t val);
......@@ -140,7 +137,6 @@ class TaskNode : public Node<TaskNode, TaskEdge> {
virtual void BuildExecGphAndRegst() = 0;
virtual void LockRegsts();
void FixRegisterNumRange();
virtual void InferProducedDataRegstTimeShape() = 0;
void NaiveInferProducedDataRegstTimeShape();
......@@ -156,7 +152,6 @@ class TaskNode : public Node<TaskNode, TaskEdge> {
int64_t machine_id_;
int64_t thrd_id_;
int64_t task_id_;
int64_t area_id_;
int64_t chain_id_;
int64_t order_in_graph_;
......
......@@ -61,6 +61,5 @@ void AccCompTaskNode::BuildExecGphAndRegst() {
}
REGISTER_USER_OP_COMP_TASK_NODE_TYPE("acc", AccCompTaskNode);
REGISTER_USER_OP_INDEPENDENT_AREA_ID("acc")
} // namespace oneflow
......@@ -19,6 +19,20 @@ limitations under the License.
#include "oneflow/core/graph/decode_random_compute_task_node.h"
#include "oneflow/core/graph/distribute_concat_compute_task_node.h"
#include "oneflow/core/graph/distribute_split_compute_task_node.h"
#include "oneflow/core/graph/wait_and_send_ids_compute_task_node.h"
#include "oneflow/core/graph/foreign_input_compute_task_node.h"
#include "oneflow/core/graph/foreign_output_compute_task_node.h"
#include "oneflow/core/graph/callback_notify_compute_task_node.h"
#include "oneflow/core/graph/reentrant_lock_compute_task_node.h"
#include "oneflow/core/graph/src_subset_tick_compute_task_node.h"
#include "oneflow/core/graph/dst_subset_tick_compute_task_node.h"
#include "oneflow/core/graph/source_tick_compute_task_node.h"
#include "oneflow/core/graph/tick_compute_task_node.h"
#include "oneflow/core/graph/device_tick_compute_task_node.h"
#include "oneflow/core/graph/acc_tick_compute_task_node.h"
#include "oneflow/core/graph/case_compute_task_node.h"
#include "oneflow/core/graph/esac_compute_task_node.h"
#include "oneflow/core/graph/decode_h2d_compute_task_node.h"
#include "oneflow/core/graph/task_graph.h"
#include "oneflow/core/graph/op_graph.h"
#include "oneflow/core/framework/framework.h"
......@@ -279,16 +293,10 @@ REGISTER_BLD_SUB_TSK_GPH_MTHD("NormalForward"
"DecodeH2D",
&TaskGraph::BldSubTskGphNormalForwardToDecodeH2D);
#define LOGICAL_TYPE_SEQ \
OF_PP_MAKE_TUPLE_SEQ(DistributeConcat, kDataForwardArea) \
OF_PP_MAKE_TUPLE_SEQ(DistributeSplit, kDataForwardArea) \
OF_PP_MAKE_TUPLE_SEQ(DecodeRandom, kDataPreprocessArea) \
OF_PP_MAKE_TUPLE_SEQ(Print, kPrintArea)
#define DEFINE_VIRTUAL_METHOD(x) \
std::string x##LogicalNode::TypeName() const { return #x; } \
CompTaskNode* x##LogicalNode::NewCompTaskNode() const { return new x##CompTaskNode; }
#define DEFINE_VIRTUAL_METHOD(x, area_type) \
std::string x##LogicalNode::TypeName() const { return #x; } \
CompTaskNode* x##LogicalNode::NewCompTaskNode() const { return new x##CompTaskNode; } \
int64_t x##LogicalNode::GetAreaId() const { return area_type; }
OF_PP_FOR_EACH_TUPLE(DEFINE_VIRTUAL_METHOD, LOGICAL_TYPE_SEQ);
std::string NormalForwardLogicalNode::TypeName() const { return "NormalForward"; }
......@@ -309,32 +317,4 @@ CompTaskNode* NormalForwardLogicalNode::NewCompTaskNode() const {
}
}
int64_t NormalForwardLogicalNode::GetAreaId() const {
if (this->SoleOp()->op_conf().has_user_conf()) {
const std::string& op_type_name = this->SoleOp()->op_conf().user_conf().op_type_name();
if (IsClassRegistered<std::string, UserOpAreaIdCreator>(op_type_name)) {
return std::unique_ptr<UserOpAreaIdCreator>(
NewObj<std::string, UserOpAreaIdCreator>(op_type_name))
->GetAreaId();
} else {
return AreaType::kDataForwardArea;
}
} else {
return AreaType::kDataForwardArea;
}
}
int64_t NewAreaId() {
static int64_t next_area_id = AreaType_ARRAYSIZE;
return ++next_area_id;
}
REGISTER_USER_OP_AREA_ID("sgd_update", AreaType::kMdUpdtArea)
REGISTER_USER_OP_AREA_ID("indexed_slices_sgd_update", AreaType::kMdUpdtArea)
REGISTER_USER_OP_AREA_ID("momentum_update", AreaType::kMdUpdtArea)
REGISTER_USER_OP_AREA_ID("indexed_slices_momentum_update", AreaType::kMdUpdtArea)
REGISTER_USER_OP_AREA_ID("adam_update", AreaType::kMdUpdtArea)
REGISTER_USER_OP_AREA_ID("indexed_slices_adam_update", AreaType::kMdUpdtArea)
REGISTER_USER_OP_AREA_ID("lamb_update", AreaType::kMdUpdtArea)
} // namespace oneflow
......@@ -68,6 +68,5 @@ void PackCompTaskNode::InferProducedDataRegstTimeShape() {
}
REGISTER_USER_OP_COMP_TASK_NODE_TYPE("pack", PackCompTaskNode);
REGISTER_USER_OP_INDEPENDENT_AREA_ID("pack")
} // namespace oneflow
......@@ -75,6 +75,5 @@ void RepeatCompTaskNode::InferProducedDataRegstTimeShape() {
}
REGISTER_USER_OP_COMP_TASK_NODE_TYPE("repeat", RepeatCompTaskNode);
REGISTER_USER_OP_INDEPENDENT_AREA_ID("repeat");
} // namespace oneflow
......@@ -115,6 +115,5 @@ class SspVariableProxyCompTaskNode final : public CompTaskNode {
};
REGISTER_USER_OP_COMP_TASK_NODE_TYPE("ssp_variable_proxy", SspVariableProxyCompTaskNode);
REGISTER_USER_OP_INDEPENDENT_AREA_ID("ssp_variable_proxy");
} // namespace oneflow
......@@ -68,6 +68,5 @@ void UnpackCompTaskNode::InferProducedDataRegstTimeShape() {
}
REGISTER_USER_OP_COMP_TASK_NODE_TYPE("unpack", UnpackCompTaskNode);
REGISTER_USER_OP_INDEPENDENT_AREA_ID("unpack")
} // namespace oneflow
......@@ -41,22 +41,11 @@ enum TaskType {
kBoxingZeros = 64;
};
enum AreaType {
kInvalidArea = 0;
kDataPreprocessArea = 1;
kDataForwardArea = 2;
// TODO: rename to OptimizerArea
kMdUpdtArea = 4;
kPrintArea = 6;
kBoundaryArea = 7;
}
message RegstDescIdSet {
repeated int64 regst_desc_id = 1;
}
message TaskSetInfo {
required int64 area_id = 1;
required int64 chain_id = 4;
required int64 order_in_graph = 5;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册