未验证 提交 14062a88 编写于 作者: C cheng cheng 提交者: GitHub

Complete support 1 regst 1 blob (#4474)

* half implement of build task graph by 1regst1blob

* Complete support 1 regst 1 blob

* fix check

* Add Lbis in TaskEdge and check valid

* reduce NormalForward out regst name prefix

* refine hasher and proxy key

* fix bug of ProxyKey ==

* fix bug of collective boxing broadcast task edge lbi
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 b53270fc
......@@ -31,8 +31,8 @@ Maybe<SubTskGphBuilderStatus> B21SubTskGphBuilder::Build(
const int64_t nearest_in_parallel_id = SubTskGphBuilderUtil::FindNearestSrcParallelId(
in_parallel_desc, out_parallel_desc, out_parallel_id);
TaskNode* nearest_in_node = sorted_in_tasks.at(nearest_in_parallel_id);
TaskNode* proxy = ctx->GetProxyNode(nearest_in_node, nearest_in_node->MemZoneId121(),
out_parallel_desc, out_parallel_id);
TaskNode* proxy =
ctx->task_graph()->GetProxyNode(nearest_in_node, lbi, out_parallel_desc, out_parallel_id);
sorted_out_tasks->push_back(proxy);
return TRY(BuildSubTskGphBuilderStatus("B21SubTskGphBuilder", ""));
} else {
......
......@@ -73,7 +73,7 @@ void NcclInitCollectiveNode(CollectiveBoxingGenericTaskNode* node,
CHECK_NOTNULL(stream_index_generator);
auto stream_index = stream_index_generator->GenerateNcclStreamIndex();
const int64_t thrd_id = SerializeStreamIdToInt64(StreamId{device_id, stream_index});
node->Init(machine_id, thrd_id, op_conf);
node->Init(machine_id, thrd_id, lbi, op_conf);
}
int64_t FindRootParallelId(const ParallelDesc& multi_device, const ParallelDesc& sole_device) {
......@@ -117,7 +117,7 @@ class NcclCollectiveBoxingAllReduceSubTskGphBuilder final : public SubTskGphBuil
auto* collective_node = ctx->task_graph()->NewNode<CollectiveBoxingGenericTaskNode>();
NcclInitCollectiveNode(collective_node, in_parallel_desc, i, op_name, lbi,
logical_blob_desc, OpType::kOpTypeAllReduce, -1);
Connect<TaskNode>(in_node, ctx->task_graph()->NewEdge(), collective_node);
ctx->task_graph()->ConnectWithLbi(in_node, collective_node, lbi);
sorted_out_tasks->push_back(collective_node);
}
return TRY(BuildSubTskGphBuilderStatus("NcclCollectiveBoxingAllReduceSubTskGphBuilder", ""));
......@@ -154,7 +154,7 @@ class NcclCollectiveBoxingReduceScatterSubTskGphBuilder final : public SubTskGph
auto* collective_node = ctx->task_graph()->NewNode<CollectiveBoxingGenericTaskNode>();
NcclInitCollectiveNode(collective_node, in_parallel_desc, i, op_name, lbi,
logical_blob_desc, OpType::kOpTypeReduceScatter, -1);
Connect<TaskNode>(in_node, ctx->task_graph()->NewEdge(), collective_node);
ctx->task_graph()->ConnectWithLbi(in_node, collective_node, lbi);
sorted_out_tasks->push_back(collective_node);
}
return TRY(
......@@ -190,11 +190,11 @@ class NcclCollectiveBoxingAllGatherSubTskGphBuilder final : public SubTskGphBuil
FOR_RANGE(int64_t, i, 0, in_parallel_desc.parallel_num()) {
TaskNode* in_node = sorted_in_tasks.at(i);
TaskNode* in_node_proxy =
ctx->GetProxyNode(in_node, in_node->MemZoneId121(), out_parallel_desc, i);
ctx->task_graph()->GetProxyNode(in_node, lbi, out_parallel_desc, i);
auto* collective_node = ctx->task_graph()->NewNode<CollectiveBoxingGenericTaskNode>();
NcclInitCollectiveNode(collective_node, out_parallel_desc, i, op_name, lbi,
logical_blob_desc, OpType::kOpTypeAllGather, -1);
Connect<TaskNode>(in_node_proxy, ctx->task_graph()->NewEdge(), collective_node);
ctx->task_graph()->ConnectWithLbi(in_node_proxy, collective_node, lbi);
sorted_out_tasks->push_back(collective_node);
}
return TRY(BuildSubTskGphBuilderStatus("NcclCollectiveBoxingAllGatherSubTskGphBuilder", ""));
......@@ -232,7 +232,7 @@ class NcclCollectiveBoxingReduceSubTskGphBuilder final : public SubTskGphBuilder
auto* collective_node = ctx->task_graph()->NewNode<CollectiveBoxingGenericTaskNode>();
NcclInitCollectiveNode(collective_node, in_parallel_desc, i, op_name, lbi,
logical_blob_desc, OpType::kOpTypeReduce, root_parallel_id);
Connect<TaskNode>(in_node, ctx->task_graph()->NewEdge(), collective_node);
ctx->task_graph()->ConnectWithLbi(in_node, collective_node, lbi);
if (i == root_parallel_id) {
sorted_out_tasks->push_back(collective_node);
} else {
......@@ -288,12 +288,12 @@ class CollectiveBoxingScatterThenNcclAllGatherSubTskGphBuilder final : public Su
slice_node->ConnectToSrcNodeWithSlice(in_node, ctx->task_graph()->NewEdge(), in_slice);
// copy to dst gpu
TaskNode* slice_node_proxy =
ctx->GetProxyNode(slice_node, slice_node->MemZoneId121(), out_parallel_desc, out_id);
ctx->task_graph()->GetProxyNode(slice_node, lbi, out_parallel_desc, out_id);
// allgather
auto* collective_node = ctx->task_graph()->NewNode<CollectiveBoxingGenericTaskNode>();
NcclInitCollectiveNode(collective_node, out_parallel_desc, out_id, op_name, lbi,
logical_blob_desc, OpType::kOpTypeAllGather, -1);
Connect<TaskNode>(slice_node_proxy, ctx->task_graph()->NewEdge(), collective_node);
ctx->task_graph()->ConnectWithLbi(slice_node_proxy, collective_node, lbi);
sorted_out_tasks->push_back(collective_node);
}
return TRY(BuildSubTskGphBuilderStatus(
......@@ -330,8 +330,8 @@ class NcclCollectiveBoxingBroadcastSubTskGphBuilder final : public SubTskGphBuil
auto* cpu_in_node = sorted_in_tasks.front();
root_parallel_id =
SubTskGphBuilderUtil::FindNearestSrcParallelId(out_parallel_desc, in_parallel_desc, 0);
gpu_in_node = ctx->GetProxyNode(cpu_in_node, cpu_in_node->MemZoneId121(), out_parallel_desc,
root_parallel_id);
gpu_in_node =
ctx->task_graph()->GetProxyNode(cpu_in_node, lbi, out_parallel_desc, root_parallel_id);
} else if (in_parallel_desc.device_type() == DeviceType::kGPU) {
root_parallel_id = FindRootParallelId(out_parallel_desc, in_parallel_desc);
......@@ -347,10 +347,11 @@ class NcclCollectiveBoxingBroadcastSubTskGphBuilder final : public SubTskGphBuil
NcclInitCollectiveNode(collective_node, out_parallel_desc, i, op_name, lbi,
logical_blob_desc, OpType::kOpTypeBroadcast, root_parallel_id);
if (i == root_parallel_id) {
Connect<TaskNode>(gpu_in_node, ctx->task_graph()->NewEdge(), collective_node);
ctx->task_graph()->ConnectWithLbi(gpu_in_node, collective_node, lbi);
} else {
gpu_in_node->BuildCtrlRegstDesc(collective_node);
Connect<TaskNode>(gpu_in_node, ctx->task_graph()->NewEdge(), collective_node);
Connect<TaskNode>(gpu_in_node, ctx->task_graph()->NewTaskEdgeWithLbi(lbi),
collective_node);
}
sorted_out_tasks->push_back(collective_node);
}
......@@ -402,18 +403,18 @@ class NcclCollectiveBoxingAll2AllSubTskGphBuilder final : public SubTskGphBuilde
ctx->task_graph()->NewNode<CollectiveBoxingPackTaskNode>();
pack_node->Init(machine_id, thrd_id, lbi, logical_blob_desc.shape(), in_sbp_parallel,
out_sbp_parallel, in_parallel_desc.parallel_num());
Connect<TaskNode>(in_node, ctx->task_graph()->NewEdge(), pack_node);
ctx->task_graph()->ConnectWithLbi(in_node, pack_node, lbi);
auto* collective_node = ctx->task_graph()->NewNode<CollectiveBoxingGenericTaskNode>();
NcclInitCollectiveNode(collective_node, out_parallel_desc, i, op_name, lbi,
logical_blob_desc, OpType::kOpTypeAll2All, -1);
Connect<TaskNode>(pack_node, ctx->task_graph()->NewEdge(), collective_node);
ctx->task_graph()->ConnectWithLbi(pack_node, collective_node, lbi);
CollectiveBoxingUnpackTaskNode* unpack_node =
ctx->task_graph()->NewNode<CollectiveBoxingUnpackTaskNode>();
unpack_node->Init(machine_id, thrd_id, lbi, logical_blob_desc.shape(), in_sbp_parallel,
out_sbp_parallel, in_parallel_desc.parallel_num());
Connect<TaskNode>(collective_node, ctx->task_graph()->NewEdge(), unpack_node);
ctx->task_graph()->ConnectWithLbi(collective_node, unpack_node, lbi);
sorted_out_tasks->push_back(unpack_node);
}
return TRY(BuildSubTskGphBuilderStatus("NcclCollectiveBoxingAll2AllSubTskGphBuilder", ""));
......
......@@ -31,8 +31,8 @@ Maybe<SubTskGphBuilderStatus> NaiveB2BSubTskGphBuilder::Build(
const int64_t nearest_in_parallel_id = SubTskGphBuilderUtil::FindNearestSrcParallelId(
in_parallel_desc, out_parallel_desc, out_id);
TaskNode* nearest_in_node = sorted_in_tasks.at(nearest_in_parallel_id);
TaskNode* proxy = ctx->GetProxyNode(nearest_in_node, nearest_in_node->MemZoneId121(),
out_parallel_desc, out_id);
TaskNode* proxy =
ctx->task_graph()->GetProxyNode(nearest_in_node, lbi, out_parallel_desc, out_id);
sorted_out_tasks->push_back(proxy);
}
return TRY(BuildSubTskGphBuilderStatus("NaiveB2BSubTskGphBuilder", ""));
......
......@@ -50,8 +50,8 @@ Maybe<SubTskGphBuilderStatus> NaiveB2PSubTskGphBuilder::Build(
const int64_t nearest_in_id = out_id2nearest_in_id.at(out_id);
TaskNode* nearest_in_node = sorted_in_tasks.at(nearest_in_id);
if (out_id == nearest_out_node_idx) {
TaskNode* proxy = ctx->GetProxyNode(nearest_in_node, nearest_in_node->MemZoneId121(),
out_parallel_desc, out_id);
TaskNode* proxy =
ctx->task_graph()->GetProxyNode(nearest_in_node, lbi, out_parallel_desc, out_id);
sorted_out_tasks->push_back(proxy);
} else {
......
......@@ -30,8 +30,7 @@ Maybe<SubTskGphBuilderStatus> OneToOneSubTskGphBuilder::Build(
&& in_sbp_parallel == out_sbp_parallel)) {
for (int64_t i = 0; i < in_parallel_desc.parallel_num(); ++i) {
TaskNode* in_node = sorted_in_tasks.at(i);
// TODO(liujuncheng): use lbi
TaskNode* proxy = ctx->GetProxyNode(in_node, in_node->MemZoneId121(), out_parallel_desc, i);
TaskNode* proxy = ctx->task_graph()->GetProxyNode(in_node, lbi, out_parallel_desc, i);
sorted_out_tasks->push_back(proxy);
}
return TRY(BuildSubTskGphBuilderStatus("OneToOneSubTskGphBuilder", ""));
......
......@@ -163,7 +163,7 @@ Maybe<SubTskGphBuilderStatus> SliceBoxingSubTskGphBuilder::Build(
dst_node->ConnectToSrcNodeWithSlice(src_node, NewEdge(), src_slice);
return dst_node;
};
const auto BuildSubTaskGphS2B = [&ctx, &CreateBoxingNode121, &NewEdge](
const auto BuildSubTaskGphS2B = [&ctx, &CreateBoxingNode121, &NewEdge, &lbi](
const ParallelDesc& in_pd, const ParallelDesc& out_pd,
const SbpParallel& in_sbp, const SbpParallel& out_sbp,
const BlobDesc& blob_desc,
......@@ -183,9 +183,8 @@ Maybe<SubTskGphBuilderStatus> SliceBoxingSubTskGphBuilder::Build(
if (SubTskGphBuilderUtil::IsOnSameGPU(in_node, out_node)) {
out_node->ConnectToSrcNodeWithSlice(in_node, NewEdge(), in_slice);
} else {
TaskNode* proxy_node =
ctx->GetProxyNode(in_node, in_node->MemZoneId121(), out_node->machine_id(),
Global<IDMgr>::Get()->CpuMemZoneId());
TaskNode* proxy_node = ctx->task_graph()->GetProxyNode(
in_node, lbi, out_node->machine_id(), Global<IDMgr>::Get()->CpuMemZoneId());
out_node->ConnectToSrcNodeWithSlice(proxy_node, NewEdge(), in_slice);
}
}
......@@ -291,9 +290,8 @@ Maybe<SubTskGphBuilderStatus> SliceBoxingSubTskGphBuilder::Build(
in_slices.at(in_id));
}
}
TaskNode* local_add_proxy_node =
ctx->GetProxyNode(local_concat_node, Global<IDMgr>::Get()->CpuMemZoneId(),
out_node->machine_id(), Global<IDMgr>::Get()->CpuMemZoneId());
TaskNode* local_add_proxy_node = ctx->task_graph()->GetProxyNode(
local_concat_node, lbi, out_node->machine_id(), Global<IDMgr>::Get()->CpuMemZoneId());
out_node->ConnectToSrcNodeWithSlice(local_add_proxy_node, NewEdge(), concat_slice);
}
}
......@@ -353,9 +351,8 @@ Maybe<SubTskGphBuilderStatus> SliceBoxingSubTskGphBuilder::Build(
for (const int64_t in_id : in_parallel_ids) {
local_add_node->ConnectToSrcNodeWithSlice(in_nodes.at(in_id), NewEdge(), in_slice);
}
TaskNode* local_add_proxy_node =
ctx->GetProxyNode(local_add_node, Global<IDMgr>::Get()->CpuMemZoneId(),
out_node->machine_id(), Global<IDMgr>::Get()->CpuMemZoneId());
TaskNode* local_add_proxy_node = ctx->task_graph()->GetProxyNode(
local_add_node, lbi, out_node->machine_id(), Global<IDMgr>::Get()->CpuMemZoneId());
out_node->ConnectToSrcNodeWithSlice(local_add_proxy_node, NewEdge(), out_slice);
}
}
......@@ -409,18 +406,17 @@ Maybe<SubTskGphBuilderStatus> SliceBoxingSubTskGphBuilder::Build(
const int64_t out_machine_id = machine_id7out_parallel_ids.first;
TaskNode* in_box_node = nullptr;
if (out_box_nodes.size() == 1) {
in_box_node = ctx->GetProxyNode(
out_box_nodes.front(), out_box_nodes.front()->MemZoneId121(),
machine_id7out_parallel_ids.first, Global<IDMgr>::Get()->CpuMemZoneId());
in_box_node = ctx->task_graph()->GetProxyNode(out_box_nodes.front(), lbi,
machine_id7out_parallel_ids.first,
Global<IDMgr>::Get()->CpuMemZoneId());
} else {
auto* add_node = ctx->task_graph()->NewNode<SliceBoxingTaskNode>();
add_node->Init(lbi, slice, kSliceBoxingTaskModeAdd, machine_id7out_parallel_ids.first,
Global<IDMgr>::Get()->PickCpuThrdIdEvenly(machine_id7out_parallel_ids.first),
Global<IDMgr>::Get()->CpuMemZoneId());
for (TaskNode* out_box_node : out_box_nodes) {
TaskNode* out_boxing_node_proxy =
ctx->GetProxyNode(out_box_node, out_box_node->MemZoneId121(), out_machine_id,
Global<IDMgr>::Get()->CpuMemZoneId());
TaskNode* out_boxing_node_proxy = ctx->task_graph()->GetProxyNode(
out_box_node, lbi, out_machine_id, Global<IDMgr>::Get()->CpuMemZoneId());
add_node->ConnectToSrcNodeWithSlice(out_boxing_node_proxy, NewEdge(), slice);
}
in_box_node = add_node;
......@@ -435,8 +431,8 @@ Maybe<SubTskGphBuilderStatus> SliceBoxingSubTskGphBuilder::Build(
} else {
UNIMPLEMENTED();
}
(*out_nodes)[out_id] = ctx->GetProxyNode(in_box_node, Global<IDMgr>::Get()->CpuMemZoneId(),
out_machine_id, mem_zone_id);
(*out_nodes)[out_id] =
ctx->task_graph()->GetProxyNode(in_box_node, lbi, out_machine_id, mem_zone_id);
}
}
};
......@@ -460,8 +456,7 @@ Maybe<SubTskGphBuilderStatus> SliceBoxingSubTskGphBuilder::Build(
SliceBoxingTaskNode* slice_node =
CreateBoxingNode121(in_pd, nearest_idx, out_slice, kSliceBoxingTaskModeCopy);
slice_node->ConnectToSrcNodeWithSlice(in_node, NewEdge(), in_slice);
TaskNode* out_node =
ctx->GetProxyNode(slice_node, slice_node->MemZoneId121(), out_pd, out_id);
TaskNode* out_node = ctx->task_graph()->GetProxyNode(slice_node, lbi, out_pd, out_id);
out_nodes->push_back(out_node);
}
......
......@@ -21,70 +21,4 @@ SubTskGphBuilderCtx::SubTskGphBuilderCtx(TaskGraph* task_graph) : task_graph_(ta
TaskGraph* SubTskGphBuilderCtx::task_graph() { return task_graph_; }
TaskNode* SubTskGphBuilderCtx::GetProxyNode(TaskNode* src_node, int64_t src_mem_zone_id,
int64_t dst_machine_id, int64_t dst_mem_zone_id) {
const auto key = std::make_pair(dst_machine_id, dst_mem_zone_id);
if (node2proxies_.find(src_node) != node2proxies_.cend()
&& node2proxies_.at(src_node).find(key) != node2proxies_.at(src_node).cend()) {
return node2proxies_.at(src_node).at(key);
} else {
if (dst_machine_id == src_node->machine_id() && dst_mem_zone_id == src_mem_zone_id) {
node2proxies_[src_node][key] = src_node;
return src_node;
} else if (Global<IDMgr>::Get()->IsGpuMemZone(dst_mem_zone_id)) {
TaskNode* proxy_on_dst_host = GetProxyNode(src_node, src_mem_zone_id, dst_machine_id,
Global<IDMgr>::Get()->CpuMemZoneId());
CopyHdTaskNode* copy_task = task_graph()->NewNode<CopyHdTaskNode>();
copy_task->Init(CopyHdOpConf::H2D, proxy_on_dst_host->machine_id(),
Global<IDMgr>::Get()->GetGpuPhyIdFromMemZoneId(dst_mem_zone_id));
Connect<TaskNode>(proxy_on_dst_host, task_graph()->NewEdge(), copy_task);
node2proxies_[src_node][key] = copy_task;
return copy_task;
} else if (Global<IDMgr>::Get()->IsCpuMemZone(dst_mem_zone_id)) {
if (src_node->machine_id() == dst_machine_id) {
if (Global<IDMgr>::Get()->IsGpuMemZone(src_mem_zone_id)) {
CopyHdTaskNode* copy_task = task_graph()->NewNode<CopyHdTaskNode>();
copy_task->Init(CopyHdOpConf::D2H, src_node->machine_id(),
Global<IDMgr>::Get()->GetGpuPhyIdFromMemZoneId(src_mem_zone_id));
Connect<TaskNode>(src_node, task_graph()->NewEdge(), copy_task);
node2proxies_[src_node][key] = copy_task;
return copy_task;
} else {
UNIMPLEMENTED();
}
} else {
TaskNode* proxy_on_src_host =
GetProxyNode(src_node, src_mem_zone_id, src_node->machine_id(),
Global<IDMgr>::Get()->CpuMemZoneId());
CopyCommNetTaskNode* copy_comm_net_task = task_graph()->NewNode<CopyCommNetTaskNode>();
copy_comm_net_task->Init(dst_machine_id);
Connect<TaskNode>(proxy_on_src_host, task_graph()->NewEdge(), copy_comm_net_task);
node2proxies_[src_node][key] = copy_comm_net_task;
return copy_comm_net_task;
}
} else {
UNIMPLEMENTED();
}
}
}
TaskNode* SubTskGphBuilderCtx::GetProxyNode(TaskNode* src_node, const int64_t src_mem_zone_id,
const ParallelDesc& dst_parallel_desc,
const int64_t dst_parallel_id) {
const int64_t dst_machine_id =
CHECK_JUST(dst_parallel_desc.MachineId4ParallelId(dst_parallel_id));
int64_t dst_mem_zone_id;
const IDMgr* id_mgr = Global<IDMgr>::Get();
if (dst_parallel_desc.device_type() == DeviceType::kCPU) {
dst_mem_zone_id = id_mgr->CpuMemZoneId();
} else if (dst_parallel_desc.device_type() == DeviceType::kGPU) {
const int64_t dst_dev_phy_id =
CHECK_JUST(dst_parallel_desc.DeviceId4ParallelId(dst_parallel_id));
dst_mem_zone_id = id_mgr->GpuMemZoneId(dst_dev_phy_id);
} else {
UNIMPLEMENTED();
}
return GetProxyNode(src_node, src_mem_zone_id, dst_machine_id, dst_mem_zone_id);
}
} // namespace oneflow
......@@ -18,13 +18,9 @@ limitations under the License.
#include "oneflow/core/common/util.h"
#include "oneflow/core/graph/task_graph.h"
#include "oneflow/core/memory/memory_allocator.h"
namespace oneflow {
class TaskGraph;
class TaskNode;
class SubTskGphBuilderCtx final {
public:
OF_DISALLOW_COPY_AND_MOVE(SubTskGphBuilderCtx);
......@@ -32,21 +28,9 @@ class SubTskGphBuilderCtx final {
virtual ~SubTskGphBuilderCtx() = default;
virtual TaskGraph* task_graph();
TaskNode* GetProxyNode(TaskNode* src_node, int64_t src_mem_zone_id, int64_t dst_machine_id,
int64_t dst_mem_zone_id);
TaskNode* GetProxyNode(TaskNode* src_node, int64_t src_mem_zone_id,
const ParallelDesc& dst_parallel_desc, const int64_t dst_parallel_id);
template<typename T1, typename T2>
void ConnectAll121(const std::vector<T1*>& src_nodes, const std::vector<T2*>& dst_nodes) {
CHECK_EQ(src_nodes.size(), dst_nodes.size());
FOR_RANGE(int64_t, i, 0, dst_nodes.size()) {
Connect<TaskNode>(src_nodes.at(i), task_graph()->NewEdge(), dst_nodes.at(i));
}
}
private:
TaskGraph* task_graph_;
HashMap<TaskNode*, HashMap<std::pair<int64_t, int64_t>, TaskNode*>> node2proxies_;
};
} // namespace oneflow
......
......@@ -19,9 +19,9 @@ limitations under the License.
namespace oneflow {
void BoxingIdentityTaskNode::Init(int64_t machine_id, int64_t thrd_id, const LogicalBlobId& lbi) {
lbi_ = lbi;
set_machine_id(machine_id);
set_thrd_id(thrd_id);
set_lbi(lbi);
}
void BoxingIdentityTaskNode::ProduceAllRegstsAndBindEdges() {
......@@ -39,7 +39,7 @@ void BoxingIdentityTaskNode::BuildExecGphAndRegst() {
OperatorConf op_conf;
op_conf.set_name("System-Boxing-Identity-" + NewUniqueId());
op_conf.set_device_tag(*CHECK_JUST(DeviceTag4DeviceType(this->device_type())));
*op_conf.mutable_boxing_identity_conf()->mutable_lbi() = lbi_;
*op_conf.mutable_boxing_identity_conf()->mutable_lbi() = lbi();
std::shared_ptr<Operator> sole_op = ConstructOp(op_conf);
node->mut_op() = sole_op;
node->BindBnWithRegst(sole_op->SoleIbn(), GetSoleConsumedRegst("in"));
......
......@@ -15,11 +15,12 @@ limitations under the License.
*/
#ifndef ONEFLOW_CORE_GRAPH_BOXING_IDENTITY_TASK_NODE_H_
#define ONEFLOW_CORE_GRAPH_BOXING_IDENTITY_TASK_NODE_H_
#include "oneflow/core/graph/task_node.h"
#include "oneflow/core/graph/transport_task_node.h"
namespace oneflow {
class BoxingIdentityTaskNode : public TaskNode {
class BoxingIdentityTaskNode : public TransportTaskNode {
public:
OF_DISALLOW_COPY_AND_MOVE(BoxingIdentityTaskNode);
BoxingIdentityTaskNode() = default;
......@@ -33,8 +34,6 @@ class BoxingIdentityTaskNode : public TaskNode {
void ProduceAllRegstsAndBindEdges() override;
void ConsumeAllRegsts() final;
void InferProducedDataRegstTimeShape() final;
LogicalBlobId lbi_;
};
} // namespace oneflow
......
......@@ -20,9 +20,9 @@ namespace oneflow {
void BoxingZerosTaskNode::Init(int64_t machine_id, int64_t thrd_id, const LogicalBlobId& lbi,
const Shape& shape, DataType data_type, const Shape& time_shape) {
lbi_ = lbi;
set_machine_id(machine_id);
set_thrd_id(thrd_id);
set_lbi(lbi);
shape_ = shape;
data_type_ = data_type;
time_shape_ = time_shape;
......@@ -42,7 +42,7 @@ void BoxingZerosTaskNode::BuildExecGphAndRegst() {
OperatorConf op_conf;
op_conf.set_name("System-Boxing-Zeros-" + NewUniqueId());
op_conf.set_device_tag(*CHECK_JUST(DeviceTag4DeviceType(this->device_type())));
*op_conf.mutable_boxing_zeros_conf()->mutable_lbi() = lbi_;
*op_conf.mutable_boxing_zeros_conf()->mutable_lbi() = lbi();
shape_.ToProto(op_conf.mutable_boxing_zeros_conf()->mutable_shape());
op_conf.mutable_boxing_zeros_conf()->set_data_type(data_type_);
std::shared_ptr<Operator> sole_op = ConstructOp(op_conf);
......
......@@ -16,11 +16,11 @@ limitations under the License.
#ifndef ONEFLOW_CORE_GRAPH_BOXING_ZEROS_TASK_NODE_H_
#define ONEFLOW_CORE_GRAPH_BOXING_ZEROS_TASK_NODE_H_
#include "oneflow/core/graph/compute_task_node.h"
#include "oneflow/core/graph/transport_task_node.h"
namespace oneflow {
class BoxingZerosTaskNode : public TaskNode {
class BoxingZerosTaskNode : public TransportTaskNode {
public:
OF_DISALLOW_COPY_AND_MOVE(BoxingZerosTaskNode);
BoxingZerosTaskNode() = default;
......@@ -36,7 +36,6 @@ class BoxingZerosTaskNode : public TaskNode {
void ConsumeAllRegsts() final;
void InferProducedDataRegstTimeShape() final;
LogicalBlobId lbi_;
Shape shape_;
DataType data_type_;
Shape time_shape_;
......
......@@ -23,9 +23,9 @@ void CollectiveBoxingPackTaskNode::Init(int64_t machine_id, int64_t thrd_id,
const SbpParallel& src_sbp_parallel,
const SbpParallel& dst_sbp_parallel,
const int64_t parallel_num) {
lbi_ = lbi;
set_machine_id(machine_id);
set_thrd_id(thrd_id);
set_lbi(lbi);
logical_shape_ = logical_shape;
parallel_num_ = parallel_num;
src_sbp_parallel_ = src_sbp_parallel;
......@@ -48,7 +48,7 @@ void CollectiveBoxingPackTaskNode::BuildExecGphAndRegst() {
op_conf.set_name("System-Collective-Boxing-Pack-" + NewUniqueId());
op_conf.set_device_tag(*CHECK_JUST(DeviceTag4DeviceType(this->device_type())));
auto* collective_boxing_pack_conf = op_conf.mutable_collective_boxing_pack_conf();
*collective_boxing_pack_conf->mutable_lbi() = lbi_;
*collective_boxing_pack_conf->mutable_lbi() = lbi();
logical_shape_.ToProto(collective_boxing_pack_conf->mutable_logical_shape());
*collective_boxing_pack_conf->mutable_src_sbp_parallel() = src_sbp_parallel_;
*collective_boxing_pack_conf->mutable_dst_sbp_parallel() = dst_sbp_parallel_;
......
......@@ -15,11 +15,12 @@ limitations under the License.
*/
#ifndef ONEFLOW_CORE_GRAPH_COLLECTIVE_BOXING_PACK_TASK_NODE_H_
#define ONEFLOW_CORE_GRAPH_COLLECTIVE_BOXING_PACK_TASK_NODE_H_
#include "oneflow/core/graph/task_node.h"
#include "oneflow/core/graph/transport_task_node.h"
namespace oneflow {
class CollectiveBoxingPackTaskNode : public TaskNode {
class CollectiveBoxingPackTaskNode : public TransportTaskNode {
public:
OF_DISALLOW_COPY_AND_MOVE(CollectiveBoxingPackTaskNode);
CollectiveBoxingPackTaskNode() = default;
......@@ -36,7 +37,6 @@ class CollectiveBoxingPackTaskNode : public TaskNode {
void ConsumeAllRegsts() final;
void InferProducedDataRegstTimeShape() final;
LogicalBlobId lbi_;
Shape logical_shape_;
SbpParallel src_sbp_parallel_;
SbpParallel dst_sbp_parallel_;
......
......@@ -19,9 +19,10 @@ limitations under the License.
namespace oneflow {
void CollectiveBoxingGenericTaskNode::Init(int64_t machine_id, int64_t thrd_id,
const OperatorConf& op_conf) {
const LogicalBlobId& lbi, const OperatorConf& op_conf) {
set_machine_id(machine_id);
set_thrd_id(thrd_id);
set_lbi(lbi);
op_conf_ = op_conf;
}
......
......@@ -16,17 +16,18 @@ limitations under the License.
#ifndef ONEFLOW_CORE_GRAPH_COLLECTIVE_BOXING_TASK_NODE_H_
#define ONEFLOW_CORE_GRAPH_COLLECTIVE_BOXING_TASK_NODE_H_
#include "oneflow/core/graph/task_node.h"
#include "oneflow/core/graph/transport_task_node.h"
namespace oneflow {
class CollectiveBoxingGenericTaskNode : public TaskNode {
class CollectiveBoxingGenericTaskNode : public TransportTaskNode {
public:
OF_DISALLOW_COPY_AND_MOVE(CollectiveBoxingGenericTaskNode);
CollectiveBoxingGenericTaskNode() = default;
~CollectiveBoxingGenericTaskNode() override = default;
void Init(int64_t machine_id, int64_t thrd_id, const OperatorConf& op_conf);
void Init(int64_t machine_id, int64_t thrd_id, const LogicalBlobId& lbi,
const OperatorConf& op_conf);
private:
void BuildExecGphAndRegst() override;
......
......@@ -23,9 +23,9 @@ void CollectiveBoxingUnpackTaskNode::Init(int64_t machine_id, int64_t thrd_id,
const SbpParallel& src_sbp_parallel,
const SbpParallel& dst_sbp_parallel,
const int64_t parallel_num) {
lbi_ = lbi;
set_machine_id(machine_id);
set_thrd_id(thrd_id);
set_lbi(lbi);
logical_shape_ = logical_shape;
parallel_num_ = parallel_num;
src_sbp_parallel_ = src_sbp_parallel;
......@@ -48,7 +48,7 @@ void CollectiveBoxingUnpackTaskNode::BuildExecGphAndRegst() {
op_conf.set_name("System-Collective-Boxing-Unpack-" + NewUniqueId());
op_conf.set_device_tag(*CHECK_JUST(DeviceTag4DeviceType(this->device_type())));
auto* collective_boxing_unpack_conf = op_conf.mutable_collective_boxing_unpack_conf();
*collective_boxing_unpack_conf->mutable_lbi() = lbi_;
*collective_boxing_unpack_conf->mutable_lbi() = lbi();
logical_shape_.ToProto(collective_boxing_unpack_conf->mutable_logical_shape());
*collective_boxing_unpack_conf->mutable_src_sbp_parallel() = src_sbp_parallel_;
*collective_boxing_unpack_conf->mutable_dst_sbp_parallel() = dst_sbp_parallel_;
......
......@@ -16,11 +16,11 @@ limitations under the License.
#ifndef ONEFLOW_CORE_GRAPH_COLLECTIVE_BOXING_UNPACK_TASK_NODE_H_
#define ONEFLOW_CORE_GRAPH_COLLECTIVE_BOXING_UNPACK_TASK_NODE_H_
#include "oneflow/core/graph/task_node.h"
#include "oneflow/core/graph/transport_task_node.h"
namespace oneflow {
class CollectiveBoxingUnpackTaskNode : public TaskNode {
class CollectiveBoxingUnpackTaskNode : public TransportTaskNode {
public:
OF_DISALLOW_COPY_AND_MOVE(CollectiveBoxingUnpackTaskNode);
CollectiveBoxingUnpackTaskNode() = default;
......@@ -38,7 +38,6 @@ class CollectiveBoxingUnpackTaskNode : public TaskNode {
void ConsumeAllRegsts() final;
void InferProducedDataRegstTimeShape() final;
LogicalBlobId lbi_;
Shape logical_shape_;
SbpParallel src_sbp_parallel_;
SbpParallel dst_sbp_parallel_;
......
......@@ -57,7 +57,8 @@ void CopyTaskNode::BuildExecGphAndRegst() {
void CopyTaskNode::InferProducedDataRegstTimeShape() { NaiveInferProducedDataRegstTimeShape(); }
void CopyHdTaskNode::Init(CopyHdOpConf::Type copy_type, int64_t machine_id, int64_t dev_phy_id) {
void CopyHdTaskNode::Init(CopyHdOpConf::Type copy_type, int64_t machine_id, int64_t dev_phy_id,
const LogicalBlobId& lbi) {
copy_type_ = copy_type;
set_machine_id(machine_id);
DeviceId device_id{static_cast<DeviceId::rank_t>(machine_id), DeviceType::kGPU,
......@@ -73,6 +74,7 @@ void CopyHdTaskNode::Init(CopyHdOpConf::Type copy_type, int64_t machine_id, int6
UNIMPLEMENTED();
}
set_thrd_id(SerializeStreamIdToInt64(StreamId{device_id, stream_index}));
set_lbi(lbi);
}
void CopyHdTaskNode::InitProducedRegstMemCase(MemoryCase* mem_case) {
......@@ -98,7 +100,7 @@ OperatorConf CopyHdTaskNode::NewCopyOpConf() {
return conf;
}
void CopyCommNetTaskNode::Init(int64_t machine_id) {
void CopyCommNetTaskNode::Init(int64_t machine_id, const LogicalBlobId& lbi) {
set_machine_id(machine_id);
DeviceId device_id{static_cast<DeviceId::rank_t>(machine_id), DeviceType::kCPU,
DeviceId::kCPUDeviceIndex};
......@@ -107,6 +109,7 @@ void CopyCommNetTaskNode::Init(int64_t machine_id) {
CHECK_NOTNULL(generator);
StreamId stream_id{device_id, generator->GenerateCommNetStreamIndex()};
set_thrd_id(SerializeStreamIdToInt64(stream_id));
set_lbi(lbi);
}
void CopyCommNetTaskNode::InitProducedRegstMemCase(MemoryCase* mem_case) {
......
......@@ -16,11 +16,11 @@ limitations under the License.
#ifndef ONEFLOW_CORE_GRAPH_COPY_TASK_NODE_H_
#define ONEFLOW_CORE_GRAPH_COPY_TASK_NODE_H_
#include "oneflow/core/graph/task_node.h"
#include "oneflow/core/graph/transport_task_node.h"
namespace oneflow {
class CopyTaskNode : public TaskNode {
class CopyTaskNode : public TransportTaskNode {
public:
OF_DISALLOW_COPY_AND_MOVE(CopyTaskNode);
CopyTaskNode() = default;
......@@ -45,7 +45,7 @@ class CopyHdTaskNode final : public CopyTaskNode {
TaskType GetTaskType() const override { return TaskType::kCopyHd; }
void Init(CopyHdOpConf::Type, int64_t machine_id, int64_t dev_phy_id);
void Init(CopyHdOpConf::Type, int64_t machine_id, int64_t dev_phy_id, const LogicalBlobId& lbi);
CopyHdOpConf::Type copy_type() const { return copy_type_; }
int64_t MemZoneId121() const override {
......@@ -74,7 +74,7 @@ class CopyCommNetTaskNode final : public CopyTaskNode {
TaskType GetTaskType() const override { return TaskType::kCopyCommNet; }
void Init(int64_t machine_id);
void Init(int64_t machine_id, const LogicalBlobId& lbi);
private:
void InitProducedRegstMemCase(MemoryCase*) override;
......
......@@ -34,8 +34,6 @@ class NormalForwardCompTaskNode final : public CompTaskNode {
bool HasBackwardCompTaskNode();
private:
bool IsAllOutNodeNormalForward() const;
bool CanProduceSeperatedRegstsForEachOutBlob() const;
void ProduceOutRegstByNameAndBlockNum(const std::string& name, size_t mem_block_num);
void BuildExecGphAndRegst() override;
void BuildExecGphStructAndBindInRegst();
......
......@@ -21,13 +21,13 @@ namespace oneflow {
void SliceBoxingTaskNode::Init(const LogicalBlobId& lbi, const TensorSliceView& out_slice,
const SliceBoxingTaskMode mode, int64_t machine_id, int64_t thrd_id,
int64_t mem_zone_id) {
lbi_ = lbi;
out_slice_ = out_slice;
out_shape_ = out_slice.shape();
mode_ = mode;
mem_zone_id_ = mem_zone_id;
set_machine_id(machine_id);
set_thrd_id(thrd_id);
set_lbi(lbi);
}
void SliceBoxingTaskNode::Init(const LogicalBlobId& lbi, const TensorSliceView& out_slice,
......@@ -77,7 +77,7 @@ void SliceBoxingTaskNode::BuildExecGphAndRegst() {
node->BindBnWithRegst(ibn, GetSoleConsumedRegst("in_" + std::to_string(i)));
}
std::shared_ptr<RegstDesc> out_regst = GetProducedRegst("out");
out_regst->AddLbi(lbi_);
out_regst->AddLbi(lbi());
node->BindBnWithRegst(op->SoleObn(), out_regst);
node->AddBnToRegstAndBindIt(&Operator::tmp_bns, GetProducedRegst("tmp"));
node->InferBlobDescs(parallel_ctx());
......@@ -94,6 +94,7 @@ void SliceBoxingTaskNode::SetInDataEdgeSlice(const TaskEdge* edge, const TensorS
void SliceBoxingTaskNode::ConnectToSrcNodeWithSlice(TaskNode* src, TaskEdge* edge,
const TensorSliceView& slice) {
edge->AddLbi(lbi());
Connect<TaskNode>(src, edge, this);
SetInDataEdgeSlice(edge, slice);
}
......@@ -104,7 +105,7 @@ OperatorConf SliceBoxingTaskNode::GetBoxingOpConf() {
OperatorConf op_conf{};
op_conf.set_device_tag(*CHECK_JUST(DeviceTag4DeviceType(device_type())));
SliceBoxingConf boxing_conf{};
*boxing_conf.mutable_lbi() = lbi_;
*boxing_conf.mutable_lbi() = lbi();
out_slice_.ToProto(boxing_conf.mutable_out_slice());
out_shape_.ToProto(boxing_conf.mutable_out_shape());
for (const TaskEdge* edge : ordered_in_data_edges_) {
......
......@@ -16,7 +16,7 @@ limitations under the License.
#ifndef ONEFLOW_CORE_GRAPH_SLICE_BOXING_TASK_NODE_H_
#define ONEFLOW_CORE_GRAPH_SLICE_BOXING_TASK_NODE_H_
#include "oneflow/core/graph/task_node.h"
#include "oneflow/core/graph/transport_task_node.h"
#include "oneflow/core/register/tensor_slice_view.h"
namespace oneflow {
......@@ -27,7 +27,7 @@ enum SliceBoxingTaskMode {
kSliceBoxingTaskModeAdd,
};
class SliceBoxingTaskNode final : public TaskNode {
class SliceBoxingTaskNode final : public TransportTaskNode {
public:
OF_DISALLOW_COPY_AND_MOVE(SliceBoxingTaskNode);
SliceBoxingTaskNode() = default;
......@@ -49,10 +49,10 @@ class SliceBoxingTaskNode final : public TaskNode {
void InferProducedDataRegstTimeShape() override;
OperatorConf GetBoxingOpConf();
void InitProducedRegstMemCase(MemoryCase*) override;
int64_t MemZoneId121() const override { return mem_zone_id_; }
HashMap<const TaskEdge*, TensorSliceView> in_data_edge2slice_;
std::vector<const TaskEdge*> ordered_in_data_edges_;
LogicalBlobId lbi_;
TensorSliceView out_slice_;
Shape out_shape_;
SliceBoxingTaskMode mode_ = kSliceBoxingTaskModeInvalid;
......
......@@ -395,32 +395,24 @@ TaskGraph::TaskGraph() {
sub_tsk_gph_builder_ctx_.reset(new SubTskGphBuilderCtx(this));
boxing_logger_ = CreateBoxingLogger();
hierarchical_sub_tsk_gph_builder_.reset(new DispatchHierarchicalSubTskGphBuilder());
HashMap<const OpNode*, std::vector<CompTaskNode*>> logical2sorted_comp_tasks;
HashMap<CompTaskNode*, HashMap<int64_t, std::vector<TaskNode*>>> buf_task;
auto MutBufTask = [&](CompTaskNode* task_node, int64_t machine_id, int32_t mem_zone_id) {
auto& buf_vec = buf_task[task_node][machine_id];
if (buf_vec.empty()) {
buf_vec.assign(Global<ResourceDesc, ForSession>::Get()->MemZoneNum(), nullptr);
}
return &(buf_vec.at(mem_zone_id));
};
HashMap<const OpNode*, std::vector<CompTaskNode*>> op_node2sorted_comp_tasks;
op_graph->ForEachNode([&](const OpNode* op_node) {
std::vector<CompTaskNode*>* sorted_comp_tasks = &(logical2sorted_comp_tasks[op_node]);
std::vector<CompTaskNode*>* sorted_comp_tasks = &(op_node2sorted_comp_tasks[op_node]);
GenSortedCompTaskNodes(op_node, sorted_comp_tasks);
for (CompTaskNode* comp_task : *sorted_comp_tasks) { AddAllocatedNode(comp_task); }
});
op_graph->ForEachEdge([&](const OpEdge* op_edge) {
// TODO(chengcheng): ForEachLbi for one regst one blob.
BldSubTskGphMthd method = GetMthdForBldSubTskGph(op_edge);
(this->*method)(op_edge, logical2sorted_comp_tasks.at(op_edge->src_node()),
logical2sorted_comp_tasks.at(op_edge->dst_node()), MutBufTask);
(this->*method)(op_edge, op_node2sorted_comp_tasks.at(op_edge->src_node()),
op_node2sorted_comp_tasks.at(op_edge->dst_node()));
});
ForEachOpGraphNecessaryCtrlEdge(
op_graph, [&](const OpNode* src, const OpNode* dst, int64_t ctrl_regst_num) {
const auto& src_task_nodes = logical2sorted_comp_tasks.at(src);
const auto& dst_task_nodes = logical2sorted_comp_tasks.at(dst);
const auto& src_task_nodes = op_node2sorted_comp_tasks.at(src);
const auto& dst_task_nodes = op_node2sorted_comp_tasks.at(dst);
if (src->op().op_conf().has_src_subset_tick_conf()) {
UNIMPLEMENTED();
} else if (dst->op().op_conf().has_dst_subset_tick_conf()) {
......@@ -436,17 +428,82 @@ TaskGraph::TaskGraph() {
TaskGraph::~TaskGraph() = default;
Maybe<void> TaskGraph::ConnectDstSubsetTickEdges(const std::vector<CompTaskNode*>& src_task_nodes,
const std::vector<CompTaskNode*>& dst_task_nodes) {
std::function<Maybe<CompTaskNode*>(int64_t mchn_id, int64_t thrd_id)> TaskNode4MachineId7ThrdId;
JUST(MakeGetterTaskNode4MachineId7ThrdId(dst_task_nodes, &TaskNode4MachineId7ThrdId));
for (CompTaskNode* src_task_node : src_task_nodes) {
CompTaskNode* dst_task_node =
JUST(TaskNode4MachineId7ThrdId(src_task_node->machine_id(), src_task_node->thrd_id()));
TaskEdge* edge = NewEdge();
Connect<TaskNode>(src_task_node, edge, dst_task_node);
TaskEdge* TaskGraph::NewTaskEdgeWithLbi(const LogicalBlobId& lbi) {
TaskEdge* edge = NewEdge();
edge->AddLbi(lbi);
return edge;
}
TaskEdge* TaskGraph::NewTaskEdgeWithLbis(const std::vector<LogicalBlobId>& lbis) {
TaskEdge* edge = NewEdge();
edge->AddLbis(lbis);
return edge;
}
TaskNode* TaskGraph::GetProxyNode(TaskNode* src_node, const LogicalBlobId& lbi,
int64_t dst_machine_id, int64_t dst_mem_zone_id) {
int64_t src_mem_zone_id = src_node->MemZoneId121();
const ProxyKey key(src_node, lbi, dst_machine_id, dst_mem_zone_id);
if (proxy2node.find(key) != proxy2node.cend()) {
return proxy2node.at(key);
} else {
if (dst_machine_id == src_node->machine_id() && dst_mem_zone_id == src_mem_zone_id) {
proxy2node[key] = src_node;
return src_node;
} else if (Global<IDMgr>::Get()->IsGpuMemZone(dst_mem_zone_id)) {
TaskNode* proxy_on_dst_host =
GetProxyNode(src_node, lbi, dst_machine_id, Global<IDMgr>::Get()->CpuMemZoneId());
CopyHdTaskNode* copy_task = NewNode<CopyHdTaskNode>();
copy_task->Init(CopyHdOpConf::H2D, proxy_on_dst_host->machine_id(),
Global<IDMgr>::Get()->GetGpuPhyIdFromMemZoneId(dst_mem_zone_id), lbi);
Connect<TaskNode>(proxy_on_dst_host, NewTaskEdgeWithLbi(lbi), copy_task);
proxy2node[key] = copy_task;
return copy_task;
} else if (Global<IDMgr>::Get()->IsCpuMemZone(dst_mem_zone_id)) {
if (src_node->machine_id() == dst_machine_id) {
if (Global<IDMgr>::Get()->IsGpuMemZone(src_mem_zone_id)) {
CopyHdTaskNode* copy_task = NewNode<CopyHdTaskNode>();
copy_task->Init(CopyHdOpConf::D2H, src_node->machine_id(),
Global<IDMgr>::Get()->GetGpuPhyIdFromMemZoneId(src_mem_zone_id), lbi);
Connect<TaskNode>(src_node, NewTaskEdgeWithLbi(lbi), copy_task);
proxy2node[key] = copy_task;
return copy_task;
} else {
UNIMPLEMENTED();
}
} else {
TaskNode* proxy_on_src_host = GetProxyNode(src_node, lbi, src_node->machine_id(),
Global<IDMgr>::Get()->CpuMemZoneId());
CopyCommNetTaskNode* copy_comm_net_task = NewNode<CopyCommNetTaskNode>();
copy_comm_net_task->Init(dst_machine_id, lbi);
Connect<TaskNode>(proxy_on_src_host, NewTaskEdgeWithLbi(lbi), copy_comm_net_task);
proxy2node[key] = copy_comm_net_task;
return copy_comm_net_task;
}
} else {
UNIMPLEMENTED();
}
}
return Maybe<void>::Ok();
UNIMPLEMENTED();
return nullptr;
}
TaskNode* TaskGraph::GetProxyNode(TaskNode* src_node, const LogicalBlobId& lbi,
const ParallelDesc& dst_parallel_desc, int64_t dst_parallel_id) {
const int64_t dst_machine_id =
CHECK_JUST(dst_parallel_desc.MachineId4ParallelId(dst_parallel_id));
int64_t dst_mem_zone_id;
const IDMgr* id_mgr = Global<IDMgr>::Get();
if (dst_parallel_desc.device_type() == DeviceType::kCPU) {
dst_mem_zone_id = id_mgr->CpuMemZoneId();
} else if (dst_parallel_desc.device_type() == DeviceType::kGPU) {
const int64_t dst_dev_phy_id =
CHECK_JUST(dst_parallel_desc.DeviceId4ParallelId(dst_parallel_id));
dst_mem_zone_id = id_mgr->GpuMemZoneId(dst_dev_phy_id);
} else {
UNIMPLEMENTED();
}
return GetProxyNode(src_node, lbi, dst_machine_id, dst_mem_zone_id);
}
void TaskGraph::ConnectCtrlEdges(const std::vector<CompTaskNode*>& src_task_nodes,
......@@ -640,17 +697,7 @@ DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBoxing) {
const OpNode* src_op_node = op_edge->src_node();
const OpNode* dst_op_node = op_edge->dst_node();
for (const LogicalBlobId& lbi : op_edge->lbis()) {
std::vector<TaskNode*> in_nodes;
if (op_edge->lbis().size() == 1) {
in_nodes.assign(sorted_src_comp_tasks.begin(), sorted_src_comp_tasks.end());
} else {
for (CompTaskNode* src_node : sorted_src_comp_tasks) {
auto* identity_node = NewNode<BoxingIdentityTaskNode>();
identity_node->Init(src_node->machine_id(), src_node->thrd_id(), lbi);
Connect<TaskNode>(src_node, NewEdge(), identity_node);
in_nodes.push_back(identity_node);
}
}
std::vector<TaskNode*> in_nodes(sorted_src_comp_tasks.begin(), sorted_src_comp_tasks.end());
std::vector<TaskNode*> out_nodes;
out_nodes.reserve(sorted_dst_comp_tasks.size());
std::vector<std::vector<TaskNode*>> sorted_ctrl_tasks;
......@@ -668,7 +715,10 @@ DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBoxing) {
boxing_logger_->Log(*status, src_op_node->op().op_name(), dst_op_node->op().op_name(),
src_parallel_desc, dst_parallel_desc, src_parallel_distribution,
dst_parallel_distribution, lbi, blob_desc);
sub_tsk_gph_builder_ctx_->ConnectAll121(out_nodes, sorted_dst_comp_tasks);
CHECK_EQ(out_nodes.size(), sorted_dst_comp_tasks.size());
FOR_RANGE(size_t, i, 0, out_nodes.size()) {
ConnectWithLbi(out_nodes.at(i), sorted_dst_comp_tasks.at(i), lbi);
}
if (!sorted_ctrl_tasks.empty()) {
CHECK_EQ(sorted_ctrl_tasks.size(), sorted_dst_comp_tasks.size());
FOR_RANGE(size_t, i, 0, sorted_dst_comp_tasks.size()) {
......@@ -687,9 +737,9 @@ DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBoxing) {
DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByOneToOne) {
CHECK_EQ(sorted_src_comp_tasks.size(), sorted_dst_comp_tasks.size());
FOR_RANGE(size_t, i, 0, sorted_src_comp_tasks.size()) {
CompTaskNode* src = sorted_src_comp_tasks.at(i);
CompTaskNode* dst = sorted_dst_comp_tasks.at(i);
BuildTaskPath(src, dst, MutBufTask, true);
for (const LogicalBlobId& lbi : op_edge->lbis()) {
BuildTaskPath(sorted_src_comp_tasks.at(i), sorted_dst_comp_tasks.at(i), lbi);
}
}
}
......@@ -698,7 +748,9 @@ DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBroadcastToBroadcast) {
CompTaskNode* nearest_src_node =
SubTskGphBuilderUtil::FindNearestNode(sorted_src_comp_tasks, dst_node);
CHECK_NOTNULL(nearest_src_node);
BuildTaskPath(nearest_src_node, dst_node, MutBufTask, true);
for (const LogicalBlobId& lbi : op_edge->lbis()) {
BuildTaskPath(nearest_src_node, dst_node, lbi);
}
}
}
......@@ -712,7 +764,7 @@ DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByPartialInLbiConnect) {
FOR_RANGE(int, i, 0, sorted_dst_comp_tasks.size()) {
const auto& lbi = dst_op.BnInOp2Lbi(dst_op.input_bns().Get(i));
if (lbis.find(lbi) != lbis.end()) {
BuildTaskPath(sorted_src_comp_tasks.at(0), sorted_dst_comp_tasks.at(i), MutBufTask, true);
BuildTaskPath(sorted_src_comp_tasks.at(0), sorted_dst_comp_tasks.at(i), lbi);
}
}
}
......@@ -727,30 +779,31 @@ DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByPartialOutLbiConnect) {
FOR_RANGE(int, i, 0, sorted_src_comp_tasks.size()) {
const auto& lbi = src_op.BnInOp2Lbi(src_op.output_bns().Get(i));
if (lbis.find(lbi) != lbis.end()) {
BuildTaskPath(sorted_src_comp_tasks.at(i), sorted_dst_comp_tasks.at(0), MutBufTask, true);
BuildTaskPath(sorted_src_comp_tasks.at(i), sorted_dst_comp_tasks.at(0), lbi);
}
}
}
Maybe<void> TaskGraph::ConnectSrcSubsetTickEdges(const std::vector<CompTaskNode*>& src_task_nodes,
const std::vector<CompTaskNode*>& dst_task_nodes) {
DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphBySrcSubsetConnect) {
std::function<Maybe<CompTaskNode*>(int64_t mchn_id, int64_t thrd_id)> TaskNode4MachineId7ThrdId;
JUST(MakeGetterTaskNode4MachineId7ThrdId(src_task_nodes, &TaskNode4MachineId7ThrdId));
for (CompTaskNode* dst_task_node : dst_task_nodes) {
CompTaskNode* src_task_node =
JUST(TaskNode4MachineId7ThrdId(dst_task_node->machine_id(), dst_task_node->thrd_id()));
TaskEdge* edge = NewEdge();
Connect<TaskNode>(src_task_node, edge, dst_task_node);
CHECK_JUST(
MakeGetterTaskNode4MachineId7ThrdId(sorted_src_comp_tasks, &TaskNode4MachineId7ThrdId));
for (CompTaskNode* dst_task_node : sorted_dst_comp_tasks) {
CompTaskNode* src_task_node = CHECK_JUST(
TaskNode4MachineId7ThrdId(dst_task_node->machine_id(), dst_task_node->thrd_id()));
Connect<TaskNode>(src_task_node, NewTaskEdgeWithLbis(op_edge->lbis()), dst_task_node);
}
return Maybe<void>::Ok();
}
DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphBySrcSubsetConnect) {
CHECK_JUST(ConnectSrcSubsetTickEdges(sorted_src_comp_tasks, sorted_dst_comp_tasks));
}
DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByDstSubsetConnect) {
CHECK_JUST(ConnectDstSubsetTickEdges(sorted_src_comp_tasks, sorted_dst_comp_tasks));
std::function<Maybe<CompTaskNode*>(int64_t mchn_id, int64_t thrd_id)> TaskNode4MachineId7ThrdId;
CHECK_JUST(
MakeGetterTaskNode4MachineId7ThrdId(sorted_dst_comp_tasks, &TaskNode4MachineId7ThrdId));
for (CompTaskNode* src_task_node : sorted_src_comp_tasks) {
CompTaskNode* dst_task_node = CHECK_JUST(
TaskNode4MachineId7ThrdId(src_task_node->machine_id(), src_task_node->thrd_id()));
Connect<TaskNode>(src_task_node, NewTaskEdgeWithLbis(op_edge->lbis()), dst_task_node);
}
}
DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphNormalForwardToDecodeH2D) {
......@@ -758,103 +811,30 @@ DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphNormalForwardToDecodeH2D) {
FOR_RANGE(size_t, i, 0, sorted_src_comp_tasks.size()) {
CompTaskNode* src = sorted_src_comp_tasks.at(i);
CompTaskNode* dst = sorted_dst_comp_tasks.at(i);
Connect<TaskNode>(src, NewEdge(), dst);
for (const LogicalBlobId& lbi : op_edge->lbis()) { BuildTaskPath(src, dst, lbi); }
}
}
void TaskGraph::BuildTaskPath(
CompTaskNode* src, CompTaskNode* dst,
std::function<TaskNode**(CompTaskNode* src, int64_t machine_id, int32_t mem_zone_id)>
MutBufTask,
bool use_buf_task_node) {
CHECK_NE(src, dst);
auto GetBufTask = [&](int64_t machine_id, int32_t mem_zone_id) {
return *MutBufTask(src, machine_id, mem_zone_id);
};
auto SetBufTask = [&](int64_t machine_id, int32_t mem_zone_id, TaskNode* new_val) {
TaskNode** cur_val = MutBufTask(src, machine_id, mem_zone_id);
if (*cur_val == nullptr) {
*cur_val = new_val;
} else {
CHECK_EQ(*cur_val, new_val);
}
return new_val;
};
TaskNode* cur_node = src;
while (cur_node->machine_id() != dst->machine_id()
|| cur_node->MemZoneId121() != dst->MemZoneId121()) {
cur_node = BuildTaskStep(cur_node, dst, GetBufTask, SetBufTask, use_buf_task_node);
}
if (cur_node != dst) { Connect<TaskNode>(cur_node, NewEdge(), dst); }
}
TaskNode* TaskGraph::BuildTaskStep(
TaskNode* cur_node, TaskNode* dst,
const std::function<TaskNode*(int64_t machine_id, int32_t mem_zone_id)>& GetBufTask,
const std::function<TaskNode*(int64_t machine_id, int32_t mem_zone_id, TaskNode*)>& SetBufTask,
bool use_buf_task_node) {
int32_t cpu_mem_zone_id = Global<IDMgr>::Get()->CpuMemZoneId();
int32_t next_mem_zone_id = -1;
TaskNode* next_node = nullptr;
if (cur_node->MemZoneId121() != cpu_mem_zone_id) {
next_mem_zone_id = cpu_mem_zone_id;
if (!use_buf_task_node || !(next_node = GetBufTask(cur_node->machine_id(), next_mem_zone_id))) {
next_node = AddCopyD2HTaskFrom(cur_node);
Connect<TaskNode>(cur_node, NewEdge(), next_node);
}
} else if (cur_node->machine_id() == dst->machine_id()) {
next_mem_zone_id = dst->MemZoneId121();
if (!use_buf_task_node || !(next_node = GetBufTask(cur_node->machine_id(), next_mem_zone_id))) {
next_node = TryAddCopyH2DTaskTo(dst);
if (next_node == nullptr) { next_node = dst; }
Connect<TaskNode>(cur_node, NewEdge(), next_node);
void TaskGraph::ConnectWithLbi(TaskNode* src_node, TaskNode* dst_node, const LogicalBlobId& lbi) {
if (src_node == dst_node) { return; }
for (TaskEdge* out_edge : src_node->out_edges()) {
TaskNode* out_node = out_edge->dst_node();
if (out_node == dst_node) {
out_edge->AddLbi(lbi);
return;
}
} else if (cur_node->machine_id() != dst->machine_id()) {
next_mem_zone_id = cpu_mem_zone_id;
if (!use_buf_task_node || !(next_node = GetBufTask(dst->machine_id(), next_mem_zone_id))) {
next_node = AddCopyCommNetTaskBetween(cur_node, dst);
Connect<TaskNode>(cur_node, NewEdge(), next_node);
}
} else {
UNIMPLEMENTED();
}
if (use_buf_task_node && (next_node != dst)) {
SetBufTask(next_node->machine_id(), next_mem_zone_id, next_node);
}
return next_node;
}
TaskNode* TaskGraph::TryAddCopyH2DTaskTo(TaskNode* task) {
if (IsInterfaceTask(task)) { return nullptr; }
if (IsClassRegistered<int32_t, TickTockTaskType>(task->GetTaskType())) { return nullptr; }
CHECK_EQ(task->device_type(), DeviceType::kGPU);
CopyHdTaskNode* copy_task = NewNode<CopyHdTaskNode>();
copy_task->Init(CopyHdOpConf::H2D, task->machine_id(), task->GpuPhyId());
return copy_task;
TaskEdge* connected_edge = NewEdge();
connected_edge->AddLbi(lbi);
Connect<TaskNode>(src_node, connected_edge, dst_node);
}
TaskNode* TaskGraph::AddCopyD2HTaskFrom(TaskNode* task) {
CHECK_EQ(task->device_type(), DeviceType::kGPU);
CopyHdTaskNode* copy_task = NewNode<CopyHdTaskNode>();
copy_task->Init(CopyHdOpConf::D2H, task->machine_id(), task->GpuPhyId());
return copy_task;
}
TaskNode* TaskGraph::AddCopyCommNetTaskBetween(TaskNode* src, TaskNode* dst) {
CHECK_NE(src->machine_id(), dst->machine_id());
CopyCommNetTaskNode* copy_comm_net_task = NewNode<CopyCommNetTaskNode>();
copy_comm_net_task->Init(dst->machine_id());
return copy_comm_net_task;
}
void TaskGraph::ConnectWithCopyCommNetIfNeed(TaskNode* src, TaskNode* dst) {
if (src->machine_id() == dst->machine_id()) {
Connect(src, NewEdge(), dst);
} else {
TaskNode* copy_comm_net_task = AddCopyCommNetTaskBetween(src, dst);
Connect<TaskNode>(src, NewEdge(), copy_comm_net_task);
Connect<TaskNode>(copy_comm_net_task, NewEdge(), dst);
}
void TaskGraph::BuildTaskPath(TaskNode* src_node, TaskNode* dst_node, const LogicalBlobId& lbi) {
int64_t dst_machine_id = dst_node->machine_id();
int64_t dst_mem_zone_id = dst_node->MemZoneId121();
TaskNode* proxy_node = GetProxyNode(src_node, lbi, dst_machine_id, dst_mem_zone_id);
ConnectWithLbi(proxy_node, dst_node, lbi);
}
} // namespace oneflow
......@@ -30,11 +30,9 @@ namespace oneflow {
class SubTskGphBuilderCtx;
class HierarchicalSubTskGphBuilder;
#define BLD_SUB_TSK_GPH_MTHD_ARGS() \
(const OpEdge* op_edge, const std::vector<CompTaskNode*>& sorted_src_comp_tasks, \
const std::vector<CompTaskNode*>& sorted_dst_comp_tasks, \
std::function<TaskNode**(CompTaskNode * src, int64_t machine_id, int32_t mem_zone_id)> \
MutBufTask)
#define BLD_SUB_TSK_GPH_MTHD_ARGS() \
(const OpEdge* op_edge, const std::vector<CompTaskNode*>& sorted_src_comp_tasks, \
const std::vector<CompTaskNode*>& sorted_dst_comp_tasks)
class TaskGraph;
using BldSubTskGphMthd = void(TaskGraph::*) BLD_SUB_TSK_GPH_MTHD_ARGS();
......@@ -53,6 +51,17 @@ class TaskGraph final : public Graph<TaskNode, TaskEdge> {
void EnableInplaceMemSharing(const std::function<bool(const std::string&, const std::string&)>&
IsOpNameDataOrCtrlReachable);
TaskNode* GetProxyNode(TaskNode* src_node, const LogicalBlobId& lbi, int64_t dst_machine_id,
int64_t dst_mem_zone_id);
TaskNode* GetProxyNode(TaskNode* src_node, const LogicalBlobId& lbi,
const ParallelDesc& dst_parallel_desc, int64_t dst_parallel_id);
TaskEdge* NewTaskEdgeWithLbi(const LogicalBlobId& lbi);
TaskEdge* NewTaskEdgeWithLbis(const std::vector<LogicalBlobId>& lbis);
void ConnectWithLbi(TaskNode* src_node, TaskNode* dst_node, const LogicalBlobId& lbi);
#define DECLARE_BLD_SUB_TASK_GRAPH_METHOD(method_name) void method_name BLD_SUB_TSK_GPH_MTHD_ARGS();
DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBoxing);
......@@ -65,26 +74,8 @@ class TaskGraph final : public Graph<TaskNode, TaskEdge> {
DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphNormalForwardToDecodeH2D);
private:
void BuildTaskPath(
CompTaskNode* src, CompTaskNode* dst,
std::function<TaskNode**(CompTaskNode* src, int64_t machine_id, int32_t mem_zone_id)>
MutBufTask,
bool use_buf_task_node);
TaskNode* BuildTaskStep(
TaskNode* cur_node, TaskNode* dst,
const std::function<TaskNode*(int64_t machine_id, int32_t mem_zone_id)>& GetBufTask,
const std::function<TaskNode*(int64_t machine_id, int32_t mem_zone_id, TaskNode*)>&
SetBufTask,
bool use_buf_task_node);
TaskNode* TryAddCopyH2DTaskTo(TaskNode*);
TaskNode* AddCopyD2HTaskFrom(TaskNode*);
TaskNode* AddCopyCommNetTaskBetween(TaskNode* src, TaskNode* dst);
void ConnectWithCopyCommNetIfNeed(TaskNode* src, TaskNode* dst);
Maybe<void> ConnectSrcSubsetTickEdges(const std::vector<CompTaskNode*>& src_task_nodes,
const std::vector<CompTaskNode*>& dst_task_nodes);
Maybe<void> ConnectDstSubsetTickEdges(const std::vector<CompTaskNode*>& src_task_nodes,
const std::vector<CompTaskNode*>& dst_task_nodes);
void BuildTaskPath(TaskNode* src_node, TaskNode* dst_node, const LogicalBlobId& lbi);
void ConnectCtrlEdges(const std::vector<CompTaskNode*>& src_task_nodes,
const std::vector<CompTaskNode*>& dst_task_nodes, int64_t ctrl_regst_num);
......@@ -109,6 +100,30 @@ class TaskGraph final : public Graph<TaskNode, TaskEdge> {
std::unique_ptr<HierarchicalSubTskGphBuilder> hierarchical_sub_tsk_gph_builder_;
std::unique_ptr<SubTskGphBuilderCtx> sub_tsk_gph_builder_ctx_;
std::unique_ptr<BoxingLogger> boxing_logger_;
struct ProxyKey {
TaskNode* src_node;
LogicalBlobId lbi;
int64_t dst_machine_id;
int64_t dst_mem_zone_id;
ProxyKey(TaskNode* src, const LogicalBlobId& arg_lbi, int64_t arg_machine, int64_t arg_zone)
: src_node(src), lbi(arg_lbi), dst_machine_id(arg_machine), dst_mem_zone_id(arg_zone) {}
bool operator==(const ProxyKey& other) const {
return src_node == other.src_node && lbi == other.lbi
&& dst_machine_id == other.dst_machine_id && dst_mem_zone_id == other.dst_mem_zone_id;
}
struct Hasher {
inline size_t operator()(const ProxyKey& key) const {
return std::hash<TaskNode*>{}(key.src_node) ^ std::hash<LogicalBlobId>{}(key.lbi)
^ key.dst_machine_id ^ key.dst_mem_zone_id;
}
};
};
HashMap<ProxyKey, TaskNode*, ProxyKey::Hasher> proxy2node;
};
} // namespace oneflow
......
......@@ -354,7 +354,13 @@ void TaskNode::TryLockConsumedRegst(const std::string& name) {
}
void TaskNode::LockRegsts() {
for (auto& pair : produced_regsts_) { pair.second->Lock(); }
for (auto& pair : produced_regsts_) {
std::shared_ptr<RegstDesc> regst = pair.second;
regst->Lock();
// NOTE(chengcheng): CHECK 1 regst 1 blob.
if (regst->regst_desc_type().has_data_regst_desc()) { CHECK_LE(regst->NumOfLbi(), 1); }
}
}
void TaskNode::UpdateTaskId() {
......@@ -397,6 +403,30 @@ void TaskEdge::AddRegst(const std::string& name_in_producer,
CHECK(name_in_producer2regst_.emplace(name_in_producer, regst).second);
}
void TaskEdge::CheckRegstLbiValid() const {
HashMap<LogicalBlobId, std::shared_ptr<RegstDesc>> lbi2data_regst;
for (auto& pair : name_in_producer2regst_) {
std::shared_ptr<RegstDesc> regst = pair.second;
if (regst->regst_desc_type().has_data_regst_desc()) {
// NOTE(chengcheng): regst_desc_type is Set, BUT regst_desc_type.data_regst_desc is UNSET!
// So you can ONLY use NumOfLbi and ForEachLbi interface.
CHECK_EQ(regst->NumOfLbi(), 1);
regst->ForEachLbi(
[&](const LogicalBlobId& lbi) { CHECK(lbi2data_regst.emplace(lbi, regst).second); });
}
}
CHECK_EQ(lbi2data_regst.size(), lbis_.size())
<< " \n\n TaskEdge lbi and regst NOT match."
<< " TaskEdge: edge_id = " << edge_id() << " From: [" << src_node()->VisualStr() << "] To: ["
<< dst_node()->VisualStr() << "]\n";
for (auto& lbi : lbis_) {
CHECK(lbi2data_regst.find(lbi) != lbi2data_regst.end())
<< " \n\n Cannot find lbi: " << lbi.DebugString() << " in TaskEdge From: ["
<< src_node()->VisualStr() << "] To: [" << dst_node()->VisualStr() << "]\n\n";
}
}
RegstDescProto* FindOrCreateProducedCtrlRegstDesc(TaskProto* task_proto,
const std::string& regst_desc_name) {
auto* produced_regst_desc = task_proto->mutable_produced_regst_desc();
......
......@@ -169,10 +169,16 @@ class TaskEdge final : public Edge<TaskNode, TaskEdge> {
std::shared_ptr<RegstDesc> GetRegst(const std::string& name_in_producer) const;
std::shared_ptr<RegstDesc> GetSoleRegst() const;
std::vector<std::shared_ptr<RegstDesc>> GetRegsts() const;
const HashSet<LogicalBlobId>& GetLbis() const { return lbis_; }
void AddRegst(const std::string& name_in_producer, const std::shared_ptr<RegstDesc>& regst);
void AddLbi(const LogicalBlobId& lbi) { lbis_.insert(lbi); }
void AddLbis(const std::vector<LogicalBlobId>& lbis) { lbis_.insert(lbis.begin(), lbis.end()); }
void CheckRegstLbiValid() const;
private:
HashSet<LogicalBlobId> lbis_;
HashMap<std::string, std::shared_ptr<RegstDesc>> name_in_producer2regst_;
};
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_GRAPH_TRANSPORT_TASK_NODE_H_
#define ONEFLOW_CORE_GRAPH_TRANSPORT_TASK_NODE_H_
#include "oneflow/core/graph/task_node.h"
#include "oneflow/core/register/logical_blob_id.pb.h"
namespace oneflow {
class TransportTaskNode : public TaskNode {
public:
OF_DISALLOW_COPY_AND_MOVE(TransportTaskNode);
TransportTaskNode() = default;
virtual ~TransportTaskNode() = default;
void set_lbi(const LogicalBlobId& lbi) { lbi_ = lbi; }
LogicalBlobId lbi() const { return lbi_; }
private:
LogicalBlobId lbi_;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_TRANSPORT_TASK_NODE_H_
......@@ -32,27 +32,12 @@ size_t RegstNum4OpSameOutputBlob(OperatorConf::OpTypeCase op_type_case) {
}
}
std::string GetOutRegstNameByObn(const std::string& obn) {
return "NormalForwardCompTaskNodeOutRegstName_" + obn;
}
std::string GetOutRegstNameByObn(const std::string& obn) { return "__" + obn; }
} // namespace
bool NormalForwardCompTaskNode::HasBackwardCompTaskNode() { return false; }
bool NormalForwardCompTaskNode::CanProduceSeperatedRegstsForEachOutBlob() const {
return op()->output_bns().size() > 1 && IsAllOutNodeNormalForward();
}
bool NormalForwardCompTaskNode::IsAllOutNodeNormalForward() const {
bool ret = true;
ForEachNodeOnOutDataEdge([&](TaskNode* node) {
auto* fw_node = dynamic_cast<NormalForwardCompTaskNode*>(node);
if (fw_node == nullptr) { ret = false; }
});
return ret;
}
void NormalForwardCompTaskNode::ProduceOutRegstByNameAndBlockNum(const std::string& name,
size_t mem_block_num) {
if (mem_block_num != -1) {
......@@ -76,33 +61,21 @@ void NormalForwardCompTaskNode::ProduceAllRegstsAndBindEdges() {
}
// when output blob num > 1 and task node on out edge is all NormalForwardCompTaskNode ,
// create multi out regst by output blob name in op
if (CanProduceSeperatedRegstsForEachOutBlob()) {
HashMap<LogicalBlobId, std::string> lbi2out_regst_name;
for (const std::string& obn : sole_op->output_bns()) {
const LogicalBlobId& lbi = sole_op->BnInOp2Lbi(obn);
std::string out_regst_name = GetOutRegstNameByObn(obn);
lbi2out_regst_name.insert({lbi, out_regst_name});
ProduceOutRegstByNameAndBlockNum(out_regst_name, mem_block_num);
}
ForEachOutDataEdge([&](TaskEdge* edge) {
TaskNode* node = edge->dst_node();
auto* dst_node = dynamic_cast<NormalForwardCompTaskNode*>(node);
CHECK(dst_node != nullptr) << "1regst1blob ONLY support normal fw comp task node 121";
std::shared_ptr<const Operator> dst_op = dst_node->op();
bool is_found = false;
for (const std::string& ibn : dst_op->input_bns()) {
const LogicalBlobId& dst_in_lbi = dst_op->BnInOp2Lbi(ibn);
if (lbi2out_regst_name.find(dst_in_lbi) != lbi2out_regst_name.end()) {
is_found = true;
BindEdgeWithProducedRegst(edge, lbi2out_regst_name.at(dst_in_lbi));
}
}
CHECK(is_found) << "Cannot find comsumed blob in dst op: " << dst_op->op_name();
});
} else {
ProduceOutRegstByNameAndBlockNum("out", mem_block_num);
ForEachOutDataEdge([&](TaskEdge* edge) { BindEdgeWithProducedRegst(edge, "out"); });
HashMap<LogicalBlobId, std::string> lbi2out_regst_name;
for (const std::string& obn : sole_op->output_bns()) {
const LogicalBlobId& lbi = sole_op->BnInOp2Lbi(obn);
std::string out_regst_name = GetOutRegstNameByObn(obn);
lbi2out_regst_name.insert({lbi, out_regst_name});
ProduceOutRegstByNameAndBlockNum(out_regst_name, mem_block_num);
}
ForEachOutDataEdge([&](TaskEdge* edge) {
for (const LogicalBlobId& lbi : edge->GetLbis()) {
auto it = lbi2out_regst_name.find(lbi);
CHECK(it != lbi2out_regst_name.end());
BindEdgeWithProducedRegst(edge, it->second);
}
});
ProduceRegst("tmp", true);
}
......@@ -127,45 +100,21 @@ void NormalForwardCompTaskNode::BuildExecGphAndRegst() {
}
void NormalForwardCompTaskNode::BuildExecGphStructAndBindInRegst() {
HashMap<LogicalBlobId, std::pair<ExecNode*, std::string>> lbi2producer;
ExecNode* cur_node = mut_exec_gph().NewNode();
cur_node->mut_op() = op();
for (const std::string& obn : op()->output_bns()) {
const LogicalBlobId& lbi = op()->BnInOp2Lbi(obn);
CHECK(lbi2producer.insert({lbi, {cur_node, obn}}).second);
}
const std::list<std::shared_ptr<RegstDesc>>& in_regsts = GetConsumedRegst("in");
for (const std::string& ibn : cur_node->op()->input_bns()) {
const LogicalBlobId& lbi = cur_node->op()->BnInOp2Lbi(ibn);
auto producer_it = lbi2producer.find(lbi);
if (producer_it != lbi2producer.end()) {
ExecEdge* edge = mut_exec_gph().NewEdge();
edge->set_lbi(lbi);
edge->mut_src_bn() = producer_it->second.second;
edge->mut_dst_bn() = ibn;
Connect(producer_it->second.first, edge, cur_node);
} else {
cur_node->BindBnWithOneOfTheRegsts(ibn, in_regsts);
}
cur_node->BindBnWithOneOfTheRegsts(ibn, in_regsts);
}
}
void NormalForwardCompTaskNode::BuildOutRegst() {
if (CanProduceSeperatedRegstsForEachOutBlob()) {
ExecNode* exec_node = mut_exec_gph().SoleNode();
for (const std::string& obn : exec_node->op()->output_bns()) {
std::string out_regst_name = GetOutRegstNameByObn(obn);
std::shared_ptr<RegstDesc> out_regst = GetProducedRegst(out_regst_name);
out_regst->AddLbi(exec_node->op()->BnInOp2Lbi(obn));
exec_node->BindBnWithRegst(obn, out_regst);
}
} else {
std::shared_ptr<RegstDesc> out_regst = GetProducedRegst("out");
ExecNode* exec_node = mut_exec_gph().SoleNode();
for (const std::string& obn : exec_node->op()->output_bns()) {
out_regst->AddLbi(exec_node->op()->BnInOp2Lbi(obn));
exec_node->BindBnWithRegst(obn, out_regst);
}
ExecNode* exec_node = mut_exec_gph().SoleNode();
for (const std::string& obn : exec_node->op()->output_bns()) {
std::string out_regst_name = GetOutRegstNameByObn(obn);
std::shared_ptr<RegstDesc> out_regst = GetProducedRegst(out_regst_name);
out_regst->AddLbi(exec_node->op()->BnInOp2Lbi(obn));
exec_node->BindBnWithRegst(obn, out_regst);
}
}
......
......@@ -87,6 +87,8 @@ void Compiler::Compile(Job* job, Plan* plan, bool need_job_complete) const {
}
task_gph->TopoForEachNode(&TaskNode::InferTimeShapeIfMeaningful);
task_gph->ForEachEdge([&](TaskEdge* task_edge) { task_edge->CheckRegstLbiValid(); });
task_gph->ForEachNode([&](TaskNode* task_node) {
if (task_node->IsMeaningLess()) { return; }
task_node->ToProto(plan->mutable_task()->Add());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册