未验证 提交 e9a20bfe 编写于 作者: G guo ran 提交者: GitHub

Add hierarchical_sub_task_graph_builder (#4393)

* Add hierarchical_sub_task_graph_builder

* fix

* fix

* fix

* refine

* refine

* refine

* refine
Co-authored-by: Ncheng cheng <472491134@qq.com>
上级 c2b57ffd
......@@ -21,54 +21,59 @@ namespace oneflow {
namespace {
#define OF_BOXING_LOGGER_CSV_COLNUM_NAME_FIELD \
"src_op_name,dst_op_name,src_parallel_conf,dst_parallel_conf," \
"src_sbp_conf,dst_sbp_conf,lbi,dtype,shape,builder,comment\n"
"src_op_name,dst_op_name,src_parallel_desc,dst_parallel_desc," \
"src_parallel_distribution," \
"dst_parallel_distribution,lbi,dtype,shape,builder,comment\n"
std::string ParallelDescToString(const ParallelDesc& parallel_desc) {
std::string serialized_parallel_desc;
std::string device_type;
device_type = *CHECK_JUST(DeviceTag4DeviceType(parallel_desc.device_type()));
auto sorted_machine_ids = parallel_desc.sorted_machine_ids();
serialized_parallel_desc += "{";
for (int64_t i = 0; i < sorted_machine_ids.size(); ++i) {
const int64_t machine_id = sorted_machine_ids.at(i);
serialized_parallel_desc += std::to_string(machine_id) + ":" + device_type + ":";
int64_t min_id = parallel_desc.sorted_dev_phy_ids(machine_id).front();
int64_t max_id = parallel_desc.sorted_dev_phy_ids(machine_id).back();
serialized_parallel_desc += std::to_string(min_id) + "-" + std::to_string(max_id);
if (i != sorted_machine_ids.size() - 1) { serialized_parallel_desc += " "; }
serialized_parallel_desc += " ";
}
serialized_parallel_desc += parallel_desc.hierarchy()->DebugStr();
serialized_parallel_desc += "}";
return serialized_parallel_desc;
}
std::string ShapeToString(const Shape& shape) {
std::stringstream shape_ss;
auto dim_vec = shape.dim_vec();
shape_ss << "[";
for (int32_t i = 0; i < dim_vec.size(); ++i) {
shape_ss << dim_vec.at(i);
if (i != dim_vec.size() - 1) { shape_ss << " "; }
std::string ParallelDistributionToString(const ParallelDistribution& parallel_distribution) {
std::string serialized_parallel_distribution;
const int64_t num_axes = parallel_distribution.sbp_parallel_size();
serialized_parallel_distribution += "[";
for (int64_t i = 0; i < num_axes - 1; ++i) {
serialized_parallel_distribution +=
SbpParallelToString(parallel_distribution.sbp_parallel(i)) + " ";
}
shape_ss << "]";
return shape_ss.str();
serialized_parallel_distribution +=
SbpParallelToString(parallel_distribution.sbp_parallel(num_axes - 1)) + "]";
return serialized_parallel_distribution;
}
std::string MakeBoxingLoggerCsvRow(const SubTskGphBuilderStatus& status,
const std::string& src_op_name, const std::string& dst_op_name,
const ParallelDesc& src_parallel_desc,
const ParallelDesc& dst_parallel_desc,
const SbpParallel& src_sbp_parallel,
const SbpParallel& dst_sbp_parallel, const LogicalBlobId& lbi,
const BlobDesc& logical_blob_desc) {
const ParallelDistribution& src_parallel_distribution,
const ParallelDistribution& dst_parallel_distribution,
const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc) {
std::string serialized_status;
serialized_status += src_op_name + ",";
serialized_status += dst_op_name + ",";
serialized_status += ParallelDescToString(src_parallel_desc) + ",";
serialized_status += ParallelDescToString(dst_parallel_desc) + ",";
serialized_status += SbpParallelToString(src_sbp_parallel) + ",";
serialized_status += SbpParallelToString(dst_sbp_parallel) + ",";
serialized_status += ParallelDistributionToString(src_parallel_distribution) + ",";
serialized_status += ParallelDistributionToString(dst_parallel_distribution) + ",";
serialized_status += GenLogicalBlobName(lbi) + ",";
serialized_status += DataType_Name(logical_blob_desc.data_type()) + ",";
serialized_status += ShapeToString(logical_blob_desc.shape()) + ",";
serialized_status += logical_blob_desc.shape().DebugStr() + ",";
serialized_status += status.builder_name() + ",";
if (status.comment().empty()) {
serialized_status += "-";
......@@ -91,11 +96,12 @@ CsvBoxingLogger::~CsvBoxingLogger() { log_stream_->Flush(); }
void CsvBoxingLogger::Log(const SubTskGphBuilderStatus& status, const std::string& src_op_name,
const std::string& dst_op_name, const ParallelDesc& src_parallel_desc,
const ParallelDesc& dst_parallel_desc,
const SbpParallel& src_sbp_parallel, const SbpParallel& dst_sbp_parallel,
const ParallelDistribution& src_parallel_distribution,
const ParallelDistribution& dst_parallel_distribution,
const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc) {
log_stream_ << MakeBoxingLoggerCsvRow(status, src_op_name, dst_op_name, src_parallel_desc,
dst_parallel_desc, src_sbp_parallel, dst_sbp_parallel, lbi,
logical_blob_desc);
dst_parallel_desc, src_parallel_distribution,
dst_parallel_distribution, lbi, logical_blob_desc);
}
} // namespace oneflow
......@@ -29,8 +29,9 @@ class BoxingLogger {
virtual void Log(const SubTskGphBuilderStatus& status, const std::string& src_op_name,
const std::string& dst_op_name, const ParallelDesc& src_parallel_desc,
const ParallelDesc& dst_parallel_desc, const SbpParallel& src_sbp_parallel,
const SbpParallel& dst_sbp_parallel, const LogicalBlobId& lbi,
const ParallelDesc& dst_parallel_desc,
const ParallelDistribution& src_parallel_distribution,
const ParallelDistribution& dst_parallel_distribution, const LogicalBlobId& lbi,
const BlobDesc& logical_blob_desc) = 0;
};
......@@ -42,8 +43,9 @@ class NullBoxingLogger final : public BoxingLogger {
void Log(const SubTskGphBuilderStatus& status, const std::string& src_op_name,
const std::string& dst_op_name, const ParallelDesc& src_parallel_desc,
const ParallelDesc& dst_parallel_desc, const SbpParallel& src_sbp_parallel,
const SbpParallel& dst_sbp_parallel, const LogicalBlobId& lbi,
const ParallelDesc& dst_parallel_desc,
const ParallelDistribution& src_parallel_distribution,
const ParallelDistribution& dst_parallel_distribution, const LogicalBlobId& lbi,
const BlobDesc& logical_blob_desc) override{};
};
......@@ -56,8 +58,9 @@ class CsvBoxingLogger final : public BoxingLogger {
void Log(const SubTskGphBuilderStatus& status, const std::string& src_op_name,
const std::string& dst_op_name, const ParallelDesc& src_parallel_desc,
const ParallelDesc& dst_parallel_desc, const SbpParallel& src_sbp_parallel,
const SbpParallel& dst_sbp_parallel, const LogicalBlobId& lbi,
const ParallelDesc& dst_parallel_desc,
const ParallelDistribution& src_parallel_distribution,
const ParallelDistribution& dst_parallel_distribution, const LogicalBlobId& lbi,
const BlobDesc& logical_blob_desc) override;
private:
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_GRAPH_BOXING_HIERARCHICAL_SUB_TASK_GRAPH_BUILDER_H_
#define ONEFLOW_CORE_GRAPH_BOXING_HIERARCHICAL_SUB_TASK_GRAPH_BUILDER_H_
#include "oneflow/core/common/util.h"
#include "oneflow/core/graph/boxing/sub_task_graph_builder_context.h"
#include "oneflow/core/graph/boxing/sub_task_graph_builder_status_util.h"
namespace oneflow {
class HierarchicalSubTskGphBuilder {
public:
OF_DISALLOW_COPY_AND_MOVE(HierarchicalSubTskGphBuilder);
HierarchicalSubTskGphBuilder() = default;
virtual ~HierarchicalSubTskGphBuilder() = default;
virtual Maybe<SubTskGphBuilderStatus> Build(
SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks,
std::vector<TaskNode*>* sorted_out_tasks,
std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc,
const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi,
const BlobDesc& logical_blob_desc, const ParallelDistribution& in_parallel_distribution,
const ParallelDistribution& out_parallel_distribution, const Shape& time_shape) const = 0;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_BOXING_HIERARCHICAL_SUB_TASK_GRAPH_BUILDER_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_impl.h"
#include "oneflow/core/graph/boxing/sub_task_graph_builder.h"
#include "oneflow/core/graph/boxing/chain_sub_task_graph_builder.h"
#include "oneflow/core/graph/boxing/collective_boxing_sub_task_graph_builder.h"
#include "oneflow/core/graph/boxing/slice_boxing_sub_task_graph_builder.h"
#include "oneflow/core/graph/boxing/naive_b2b_sub_task_graph_builder.h"
#include "oneflow/core/graph/boxing/naive_b2p_sub_task_graph_builder.h"
#include "oneflow/core/graph/boxing/b21_sub_task_graph_builder.h"
#include "oneflow/core/graph/boxing/one_to_one_sub_task_graph_builder.h"
#include "oneflow/core/graph/boxing/sub_task_graph_builder_util.h"
namespace oneflow {
namespace {
void ParallelDimReduce(const ParallelDesc& parallel_desc,
const ParallelDistribution& parallel_distribution,
ParallelDesc* reduced_parallel_desc,
ParallelDistribution* reduced_parallel_distribution) {
const auto& hierarchy = parallel_desc.hierarchy();
DimVector reduced_hierarchy;
reduced_hierarchy.push_back(hierarchy->At(0));
*reduced_parallel_distribution->add_sbp_parallel() = parallel_distribution.sbp_parallel(0);
FOR_RANGE(int64_t, i, 1, hierarchy->NumAxes()) {
if (parallel_distribution.sbp_parallel(i) == parallel_distribution.sbp_parallel(i - 1)) {
reduced_hierarchy.back() *= hierarchy->At(i);
} else {
reduced_hierarchy.push_back(hierarchy->At(i));
*reduced_parallel_distribution->add_sbp_parallel() = parallel_distribution.sbp_parallel(i);
}
}
ParallelConf reduced_parallel_conf = parallel_desc.parallel_conf();
Shape(reduced_hierarchy).ToProto(reduced_parallel_conf.mutable_hierarchy());
*reduced_parallel_desc = ParallelDesc(reduced_parallel_conf);
}
void CollaborativeParallelDimReduce(const ParallelDesc& in_parallel_desc,
const ParallelDesc& out_parallel_desc,
const ParallelDistribution& in_parallel_distribution,
const ParallelDistribution& out_parallel_distribution,
ParallelDesc* reduced_in_parallel_desc,
ParallelDesc* reduced_out_parallel_desc,
ParallelDistribution* reduced_in_parallel_distribution,
ParallelDistribution* reduced_out_parallel_distribution) {
const auto& in_hierarchy = in_parallel_desc.hierarchy();
const auto& out_hierarchy = out_parallel_desc.hierarchy();
CHECK_EQ(in_hierarchy->NumAxes(), out_hierarchy->NumAxes());
DimVector reduced_in_hierarchy;
reduced_in_hierarchy.push_back(in_hierarchy->At(0));
*reduced_in_parallel_distribution->add_sbp_parallel() = in_parallel_distribution.sbp_parallel(0);
DimVector reduced_out_hierarchy;
reduced_out_hierarchy.push_back(out_hierarchy->At(0));
*reduced_out_parallel_distribution->add_sbp_parallel() =
out_parallel_distribution.sbp_parallel(0);
FOR_RANGE(int64_t, i, 1, in_hierarchy->NumAxes()) {
if ((in_parallel_distribution.sbp_parallel(i) == in_parallel_distribution.sbp_parallel(i - 1))
&& (out_parallel_distribution.sbp_parallel(i)
== out_parallel_distribution.sbp_parallel(i - 1))) {
reduced_in_hierarchy.back() *= in_hierarchy->At(i);
reduced_out_hierarchy.back() *= out_hierarchy->At(i);
} else {
reduced_in_hierarchy.push_back(in_hierarchy->At(i));
*reduced_in_parallel_distribution->add_sbp_parallel() =
in_parallel_distribution.sbp_parallel(i);
reduced_out_hierarchy.push_back(out_hierarchy->At(i));
*reduced_out_parallel_distribution->add_sbp_parallel() =
out_parallel_distribution.sbp_parallel(i);
}
}
ParallelConf reduced_in_parallel_conf = in_parallel_desc.parallel_conf();
Shape(reduced_in_hierarchy).ToProto(reduced_in_parallel_conf.mutable_hierarchy());
*reduced_in_parallel_desc = ParallelDesc(reduced_in_parallel_conf);
ParallelConf reduced_out_parallel_conf = out_parallel_desc.parallel_conf();
Shape(reduced_out_hierarchy).ToProto(reduced_out_parallel_conf.mutable_hierarchy());
*reduced_out_parallel_desc = ParallelDesc(reduced_out_parallel_conf);
}
void InOutParallelDimReduce(const ParallelDesc& in_parallel_desc,
const ParallelDesc& out_parallel_desc,
const ParallelDistribution& in_parallel_distribution,
const ParallelDistribution& out_parallel_distribution,
ParallelDesc* reduced_in_parallel_desc,
ParallelDesc* reduced_out_parallel_desc,
ParallelDistribution* reduced_in_parallel_distribution,
ParallelDistribution* reduced_out_parallel_distribution) {
const int64_t in_hierarchy_axes = in_parallel_desc.hierarchy()->NumAxes();
const int64_t out_hierarchy_axes = out_parallel_desc.hierarchy()->NumAxes();
if (in_hierarchy_axes == 1 && out_hierarchy_axes == 1) {
*reduced_in_parallel_desc = in_parallel_desc;
*reduced_out_parallel_desc = out_parallel_desc;
*reduced_in_parallel_distribution = in_parallel_distribution;
*reduced_out_parallel_distribution = out_parallel_distribution;
} else if (in_hierarchy_axes != out_hierarchy_axes) {
ParallelDimReduce(in_parallel_desc, in_parallel_distribution, reduced_in_parallel_desc,
reduced_in_parallel_distribution);
ParallelDimReduce(out_parallel_desc, out_parallel_distribution, reduced_out_parallel_desc,
reduced_out_parallel_distribution);
} else {
CollaborativeParallelDimReduce(in_parallel_desc, out_parallel_desc, in_parallel_distribution,
out_parallel_distribution, reduced_in_parallel_desc,
reduced_out_parallel_desc, reduced_in_parallel_distribution,
reduced_out_parallel_distribution);
}
}
} // namespace
class FlatSubTskGphBuilder final : public HierarchicalSubTskGphBuilder {
public:
OF_DISALLOW_COPY_AND_MOVE(FlatSubTskGphBuilder);
FlatSubTskGphBuilder() {
std::vector<std::shared_ptr<SubTskGphBuilder>> builders;
builders.emplace_back(new OneToOneSubTskGphBuilder());
builders.emplace_back(new B21SubTskGphBuilder());
if (!Global<ResourceDesc, ForSession>::Get()->nccl_use_compute_stream()) {
builders.emplace_back(new CollectiveBoxingSubTskGphBuilder());
}
builders.emplace_back(new SliceBoxingSubTskGphBuilder());
builders.emplace_back(new NaiveB2BSubTskGphBuilder());
builders.emplace_back(new NaiveB2PSubTskGphBuilder());
sub_tsk_gph_builder_.reset(new ChainSubTskGphBuilder(builders));
}
~FlatSubTskGphBuilder() = default;
Maybe<SubTskGphBuilderStatus> Build(SubTskGphBuilderCtx* ctx,
const std::vector<TaskNode*>& sorted_in_tasks,
std::vector<TaskNode*>* sorted_out_tasks,
std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks,
const ParallelDesc& in_parallel_desc,
const ParallelDesc& out_parallel_desc,
const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc,
const ParallelDistribution& in_parallel_distribution,
const ParallelDistribution& out_parallel_distribution,
const Shape& time_shape) const override {
return JUST(sub_tsk_gph_builder_->Build(
ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, in_parallel_desc,
out_parallel_desc, lbi, logical_blob_desc, in_parallel_distribution.sbp_parallel(0),
out_parallel_distribution.sbp_parallel(0), time_shape));
}
private:
std::unique_ptr<SubTskGphBuilder> sub_tsk_gph_builder_;
};
struct DispatchHierarchicalSubTskGphBuilder::Impl {
Impl();
std::unique_ptr<FlatSubTskGphBuilder> flat_sub_tsk_gph_builder_;
};
DispatchHierarchicalSubTskGphBuilder::Impl::Impl() {
flat_sub_tsk_gph_builder_.reset(new FlatSubTskGphBuilder());
}
DispatchHierarchicalSubTskGphBuilder::DispatchHierarchicalSubTskGphBuilder() {
impl_.reset(new Impl());
}
DispatchHierarchicalSubTskGphBuilder::~DispatchHierarchicalSubTskGphBuilder() = default;
Maybe<SubTskGphBuilderStatus> DispatchHierarchicalSubTskGphBuilder::Build(
SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks,
std::vector<TaskNode*>* sorted_out_tasks,
std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc,
const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi,
const BlobDesc& logical_blob_desc, const ParallelDistribution& in_parallel_distribution,
const ParallelDistribution& out_parallel_distribution, const Shape& time_shape) const {
ParallelDesc reduced_in_parallel_desc = in_parallel_desc;
ParallelDesc reduced_out_parallel_desc = out_parallel_desc;
ParallelDistribution reduced_in_parallel_distribution;
ParallelDistribution reduced_out_parallel_distribution;
InOutParallelDimReduce(in_parallel_desc, out_parallel_desc, in_parallel_distribution,
out_parallel_distribution, &reduced_in_parallel_desc,
&reduced_out_parallel_desc, &reduced_in_parallel_distribution,
&reduced_out_parallel_distribution);
const auto& reduced_in_parallel_hierarchy = reduced_in_parallel_desc.hierarchy();
const auto& reduced_out_parallel_hierarchy = reduced_out_parallel_desc.hierarchy();
if (reduced_in_parallel_hierarchy->NumAxes() == 1
&& reduced_out_parallel_hierarchy->NumAxes() == 1) {
return impl_->flat_sub_tsk_gph_builder_->Build(
ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, reduced_in_parallel_desc,
reduced_out_parallel_desc, lbi, logical_blob_desc, reduced_in_parallel_distribution,
reduced_out_parallel_distribution, time_shape);
} else {
return Error::BoxingNotSupportedError();
}
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_GRAPH_BOXING_HIERARCHICAL_SUB_TASK_GRAPH_BUILDER_IMPL_H_
#define ONEFLOW_CORE_GRAPH_BOXING_HIERARCHICAL_SUB_TASK_GRAPH_BUILDER_IMPL_H_
#include "oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder.h"
namespace oneflow {
class DispatchHierarchicalSubTskGphBuilder final : public HierarchicalSubTskGphBuilder {
public:
OF_DISALLOW_COPY_AND_MOVE(DispatchHierarchicalSubTskGphBuilder);
DispatchHierarchicalSubTskGphBuilder();
~DispatchHierarchicalSubTskGphBuilder() override;
Maybe<SubTskGphBuilderStatus> Build(SubTskGphBuilderCtx* ctx,
const std::vector<TaskNode*>& sorted_in_tasks,
std::vector<TaskNode*>* sorted_out_tasks,
std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks,
const ParallelDesc& in_parallel_desc,
const ParallelDesc& out_parallel_desc,
const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc,
const ParallelDistribution& in_parallel_distribution,
const ParallelDistribution& out_parallel_distribution,
const Shape& time_shape) const override;
private:
struct Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_BOXING_HIERARCHICAL_SUB_TASK_GRAPH_BUILDER_IMPL_H_
......@@ -21,20 +21,12 @@ limitations under the License.
#include "oneflow/core/operator/variable_op.h"
#include "oneflow/core/graph/op_graph.h"
#include "oneflow/core/graph/normal_forward_compute_task_node.h"
#include "oneflow/core/graph/boxing/sub_task_graph_builder_context.h"
#include "oneflow/core/graph/boxing/sub_task_graph_builder.h"
#include "oneflow/core/graph/boxing/chain_sub_task_graph_builder.h"
#include "oneflow/core/graph/boxing/collective_boxing_sub_task_graph_builder.h"
#include "oneflow/core/graph/boxing/slice_boxing_sub_task_graph_builder.h"
#include "oneflow/core/graph/boxing/naive_b2b_sub_task_graph_builder.h"
#include "oneflow/core/graph/boxing/naive_b2p_sub_task_graph_builder.h"
#include "oneflow/core/graph/boxing/b21_sub_task_graph_builder.h"
#include "oneflow/core/graph/boxing/one_to_one_sub_task_graph_builder.h"
#include "oneflow/core/graph/boxing/sub_task_graph_builder_util.h"
#include "oneflow/core/graph/boxing_identity_task_node.h"
#include "oneflow/core/job/scope.h"
#include "oneflow/core/vm/symbol_storage.h"
#include "oneflow/core/job_rewriter/calculation_pass.h"
#include "oneflow/core/graph/boxing/sub_task_graph_builder_util.h"
#include "oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_impl.h"
namespace oneflow {
......@@ -255,16 +247,7 @@ TaskGraph::TaskGraph(std::unique_ptr<const LogicalGraph>&& logical_gph) {
logical_gph_ = std::move(logical_gph);
sub_tsk_gph_builder_ctx_.reset(new SubTskGphBuilderCtx(this));
boxing_logger_ = CreateBoxingLogger();
std::vector<std::shared_ptr<SubTskGphBuilder>> builders;
builders.emplace_back(new OneToOneSubTskGphBuilder());
builders.emplace_back(new B21SubTskGphBuilder());
if (!Global<ResourceDesc, ForSession>::Get()->nccl_use_compute_stream()) {
builders.emplace_back(new CollectiveBoxingSubTskGphBuilder());
}
builders.emplace_back(new SliceBoxingSubTskGphBuilder());
builders.emplace_back(new NaiveB2BSubTskGphBuilder());
builders.emplace_back(new NaiveB2PSubTskGphBuilder());
sub_tsk_gph_builder_.reset(new ChainSubTskGphBuilder(builders));
hierarchical_sub_tsk_gph_builder_.reset(new DispatchHierarchicalSubTskGphBuilder());
HashMap<const LogicalNode*, std::vector<CompTaskNode*>> logical2sorted_comp_tasks;
HashMap<CompTaskNode*, HashMap<int64_t, std::vector<TaskNode*>>> buf_task;
auto MutBufTask = [&](CompTaskNode* task_node, int64_t machine_id, int32_t mem_zone_id) {
......@@ -306,6 +289,8 @@ TaskGraph::TaskGraph(std::unique_ptr<const LogicalGraph>&& logical_gph) {
if (Global<ResourceDesc, ForSession>::Get()->enable_debug_mode()) { ToDotWithAutoFilePath(); }
}
TaskGraph::~TaskGraph() = default;
Maybe<void> TaskGraph::ConnectDstSubsetTickEdges(const std::vector<CompTaskNode*>& src_task_nodes,
const std::vector<CompTaskNode*>& dst_task_nodes) {
std::function<Maybe<CompTaskNode*>(int64_t mchn_id, int64_t thrd_id)> TaskNode4MachineId7ThrdId;
......@@ -523,27 +508,30 @@ DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBoxing) {
std::vector<TaskNode*> out_nodes;
out_nodes.reserve(sorted_dst_comp_tasks.size());
std::vector<std::vector<TaskNode*>> sorted_ctrl_tasks;
const SbpParallel& src_sbp_parallel =
Global<OpGraph>::Get()->GetSbpParallel(src_logical->SoleOp()->op_name(), lbi);
const SbpParallel& dst_sbp_parallel =
Global<OpGraph>::Get()->GetSbpParallel(dst_logical->SoleOp()->op_name(), lbi);
const ParallelDistribution& src_parallel_distribution =
Global<OpGraph>::Get()->GetParallelDistribution(src_logical->SoleOp()->op_name(), lbi);
const ParallelDistribution& dst_parallel_distribution =
Global<OpGraph>::Get()->GetParallelDistribution(dst_logical->SoleOp()->op_name(), lbi);
const std::shared_ptr<const ParallelDesc>& src_parallel_desc = src_logical->parallel_desc();
const std::shared_ptr<const ParallelDesc>& dst_parallel_desc = dst_logical->parallel_desc();
const BlobDesc& blob_desc = Global<OpGraph>::Get()->GetLogicalBlobDesc(lbi);
auto status = CHECK_JUST(sub_tsk_gph_builder_->Build(
auto status = CHECK_JUST(hierarchical_sub_tsk_gph_builder_->Build(
sub_tsk_gph_builder_ctx_.get(), in_nodes, &out_nodes, &sorted_ctrl_tasks,
*src_parallel_desc, *dst_parallel_desc, lbi, blob_desc, src_sbp_parallel, dst_sbp_parallel,
*src_logical->out_blob_time_shape()));
*src_parallel_desc, *dst_parallel_desc, lbi, blob_desc, src_parallel_distribution,
dst_parallel_distribution, *src_logical->out_blob_time_shape()));
boxing_logger_->Log(*status, src_logical->SoleOp()->op_name(), dst_logical->SoleOp()->op_name(),
*src_parallel_desc, *dst_parallel_desc, src_sbp_parallel, dst_sbp_parallel,
lbi, blob_desc);
*src_parallel_desc, *dst_parallel_desc, src_parallel_distribution,
dst_parallel_distribution, lbi, blob_desc);
sub_tsk_gph_builder_ctx_->ConnectAll121(out_nodes, sorted_dst_comp_tasks);
if (!sorted_ctrl_tasks.empty()) {
CHECK_EQ(sorted_ctrl_tasks.size(), sorted_dst_comp_tasks.size());
FOR_RANGE(size_t, i, 0, sorted_dst_comp_tasks.size()) {
for (TaskNode* ctrl_node : sorted_ctrl_tasks.at(i)) {
Connect<TaskNode>(ctrl_node, NewEdge(), sorted_dst_comp_tasks.at(i));
ctrl_node->BuildCtrlRegstDesc(sorted_dst_comp_tasks.at(i));
std::string regst_desc_name;
ctrl_node->BuildCtrlRegstDesc(sorted_dst_comp_tasks.at(i), &regst_desc_name);
TaskEdge* edge = NewEdge();
Connect<TaskNode>(ctrl_node, edge, sorted_dst_comp_tasks.at(i));
ctrl_node->BindEdgeWithProducedRegst(edge, regst_desc_name);
}
}
}
......
......@@ -26,14 +26,14 @@ limitations under the License.
namespace oneflow {
class SubTskGphBuilder;
class SubTskGphBuilderCtx;
class HierarchicalSubTskGphBuilder;
class TaskGraph final : public Graph<TaskNode, TaskEdge> {
public:
OF_DISALLOW_COPY_AND_MOVE(TaskGraph);
TaskGraph() = delete;
~TaskGraph() override = default;
~TaskGraph() override;
explicit TaskGraph(std::unique_ptr<const LogicalGraph>&& logical_gph);
......@@ -98,8 +98,8 @@ class TaskGraph final : public Graph<TaskNode, TaskEdge> {
std::unique_ptr<const LogicalGraph> logical_gph_;
std::vector<TaskNode*> ordered_task_nodes_;
std::shared_ptr<SubTskGphBuilder> sub_tsk_gph_builder_;
std::shared_ptr<SubTskGphBuilderCtx> sub_tsk_gph_builder_ctx_;
std::unique_ptr<HierarchicalSubTskGphBuilder> hierarchical_sub_tsk_gph_builder_;
std::unique_ptr<SubTskGphBuilderCtx> sub_tsk_gph_builder_ctx_;
std::unique_ptr<BoxingLogger> boxing_logger_;
};
......
......@@ -57,8 +57,8 @@ const LogicalEdge* GetConnectedEdge(const LogicalNode* src_node, const LogicalNo
return connect_edge;
}
static bool IsConnectedLbisAllSameSbpParallel(const LogicalNode* src_node,
const LogicalNode* dst_node) {
static bool IsConnectedLbisAllSameParallelDistribution(const LogicalNode* src_node,
const LogicalNode* dst_node) {
if (src_node->parallel_desc()->parallel_num() != dst_node->parallel_desc()->parallel_num()) {
return false;
}
......@@ -69,9 +69,11 @@ static bool IsConnectedLbisAllSameSbpParallel(const LogicalNode* src_node,
const std::string& dst_op_name = dst_node->SoleOp()->op_name();
HashSet<bool> predicators;
for (const LogicalBlobId& lbi : connect_edge->lbis()) {
const auto& src_sbp = Global<OpGraph>::Get()->GetSbpParallel(src_op_name, lbi);
const auto& dst_sbp = Global<OpGraph>::Get()->GetSbpParallel(dst_op_name, lbi);
predicators.insert(src_sbp == dst_sbp);
const ParallelDistribution& src_parallel_distribution =
Global<OpGraph>::Get()->GetParallelDistribution(src_op_name, lbi);
const ParallelDistribution& dst_parallel_distribution =
Global<OpGraph>::Get()->GetParallelDistribution(dst_op_name, lbi);
predicators.insert(src_parallel_distribution == dst_parallel_distribution);
}
CHECK_EQ(predicators.size(), 1);
return *predicators.begin();
......@@ -223,7 +225,8 @@ BldSubTskGphMthd GetMthdForBldSubTskGph(const LogicalNode* src_node, const Logic
return &TaskGraph::BldSubTskGphByOneToOne;
}
if (src_pd->parallel_num() == dst_pd->parallel_num()
&& IsConnectedLbisAllSameSbpParallel(src_node, dst_node)) {
&& *src_pd->hierarchy() == *dst_pd->hierarchy()
&& IsConnectedLbisAllSameParallelDistribution(src_node, dst_node)) {
return &TaskGraph::BldSubTskGphByOneToOne;
}
return &TaskGraph::BldSubTskGphByBoxing;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册