未验证 提交 3d33bde2 编写于 作者: L Li Xinqi 提交者: GitHub

Multi reentrant lock (#4225)

* source subset tick

* remove useless header files

* insert DstSubsetTickOp

* remove incorrect CHECK

* add tick op for each machine

* TryBindBnWithOneofRegst

* add sink tick op in main_job

* refactor LinkMainJob

* fix typo in task_graph

* refactor AddGlobalCriticalSection

* rename and refactor DstSubsetTick::InferBlobDescs and SrcSubsetTick::InferBlobDescs

* add src_subset_tick for input-output critical section

* refactor AutoSourceTick and AutoSinkTick

* vectorizedly link main job

* resize vectorh identity_tick_op_names then access elements

* SrcSubsetTickCompTaskNode: bind bns and in_regst if bns is valid in current device

* refactor optional input to repeated inputs for SrcSubsetTickOpConf

* fix a bug in CaseCompTaskNode; fix a bug when create identity tick in main_job

* 1) Insert tick between sourc tick and src_subset_tick; 2) Insert tick between dst_subset_tick and sink tick

* stash code

* refactor MakeMainJob by using Range::ForEachSubRange

* refactor MakeMainJob by using Range::ForEachSubRange

* rename ReentrantLockLinkPoint to ReentrantLockBackEdge

* set piece id for regst sent by wait_and_send_ids actor

* callback_notifier_sink_tick
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 b4fcfd50
......@@ -24,6 +24,7 @@ void WaitAndSendIdsCompActor::VirtualCompActorInit(const TaskProto& task_proto)
wait_and_send_ids_status_.in_id_ = 0;
wait_and_send_ids_status_.out_idx_ = 0;
wait_and_send_ids_status_.out_num_ = 0;
cur_piece_id_ = -1;
OF_SET_MSG_HANDLER(&WaitAndSendIdsCompActor::HandlerWaitToStart);
}
......@@ -36,7 +37,10 @@ void WaitAndSendIdsCompActor::Act() {
void WaitAndSendIdsCompActor::VirtualAsyncSendNaiveProducedRegstMsgToConsumer() {
if (wait_and_send_ids_status_.buffer_status_ == kBufferStatusSuccess) {
HandleProducedNaiveDataRegstToConsumer();
HandleProducedNaiveDataRegstToConsumer([&](Regst* regst) {
regst->set_piece_id(++cur_piece_id_);
return true;
});
}
}
......
......@@ -41,6 +41,7 @@ class WaitAndSendIdsCompActor final : public CompActor {
int HandlerWaitToStart(const ActorMsg&);
WaitAndSendIdsStatus wait_and_send_ids_status_;
int64_t cur_piece_id_;
};
} // namespace oneflow
......
......@@ -27,6 +27,17 @@ void Range::ToProto(RangeProto* ret) const {
ret->set_end(end_);
}
Maybe<void> Range::ForEachSubRange(
int64_t sub_range_size, const std::function<Maybe<void>(const Range&)>& DoEachRange) const {
CHECK_EQ_OR_RETURN(size() % sub_range_size, 0);
int64_t start = begin();
for (; start < end(); start += sub_range_size) {
JUST(DoEachRange(Range(start, start + sub_range_size)));
}
CHECK_EQ_OR_RETURN(start, end());
return Maybe<void>::Ok();
}
Range FindIntersectant(const Range& lhs, const Range& rhs) {
if (lhs.end() > rhs.begin() && rhs.end() > lhs.begin()) {
int64_t left = lhs.begin() > rhs.begin() ? lhs.begin() : rhs.begin();
......@@ -37,4 +48,4 @@ Range FindIntersectant(const Range& lhs, const Range& rhs) {
}
}
} // namespace oneflow
\ No newline at end of file
} // namespace oneflow
......@@ -17,6 +17,7 @@ limitations under the License.
#define ONEFLOW_CORE_COMMON_RANGE_H_
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/common/range.pb.h"
namespace oneflow {
......@@ -41,6 +42,9 @@ class Range final {
int64_t size() const { return end_ - begin_; }
Maybe<void> ForEachSubRange(int64_t sub_range_size,
const std::function<Maybe<void>(const Range&)>& DoEachRange) const;
void ToProto(RangeProto* ret) const;
private:
......
......@@ -189,6 +189,12 @@ BldSubTskGphMthd GetMthdForBldSubTskGph(const LogicalNode* src_node, const Logic
std::shared_ptr<const ParallelDesc> src_pd = src_node->parallel_desc();
std::shared_ptr<const ParallelDesc> dst_pd = dst_node->parallel_desc();
if (src_node->op_vec().size() == 1 && dst_node->op_vec().size() == 1) {
if (src_node->SoleOp()->op_conf().has_wait_and_send_ids_conf()
&& dst_node->SoleOp()->op_conf().has_reentrant_lock_conf()) {
CHECK_EQ(src_pd->parallel_num(), 1);
CHECK_EQ(dst_pd->parallel_num(), 1);
return &TaskGraph::BldSubTskGphByBoxing;
}
if (src_node->SoleOp()->op_conf().has_record_load_conf()
&& dst_node->SoleOp()->op_conf().has_tick_conf()) {
CHECK(src_pd->parallel_num() == dst_pd->parallel_num());
......
此差异已折叠。
......@@ -94,9 +94,6 @@ Maybe<void> CreateDstSubsetTickAndSinkTicks(CriticalSection* critical_section,
JUST(BuildDstSubsetTickOpAndParallelConf(tick_lbis, &dst_subset_tick, job_builder));
int64_t num_machines = Global<ResourceDesc, ForSession>::Get()->TotalMachineNum();
auto* map = critical_section->mutable_machine_id2sink_tick_op_name();
ParallelConf cpu0_parallel_conf;
cpu0_parallel_conf.set_device_tag("cpu");
cpu0_parallel_conf.add_device_name("0:0");
for (int64_t machine_id = 0; machine_id < num_machines; ++machine_id) {
ParallelConf parallel_conf;
parallel_conf.set_device_tag("cpu");
......@@ -116,7 +113,7 @@ Maybe<void> CreateDstSubsetTickAndSinkTicks(CriticalSection* critical_section,
auto* sink_tick_conf = sink_tick_op.mutable_sink_tick_conf();
sink_tick_conf->add_tick(tick_op.name() + "/out");
sink_tick_conf->set_out("out");
JUST(job_builder->AddOp(cpu0_parallel_conf, sink_tick_op));
JUST(job_builder->AddOp(parallel_conf, sink_tick_op));
}
(*map)[machine_id] = sink_tick_op.name();
}
......@@ -142,15 +139,15 @@ Maybe<void> CreateSourceTicksAndSrcSubsetTick(CriticalSection* critical_section,
JobBuilder* job_builder) {
int64_t num_machines = Global<ResourceDesc, ForSession>::Get()->TotalMachineNum();
auto* map = critical_section->mutable_machine_id2source_tick_op_name();
ParallelConf cpu0_parallel_conf;
cpu0_parallel_conf.set_device_tag("cpu");
cpu0_parallel_conf.add_device_name("0:0");
for (int64_t machine_id = 0; machine_id < num_machines; ++machine_id) {
ParallelConf parallel_conf;
parallel_conf.set_device_tag("cpu");
parallel_conf.add_device_name(std::to_string(machine_id) + ":0");
OperatorConf src_tick_op;
{
src_tick_op.set_name("System-AutoTick-SourceTick_" + NewUniqueId());
src_tick_op.mutable_source_tick_conf()->set_out("out");
JUST(job_builder->AddOp(cpu0_parallel_conf, src_tick_op));
JUST(job_builder->AddOp(parallel_conf, src_tick_op));
}
(*map)[machine_id] = src_tick_op.name();
OperatorConf tick_op;
......@@ -158,9 +155,6 @@ Maybe<void> CreateSourceTicksAndSrcSubsetTick(CriticalSection* critical_section,
tick_op.set_name("System-AutoTick-Tick_" + NewUniqueId());
tick_op.mutable_tick_conf()->add_tick(src_tick_op.name() + "/out");
tick_op.mutable_tick_conf()->set_out("out");
ParallelConf parallel_conf;
parallel_conf.set_device_tag("cpu");
parallel_conf.add_device_name(std::to_string(machine_id) + ":0");
JUST(job_builder->AddOp(parallel_conf, tick_op));
}
src_subset_tick_op->mutable_src_subset_tick_conf()->add_in(tick_op.name() + "/out");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册