未验证 提交 9f504dfc 编写于 作者: J Juncheng 提交者: GitHub

Add NaiveB2PSubTskGphBuilder (#3942)

* Add NaiveB2PSubTskGphBuilder

* refine

* refine

* refine

* refine
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 c4bbf8c0
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/actor/naive_actor.h"
namespace oneflow {
class BoxingZerosActor : public NaiveActor {
public:
OF_DISALLOW_COPY_AND_MOVE(BoxingZerosActor);
BoxingZerosActor() = default;
~BoxingZerosActor() override = default;
void VirtualActorInit(const TaskProto& task_proto) override {
NaiveActor::VirtualActorInit(task_proto);
piece_id_ = 0;
out_inited_ = false;
}
private:
void Act() override {
if (!out_inited_) {
NaiveActor::Act();
out_inited_ = true;
}
}
void VirtualAsyncSendNaiveProducedRegstMsgToConsumer() override {
int64_t piece_id = piece_id_;
HandleProducedNaiveDataRegstToConsumer([&](Regst* regst) {
regst->set_piece_id(piece_id);
return true;
});
piece_id_ += 1;
}
int64_t piece_id_;
bool out_inited_;
};
REGISTER_ACTOR(TaskType::kBoxingZeros, BoxingZerosActor);
} // namespace oneflow
......@@ -30,8 +30,10 @@ class NaiveActor : public Actor {
OF_SET_MSG_HANDLER(&NaiveActor::HandlerNormal);
}
protected:
void Act() override;
private:
void Act() override final;
void VirtualAsyncSendNaiveProducedRegstMsgToConsumer() override;
};
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/graph/boxing/naive_b2p_sub_task_graph_builder.h"
#include "oneflow/core/graph/boxing/sub_task_graph_builder_util.h"
#include "oneflow/core/graph/boxing_zeros_task_node.h"
namespace oneflow {
Maybe<SubTskGphBuilderStatus> NaiveB2PSubTskGphBuilder::Build(
SubTskGphBuilderCtx* ctx, const std::vector<CompTaskNode*>& sorted_src_comp_tasks,
const std::vector<CompTaskNode*>& sorted_dst_comp_tasks, const ParallelDesc& src_parallel_desc,
const ParallelDesc& dst_parallel_desc, const LogicalBlobId& lbi,
const BlobDesc& logical_blob_desc, const SbpParallel& src_sbp_parallel,
const SbpParallel& dst_sbp_parallel) const {
if ((src_parallel_desc.parallel_num() == 1 || src_sbp_parallel.has_broadcast_parallel())
&& dst_parallel_desc.parallel_num() != 1 && dst_sbp_parallel.has_partial_sum_parallel()) {
HashMap<CompTaskNode*, CompTaskNode*> dst_node2nearest_src_node;
int64_t nearest_dst_node_idx = -1;
int64_t nearest_dst_node_distance = -1;
std::vector<CompTaskNode*> nearest_src_comp_tasks;
for (int64_t dst_node_idx = 0; dst_node_idx < sorted_dst_comp_tasks.size(); ++dst_node_idx) {
CompTaskNode* dst_node = sorted_dst_comp_tasks.at(dst_node_idx);
const int64_t nearest_src_node_idx =
SubTskGphBuilderUtil::FindNearestNodeIndex(sorted_src_comp_tasks, dst_node);
CHECK_NE_OR_RETURN(nearest_src_node_idx, -1);
CompTaskNode* nearest_src_node = sorted_src_comp_tasks.at(nearest_src_node_idx);
CHECK_OR_RETURN(dst_node2nearest_src_node.emplace(dst_node, nearest_src_node).second);
const int64_t distance = SubTskGphBuilderUtil::GetDistance(nearest_src_node, dst_node);
if (nearest_dst_node_idx == -1 || distance < nearest_dst_node_distance) {
nearest_dst_node_idx = dst_node_idx;
nearest_dst_node_distance = distance;
}
}
for (int64_t dst_node_idx = 0; dst_node_idx < sorted_dst_comp_tasks.size(); ++dst_node_idx) {
CompTaskNode* dst_node = sorted_dst_comp_tasks.at(dst_node_idx);
CompTaskNode* nearest_src_node = dst_node2nearest_src_node.at(dst_node);
if (dst_node_idx == nearest_dst_node_idx) {
TaskNode* proxy = ctx->GetProxyNode(nearest_src_node, nearest_src_node->MemZoneId121(),
dst_node->machine_id(), dst_node->MemZoneId121());
Connect<TaskNode>(proxy, ctx->task_graph()->NewEdge(), dst_node);
} else {
auto* zeros_node = ctx->task_graph()->NewNode<BoxingZerosTaskNode>();
zeros_node->Init(dst_node->machine_id(), dst_node->thrd_id(), dst_node->area_id(), lbi,
logical_blob_desc.shape(), logical_blob_desc.data_type(),
*nearest_src_node->logical_node()->out_blob_time_shape());
nearest_src_node->BuildCtrlRegstDesc(zeros_node);
Connect<TaskNode>(nearest_src_node, ctx->task_graph()->NewEdge(), zeros_node);
Connect<TaskNode>(zeros_node, ctx->task_graph()->NewEdge(), dst_node);
}
}
return TRY(BuildSubTskGphBuilderStatus(sorted_src_comp_tasks.front(),
sorted_dst_comp_tasks.front(), src_parallel_desc,
dst_parallel_desc, src_sbp_parallel, dst_sbp_parallel,
lbi, logical_blob_desc, "NaiveB2PSubTskGphBuilder", ""));
} else {
return Error::BoxingNotSupportedError();
}
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_GRAPH_BOXING_NAIVE_B2P_SUB_TASK_GRAPH_BUILDER_H_
#define ONEFLOW_CORE_GRAPH_BOXING_NAIVE_B2P_SUB_TASK_GRAPH_BUILDER_H_
#include "oneflow/core/graph/boxing/sub_task_graph_builder.h"
namespace oneflow {
class NaiveB2PSubTskGphBuilder final : public SubTskGphBuilder {
public:
OF_DISALLOW_COPY_AND_MOVE(NaiveB2PSubTskGphBuilder);
NaiveB2PSubTskGphBuilder() = default;
~NaiveB2PSubTskGphBuilder() override = default;
Maybe<SubTskGphBuilderStatus> Build(SubTskGphBuilderCtx* ctx,
const std::vector<CompTaskNode*>& sorted_src_comp_tasks,
const std::vector<CompTaskNode*>& sorted_dst_comp_tasks,
const ParallelDesc& src_parallel_desc,
const ParallelDesc& dst_parallel_desc,
const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc,
const SbpParallel& src_sbp_parallel,
const SbpParallel& dst_sbp_parallel) const override;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_BOXING_NAIVE_B2P_SUB_TASK_GRAPH_BUILDER_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/graph/boxing_zeros_task_node.h"
namespace oneflow {
void BoxingZerosTaskNode::Init(int64_t machine_id, int64_t thrd_id, int64_t area_id,
const LogicalBlobId& lbi, const Shape& shape, DataType data_type,
const Shape& time_shape) {
lbi_ = lbi;
set_machine_id(machine_id);
set_thrd_id(thrd_id);
set_area_id(area_id);
shape_ = shape;
data_type_ = data_type;
time_shape_ = time_shape;
}
void BoxingZerosTaskNode::ProduceAllRegstsAndBindEdges() {
std::shared_ptr<RegstDesc> out_regst = ProduceRegst("out", false, 1, 1);
this->ForEachOutDataEdge([&](TaskEdge* out_dege) { out_dege->AddRegst("out", out_regst); });
}
void BoxingZerosTaskNode::ConsumeAllRegsts() {
// do nothing
}
void BoxingZerosTaskNode::BuildExecGphAndRegst() {
ExecNode* node = mut_exec_gph().NewNode();
OperatorConf op_conf;
op_conf.set_name("System-Boxing-Zeros-" + NewUniqueId());
op_conf.set_device_tag(CHECK_JUST(DeviceTag4DeviceType(this->device_type())));
*op_conf.mutable_boxing_zeros_conf()->mutable_lbi() = lbi_;
shape_.ToProto(op_conf.mutable_boxing_zeros_conf()->mutable_shape());
op_conf.mutable_boxing_zeros_conf()->set_data_type(data_type_);
std::shared_ptr<Operator> sole_op = ConstructOp(op_conf, &GlobalJobDesc());
node->mut_op() = sole_op;
std::shared_ptr<RegstDesc> out_regst = GetProducedRegst("out");
out_regst->AddLbi(sole_op->BnInOp2Lbi(sole_op->SoleObn()));
node->BindBnWithRegst(sole_op->SoleObn(), out_regst);
node->InferBlobDescs(nullptr);
}
void BoxingZerosTaskNode::InferProducedDataRegstTimeShape() {
GetProducedRegst("out")->mut_data_regst_time_shape()->reset(new Shape(time_shape_));
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_GRAPH_BOXING_ZEROS_TASK_NODE_H_
#define ONEFLOW_CORE_GRAPH_BOXING_ZEROS_TASK_NODE_H_
#include "oneflow/core/graph/compute_task_node.h"
namespace oneflow {
class BoxingZerosTaskNode : public TaskNode {
public:
OF_DISALLOW_COPY_AND_MOVE(BoxingZerosTaskNode);
BoxingZerosTaskNode() = default;
~BoxingZerosTaskNode() override = default;
void Init(int64_t machine_id, int64_t thrd_id, int64_t area_id, const LogicalBlobId& lbi,
const Shape& shape, DataType data_type, const Shape& time_shape);
TaskType GetTaskType() const override { return TaskType::kBoxingZeros; }
private:
void BuildExecGphAndRegst() override;
void ProduceAllRegstsAndBindEdges() override;
void ConsumeAllRegsts() final;
void InferProducedDataRegstTimeShape() final;
LogicalBlobId lbi_;
Shape shape_;
DataType data_type_;
Shape time_shape_;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_BOXING_ZEROS_TASK_NODE_H_
......@@ -30,6 +30,7 @@ limitations under the License.
#include "oneflow/core/graph/boxing/collective_boxing_sub_task_graph_builder.h"
#include "oneflow/core/graph/boxing/slice_boxing_sub_task_graph_builder.h"
#include "oneflow/core/graph/boxing/naive_b2b_sub_task_graph_builder.h"
#include "oneflow/core/graph/boxing/naive_b2p_sub_task_graph_builder.h"
#include "oneflow/core/graph/boxing/b21_sub_task_graph_builder.h"
#include "oneflow/core/graph/boxing/one_to_one_sub_task_graph_builder.h"
#include "oneflow/core/graph/boxing/to_interface_sub_task_graph_builder.h"
......@@ -224,6 +225,7 @@ TaskGraph::TaskGraph(std::unique_ptr<const LogicalGraph>&& logical_gph) {
builders.emplace_back(new CollectiveBoxingSubTskGphBuilder());
builders.emplace_back(new SliceBoxingSubTskGphBuilder());
builders.emplace_back(new NaiveB2BSubTskGphBuilder());
builders.emplace_back(new NaiveB2PSubTskGphBuilder());
sub_tsk_gph_builder_.reset(new ChainSubTskGphBuilder(builders));
HashMap<const LogicalNode*, std::vector<CompTaskNode*>> logical2sorted_comp_tasks;
HashMap<const LogicalNode*, std::vector<TaskNode*>> logical2sorted_in_box;
......
......@@ -38,6 +38,7 @@ enum TaskType {
kBoxingS2SAll2AllPack = 61;
kBoxingS2SAll2AllUnpack = 62;
kSspVariableProxy = 63;
kBoxingZeros = 64;
};
enum AreaType {
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/kernel/kernel.h"
#include "oneflow/core/kernel/kernel_context.h"
namespace oneflow {
template<DeviceType device_type>
class BoxingZerosKernel final : public KernelIf<device_type> {
public:
OF_DISALLOW_COPY_AND_MOVE(BoxingZerosKernel);
BoxingZerosKernel() = default;
~BoxingZerosKernel() override = default;
private:
void ForwardDataContent(const KernelCtx& ctx,
std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
};
template<DeviceType device_type>
void BoxingZerosKernel<device_type>::ForwardDataContent(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
Blob* out = BnInOp2Blob("out");
Memset<device_type>(ctx.device_ctx, out->mut_dptr(), 0, out->ByteSizeOfBlobBody());
}
ADD_DEVICE_TYPE_KERNEL_CREATOR(OperatorConf::kBoxingZerosConf, BoxingZerosKernel);
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/common/protobuf.h"
namespace oneflow {
class BoxingZerosOp : public Operator {
public:
OF_DISALLOW_COPY_AND_MOVE(BoxingZerosOp);
BoxingZerosOp() = default;
~BoxingZerosOp() override = default;
void InitFromOpConf() override;
Maybe<void> InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const override;
private:
LogicalBlobId lbi4ibn(const std::string& input_bn) const override;
LogicalBlobId lbi4obn(const std::string& output_bn) const override;
};
void BoxingZerosOp::InitFromOpConf() { EnrollOutputBn("out", false); }
LogicalBlobId BoxingZerosOp::lbi4ibn(const std::string& input_bn) const {
return this->op_conf().boxing_zeros_conf().lbi();
}
LogicalBlobId BoxingZerosOp::lbi4obn(const std::string& output_bn) const {
return this->op_conf().boxing_zeros_conf().lbi();
}
Maybe<void> BoxingZerosOp::InferBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const {
const BoxingZerosOpConf& conf = this->op_conf().boxing_zeros_conf();
BlobDesc* out = GetBlobDesc4BnInOp("out");
out->set_data_type(conf.data_type());
out->mut_shape() = Shape(conf.shape());
return Maybe<void>::Ok();
}
REGISTER_OP(OperatorConf::kBoxingZerosConf, BoxingZerosOp);
} // namespace oneflow
......@@ -864,6 +864,12 @@ message ImageDecoderRandomCropResizeOpConf {
optional float random_aspect_ratio_max = 13 [default = 1.333333];
}
message BoxingZerosOpConf {
required LogicalBlobId lbi = 1;
required ShapeProto shape = 2;
required DataType data_type = 3;
}
message OperatorConf {
required string name = 1;
optional bool trainable = 3 [default = true];
......@@ -914,6 +920,7 @@ message OperatorConf {
CastToStaticShapeOpConf cast_to_static_shape_conf = 173;
BoxingS2SAll2AllPackOpConf boxing_s2s_all2all_pack_conf = 174;
BoxingS2SAll2AllUnpackOpConf boxing_s2s_all2all_unpack_conf = 175;
BoxingZerosOpConf boxing_zeros_conf = 176;
UserOpConf user_conf = 199;
// domain op
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册