提交 5d0009ae 编写于 作者: L lixinqi

oneflow.experimental.delay_tick


Former-commit-id: bbd11718d498eac4639382dd6c0375f4670a0611
上级 c6705ca9
......@@ -18,6 +18,10 @@ limitations under the License.
namespace oneflow {
void DelayTickCompActor::Act() {
// do nothing
}
void DelayTickCompActor::VirtualCompActorInit(const TaskProto& task_proto) {
eord_received_ = false;
{
......@@ -57,13 +61,6 @@ void DelayTickCompActor::ForEachCurCustomizedReadableRegst(
Handler(consumed_rs_.Front(consumed_regst_desc_id_));
}
void DelayTickCompActor::VirtualAsyncSendNaiveProducedRegstMsgToConsumer() {
HandleProducedNaiveDataRegstToConsumer([this](Regst* regst) {
regst->set_piece_id(consumed_rs_.Front(consumed_regst_desc_id_)->piece_id());
return true;
});
}
void DelayTickCompActor::AsyncReturnCurCustomizedReadableRegst() {
Regst* regst = consumed_rs_.Front(consumed_regst_desc_id_);
CHECK(regst);
......@@ -112,6 +109,7 @@ void DelayTickCompActor::UpdtStateAsCustomizedProducedRegst(Regst* regst) {
void DelayTickCompActor::AsyncSendCustomizedProducedRegstMsgToConsumer() {
Regst* const regst = produced_rs_.Front(produced_regst_desc_id_);
regst->set_piece_id(consumed_rs_.Front(consumed_regst_desc_id_)->piece_id());
CHECK_GT(HandleRegstToConsumer(regst, [](int64_t) { return true; }), 0);
produced_rs_.PopFrontRegsts({produced_regst_desc_id_});
}
......
......@@ -34,7 +34,7 @@ class DelayTickCompActor final : public CompActor {
bool ProducedCtrlRegstValid(int64_t regst_desc_id) const override { return true; }
private:
void Act() override {}
void Act() override;
// consumed regst slot
std::pair<RegstNameType, HashSet<std::string>> GetNaiveOrCustomizedConsumedRegstDescName()
override {
......@@ -48,7 +48,6 @@ class DelayTickCompActor final : public CompActor {
void AsyncReturnCurCustomizedReadableRegst();
void AsyncReturnAllCustomizedReadableRegst() override;
void VirtualAsyncSendNaiveProducedRegstMsgToConsumer() override;
void AsyncSendCustomizedConsumedRegstMsgToProducer() override;
void TakeOverConsumedRegst(const PbMap<std::string, RegstDescIdSet>& consumed_ids);
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/graph/delay_tick_compute_task_node.h"
#include "oneflow/core/graph/logical_node.h"
namespace oneflow {
void DelayTickCompTaskNode::ProduceAllRegstsAndBindEdges() {
ProduceRegst("out", false, 1, 1);
ForEachOutDataEdge([&](TaskEdge* edge) { BindEdgeWithProducedRegst(edge, "out"); });
}
void DelayTickCompTaskNode::ConsumeAllRegsts() {
ConsumeRegst("in");
ForEachInDataEdge([&](TaskEdge* edge) { ConsumeRegst("in", edge->GetSoleRegst()); });
}
void DelayTickCompTaskNode::BuildExecGphAndRegst() {
ExecNode* node = mut_exec_gph().NewNode();
node->mut_op() = logical_node()->SoleOp();
const std::list<std::shared_ptr<RegstDesc>>& in_regsts = GetConsumedRegst("in");
for (const std::string& ibn : node->op()->input_bns()) {
node->BindBnWithOneOfTheRegsts(ibn, in_regsts);
}
std::shared_ptr<RegstDesc> out_regst = GetProducedRegst("out");
for (const std::string& obn : node->op()->output_bns()) {
const LogicalBlobId& lbi = node->op()->BnInOp2Lbi(obn);
out_regst->AddLbi(lbi);
node->BindBnWithRegst(obn, out_regst);
}
node->InferBlobDescs(parallel_ctx());
}
void DelayTickCompTaskNode::InferProducedDataRegstTimeShape() {
auto time_shape = (*in_edges().begin())->src_node()->GetFastestInputOutputTimeShape();
for (TaskEdge* edge : in_edges()) {
CHECK(time_shape->elem_cnt() == edge->src_node()->GetFastestInputOutputTimeShape()->elem_cnt());
}
ForEachProducedDataRegst([time_shape](const std::string& name, RegstDesc* regst) {
*regst->mut_data_regst_time_shape() = time_shape;
});
}
REGISTER_TICK_TOCK_TASK_TYPE(TaskType::kDelayTick);
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_GRAPH_DELAY_TICK_COMPUTE_TASK_NODE_H_
#define ONEFLOW_CORE_GRAPH_DELAY_TICK_COMPUTE_TASK_NODE_H_
#include "oneflow/core/graph/compute_task_node.h"
namespace oneflow {
class DelayTickCompTaskNode : public CompTaskNode {
public:
OF_DISALLOW_COPY_AND_MOVE(DelayTickCompTaskNode);
DelayTickCompTaskNode() = default;
virtual ~DelayTickCompTaskNode() = default;
bool IsMeaningLess() override { return false; }
TaskType GetTaskType() const override { return TaskType::kDelayTick; }
private:
void ProduceAllRegstsAndBindEdges() override;
void ConsumeAllRegsts() override;
void BuildExecGphAndRegst() override;
void InferProducedDataRegstTimeShape() override;
bool IsIndependent() const override { return true; }
};
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_DELAY_TICK_COMPUTE_TASK_NODE_H_
......@@ -231,6 +231,10 @@ BldSubTskGphMthd GetMthdForBldSubTskGph(const LogicalNode* src_node, const Logic
return &TaskGraph::BldSubTskGphByBoxing;
}
REGISTER_BLD_SUB_TSK_GPH_MTHD("*"
"DelayTick",
&TaskGraph::BldSubTskGphByStrictOneToOne);
REGISTER_BLD_SUB_TSK_GPH_MTHD("RecordLoad"
"Decode",
&TaskGraph::BldSubTskGphByOneToOne);
......
......@@ -27,6 +27,7 @@ limitations under the License.
#include "oneflow/core/graph/reentrant_lock_compute_task_node.h"
#include "oneflow/core/graph/source_tick_compute_task_node.h"
#include "oneflow/core/graph/tick_compute_task_node.h"
#include "oneflow/core/graph/delay_tick_compute_task_node.h"
#include "oneflow/core/graph/device_tick_compute_task_node.h"
#include "oneflow/core/graph/acc_tick_compute_task_node.h"
#include "oneflow/core/graph/repeat_forward_compute_task_node.h"
......
......@@ -476,6 +476,15 @@ DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBoxing) {
}
}
DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByStrictOneToOne) {
CHECK(*src_logical->parallel_desc() == *dst_logical->parallel_desc());
FOR_RANGE(size_t, i, 0, sorted_src_comp_tasks.size()) {
CompTaskNode* src = sorted_src_comp_tasks.at(i);
CompTaskNode* dst = sorted_dst_comp_tasks.at(i);
BuildTaskPath(src, dst, MutBufTask, true);
}
}
DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByOneToOne) {
CHECK_EQ(sorted_src_comp_tasks.size(), sorted_dst_comp_tasks.size());
FOR_RANGE(size_t, i, 0, sorted_src_comp_tasks.size()) {
......
......@@ -49,6 +49,7 @@ class TaskGraph final : public Graph<TaskNode, TaskEdge> {
#define DECLARE_BLD_SUB_TASK_GRAPH_METHOD(method_name) void method_name BLD_SUB_TSK_GPH_MTHD_ARGS();
DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBoxing);
DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByStrictOneToOne);
DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByOneToOne);
DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBroadcastToBroadcast);
DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByPartialInLbiConnect);
......
......@@ -55,6 +55,5 @@ void TickCompTaskNode::InferProducedDataRegstTimeShape() {
}
REGISTER_TICK_TOCK_TASK_TYPE(TaskType::kTick);
REGISTER_TICK_TOCK_TASK_TYPE(TaskType::kDelayTick);
} // namespace oneflow
......@@ -37,13 +37,6 @@ class TickCompTaskNode : public CompTaskNode {
bool IsIndependent() const override { return true; }
};
class DelayTickCompTaskNode : public TickCompTaskNode {
public:
OF_DISALLOW_COPY_AND_MOVE(DelayTickCompTaskNode);
DelayTickCompTaskNode() = default;
~DelayTickCompTaskNode() override = default;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_NORMAL_FORWARD_COMPUTE_TASK_NODE_H_
#endif // ONEFLOW_CORE_GRAPH_TICK_COMPUTE_TASK_NODE_H_
......@@ -28,7 +28,9 @@ class DelayTickKernel final : public KernelIf<device_type> {
private:
void ForwardDataContent(const KernelCtx& ctx,
std::function<Blob*(const std::string&)> BnInOp2Blob) const override {}
std::function<Blob*(const std::string&)> BnInOp2Blob) const override {
LOG(ERROR) << "\n" << this->op_conf().DebugString();
}
const PbMessage& GetCustomizedOpConf() const override {
return this->op_conf().delay_tick_conf();
}
......
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from __future__ import absolute_import
from typing import Optional, Tuple
import oneflow as flow
import oneflow.core.operator.op_conf_pb2 as op_conf_util
import oneflow.core.register.logical_blob_id_pb2 as logical_blob_id_util
import oneflow.python.framework.interpret_util as interpret_util
import oneflow.python.framework.id_util as id_util
import oneflow.python.framework.input_blob_def as input_blob_util
import oneflow.python.framework.remote_blob as remote_blob_util
import oneflow.python.framework.dtype as dtype_util
from oneflow.python.oneflow_export import oneflow_export
@oneflow_export("experimental.delay_tick")
def delay_tick(
x: input_blob_util.ArgBlobDef, delay_num: int = 0, name: Optional[str] = None,
) -> remote_blob_util.BlobDef:
op_conf = op_conf_util.OperatorConf()
if name is None:
op_conf.name = id_util.UniqueStr("DelayTick_")
else:
op_conf.name = name
op_conf.delay_tick_conf.tick = x.unique_name
op_conf.delay_tick_conf.out = "out"
op_conf.delay_tick_conf.delay_num = delay_num
interpret_util.Forward(op_conf)
out_lbi = logical_blob_id_util.LogicalBlobId()
out_lbi.op_name = op_conf.name
out_lbi.blob_name = "out"
return remote_blob_util.RemoteBlob(out_lbi)
......@@ -21,7 +21,6 @@ import oneflow as flow
import oneflow.core.operator.op_conf_pb2 as op_conf_util
import oneflow.core.register.logical_blob_id_pb2 as logical_blob_id_util
import oneflow.python.framework.interpret_util as interpret_util
import oneflow.python.framework.distribute as distribute_util
import oneflow.python.framework.id_util as id_util
import oneflow.python.framework.input_blob_def as input_blob_util
import oneflow.python.framework.remote_blob as remote_blob_util
......
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import oneflow as flow
import oneflow.typing as tp
import os
import unittest
@flow.unittest.skip_unless_1n1d()
class Test1dDelayTick(flow.unittest.TestCase):
@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
def test_1d_no_delay(test_case):
if flow.eager_execution_enabled():
return
device_name = "0:0"
flow.config.gpu_device_num(2)
@flow.global_function()
def Foo() -> tp.Numpy:
with flow.scope.placement("gpu", device_name):
w = flow.get_variable(
"w",
shape=(10,),
dtype=flow.float,
initializer=flow.constant_initializer(0),
)
return flow.experimental.delay_tick(w, delay_num=0)
x = Foo()
test_case.assertTrue(x.shape == (1,))
def test_1d_no_delay_with_callback(test_case):
if flow.eager_execution_enabled():
return
device_name = "0:0"
flow.config.gpu_device_num(2)
@flow.global_function()
def Foo() -> tp.Callback[tp.Numpy]:
with flow.scope.placement("gpu", device_name):
w = flow.get_variable(
"w",
shape=(10,),
dtype=flow.float,
initializer=flow.constant_initializer(0),
)
return flow.experimental.delay_tick(w, delay_num=0)
future = Foo()
future(lambda x: test_case.assertTrue(x.shape == (1,)))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册