diff --git a/oneflow/core/actor/actor_message.cpp b/oneflow/core/actor/actor_message.cpp index 7c8da88278fff2ae5ee9ad58490d95fdfac13696..ef7e207ba1a26c8e87abda516beafe4c3ee3a220 100644 --- a/oneflow/core/actor/actor_message.cpp +++ b/oneflow/core/actor/actor_message.cpp @@ -4,6 +4,24 @@ namespace oneflow { +namespace { + +bool IsSoleBlobAndDynamicEmpty(Regst* regst) { + if (regst == nullptr) { return false; } + if (regst->GetBlobSize() != 1) { return false; } + Blob* sole_blob = regst->GetMutSoleBlob(); + if (sole_blob->num_of_tensor_list_slices() != 1) { return false; } + if (sole_blob->total_num_of_tensors() != 1) { return false; } + if (!regst->GetSoleBlob()->IsBodyEmpty()) { return false; } + const auto& shape = sole_blob->shape(); + for (int i = 0; i < shape.NumAxes(); ++i) { + if (shape.At(i) != 0) { return false; } + } + return true; +} + +} // namespace + ActorMsg ActorMsg::BuildRegstMsgToConsumer(int64_t producer, int64_t consumer, Regst* regst_raw_ptr) { ActorMsg msg; @@ -18,6 +36,8 @@ ActorMsg ActorMsg::BuildRegstMsgToConsumer(int64_t producer, int64_t consumer, msg.regst_wrapper_.comm_net_token = regst_raw_ptr->comm_net_token(); } msg.regst_wrapper_.regst_status = regst_raw_ptr->status(); + msg.regst_wrapper_.has_sole_empty_tensor_in_sole_tensor_list = + IsSoleBlobAndDynamicEmpty(regst_raw_ptr); return msg; } @@ -29,6 +49,8 @@ ActorMsg ActorMsg::BuildRegstMsgToProducer(int64_t consumer, int64_t producer, msg.msg_type_ = ActorMsgType::kRegstMsg; msg.regst_wrapper_.regst = regst_raw_ptr; msg.regst_wrapper_.comm_net_token = nullptr; + // you can NOT access the regst ptr when multi nodes, because the address is in another machine + msg.regst_wrapper_.has_sole_empty_tensor_in_sole_tensor_list = false; return msg; } @@ -89,6 +111,11 @@ void* ActorMsg::comm_net_token() const { return regst_wrapper_.comm_net_token; } +bool ActorMsg::has_sole_empty_tensor_in_sole_tensor_list() const { + CHECK_EQ(msg_type_, ActorMsgType::kRegstMsg); + return regst_wrapper_.has_sole_empty_tensor_in_sole_tensor_list; +} + int64_t ActorMsg::eord_regst_desc_id() const { CHECK_EQ(msg_type_, ActorMsgType::kEordMsg); return eord_regst_desc_id_; diff --git a/oneflow/core/actor/actor_message.h b/oneflow/core/actor/actor_message.h index 8eb4967d23f29df51b9aef69f6c9152232ffa6ac..32eaf8d55d0ffb24cfd890d1f575f5e6f38b4888 100644 --- a/oneflow/core/actor/actor_message.h +++ b/oneflow/core/actor/actor_message.h @@ -39,6 +39,7 @@ class ActorMsg final { int64_t piece_id() const; int64_t act_id() const; void* comm_net_token() const; + bool has_sole_empty_tensor_in_sole_tensor_list() const; int64_t eord_regst_desc_id() const; // Serialize @@ -56,6 +57,7 @@ class ActorMsg final { Regst* regst; void* comm_net_token; RegstStatus regst_status; + bool has_sole_empty_tensor_in_sole_tensor_list; }; int64_t src_actor_id_; diff --git a/oneflow/core/actor/copy_comm_net_actor.cpp b/oneflow/core/actor/copy_comm_net_actor.cpp index 4c09210cccdeb15b933accbe650b5d00303d9e1a..649d2f892dbe99d829f60affb11de07c1156a552 100644 --- a/oneflow/core/actor/copy_comm_net_actor.cpp +++ b/oneflow/core/actor/copy_comm_net_actor.cpp @@ -55,6 +55,8 @@ bool CopyCommNetActor::NormalTryProcessReadableMsgFromOtherMachine(const ActorMs regst_ctx.regst_raw_ptr = msg.regst(); regst_ctx.producer = msg.src_actor_id(); regst_ctx.act_id = msg.act_id(); + regst_ctx.has_sole_empty_tensor_in_sole_tensor_list = + msg.has_sole_empty_tensor_in_sole_tensor_list(); CHECK(piece_id2regst_ctx_.emplace(msg.piece_id(), regst_ctx).second); return true; } @@ -66,9 +68,23 @@ void CopyCommNetActor::Act() { int64_t src_actor_id = readable_it->second.producer; int64_t src_machine_id = Global::Get()->MachineId4ActorId(src_actor_id); // writeable - void* writeable_token = GetNaiveCurWriteable("copy_out")->comm_net_token(); - // Async - Global::Get()->Read(actor_read_id_, src_machine_id, readable_token, writeable_token); + Regst* writeable_regst = GetNaiveCurWriteable("copy_out"); + if (readable_it->second.has_sole_empty_tensor_in_sole_tensor_list) { + // pass if regst dynamic body is emtpy + Blob* data_blob = writeable_regst->GetMutSoleBlob(); + TensorBackInserter back_inserter(data_blob); + back_inserter.ReserveOneEmptyTensorList(); + FullyMutTensorView* tensor_view = back_inserter.add_tensor(); + Shape empty_shape = data_blob->static_shape(); + for (int i = 0; i < empty_shape.NumAxes(); ++i) { empty_shape.Set(i, 0); } + tensor_view->set_shape(empty_shape); + LOG(INFO) << "cclog: PASS"; + } else { + void* writeable_token = writeable_regst->comm_net_token(); + // Async + Global::Get()->Read(actor_read_id_, src_machine_id, readable_token, writeable_token); + LOG(INFO) << "cclog: READ"; + } } void CopyCommNetActor::VirtualAsyncSendNaiveProducedRegstMsgToConsumer() { diff --git a/oneflow/core/actor/copy_comm_net_actor.h b/oneflow/core/actor/copy_comm_net_actor.h index 58cf354a73b71b9543578839d1566420283d7d8b..8fb7bb17c5a06d70810bb0d2647d36aa7656a652 100644 --- a/oneflow/core/actor/copy_comm_net_actor.h +++ b/oneflow/core/actor/copy_comm_net_actor.h @@ -18,6 +18,7 @@ class CopyCommNetActor final : public Actor { Regst* regst_raw_ptr; int64_t producer; int64_t act_id; + bool has_sole_empty_tensor_in_sole_tensor_list; }; void VirtualActorInit(const TaskProto&) override; diff --git a/oneflow/core/kernel/relu_kernel.h b/oneflow/core/kernel/relu_kernel.h index 5e4be8c2319f3d46cb23830e23b71ee342dc6aaa..dcd09a049e9bfc263497e9961e5f30d431aa0bbf 100644 --- a/oneflow/core/kernel/relu_kernel.h +++ b/oneflow/core/kernel/relu_kernel.h @@ -14,6 +14,7 @@ class ReluKernel final : public KernelIf { ~ReluKernel() = default; private: + bool IsStateless() const override { return true; } void ForwardDataContent(const KernelCtx&, std::function) const override; }; diff --git a/oneflow/core/register/blob.cpp b/oneflow/core/register/blob.cpp index 2849e985e27b9928ba969d3b9f0419a7e06f81b5..23a4fd47e0e6ea15d7c5e702db17b4ddfc3e6973 100644 --- a/oneflow/core/register/blob.cpp +++ b/oneflow/core/register/blob.cpp @@ -67,7 +67,6 @@ void Blob::Init(const MemoryCase& mem_case, const RtBlobDesc* blob_desc, char* h new TensorView(this, header_field(), dptr())); begin_mut_tensor_.reset(new DataOnlyMutTensorView( this, mut_header_field(), mut_dptr())); - tensor_back_inserter_.reset(new TensorBackInserter(this)); int64_t* shape_ptr = mut_header_field(); shape_view_.reset(new ShapeView(shape_ptr, static_shape().NumAxes())); if (blob_desc->is_dynamic()) { diff --git a/oneflow/core/register/blob.h b/oneflow/core/register/blob.h index 1f52fc2d09c1142f54250f96db70039cd74e03e2..789e7468890d80ed687bd9c6c39b6e71266deb49 100644 --- a/oneflow/core/register/blob.h +++ b/oneflow/core/register/blob.h @@ -54,7 +54,6 @@ class Blob final { bool IsEndTensor(const DataOnlyMutTensorView& tensor) const; friend class TensorBackInserter; - const TensorBackInserter& tensor_back_inserter() { return *tensor_back_inserter_; } // tensor list slice size_t num_of_tensor_list_slices() const; @@ -128,7 +127,6 @@ class Blob final { std::unique_ptr header_ptr_; std::unique_ptr begin_tensor_; std::unique_ptr begin_mut_tensor_; - std::unique_ptr tensor_back_inserter_; // TODO(); remove this ugly code int32_t record_num_; }; diff --git a/oneflow/core/register/register.cpp b/oneflow/core/register/register.cpp index b0f6edd924bccd19ae13e35efb2aba7e0d88c6b6..e07da743482b6f2a64bded45ca213fb1b9ee509b 100644 --- a/oneflow/core/register/register.cpp +++ b/oneflow/core/register/register.cpp @@ -38,4 +38,14 @@ void Regst::set_regst_desc(const RtRegstDesc* regst_desc) { status_.regst_desc_id = regst_desc_->regst_desc_id(); } +Blob* Regst::GetMutSoleBlob() { + CHECK_EQ(GetBlobSize(), 1); + return lbi2blob_.begin()->second.get(); +} + +const Blob* Regst::GetSoleBlob() const { + CHECK_EQ(GetBlobSize(), 1); + return lbi2blob_.begin()->second.get(); +} + } // namespace oneflow diff --git a/oneflow/core/register/register.h b/oneflow/core/register/register.h index 3ef2ba0c7a2c4de8d8bc8163b0fb3c60046e8beb..96b1eba1a3fa6f7ea7c6a9c3c411eb0017e91c01 100644 --- a/oneflow/core/register/register.h +++ b/oneflow/core/register/register.h @@ -34,6 +34,9 @@ class Regst final { const std::vector& consumers_actor_id() const; const RtRegstDesc* regst_desc() const { return regst_desc_; } Blob* GetBlobByLbi(const LogicalBlobId& lbi); + const Blob* GetSoleBlob() const; + Blob* GetMutSoleBlob(); + int64_t GetBlobSize() const { return lbi2blob_.size(); } const HashMap>& lbi2blob() const { return lbi2blob_; } Blob* packed_blob() { return packed_blob_.get(); } bool IsMaxCol() const { return col_id() == max_col_id(); } diff --git a/oneflow/python/test/test_copy_comm_net_pass_empty.py b/oneflow/python/test/test_copy_comm_net_pass_empty.py new file mode 100644 index 0000000000000000000000000000000000000000..74411e4800dd0500321edc8ab38ef5da1e9af5ec --- /dev/null +++ b/oneflow/python/test/test_copy_comm_net_pass_empty.py @@ -0,0 +1,96 @@ +import oneflow as flow +import numpy as np + + +def ccrelu(x, name): + return flow.user_op_builder(name)\ + .Op("ccrelu")\ + .Input("in",[x])\ + .Output("out")\ + .Build().RemoteBlobList()[0] + +@flow.unittest.num_nodes_required(2) +def test_multi_node_comm_net(test_case): + func_config = flow.FunctionConfig() + func_config.default_distribute_strategy(flow.distribute.consistent_strategy()) + func_config.default_data_type(flow.float) + flow.config.gpu_device_num(1) + + @flow.function(func_config) + def ReluJob(x = flow.FixedTensorDef((10, 2))): + with flow.fixed_placement("gpu", "0:0"): + out0 = ccrelu(x, "my_op_0_0") + with flow.fixed_placement("gpu", "1:0"): + out1 = ccrelu(out0, "my_op_1_0") + with flow.fixed_placement("gpu", "0:0"): + out2 = ccrelu(out1, "my_op_print") + return out2 + index = [-2, -1, 0, 1, 2] + data = [] + for i in index: data.append(np.ones((10, 2,), dtype=np.float32) * i) + for i in range(5): + ret = ReluJob(data[i]).get().ndarray() + print(ret) + if index[i] > 0 : + test_case.assertTrue(np.array_equal(ret, np.ones((10, 2,), dtype=np.float32) * index[i])) + else: + test_case.assertTrue(np.array_equal(ret, np.zeros((10, 2,), dtype=np.float32))) + +@flow.unittest.num_nodes_required(2) +def test_multi_node_comm_net_dynamic(test_case): + func_config = flow.FunctionConfig() + func_config.default_distribute_strategy(flow.distribute.mirrored_strategy()) + func_config.default_placement_scope(flow.fixed_placement("gpu", "0:0")) + func_config.default_data_type(flow.float) + flow.config.machine_num(2) + flow.config.gpu_device_num(1) + + @flow.function(func_config) + def ReluJob(x = flow.MirroredTensorDef((10, 2))): + with flow.fixed_placement("gpu", "0:0"): + out0 = flow.keras.activations.relu(x) + with flow.fixed_placement("gpu", "1:0"): + out1 = flow.keras.activations.relu(out0) + with flow.fixed_placement("gpu", "0:0"): + out2 = flow.keras.activations.relu(out1) + return out2 + index = [-2, -1, 0, 1, 2] + data = [] + for i in index: data.append(np.ones((5, 2,), dtype=np.float32) * i) + for i in range(5): + ret = ReluJob([data[i]]).get().ndarray_list()[0] + print(ret) + if index[i] > 0 : + test_case.assertTrue(np.array_equal(ret, np.ones((5, 2,), dtype=np.float32) * index[i])) + else: + test_case.assertTrue(np.array_equal(ret, np.zeros((5, 2,), dtype=np.float32))) + +@flow.unittest.num_nodes_required(2) +def test_multi_node_comm_net_dynamic_empty(test_case): + func_config = flow.FunctionConfig() + func_config.default_distribute_strategy(flow.distribute.mirrored_strategy()) + func_config.default_placement_scope(flow.fixed_placement("cpu", "0:0")) + func_config.default_data_type(flow.float) + flow.config.machine_num(2) + flow.config.gpu_device_num(1) + + @flow.function(func_config) + def ReluJob(x = flow.MirroredTensorDef((10, 2))): + with flow.fixed_placement("cpu", "0:0"): + out0 = flow.keras.activations.relu(x) + with flow.fixed_placement("cpu", "1:0"): + out1 = flow.keras.activations.relu(out0) + with flow.fixed_placement("cpu", "0:0"): + out2 = flow.keras.activations.relu(out1) + return out2 + index = [-2, -1, 0, 1, 2] + data = [] + for i in index: data.append(np.ones((0,0,), dtype=np.float32) * i) + for i in range(5): + ret = ReluJob([data[i]]).get().ndarray_list()[0] + print(ret) + if index[i] > 0 : + test_case.assertTrue(np.array_equal(ret, np.ones((0, 0,), dtype=np.float32) * index[i])) + else: + test_case.assertTrue(np.array_equal(ret, np.zeros((0, 0,), dtype=np.float32))) +