diff --git a/paddle/fluid/framework/channel.h b/paddle/fluid/framework/channel.h index 8ca1f2aa47b9d9ecc4da3f8d0d917cc07644bf08..be578059388e627778d485365e5a735158603fe5 100644 --- a/paddle/fluid/framework/channel.h +++ b/paddle/fluid/framework/channel.h @@ -100,8 +100,7 @@ class ChannelHolder { virtual ~Placeholder() {} virtual const std::type_index Type() const = 0; virtual void* Ptr() const = 0; - virtual void Close() const = 0; - std::type_info type_; + virtual void Close() = 0; }; template @@ -116,7 +115,7 @@ class ChannelHolder { if (channel_) channel_->Close(); } - std::unique_ptr*> channel_; + std::unique_ptr> channel_; const std::type_index type_; }; diff --git a/paddle/fluid/framework/channel_test.cc b/paddle/fluid/framework/channel_test.cc index 953fa40fec8c0480726b44760a3a4c7f59c80a85..2c4e622bd789c691c38a8810fe5e09e464a8cf1f 100644 --- a/paddle/fluid/framework/channel_test.cc +++ b/paddle/fluid/framework/channel_test.cc @@ -20,6 +20,7 @@ limitations under the License. */ #include "gtest/gtest.h" using paddle::framework::Channel; +using paddle::framework::ChannelHolder; using paddle::framework::MakeChannel; using paddle::framework::CloseChannel; using paddle::framework::details::Buffered; @@ -508,3 +509,36 @@ TEST(Channel, UnbufferedChannelDestroyUnblocksSendersTest) { auto ch = MakeChannel(0); ChannelDestroyUnblockSenders(ch); } + +void ChannelHolderSendReceive(ChannelHolder *ch) { + unsigned sum_send = 0; + std::thread t([&]() { + for (int i = 0; i < 5; i++) { + EXPECT_EQ(ch->Send(&i), true); + sum_send += i; + } + }); + for (int i = 0; i < 5; i++) { + int recv; + EXPECT_EQ(ch->Receive(&recv), true); + EXPECT_EQ(recv, i); + } + + ch->close(); + t.join(); + EXPECT_EQ(sum_send, 10U); +} + +TEST(ChannelHolder, ChannelHolderBufferedSendReceiveTest) { + ChannelHolder *ch = new ChannelHolder(); + ch->Reset(10); + ChannelHolderSendReceive(ch); + delete ch; +} + +TEST(ChannelHolder, ChannelHolderUnBufferedSendReceiveTest) { + ChannelHolder *ch = new ChannelHolder(); + ch->Reset(0); + ChannelHolderSendReceive(ch); + delete ch; +}