未验证 提交 e0370ccd 编写于 作者: qq_22305325's avatar qq_22305325 提交者: GitHub

Dev boxing log (#3466)

* add SubTskGphBuilderUtil::BuildBoxingInfo() interface

* add BuildBoxingInfo to sub_task_graph_builder

* save boxing log

* add not nullptr check

* add TYR

* remove blank line

* change BuildBoxingInfo to BuildBoxingLogInfo

* fix code style

* rename boxing_logging_ to boxing_logging_lines_

* fix c++ code format

* rename boxing_info to boxing_log_line

* rename boxing_info to boxing_log_line

* fix code format

* abstract SubTskGphBuilderStatus struct

* rename boxing_log_line to boxingbuilderstatus

* delete useless code

* fix code format

* remove useless include file

* complete GetBlobInfo4LogicalBlobDesc

* change boxing_log_lines to boxing_log_list_(std::list)

* Optimize log saving function

* remove useless include file

* add colnum filed  Macro

* add CHECK_NOT_NULL

* Overload SetLogStream

* change inappropriate variable name

* remove usleless include file

* add CHECK_JUST & TRY

* fix c++ format

* add new line at end of file

* optimization boxing log

* remove useless include

* fix c++ format

* optimize SubTskGphBuilderStatus Getters

* change macro name

* fix c++ format

* optimize boxing log

* fix c++ format

* rename SerializeParallelDesc to ParallelDescToString

* optimize boxing log

* fix small bug

* optimize boxing log

* fix c++ format

* optimize boxing log

* fix c++ format

* optimize boxing log

* fix c++ format

* optimize boxing log

* optimize boxing log

* fix c++ format

* optimize boxing log

* remove useless include file

* optimize boxing log
Co-authored-by: Noneflow-bot <69100618+oneflow-bot@users.noreply.github.com>
上级 1d35fa1e
......@@ -18,14 +18,12 @@ limitations under the License.
namespace oneflow {
Maybe<void> B21SubTskGphBuilder::Build(SubTskGphBuilderCtx* ctx,
const std::vector<CompTaskNode*>& sorted_src_comp_tasks,
const std::vector<CompTaskNode*>& sorted_dst_comp_tasks,
const ParallelDesc& src_parallel_desc,
const ParallelDesc& dst_parallel_desc,
const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc,
const SbpParallel& src_sbp_parallel,
const SbpParallel& dst_sbp_parallel) const {
Maybe<SubTskGphBuilderStatus> B21SubTskGphBuilder::Build(
SubTskGphBuilderCtx* ctx, const std::vector<CompTaskNode*>& sorted_src_comp_tasks,
const std::vector<CompTaskNode*>& sorted_dst_comp_tasks, const ParallelDesc& src_parallel_desc,
const ParallelDesc& dst_parallel_desc, const LogicalBlobId& lbi,
const BlobDesc& logical_blob_desc, const SbpParallel& src_sbp_parallel,
const SbpParallel& dst_sbp_parallel) const {
if ((src_parallel_desc.parallel_num() == 1 || src_sbp_parallel.has_broadcast_parallel())
&& dst_parallel_desc.parallel_num() == 1) {
CompTaskNode* dst_node = sorted_dst_comp_tasks.front();
......@@ -35,7 +33,10 @@ Maybe<void> B21SubTskGphBuilder::Build(SubTskGphBuilderCtx* ctx,
TaskNode* proxy = ctx->GetProxyNode(nearest_src_node, nearest_src_node->MemZoneId121(),
dst_node->machine_id(), dst_node->MemZoneId121());
Connect<TaskNode>(proxy, ctx->task_graph()->NewEdge(), dst_node);
return Maybe<void>::Ok();
return TRY(BuildSubTskGphBuilderStatus(sorted_src_comp_tasks.front(),
sorted_dst_comp_tasks.front(), src_parallel_desc,
dst_parallel_desc, src_sbp_parallel, dst_sbp_parallel,
lbi, logical_blob_desc, "B21SubTskGphBuilder", ""));
} else {
return Error::BoxingNotSupported();
}
......
......@@ -26,13 +26,14 @@ class B21SubTskGphBuilder final : public SubTskGphBuilder {
B21SubTskGphBuilder() = default;
~B21SubTskGphBuilder() override = default;
Maybe<void> Build(SubTskGphBuilderCtx* ctx,
const std::vector<CompTaskNode*>& sorted_src_comp_tasks,
const std::vector<CompTaskNode*>& sorted_dst_comp_tasks,
const ParallelDesc& src_parallel_desc, const ParallelDesc& dst_parallel_desc,
const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc,
const SbpParallel& src_sbp_parallel,
const SbpParallel& dst_sbp_parallel) const override;
Maybe<SubTskGphBuilderStatus> Build(SubTskGphBuilderCtx* ctx,
const std::vector<CompTaskNode*>& sorted_src_comp_tasks,
const std::vector<CompTaskNode*>& sorted_dst_comp_tasks,
const ParallelDesc& src_parallel_desc,
const ParallelDesc& dst_parallel_desc,
const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc,
const SbpParallel& src_sbp_parallel,
const SbpParallel& dst_sbp_parallel) const override;
};
} // namespace oneflow
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/graph/boxing/boxing_logger.h"
namespace oneflow {
namespace {
#define OF_BOXING_LOGGER_CSV_COLNUM_NAME_FIELD \
"src_op_name,dst_op_name,src_parallel_conf,dst_parallel_conf," \
"src_sbp_conf,dst_sbp_conf,lbi,dtype,shape,builder,comment\n"
std::string SbpParallelToString(const SbpParallel& sbp_parallel) {
std::string serialized_sbp_parallel;
if (sbp_parallel.has_broadcast_parallel()) {
serialized_sbp_parallel = "B";
} else if (sbp_parallel.has_partial_sum_parallel()) {
serialized_sbp_parallel = "P";
} else if (sbp_parallel.has_split_parallel()) {
serialized_sbp_parallel = "S(" + std::to_string(sbp_parallel.split_parallel().axis()) + ")";
} else {
UNIMPLEMENTED();
}
return serialized_sbp_parallel;
}
std::string ParallelDescToString(const ParallelDesc& parallel_desc) {
std::string serialized_parallel_desc;
std::string device_type;
if (parallel_desc.device_type() == DeviceType::kCPU) {
device_type = "CPU";
} else if (parallel_desc.device_type() == DeviceType::kGPU) {
device_type = "GPU";
} else {
device_type = "UNKNOWN_DEVICE";
}
auto sorted_machine_ids = parallel_desc.sorted_machine_ids();
for (int64_t i = 0; i < sorted_machine_ids.size(); ++i) {
const int64_t machine_id = sorted_machine_ids.at(i);
serialized_parallel_desc += std::to_string(machine_id) + ":" + device_type + ":";
int64_t min_id = parallel_desc.sorted_dev_phy_ids(machine_id).front();
int64_t max_id = parallel_desc.sorted_dev_phy_ids(machine_id).back();
serialized_parallel_desc += std::to_string(min_id) + "-" + std::to_string(max_id);
if (i != sorted_machine_ids.size() - 1) { serialized_parallel_desc += " "; }
}
return serialized_parallel_desc;
}
std::string ShapeToString(const Shape& shape) {
std::stringstream shape_ss;
auto dim_vec = shape.dim_vec();
shape_ss << "[";
for (int32_t i = 0; i < dim_vec.size(); ++i) {
shape_ss << dim_vec.at(i);
if (i != dim_vec.size() - 1) { shape_ss << " "; }
}
shape_ss << "]";
return shape_ss.str();
}
std::string SubTskGphBuilderStatusToCsvLine(const SubTskGphBuilderStatus& status) {
std::string serialized_status;
serialized_status += status.src_op_name() + ",";
serialized_status += status.dst_op_name() + ",";
serialized_status += ParallelDescToString(status.src_parallel_desc()) + ",";
serialized_status += ParallelDescToString(status.dst_parallel_desc()) + ",";
serialized_status += SbpParallelToString(status.src_sbp_parallel()) + ",";
serialized_status += SbpParallelToString(status.dst_sbp_parallel()) + ",";
serialized_status += GenLogicalBlobName(status.lbi()) + ",";
serialized_status += DataType_Name(status.logical_blob_desc().data_type()) + ",";
serialized_status += ShapeToString(status.logical_blob_desc().shape()) + ",";
serialized_status += status.builder_name() + ",";
if (status.comment().empty()) {
serialized_status += "-";
} else {
serialized_status += status.comment();
}
serialized_status += "\n";
return serialized_status;
}
} // namespace
CsvBoxingLogger::CsvBoxingLogger(std::string path) {
log_stream_ = TeePersistentLogStream::Create(path);
log_stream_ << OF_BOXING_LOGGER_CSV_COLNUM_NAME_FIELD;
}
CsvBoxingLogger::~CsvBoxingLogger() { log_stream_->Flush(); }
void CsvBoxingLogger::Log(const SubTskGphBuilderStatus& status) {
log_stream_ << SubTskGphBuilderStatusToCsvLine(status);
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_GRAPH_BOXING_LOGGER_H_
#define ONEFLOW_CORE_GRAPH_BOXING_LOGGER_H_
#include "oneflow/core/persistence/tee_persistent_log_stream.h"
#include "oneflow/core/graph/boxing/sub_task_graph_builder_status_util.h"
namespace oneflow {
class BoxingLogger {
public:
OF_DISALLOW_COPY_AND_MOVE(BoxingLogger);
BoxingLogger() = default;
virtual ~BoxingLogger() = default;
virtual void Log(const SubTskGphBuilderStatus& status) = 0;
};
class NullBoxingLogger final : public BoxingLogger {
public:
OF_DISALLOW_COPY_AND_MOVE(NullBoxingLogger);
NullBoxingLogger() = default;
~NullBoxingLogger() override = default;
void Log(const SubTskGphBuilderStatus& status) override{};
};
class CsvBoxingLogger final : public BoxingLogger {
public:
OF_DISALLOW_COPY_AND_MOVE(CsvBoxingLogger);
CsvBoxingLogger() = delete;
CsvBoxingLogger(std::string path);
~CsvBoxingLogger() override;
void Log(const SubTskGphBuilderStatus& status) override;
private:
std::unique_ptr<TeePersistentLogStream> log_stream_;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_BOXING_LOGGER_H_
......@@ -18,20 +18,21 @@ limitations under the License.
namespace oneflow {
Maybe<void> ChainSubTskGphBuilder::Build(
Maybe<SubTskGphBuilderStatus> ChainSubTskGphBuilder::Build(
SubTskGphBuilderCtx* ctx, const std::vector<CompTaskNode*>& sorted_src_comp_tasks,
const std::vector<CompTaskNode*>& sorted_dst_comp_tasks, const ParallelDesc& src_parallel_desc,
const ParallelDesc& dst_parallel_desc, const LogicalBlobId& lbi,
const BlobDesc& logical_blob_desc, const SbpParallel& src_sbp_parallel,
const SbpParallel& dst_sbp_parallel) const {
for (const auto& builder : builders_) {
Maybe<void> status = TRY(builder->Build(ctx, sorted_src_comp_tasks, sorted_dst_comp_tasks,
src_parallel_desc, dst_parallel_desc, lbi,
logical_blob_desc, src_sbp_parallel, dst_sbp_parallel));
if (!status.IsOk() && SubTskGphBuilderUtil::IsErrorBoxingNotSupported(*status.error())) {
Maybe<SubTskGphBuilderStatus> boxing_builder_status = TRY(builder->Build(
ctx, sorted_src_comp_tasks, sorted_dst_comp_tasks, src_parallel_desc, dst_parallel_desc,
lbi, logical_blob_desc, src_sbp_parallel, dst_sbp_parallel));
if (!boxing_builder_status.IsOk()
&& SubTskGphBuilderUtil::IsErrorBoxingNotSupported(*boxing_builder_status.error())) {
continue;
} else {
return status;
return boxing_builder_status;
}
}
return Error::BoxingNotSupported();
......
......@@ -27,13 +27,14 @@ class ChainSubTskGphBuilder final : public SubTskGphBuilder {
: builders_(std::move(builders)) {}
~ChainSubTskGphBuilder() override = default;
Maybe<void> Build(SubTskGphBuilderCtx* ctx,
const std::vector<CompTaskNode*>& sorted_src_comp_tasks,
const std::vector<CompTaskNode*>& sorted_dst_comp_tasks,
const ParallelDesc& src_parallel_desc, const ParallelDesc& dst_parallel_desc,
const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc,
const SbpParallel& src_sbp_parallel,
const SbpParallel& dst_sbp_parallel) const override;
Maybe<SubTskGphBuilderStatus> Build(SubTskGphBuilderCtx* ctx,
const std::vector<CompTaskNode*>& sorted_src_comp_tasks,
const std::vector<CompTaskNode*>& sorted_dst_comp_tasks,
const ParallelDesc& src_parallel_desc,
const ParallelDesc& dst_parallel_desc,
const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc,
const SbpParallel& src_sbp_parallel,
const SbpParallel& dst_sbp_parallel) const override;
private:
std::vector<std::shared_ptr<SubTskGphBuilder>> builders_;
......
......@@ -87,13 +87,14 @@ class NcclCollectiveBoxingAllReduceSubTskGphBuilder final : public SubTskGphBuil
NcclCollectiveBoxingAllReduceSubTskGphBuilder() = default;
~NcclCollectiveBoxingAllReduceSubTskGphBuilder() override = default;
Maybe<void> Build(SubTskGphBuilderCtx* ctx,
const std::vector<CompTaskNode*>& sorted_src_comp_tasks,
const std::vector<CompTaskNode*>& sorted_dst_comp_tasks,
const ParallelDesc& src_parallel_desc, const ParallelDesc& dst_parallel_desc,
const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc,
const SbpParallel& src_sbp_parallel,
const SbpParallel& dst_sbp_parallel) const override {
Maybe<SubTskGphBuilderStatus> Build(SubTskGphBuilderCtx* ctx,
const std::vector<CompTaskNode*>& sorted_src_comp_tasks,
const std::vector<CompTaskNode*>& sorted_dst_comp_tasks,
const ParallelDesc& src_parallel_desc,
const ParallelDesc& dst_parallel_desc,
const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc,
const SbpParallel& src_sbp_parallel,
const SbpParallel& dst_sbp_parallel) const override {
if (dst_parallel_desc.Equals(src_parallel_desc)
&& !SubTskGphBuilderUtil::BlobHasDynamicShape(logical_blob_desc)
&& dst_parallel_desc.device_type() == DeviceType::kGPU
......@@ -109,7 +110,10 @@ class NcclCollectiveBoxingAllReduceSubTskGphBuilder final : public SubTskGphBuil
Connect<TaskNode>(src_node, ctx->task_graph()->NewEdge(), collective_node);
Connect<TaskNode>(collective_node, ctx->task_graph()->NewEdge(), dst_node);
}
return Maybe<void>::Ok();
return TRY(BuildSubTskGphBuilderStatus(
sorted_src_comp_tasks.front(), sorted_dst_comp_tasks.front(), src_parallel_desc,
dst_parallel_desc, src_sbp_parallel, dst_sbp_parallel, lbi, logical_blob_desc,
"NcclCollectiveBoxingAllReduceSubTskGphBuilder", ""));
} else {
return Error::BoxingNotSupported();
}
......@@ -122,13 +126,14 @@ class NcclCollectiveBoxingReduceScatterSubTskGphBuilder final : public SubTskGph
NcclCollectiveBoxingReduceScatterSubTskGphBuilder() = default;
~NcclCollectiveBoxingReduceScatterSubTskGphBuilder() override = default;
Maybe<void> Build(SubTskGphBuilderCtx* ctx,
const std::vector<CompTaskNode*>& sorted_src_comp_tasks,
const std::vector<CompTaskNode*>& sorted_dst_comp_tasks,
const ParallelDesc& src_parallel_desc, const ParallelDesc& dst_parallel_desc,
const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc,
const SbpParallel& src_sbp_parallel,
const SbpParallel& dst_sbp_parallel) const override {
Maybe<SubTskGphBuilderStatus> Build(SubTskGphBuilderCtx* ctx,
const std::vector<CompTaskNode*>& sorted_src_comp_tasks,
const std::vector<CompTaskNode*>& sorted_dst_comp_tasks,
const ParallelDesc& src_parallel_desc,
const ParallelDesc& dst_parallel_desc,
const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc,
const SbpParallel& src_sbp_parallel,
const SbpParallel& dst_sbp_parallel) const override {
if (dst_parallel_desc.Equals(src_parallel_desc)
&& !SubTskGphBuilderUtil::BlobHasDynamicShape(logical_blob_desc)
&& dst_parallel_desc.device_type() == DeviceType::kGPU
......@@ -147,7 +152,10 @@ class NcclCollectiveBoxingReduceScatterSubTskGphBuilder final : public SubTskGph
Connect<TaskNode>(src_node, ctx->task_graph()->NewEdge(), collective_node);
Connect<TaskNode>(collective_node, ctx->task_graph()->NewEdge(), dst_node);
}
return Maybe<void>::Ok();
return TRY(BuildSubTskGphBuilderStatus(
sorted_src_comp_tasks.front(), sorted_dst_comp_tasks.front(), src_parallel_desc,
dst_parallel_desc, src_sbp_parallel, dst_sbp_parallel, lbi, logical_blob_desc,
"NcclCollectiveBoxingReduceScatterSubTskGphBuilder", ""));
} else {
return Error::BoxingNotSupported();
}
......@@ -160,13 +168,14 @@ class NcclCollectiveBoxingAllGatherSubTskGphBuilder final : public SubTskGphBuil
NcclCollectiveBoxingAllGatherSubTskGphBuilder() = default;
~NcclCollectiveBoxingAllGatherSubTskGphBuilder() override = default;
Maybe<void> Build(SubTskGphBuilderCtx* ctx,
const std::vector<CompTaskNode*>& sorted_src_comp_tasks,
const std::vector<CompTaskNode*>& sorted_dst_comp_tasks,
const ParallelDesc& src_parallel_desc, const ParallelDesc& dst_parallel_desc,
const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc,
const SbpParallel& src_sbp_parallel,
const SbpParallel& dst_sbp_parallel) const override {
Maybe<SubTskGphBuilderStatus> Build(SubTskGphBuilderCtx* ctx,
const std::vector<CompTaskNode*>& sorted_src_comp_tasks,
const std::vector<CompTaskNode*>& sorted_dst_comp_tasks,
const ParallelDesc& src_parallel_desc,
const ParallelDesc& dst_parallel_desc,
const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc,
const SbpParallel& src_sbp_parallel,
const SbpParallel& dst_sbp_parallel) const override {
if (dst_parallel_desc.EqualsIgnoringDeviceType(src_parallel_desc)
&& !SubTskGphBuilderUtil::BlobHasDynamicShape(logical_blob_desc)
&& SubTskGphBuilderUtil::IsDeviceTypeCPUOrGPU(src_parallel_desc)
......@@ -187,7 +196,10 @@ class NcclCollectiveBoxingAllGatherSubTskGphBuilder final : public SubTskGphBuil
Connect<TaskNode>(src_node, ctx->task_graph()->NewEdge(), collective_node);
Connect<TaskNode>(collective_node, ctx->task_graph()->NewEdge(), dst_node);
}
return Maybe<void>::Ok();
return TRY(BuildSubTskGphBuilderStatus(
sorted_src_comp_tasks.front(), sorted_dst_comp_tasks.front(), src_parallel_desc,
dst_parallel_desc, src_sbp_parallel, dst_sbp_parallel, lbi, logical_blob_desc,
"NcclCollectiveBoxingReduceScatterSubTskGphBuilder", ""));
} else {
return Error::BoxingNotSupported();
}
......@@ -200,13 +212,14 @@ class NcclCollectiveBoxingReduceSubTskGphBuilder final : public SubTskGphBuilder
NcclCollectiveBoxingReduceSubTskGphBuilder() = default;
~NcclCollectiveBoxingReduceSubTskGphBuilder() override = default;
Maybe<void> Build(SubTskGphBuilderCtx* ctx,
const std::vector<CompTaskNode*>& sorted_src_comp_tasks,
const std::vector<CompTaskNode*>& sorted_dst_comp_tasks,
const ParallelDesc& src_parallel_desc, const ParallelDesc& dst_parallel_desc,
const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc,
const SbpParallel& src_sbp_parallel,
const SbpParallel& dst_sbp_parallel) const override {
Maybe<SubTskGphBuilderStatus> Build(SubTskGphBuilderCtx* ctx,
const std::vector<CompTaskNode*>& sorted_src_comp_tasks,
const std::vector<CompTaskNode*>& sorted_dst_comp_tasks,
const ParallelDesc& src_parallel_desc,
const ParallelDesc& dst_parallel_desc,
const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc,
const SbpParallel& src_sbp_parallel,
const SbpParallel& dst_sbp_parallel) const override {
if (src_parallel_desc.parallel_num() > 1 && dst_parallel_desc.parallel_num() == 1
&& src_parallel_desc.device_type() == DeviceType::kGPU
&& dst_parallel_desc.device_type() == DeviceType::kGPU
......@@ -230,7 +243,10 @@ class NcclCollectiveBoxingReduceSubTskGphBuilder final : public SubTskGphBuilder
Connect<TaskNode>(collective_node, ctx->task_graph()->NewEdge(), dst_node);
}
}
return Maybe<void>::Ok();
return TRY(BuildSubTskGphBuilderStatus(
sorted_src_comp_tasks.front(), sorted_dst_comp_tasks.front(), src_parallel_desc,
dst_parallel_desc, src_sbp_parallel, dst_sbp_parallel, lbi, logical_blob_desc,
"NcclCollectiveBoxingReduceScatterSubTskGphBuilder", ""));
} else {
return Error::BoxingNotSupported();
}
......@@ -243,13 +259,14 @@ class CollectiveBoxingScatterThenNcclAllGatherSubTskGphBuilder final : public Su
CollectiveBoxingScatterThenNcclAllGatherSubTskGphBuilder() = default;
~CollectiveBoxingScatterThenNcclAllGatherSubTskGphBuilder() override = default;
Maybe<void> Build(SubTskGphBuilderCtx* ctx,
const std::vector<CompTaskNode*>& sorted_src_comp_tasks,
const std::vector<CompTaskNode*>& sorted_dst_comp_tasks,
const ParallelDesc& src_parallel_desc, const ParallelDesc& dst_parallel_desc,
const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc,
const SbpParallel& src_sbp_parallel,
const SbpParallel& dst_sbp_parallel) const override {
Maybe<SubTskGphBuilderStatus> Build(SubTskGphBuilderCtx* ctx,
const std::vector<CompTaskNode*>& sorted_src_comp_tasks,
const std::vector<CompTaskNode*>& sorted_dst_comp_tasks,
const ParallelDesc& src_parallel_desc,
const ParallelDesc& dst_parallel_desc,
const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc,
const SbpParallel& src_sbp_parallel,
const SbpParallel& dst_sbp_parallel) const override {
if (src_parallel_desc.parallel_num() == 1 && dst_parallel_desc.parallel_num() > 1
&& src_parallel_desc.device_type() == DeviceType::kCPU
&& dst_parallel_desc.device_type() == DeviceType::kGPU
......@@ -287,7 +304,10 @@ class CollectiveBoxingScatterThenNcclAllGatherSubTskGphBuilder final : public Su
Connect<TaskNode>(slice_node_proxy, ctx->task_graph()->NewEdge(), collective_node);
Connect<TaskNode>(collective_node, ctx->task_graph()->NewEdge(), dst_node);
}
return Maybe<void>::Ok();
return TRY(BuildSubTskGphBuilderStatus(
sorted_src_comp_tasks.front(), sorted_dst_comp_tasks.front(), src_parallel_desc,
dst_parallel_desc, src_sbp_parallel, dst_sbp_parallel, lbi, logical_blob_desc,
"NcclCollectiveBoxingReduceScatterSubTskGphBuilder", ""));
} else {
return Error::BoxingNotSupported();
}
......@@ -300,13 +320,14 @@ class NcclCollectiveBoxingBroadcastSubTskGphBuilder final : public SubTskGphBuil
NcclCollectiveBoxingBroadcastSubTskGphBuilder() = default;
~NcclCollectiveBoxingBroadcastSubTskGphBuilder() override = default;
Maybe<void> Build(SubTskGphBuilderCtx* ctx,
const std::vector<CompTaskNode*>& sorted_src_comp_tasks,
const std::vector<CompTaskNode*>& sorted_dst_comp_tasks,
const ParallelDesc& src_parallel_desc, const ParallelDesc& dst_parallel_desc,
const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc,
const SbpParallel& src_sbp_parallel,
const SbpParallel& dst_sbp_parallel) const override {
Maybe<SubTskGphBuilderStatus> Build(SubTskGphBuilderCtx* ctx,
const std::vector<CompTaskNode*>& sorted_src_comp_tasks,
const std::vector<CompTaskNode*>& sorted_dst_comp_tasks,
const ParallelDesc& src_parallel_desc,
const ParallelDesc& dst_parallel_desc,
const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc,
const SbpParallel& src_sbp_parallel,
const SbpParallel& dst_sbp_parallel) const override {
if (src_parallel_desc.parallel_num() == 1 && dst_parallel_desc.parallel_num() > 1
&& (src_parallel_desc.device_type() == DeviceType::kGPU
|| (src_parallel_desc.device_type() == DeviceType::kCPU
......@@ -346,7 +367,10 @@ class NcclCollectiveBoxingBroadcastSubTskGphBuilder final : public SubTskGphBuil
}
Connect<TaskNode>(collective_node, ctx->task_graph()->NewEdge(), dst_node);
}
return Maybe<void>::Ok();
return TRY(BuildSubTskGphBuilderStatus(
sorted_src_comp_tasks.front(), sorted_dst_comp_tasks.front(), src_parallel_desc,
dst_parallel_desc, src_sbp_parallel, dst_sbp_parallel, lbi, logical_blob_desc,
"NcclCollectiveBoxingReduceScatterSubTskGphBuilder", ""));
} else {
return Error::BoxingNotSupported();
}
......@@ -365,7 +389,7 @@ CollectiveBoxingSubTskGphBuilder::CollectiveBoxingSubTskGphBuilder() {
chain_builder_.reset(new ChainSubTskGphBuilder(builders));
}
Maybe<void> CollectiveBoxingSubTskGphBuilder::Build(
Maybe<SubTskGphBuilderStatus> CollectiveBoxingSubTskGphBuilder::Build(
SubTskGphBuilderCtx* ctx, const std::vector<CompTaskNode*>& sorted_src_comp_tasks,
const std::vector<CompTaskNode*>& sorted_dst_comp_tasks, const ParallelDesc& src_parallel_desc,
const ParallelDesc& dst_parallel_desc, const LogicalBlobId& lbi,
......
......@@ -26,13 +26,14 @@ class CollectiveBoxingSubTskGphBuilder final : public SubTskGphBuilder {
CollectiveBoxingSubTskGphBuilder();
~CollectiveBoxingSubTskGphBuilder() override = default;
Maybe<void> Build(SubTskGphBuilderCtx* ctx,
const std::vector<CompTaskNode*>& sorted_src_comp_tasks,
const std::vector<CompTaskNode*>& sorted_dst_comp_tasks,
const ParallelDesc& src_parallel_desc, const ParallelDesc& dst_parallel_desc,
const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc,
const SbpParallel& src_sbp_parallel,
const SbpParallel& dst_sbp_parallel) const override;
Maybe<SubTskGphBuilderStatus> Build(SubTskGphBuilderCtx* ctx,
const std::vector<CompTaskNode*>& sorted_src_comp_tasks,
const std::vector<CompTaskNode*>& sorted_dst_comp_tasks,
const ParallelDesc& src_parallel_desc,
const ParallelDesc& dst_parallel_desc,
const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc,
const SbpParallel& src_sbp_parallel,
const SbpParallel& dst_sbp_parallel) const override;
private:
std::unique_ptr<SubTskGphBuilder> chain_builder_;
......
......@@ -18,7 +18,7 @@ limitations under the License.
namespace oneflow {
Maybe<void> NaiveB2BSubTskGphBuilder::Build(
Maybe<SubTskGphBuilderStatus> NaiveB2BSubTskGphBuilder::Build(
SubTskGphBuilderCtx* ctx, const std::vector<CompTaskNode*>& sorted_src_comp_tasks,
const std::vector<CompTaskNode*>& sorted_dst_comp_tasks, const ParallelDesc& src_parallel_desc,
const ParallelDesc& dst_parallel_desc, const LogicalBlobId& lbi,
......@@ -35,7 +35,10 @@ Maybe<void> NaiveB2BSubTskGphBuilder::Build(
dst_node->machine_id(), dst_node->MemZoneId121());
Connect<TaskNode>(proxy, ctx->task_graph()->NewEdge(), dst_node);
}
return Maybe<void>::Ok();
return TRY(BuildSubTskGphBuilderStatus(sorted_src_comp_tasks.front(),
sorted_dst_comp_tasks.front(), src_parallel_desc,
dst_parallel_desc, src_sbp_parallel, dst_sbp_parallel,
lbi, logical_blob_desc, "NaiveB2BSubTskGphBuilder", ""));
} else {
return Error::BoxingNotSupported();
}
......
......@@ -26,13 +26,14 @@ class NaiveB2BSubTskGphBuilder final : public SubTskGphBuilder {
NaiveB2BSubTskGphBuilder() = default;
~NaiveB2BSubTskGphBuilder() override = default;
Maybe<void> Build(SubTskGphBuilderCtx* ctx,
const std::vector<CompTaskNode*>& sorted_src_comp_tasks,
const std::vector<CompTaskNode*>& sorted_dst_comp_tasks,
const ParallelDesc& src_parallel_desc, const ParallelDesc& dst_parallel_desc,
const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc,
const SbpParallel& src_sbp_parallel,
const SbpParallel& dst_sbp_parallel) const override;
Maybe<SubTskGphBuilderStatus> Build(SubTskGphBuilderCtx* ctx,
const std::vector<CompTaskNode*>& sorted_src_comp_tasks,
const std::vector<CompTaskNode*>& sorted_dst_comp_tasks,
const ParallelDesc& src_parallel_desc,
const ParallelDesc& dst_parallel_desc,
const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc,
const SbpParallel& src_sbp_parallel,
const SbpParallel& dst_sbp_parallel) const override;
};
} // namespace oneflow
......
......@@ -18,7 +18,7 @@ limitations under the License.
namespace oneflow {
Maybe<void> OneToOneSubTskGphBuilder::Build(
Maybe<SubTskGphBuilderStatus> OneToOneSubTskGphBuilder::Build(
SubTskGphBuilderCtx* ctx, const std::vector<CompTaskNode*>& sorted_src_comp_tasks,
const std::vector<CompTaskNode*>& sorted_dst_comp_tasks, const ParallelDesc& src_parallel_desc,
const ParallelDesc& dst_parallel_desc, const LogicalBlobId& lbi,
......@@ -35,7 +35,10 @@ Maybe<void> OneToOneSubTskGphBuilder::Build(
dst_node->machine_id(), dst_node->MemZoneId121());
Connect<TaskNode>(proxy, ctx->task_graph()->NewEdge(), dst_node);
}
return Maybe<void>::Ok();
return TRY(BuildSubTskGphBuilderStatus(sorted_src_comp_tasks.front(),
sorted_dst_comp_tasks.front(), src_parallel_desc,
dst_parallel_desc, src_sbp_parallel, dst_sbp_parallel,
lbi, logical_blob_desc, "OneToOneSubTskGphBuilder", ""));
} else {
return Error::BoxingNotSupported();
}
......
......@@ -26,13 +26,14 @@ class OneToOneSubTskGphBuilder final : public SubTskGphBuilder {
OneToOneSubTskGphBuilder() = default;
~OneToOneSubTskGphBuilder() override = default;
Maybe<void> Build(SubTskGphBuilderCtx* ctx,
const std::vector<CompTaskNode*>& sorted_src_comp_tasks,
const std::vector<CompTaskNode*>& sorted_dst_comp_tasks,
const ParallelDesc& src_parallel_desc, const ParallelDesc& dst_parallel_desc,
const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc,
const SbpParallel& src_sbp_parallel,
const SbpParallel& dst_sbp_parallel) const override;
Maybe<SubTskGphBuilderStatus> Build(SubTskGphBuilderCtx* ctx,
const std::vector<CompTaskNode*>& sorted_src_comp_tasks,
const std::vector<CompTaskNode*>& sorted_dst_comp_tasks,
const ParallelDesc& src_parallel_desc,
const ParallelDesc& dst_parallel_desc,
const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc,
const SbpParallel& src_sbp_parallel,
const SbpParallel& dst_sbp_parallel) const override;
};
} // namespace oneflow
......
......@@ -52,7 +52,7 @@ bool IsSameDevice(const ParallelDesc& in_pd, const ParallelDesc& out_pd,
} // namespace
Maybe<void> SliceBoxingSubTskGphBuilder::Build(
Maybe<SubTskGphBuilderStatus> SliceBoxingSubTskGphBuilder::Build(
SubTskGphBuilderCtx* ctx, const std::vector<CompTaskNode*>& sorted_src_comp_tasks,
const std::vector<CompTaskNode*>& sorted_dst_comp_tasks, const ParallelDesc& src_parallel_desc,
const ParallelDesc& dst_parallel_desc, const LogicalBlobId& lbi,
......@@ -439,19 +439,24 @@ Maybe<void> SliceBoxingSubTskGphBuilder::Build(
std::vector<TaskNode*> in_nodes;
in_nodes.assign(sorted_src_comp_tasks.begin(), sorted_src_comp_tasks.end());
std::vector<TaskNode*> out_nodes;
std::string comment;
if (SubTskGphBuilderUtil::IsBoxingS2B(src_sbp_parallel, dst_sbp_parallel)) {
BuildSubTaskGphS2B(src_parallel_desc, dst_parallel_desc, src_sbp_parallel, dst_sbp_parallel,
logical_blob_desc, in_nodes, &out_nodes);
comment = "BuildSubTaskGphS2B";
} else if (SubTskGphBuilderUtil::IsBoxingS2S(src_sbp_parallel, dst_sbp_parallel)) {
BuildSubTaskGphS2S(src_parallel_desc, dst_parallel_desc, src_sbp_parallel, dst_sbp_parallel,
logical_blob_desc, in_nodes, &out_nodes);
comment = "BuildSubTaskGphS2S";
} else if (SubTskGphBuilderUtil::IsBoxingP2S(src_sbp_parallel, dst_sbp_parallel)) {
BuildSubTaskGphP2S(src_parallel_desc, dst_parallel_desc, src_sbp_parallel, dst_sbp_parallel,
logical_blob_desc, in_nodes, &out_nodes);
comment = "BuildSubTaskGphP2S";
} else if (SubTskGphBuilderUtil::IsBoxingP2B(src_sbp_parallel, dst_sbp_parallel)) {
if (logical_blob_desc.shape().elem_cnt() < dst_parallel_desc.parallel_num()) {
BuildSubTaskGphP2B(src_parallel_desc, dst_parallel_desc, src_sbp_parallel, dst_sbp_parallel,
logical_blob_desc, in_nodes, &out_nodes);
comment = "BuildSubTaskGphP2B";
} else {
BlobDesc flat_blob_desc(logical_blob_desc.data_type());
flat_blob_desc.mut_shape() = Shape({logical_blob_desc.shape().elem_cnt()});
......@@ -462,20 +467,26 @@ Maybe<void> SliceBoxingSubTskGphBuilder::Build(
flat_blob_desc, in_nodes, &middle_nodes);
BuildSubTaskGphS2B(dst_parallel_desc, dst_parallel_desc, middle_sbp, dst_sbp_parallel,
flat_blob_desc, middle_nodes, &out_nodes);
comment = "BuildSubTaskGphP2S->BuildSubTaskGphS2B";
for (TaskNode* out_node : out_nodes) {
auto* slice_boxing_node = dynamic_cast<SliceBoxingTaskNode*>(out_node);
CHECK_NOTNULL(slice_boxing_node);
slice_boxing_node->SetOutShape(logical_blob_desc.shape());
}
}
} else if (SubTskGphBuilderUtil::IsBoxingB2S(src_sbp_parallel, dst_sbp_parallel)) {
BuildSubTaskGphB2S(src_parallel_desc, dst_parallel_desc, src_sbp_parallel, dst_sbp_parallel,
logical_blob_desc, in_nodes, &out_nodes);
comment = "BuildSubTaskGphB2S";
} else {
UNIMPLEMENTED();
}
ctx->ConnectAll121(out_nodes, sorted_dst_comp_tasks);
return Maybe<void>::Ok();
return TRY(BuildSubTskGphBuilderStatus(
sorted_src_comp_tasks.front(), sorted_dst_comp_tasks.front(), src_parallel_desc,
dst_parallel_desc, src_sbp_parallel, dst_sbp_parallel, lbi, logical_blob_desc,
"SliceBoxingSubTskGphBuilder", comment));
}
} // namespace oneflow
......@@ -26,13 +26,14 @@ class SliceBoxingSubTskGphBuilder final : public SubTskGphBuilder {
SliceBoxingSubTskGphBuilder() = default;
~SliceBoxingSubTskGphBuilder() override = default;
Maybe<void> Build(SubTskGphBuilderCtx* ctx,
const std::vector<CompTaskNode*>& sorted_src_comp_tasks,
const std::vector<CompTaskNode*>& sorted_dst_comp_tasks,
const ParallelDesc& src_parallel_desc, const ParallelDesc& dst_parallel_desc,
const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc,
const SbpParallel& src_sbp_parallel,
const SbpParallel& dst_sbp_parallel) const override;
Maybe<SubTskGphBuilderStatus> Build(SubTskGphBuilderCtx* ctx,
const std::vector<CompTaskNode*>& sorted_src_comp_tasks,
const std::vector<CompTaskNode*>& sorted_dst_comp_tasks,
const ParallelDesc& src_parallel_desc,
const ParallelDesc& dst_parallel_desc,
const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc,
const SbpParallel& src_sbp_parallel,
const SbpParallel& dst_sbp_parallel) const override;
};
} // namespace oneflow
......
......@@ -18,6 +18,7 @@ limitations under the License.
#include "oneflow/core/common/util.h"
#include "oneflow/core/graph/boxing/sub_task_graph_builder_context.h"
#include "oneflow/core/graph/boxing/sub_task_graph_builder_status_util.h"
namespace oneflow {
......@@ -27,13 +28,12 @@ class SubTskGphBuilder {
SubTskGphBuilder() = default;
virtual ~SubTskGphBuilder() = default;
virtual Maybe<void> Build(SubTskGphBuilderCtx* ctx,
const std::vector<CompTaskNode*>& sorted_src_comp_tasks,
const std::vector<CompTaskNode*>& sorted_dst_comp_tasks,
const ParallelDesc& src_parallel_desc,
const ParallelDesc& dst_parallel_desc, const LogicalBlobId& lbi,
const BlobDesc& logical_blob_desc, const SbpParallel& src_sbp_parallel,
const SbpParallel& dst_sbp_parallel) const = 0;
virtual Maybe<SubTskGphBuilderStatus> Build(
SubTskGphBuilderCtx* ctx, const std::vector<CompTaskNode*>& sorted_src_comp_tasks,
const std::vector<CompTaskNode*>& sorted_dst_comp_tasks,
const ParallelDesc& src_parallel_desc, const ParallelDesc& dst_parallel_desc,
const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc,
const SbpParallel& src_sbp_parallel, const SbpParallel& dst_sbp_parallel) const = 0;
};
} // namespace oneflow
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/graph/boxing/sub_task_graph_builder_status_util.h"
#include "oneflow/core/graph/logical_node.h"
namespace oneflow {
Maybe<SubTskGphBuilderStatus> BuildSubTskGphBuilderStatus(
const CompTaskNode* src_node, const CompTaskNode* dst_node,
const ParallelDesc& src_parallel_desc, const ParallelDesc& dst_parallel_desc,
const SbpParallel& src_sbp_parallel, const SbpParallel& dst_sbp_parallel,
const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const std::string& builder_name,
const std::string& comment) {
std::string src_op_name = src_node->logical_node()->op_vec().at(0)->op_name();
std::string dst_op_name = dst_node->logical_node()->op_vec().at(0)->op_name();
SubTskGphBuilderStatus status(src_op_name, dst_op_name, src_parallel_desc, dst_parallel_desc,
src_sbp_parallel, dst_sbp_parallel, lbi, logical_blob_desc,
builder_name, comment);
return status;
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_GRAPH_SUB_TASK_GRAPH_BUILDER_STATUS_UTIL_H_
#define ONEFLOW_CORE_GRAPH_SUB_TASK_GRAPH_BUILDER_STATUS_UTIL_H_
#include "oneflow/core/graph/compute_task_node.h"
namespace oneflow {
class SubTskGphBuilderStatus;
Maybe<SubTskGphBuilderStatus> BuildSubTskGphBuilderStatus(
const CompTaskNode* src_node, const CompTaskNode* dst_node,
const ParallelDesc& src_parallel_desc, const ParallelDesc& dst_parallel_desc,
const SbpParallel& src_sbp_parallel, const SbpParallel& dst_sbp_parallel,
const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const std::string& builder_name,
const std::string& comment);
class SubTskGphBuilderStatus final {
public:
SubTskGphBuilderStatus(const std::string& src_op_name, const std::string& dst_op_name,
const ParallelDesc& src_parallel_desc,
const ParallelDesc& dst_parallel_desc,
const SbpParallel& src_sbp_parallel_, const SbpParallel& dst_sbp_parallel,
const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc,
const std::string& builder_name, const std::string& comment)
: src_op_name_(src_op_name),
dst_op_name_(dst_op_name),
src_parallel_desc_(src_parallel_desc),
dst_parallel_desc_(dst_parallel_desc),
src_sbp_parallel_(src_sbp_parallel_),
dst_sbp_parallel_(dst_sbp_parallel),
lbi_(lbi),
logical_blob_desc_(logical_blob_desc),
builder_name_(builder_name),
comment_(comment){};
~SubTskGphBuilderStatus() = default;
// Getters
const std::string& src_op_name() const { return src_op_name_; }
const std::string& dst_op_name() const { return dst_op_name_; }
const ParallelDesc& src_parallel_desc() const { return src_parallel_desc_; }
const ParallelDesc& dst_parallel_desc() const { return dst_parallel_desc_; }
const SbpParallel& src_sbp_parallel() const { return src_sbp_parallel_; }
const SbpParallel& dst_sbp_parallel() const { return dst_sbp_parallel_; }
const LogicalBlobId& lbi() const { return lbi_; }
const BlobDesc& logical_blob_desc() const { return logical_blob_desc_; }
const std::string& builder_name() const { return builder_name_; }
const std::string& comment() const { return comment_; }
private:
std::string src_op_name_;
std::string dst_op_name_;
ParallelDesc src_parallel_desc_;
ParallelDesc dst_parallel_desc_;
SbpParallel src_sbp_parallel_;
SbpParallel dst_sbp_parallel_;
LogicalBlobId lbi_;
BlobDesc logical_blob_desc_;
std::string builder_name_;
std::string comment_;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_SUB_TASK_GRAPH_BUILDER_STATUS_UTIL_H_
......@@ -46,7 +46,6 @@ struct SubTskGphBuilderUtil {
static bool BlobHasDynamicShape(const BlobDesc& blob_desc);
static bool IsErrorBoxingNotSupported(const ErrorProto& error);
static int64_t GetDistance(const TaskNode* src, const TaskNode* dst);
template<typename NodeType>
static int64_t FindNearestNodeIndex(const std::vector<NodeType*> from_nodes,
const NodeType* to_node) {
......
......@@ -19,7 +19,7 @@ limitations under the License.
namespace oneflow {
Maybe<void> ToInterfaceSubTskGphBuilder::Build(
Maybe<SubTskGphBuilderStatus> ToInterfaceSubTskGphBuilder::Build(
SubTskGphBuilderCtx* ctx, const std::vector<CompTaskNode*>& sorted_src_comp_tasks,
const std::vector<CompTaskNode*>& sorted_dst_comp_tasks, const ParallelDesc& src_parallel_desc,
const ParallelDesc& dst_parallel_desc, const LogicalBlobId& lbi,
......@@ -47,7 +47,10 @@ Maybe<void> ToInterfaceSubTskGphBuilder::Build(
Connect<TaskNode>(proxy, ctx->task_graph()->NewEdge(), dst_node);
}
}
return Maybe<void>::Ok();
return TRY(BuildSubTskGphBuilderStatus(
sorted_src_comp_tasks.front(), sorted_dst_comp_tasks.front(), src_parallel_desc,
dst_parallel_desc, src_sbp_parallel, dst_sbp_parallel, lbi, logical_blob_desc,
"ToInterfaceSubTskGphBuilder", "BuildSubTaskGphB2B"));
} else if ((src_parallel_desc.parallel_num() == 1 || src_sbp_parallel.has_broadcast_parallel())
&& (dst_parallel_desc.parallel_num() > 1 || dst_sbp_parallel.has_split_parallel())) {
const TensorSliceView in_slice =
......@@ -78,7 +81,10 @@ Maybe<void> ToInterfaceSubTskGphBuilder::Build(
Global<IDMgr>::Get()->CpuMemZoneId());
Connect<TaskNode>(proxy, ctx->task_graph()->NewEdge(), dst_node);
}
return Maybe<void>::Ok();
return TRY(BuildSubTskGphBuilderStatus(
sorted_src_comp_tasks.front(), sorted_dst_comp_tasks.front(), src_parallel_desc,
dst_parallel_desc, src_sbp_parallel, dst_sbp_parallel, lbi, logical_blob_desc,
"ToInterfaceSubTskGphBuilder", "BuildSubTaskGphB2S"));
} else {
return Error::BoxingNotSupported();
}
......
......@@ -26,13 +26,14 @@ class ToInterfaceSubTskGphBuilder final : public SubTskGphBuilder {
ToInterfaceSubTskGphBuilder() = default;
~ToInterfaceSubTskGphBuilder() override = default;
Maybe<void> Build(SubTskGphBuilderCtx* ctx,
const std::vector<CompTaskNode*>& sorted_src_comp_tasks,
const std::vector<CompTaskNode*>& sorted_dst_comp_tasks,
const ParallelDesc& src_parallel_desc, const ParallelDesc& dst_parallel_desc,
const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc,
const SbpParallel& src_sbp_parallel,
const SbpParallel& dst_sbp_parallel) const override;
Maybe<SubTskGphBuilderStatus> Build(SubTskGphBuilderCtx* ctx,
const std::vector<CompTaskNode*>& sorted_src_comp_tasks,
const std::vector<CompTaskNode*>& sorted_dst_comp_tasks,
const ParallelDesc& src_parallel_desc,
const ParallelDesc& dst_parallel_desc,
const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc,
const SbpParallel& src_sbp_parallel,
const SbpParallel& dst_sbp_parallel) const override;
};
} // namespace oneflow
......
......@@ -142,11 +142,21 @@ bool IsInplaceAllowed(
return true;
}
std::unique_ptr<BoxingLogger> CreateBoxingLogger() {
if (Global<ResourceDesc, ForSession>::Get()->enable_debug_mode()) {
return std::unique_ptr<BoxingLogger>(
new CsvBoxingLogger(StrCat("boxing/log/", GlobalJobDesc().job_id()) + ".csv"));
} else {
return std::unique_ptr<BoxingLogger>(new NullBoxingLogger());
}
}
} // namespace
TaskGraph::TaskGraph(std::unique_ptr<const LogicalGraph>&& logical_gph) {
logical_gph_ = std::move(logical_gph);
sub_tsk_gph_builder_ctx_.reset(new SubTskGphBuilderCtx(this));
boxing_logger_ = CreateBoxingLogger();
std::vector<std::shared_ptr<SubTskGphBuilder>> builders;
builders.emplace_back(new ToInterfaceSubTskGphBuilder());
builders.emplace_back(new OneToOneSubTskGphBuilder());
......@@ -464,10 +474,10 @@ DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBoxing) {
const std::shared_ptr<const ParallelDesc>& src_parallel_desc = src_logical->parallel_desc();
const std::shared_ptr<const ParallelDesc>& dst_parallel_desc = dst_logical->parallel_desc();
const BlobDesc& blob_desc = Global<OpGraph>::Get()->GetLogicalBlobDesc(lbi);
Maybe<void> status = TRY(sub_tsk_gph_builder_->Build(
auto status = CHECK_JUST(sub_tsk_gph_builder_->Build(
sub_tsk_gph_builder_ctx_.get(), src_nodes, sorted_dst_comp_tasks, *src_parallel_desc,
*dst_parallel_desc, lbi, blob_desc, src_sbp_parallel, dst_sbp_parallel));
CHECK(status.IsOk());
boxing_logger_->Log(*status);
}
}
......
......@@ -22,6 +22,7 @@ limitations under the License.
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/graph/copy_task_node.h"
#include "oneflow/core/register/op_blob_arg_info.h"
#include "oneflow/core/graph/boxing/boxing_logger.h"
namespace oneflow {
......@@ -99,6 +100,7 @@ class TaskGraph final : public Graph<TaskNode, TaskEdge> {
std::vector<TaskNode*> ordered_task_nodes_;
std::shared_ptr<SubTskGphBuilder> sub_tsk_gph_builder_;
std::shared_ptr<SubTskGphBuilderCtx> sub_tsk_gph_builder_ctx_;
std::unique_ptr<BoxingLogger> boxing_logger_;
};
bool IsBackEdge(TaskNode* src, TaskNode* dst);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册