From 65534c47625239ce68b5e5c02ae72c3bb1532214 Mon Sep 17 00:00:00 2001 From: Abhinav Arora Date: Mon, 26 Mar 2018 19:11:54 -0700 Subject: [PATCH] Fluid channels should match the semantics of Go Channels (#9265) * Fluid Channel should match Go Channel in Semantics * Fix Python channel_send * Address code rveiew feedback * Fix open_files_op.cc * Add description to Channel Asserts --- paddle/fluid/framework/channel.h | 93 +++++++++++-------- paddle/fluid/framework/channel_impl.h | 35 ++++--- paddle/fluid/framework/channel_test.cc | 93 +++++++++++++++---- paddle/fluid/operators/channel_send_op.cc | 25 +---- .../operators/concurrency/channel_util.cc | 14 +-- .../operators/concurrency/channel_util.h | 2 +- .../reader/create_double_buffer_reader_op.cc | 4 +- .../fluid/operators/reader/open_files_op.cc | 9 +- python/paddle/fluid/concurrency.py | 15 +-- 9 files changed, 172 insertions(+), 118 deletions(-) diff --git a/paddle/fluid/framework/channel.h b/paddle/fluid/framework/channel.h index adfaba26ac..019bea600f 100644 --- a/paddle/fluid/framework/channel.h +++ b/paddle/fluid/framework/channel.h @@ -34,7 +34,7 @@ class Channel { public: virtual bool CanSend() = 0; virtual bool CanReceive() = 0; - virtual bool Send(T*) = 0; + virtual void Send(T*) = 0; virtual bool Receive(T*) = 0; virtual size_t Cap() = 0; virtual void Lock() = 0; @@ -84,69 +84,81 @@ class ChannelHolder { } template - bool Send(T* data) { - if (!IsInitialized()) return false; - PADDLE_ENFORCE_EQ(holder_->Type(), std::type_index(typeid(T))); + void Send(T* data) { + PADDLE_ENFORCE_EQ(IsInitialized(), true, + "The Channel hasn't been initialized"); + PADDLE_ENFORCE_EQ( + holder_->Type(), std::type_index(typeid(T)), + "Channel type is not same as the type of the data being sent"); // Static cast should be safe because we have ensured that types are same Channel* channel = static_cast*>(holder_->Ptr()); - return channel != nullptr ? channel->Send(data) : false; + PADDLE_ENFORCE_EQ(channel != nullptr, true, "Channel should not be null."); + channel->Send(data); } template bool Receive(T* data) { - if (!IsInitialized()) return false; - PADDLE_ENFORCE_EQ(holder_->Type(), std::type_index(typeid(T))); + PADDLE_ENFORCE_EQ(IsInitialized(), true, + "The Channel hasn't been initialized"); + PADDLE_ENFORCE_EQ( + holder_->Type(), std::type_index(typeid(T)), + "Channel type is not same as the type of the data being sent"); Channel* channel = static_cast*>(holder_->Ptr()); - return channel != nullptr ? channel->Receive(data) : false; + PADDLE_ENFORCE_EQ(channel != nullptr, true, "Channel should not be null."); + return channel->Receive(data); } bool IsClosed() { - if (IsInitialized()) { - return holder_->IsClosed(); - } - return false; + PADDLE_ENFORCE_EQ(IsInitialized(), true, + "The Channel hasn't been initialized"); + return holder_->IsClosed(); } bool CanSend() { - if (IsInitialized()) { - return holder_->CanSend(); - } - return false; + PADDLE_ENFORCE_EQ(IsInitialized(), true, + "The Channel hasn't been initialized"); + return holder_->CanSend(); } bool CanReceive() { - if (IsInitialized()) { - return holder_->CanReceive(); - } - return false; + PADDLE_ENFORCE_EQ(IsInitialized(), true, + "The Channel hasn't been initialized"); + return holder_->CanReceive(); } void close() { - if (IsInitialized()) holder_->Close(); + PADDLE_ENFORCE_EQ(IsInitialized(), true, + "The Channel hasn't been initialized"); + holder_->Close(); } size_t Cap() { - if (IsInitialized()) return holder_->Cap(); - return -1; + PADDLE_ENFORCE_EQ(IsInitialized(), true, + "The Channel hasn't been initialized"); + return holder_->Cap(); } void Lock() { - if (IsInitialized()) holder_->Lock(); + PADDLE_ENFORCE_EQ(IsInitialized(), true, + "The Channel hasn't been initialized"); + holder_->Lock(); } void Unlock() { - if (IsInitialized()) holder_->Unlock(); + PADDLE_ENFORCE_EQ(IsInitialized(), true, + "The Channel hasn't been initialized"); + holder_->Unlock(); } template void AddToSendQ(const void* referrer, T* data, std::shared_ptr cond, std::function cb) { - if (IsInitialized()) { - Channel* channel = static_cast*>(holder_->Ptr()); - if (channel != nullptr) { - channel->AddToSendQ(referrer, data, cond, cb); - } + PADDLE_ENFORCE_EQ(IsInitialized(), true, + "The Channel hasn't been initialized"); + Channel* channel = static_cast*>(holder_->Ptr()); + if (channel != nullptr) { + channel->AddToSendQ(referrer, data, cond, cb); } } @@ -154,26 +166,31 @@ class ChannelHolder { void AddToReceiveQ(const void* referrer, T* data, std::shared_ptr cond, std::function cb) { - if (IsInitialized()) { - Channel* channel = static_cast*>(holder_->Ptr()); - if (channel != nullptr) { - channel->AddToReceiveQ(referrer, data, cond, cb); - } + PADDLE_ENFORCE_EQ(IsInitialized(), true, + "The Channel hasn't been initialized"); + Channel* channel = static_cast*>(holder_->Ptr()); + if (channel != nullptr) { + channel->AddToReceiveQ(referrer, data, cond, cb); } } void RemoveFromSendQ(const void* referrer) { - if (IsInitialized()) holder_->RemoveFromSendQ(referrer); + PADDLE_ENFORCE_EQ(IsInitialized(), true, + "The Channel hasn't been initialized"); + holder_->RemoveFromSendQ(referrer); } void RemoveFromReceiveQ(const void* referrer) { - if (IsInitialized()) holder_->RemoveFromReceiveQ(referrer); + PADDLE_ENFORCE_EQ(IsInitialized(), true, + "The Channel hasn't been initialized"); + holder_->RemoveFromReceiveQ(referrer); } inline bool IsInitialized() const { return holder_ != nullptr; } inline const std::type_index Type() { - PADDLE_ENFORCE_EQ(IsInitialized(), true); + PADDLE_ENFORCE_EQ(IsInitialized(), true, + "The Channel hasn't been initialized"); return holder_->Type(); } diff --git a/paddle/fluid/framework/channel_impl.h b/paddle/fluid/framework/channel_impl.h index 457abbf373..378a0bab1c 100644 --- a/paddle/fluid/framework/channel_impl.h +++ b/paddle/fluid/framework/channel_impl.h @@ -31,7 +31,7 @@ class ChannelImpl : public paddle::framework::Channel { public: virtual bool CanSend(); virtual bool CanReceive(); - virtual bool Send(T *); + virtual void Send(T *); virtual bool Receive(T *); virtual size_t Cap() { return cap_; } virtual void Lock(); @@ -76,10 +76,9 @@ class ChannelImpl : public paddle::framework::Channel { } }; - bool send_return(bool value) { + void send_return() { send_ctr--; destructor_cond_.notify_all(); - return value; } bool recv_return(bool value) { @@ -118,15 +117,15 @@ bool ChannelImpl::CanReceive() { } template -bool ChannelImpl::Send(T *item) { +void ChannelImpl::Send(T *item) { send_ctr++; std::unique_lock lock{mu_}; - // If channel is closed, do nothing + // If channel is closed, throw exception if (closed_) { lock.unlock(); - // TODO(abhinavarora) Should panic on closed channel - return send_return(false); + send_return(); + PADDLE_THROW("Cannot send on closed channel"); } // If there is a receiver, directly pass the value we want @@ -143,7 +142,7 @@ bool ChannelImpl::Send(T *item) { if (m->callback != nullptr) do_send = m->callback(ChannelAction::SEND); if (do_send) *(m->data) = std::move(*item); - else + else { // We cannot do the data transfer because // this QueueMessage was added by Select // and some other case was executed. @@ -151,12 +150,17 @@ bool ChannelImpl::Send(T *item) { // We do not care about notifying other // because they would have been notified // by the executed select case. - return send_return(Send(item)); + lock.unlock(); + Send(item); + send_return(); + return; + } // Wake up the blocked process and unlock m->Notify(); lock.unlock(); - return send_return(true); + send_return(); + return; } // Unbuffered channel will always bypass this @@ -167,7 +171,8 @@ bool ChannelImpl::Send(T *item) { buf_.push_back(std::move(*item)); // Release lock and return true lock.unlock(); - return send_return(true); + send_return(); + return; } // Block on channel, because some receiver will complete @@ -175,8 +180,12 @@ bool ChannelImpl::Send(T *item) { auto m = std::make_shared(item); sendq.push_back(m); m->Wait(lock); - // TODO(abhinavarora) Should panic on closed channel - return send_return(!m->chan_closed); + if (m->chan_closed) { + lock.unlock(); + send_return(); + PADDLE_THROW("Cannot send on closed channel"); + } + send_return(); } template diff --git a/paddle/fluid/framework/channel_test.cc b/paddle/fluid/framework/channel_test.cc index 73be5cdbe2..e2380bb54b 100644 --- a/paddle/fluid/framework/channel_test.cc +++ b/paddle/fluid/framework/channel_test.cc @@ -16,7 +16,6 @@ limitations under the License. */ #include #include - #include "gtest/gtest.h" using paddle::framework::Channel; @@ -41,7 +40,7 @@ void RecevingOrderEqualToSendingOrder(Channel *ch) { unsigned sum_send = 0; std::thread t([&]() { for (int i = 0; i < 5; i++) { - EXPECT_EQ(ch->Send(&i), true); + ch->Send(&i); sum_send += i; } }); @@ -61,7 +60,7 @@ TEST(Channel, SufficientBufferSizeDoesntBlock) { const size_t buffer_size = 10; auto ch = MakeChannel(buffer_size); for (size_t i = 0; i < buffer_size; ++i) { - EXPECT_EQ(ch->Send(&i), true); // should not block + ch->Send(&i); } size_t out; @@ -82,7 +81,7 @@ void SendReceiveWithACloseChannelShouldPanic(Channel *ch) { const size_t data = 5; std::thread send_thread{[&]() { size_t i = data; - EXPECT_EQ(ch->Send(&i), true); // should not block + ch->Send(&i); // should not block }}; std::thread recv_thread{[&]() { @@ -94,12 +93,18 @@ void SendReceiveWithACloseChannelShouldPanic(Channel *ch) { send_thread.join(); recv_thread.join(); - // After closing send should return false. Receive should - // also return false as there is no data in queue. + // After closing send should panic. Receive should + // also false as there is no data in queue. CloseChannel(ch); send_thread = std::thread{[&]() { size_t i = data; - EXPECT_EQ(ch->Send(&i), false); // should return false + bool is_exception = false; + try { + ch->Send(&i); + } catch (paddle::platform::EnforceNotMet e) { + is_exception = true; + } + EXPECT_EQ(is_exception, true); }}; recv_thread = std::thread{[&]() { size_t i; @@ -129,7 +134,7 @@ TEST(Channel, ReceiveFromBufferedChannelReturnResidualValuesTest) { auto ch = MakeChannel(buffer_size); for (size_t i = 0; i < buffer_size; ++i) { - EXPECT_EQ(ch->Send(&i), true); // sending should not block + ch->Send(&i); // sending should not block } size_t out; @@ -160,9 +165,16 @@ TEST(Channel, ConcurrentSendNonConcurrentReceiveWithSufficientBufferSize) { // Try to write more than buffer size. for (size_t i = 0; i < 2 * buffer_size; ++i) { if (i < buffer_size) - EXPECT_EQ(ch->Send(&i), true); // should block after 10 iterations - else - EXPECT_EQ(ch->Send(&i), false); + ch->Send(&i); // should block after 10 iterations + else { + bool is_exception = false; + try { + ch->Send(&i); + } catch (paddle::platform::EnforceNotMet e) { + is_exception = true; + } + EXPECT_EQ(is_exception, true); + } } }); std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait 0.2 sec @@ -231,7 +243,13 @@ void ChannelCloseUnblocksSendersTest(Channel *ch, bool isBuffered) { t[i] = std::thread( [&](bool *ended, bool *success) { int data = 10; - *success = ch->Send(&data); + bool is_exception = false; + try { + ch->Send(&data); + } catch (paddle::platform::EnforceNotMet e) { + is_exception = true; + } + *success = !is_exception; *ended = true; }, &thread_ended[i], &send_success[i]); @@ -316,8 +334,11 @@ TEST(Channel, UnbufferedLessReceiveMoreSendTest) { // Try to send more number of times // than receivers for (int i = 0; i < 4; i++) { - ch->Send(&i); - sum_send += i; + try { + ch->Send(&i); + sum_send += i; + } catch (paddle::platform::EnforceNotMet e) { + } } }); for (int i = 0; i < 3; i++) { @@ -382,7 +403,13 @@ void ChannelDestroyUnblockSenders(Channel *ch, bool isBuffered) { t[i] = std::thread( [&](bool *ended, bool *success) { int data = 10; - *success = ch->Send(&data); + bool is_exception = false; + try { + ch->Send(&data); + } catch (paddle::platform::EnforceNotMet e) { + is_exception = true; + } + *success = !is_exception; *ended = true; }, &thread_ended[i], &send_success[i]); @@ -508,7 +535,7 @@ void ChannelHolderSendReceive(ChannelHolder *ch) { unsigned sum_send = 0; std::thread t([&]() { for (int i = 0; i < 5; i++) { - EXPECT_EQ(ch->Send(&i), true); + ch->Send(&i); sum_send += i; } }); @@ -541,8 +568,22 @@ TEST(ChannelHolder, ChannelUninitializedTest) { ChannelHolder *ch = new ChannelHolder(); EXPECT_EQ(ch->IsInitialized(), false); int i = 10; - EXPECT_EQ(ch->Send(&i), false); - EXPECT_EQ(ch->Receive(&i), false); + bool send_exception = false; + try { + ch->Send(&i); + } catch (paddle::platform::EnforceNotMet e) { + send_exception = true; + } + EXPECT_EQ(send_exception, true); + + bool recv_exception = false; + try { + ch->Receive(&i); + } catch (paddle::platform::EnforceNotMet e) { + recv_exception = true; + } + EXPECT_EQ(recv_exception, true); + bool is_exception = false; try { ch->Type(); @@ -669,7 +710,13 @@ void ChannelHolderCloseUnblocksSendersTest(ChannelHolder *ch, bool isBuffered) { t[i] = std::thread( [&](bool *ended, bool *success) { int data = 10; - *success = ch->Send(&data); + bool is_exception = false; + try { + ch->Send(&data); + } catch (paddle::platform::EnforceNotMet e) { + is_exception = true; + } + *success = !is_exception; *ended = true; }, &thread_ended[i], &send_success[i]); @@ -760,7 +807,13 @@ void ChannelHolderDestroyUnblockSenders(ChannelHolder *ch, bool isBuffered) { t[i] = std::thread( [&](bool *ended, bool *success) { int data = 10; - *success = ch->Send(&data); + bool is_exception = false; + try { + ch->Send(&data); + } catch (paddle::platform::EnforceNotMet e) { + is_exception = true; + } + *success = !is_exception; *ended = true; }, &thread_ended[i], &send_success[i]); diff --git a/paddle/fluid/operators/channel_send_op.cc b/paddle/fluid/operators/channel_send_op.cc index 47cf7d7efc..66d33617ed 100644 --- a/paddle/fluid/operators/channel_send_op.cc +++ b/paddle/fluid/operators/channel_send_op.cc @@ -23,21 +23,10 @@ limitations under the License. */ static constexpr char Channel[] = "Channel"; static constexpr char X[] = "X"; -static constexpr char Status[] = "Status"; -static constexpr char copy[] = "copy"; namespace paddle { namespace operators { -void SetSendStatus(const platform::Place &dev_place, - framework::Variable &status_var, bool status) { - auto cpu = platform::CPUPlace(); - auto status_tensor = - status_var.GetMutable()->mutable_data({1}, - cpu); - status_tensor[0] = status; -} - class ChannelSendOp : public framework::OperatorBase { public: ChannelSendOp(const std::string &type, @@ -51,9 +40,6 @@ class ChannelSendOp : public framework::OperatorBase { "Input(Channel) of ChannelSendOp should not be null."); PADDLE_ENFORCE(ctx->HasInput(X), "Input(X) of ChannelSendOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput(Status), - "Output(Status) of ChannelSendOp should not be null."); - ctx->SetOutputDim("Status", {1}); } private: @@ -65,10 +51,7 @@ class ChannelSendOp : public framework::OperatorBase { auto input_var = scope.FindVar(Input(X)); // Send the input data through the channel. - bool ok = concurrency::ChannelSend(ch, input_var); - - // Set the status output of the `ChannelSend` call. - SetSendStatus(dev_place, *scope.FindVar(Output(Status)), ok); + concurrency::ChannelSend(ch, input_var); } }; @@ -82,12 +65,6 @@ class ChannelSendOpMaker : public framework::OpProtoAndCheckerMaker { .AsDuplicable(); AddInput(X, "(Variable) The value which gets sent by the channel.") .AsDuplicable(); - AddOutput(Status, - "(Tensor) An LoD Tensor that returns a boolean status of the" - "result of the send operation.") - .AsDuplicable(); - AddAttr(copy, "(bool, default false) Should copy before send") - .SetDefault(false); AddComment(R"DOC( )DOC"); } diff --git a/paddle/fluid/operators/concurrency/channel_util.cc b/paddle/fluid/operators/concurrency/channel_util.cc index a483af7aff..246c99489c 100644 --- a/paddle/fluid/operators/concurrency/channel_util.cc +++ b/paddle/fluid/operators/concurrency/channel_util.cc @@ -17,20 +17,20 @@ limitations under the License. */ namespace poc = paddle::operators::concurrency; -bool poc::ChannelSend(framework::ChannelHolder *ch, framework::Variable *var) { +void poc::ChannelSend(framework::ChannelHolder *ch, framework::Variable *var) { auto type = framework::ToVarType(var->Type()); if (type == framework::proto::VarType_Type_LOD_TENSOR) - return ch->Send(var->GetMutable()); + ch->Send(var->GetMutable()); else if (type == framework::proto::VarType_Type_LOD_RANK_TABLE) - return ch->Send(var->GetMutable()); + ch->Send(var->GetMutable()); else if (type == framework::proto::VarType_Type_LOD_TENSOR_ARRAY) - return ch->Send(var->GetMutable()); + ch->Send(var->GetMutable()); else if (type == framework::proto::VarType_Type_SELECTED_ROWS) - return ch->Send(var->GetMutable()); + ch->Send(var->GetMutable()); else if (type == framework::proto::VarType_Type_READER) - return ch->Send(var->GetMutable()); + ch->Send(var->GetMutable()); else if (type == framework::proto::VarType_Type_CHANNEL) - return ch->Send(var->GetMutable()); + ch->Send(var->GetMutable()); else PADDLE_THROW("ChannelSend:Unsupported type"); } diff --git a/paddle/fluid/operators/concurrency/channel_util.h b/paddle/fluid/operators/concurrency/channel_util.h index c3674bd981..cd18ca78c6 100644 --- a/paddle/fluid/operators/concurrency/channel_util.h +++ b/paddle/fluid/operators/concurrency/channel_util.h @@ -21,7 +21,7 @@ namespace paddle { namespace operators { namespace concurrency { -bool ChannelSend(framework::ChannelHolder *ch, framework::Variable *var); +void ChannelSend(framework::ChannelHolder *ch, framework::Variable *var); bool ChannelReceive(framework::ChannelHolder *ch, framework::Variable *var); void ChannelAddToSendQ(framework::ChannelHolder *ch, const void *referrer, diff --git a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc index 76cdb794cc..141a3eb935 100644 --- a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc +++ b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc @@ -166,7 +166,9 @@ void DoubleBufferReader::PrefetchThreadFunc() { std::swap(gpu_batch, batch.payloads_); } - if (!buffer_->Send(&batch)) { + try { + buffer_->Send(&batch); + } catch (paddle::platform::EnforceNotMet e) { VLOG(5) << "WARNING: The double buffer channel has been closed. The " "prefetch thread will terminate."; break; diff --git a/paddle/fluid/operators/reader/open_files_op.cc b/paddle/fluid/operators/reader/open_files_op.cc index 414c76fea0..b6ac7b21d5 100644 --- a/paddle/fluid/operators/reader/open_files_op.cc +++ b/paddle/fluid/operators/reader/open_files_op.cc @@ -146,14 +146,19 @@ void MultipleReader::PrefetchThreadFunc(std::string file_name, while (reader->HasNext()) { std::vector ins; reader->ReadNext(&ins); - if (!buffer_->Send(&ins)) { + try { + buffer_->Send(&ins); + } catch (paddle::platform::EnforceNotMet e) { VLOG(5) << "WARNING: The buffer channel has been closed. The prefetch " "thread of file '" << file_name << "' will terminate."; break; } } - if (!available_thread_idx_->Send(&thread_idx)) { + + try { + available_thread_idx_->Send(&thread_idx); + } catch (paddle::platform::EnforceNotMet e) { VLOG(5) << "WARNING: The available_thread_idx_ channel has been closed. " "Fail to send thread_idx."; } diff --git a/python/paddle/fluid/concurrency.py b/python/paddle/fluid/concurrency.py index d65e1a6858..a0f5ef2329 100644 --- a/python/paddle/fluid/concurrency.py +++ b/python/paddle/fluid/concurrency.py @@ -339,11 +339,6 @@ def channel_send(channel, value, is_copy=False): main_program = helper.main_program channel_send_block = main_program.current_block() - status = helper.create_variable( - name=unique_name.generate('status'), - type=core.VarDesc.VarType.LOD_TENSOR, - dtype=core.VarDesc.VarType.BOOL) - X = value if is_copy is True: @@ -359,15 +354,11 @@ def channel_send(channel, value, is_copy=False): type="assign_op", inputs={"X": value}, outputs={"Out": copied_X}) X = copied_X - channel_send_op = channel_send_block.append_op( - type="channel_send", - inputs={ + channel_send_block.append_op( + type="channel_send", inputs={ "Channel": channel, "X": X, - }, - outputs={"Status": status}) - - return status + }) def channel_recv(channel, return_value): -- GitLab