未验证 提交 65534c47 编写于 作者: A Abhinav Arora 提交者: GitHub

Fluid channels should match the semantics of Go Channels (#9265)

* Fluid Channel should match Go Channel in Semantics

* Fix Python channel_send

* Address code rveiew feedback

* Fix open_files_op.cc

* Add description to Channel Asserts
上级 ab5a3560
...@@ -34,7 +34,7 @@ class Channel { ...@@ -34,7 +34,7 @@ class Channel {
public: public:
virtual bool CanSend() = 0; virtual bool CanSend() = 0;
virtual bool CanReceive() = 0; virtual bool CanReceive() = 0;
virtual bool Send(T*) = 0; virtual void Send(T*) = 0;
virtual bool Receive(T*) = 0; virtual bool Receive(T*) = 0;
virtual size_t Cap() = 0; virtual size_t Cap() = 0;
virtual void Lock() = 0; virtual void Lock() = 0;
...@@ -84,69 +84,81 @@ class ChannelHolder { ...@@ -84,69 +84,81 @@ class ChannelHolder {
} }
template <typename T> template <typename T>
bool Send(T* data) { void Send(T* data) {
if (!IsInitialized()) return false; PADDLE_ENFORCE_EQ(IsInitialized(), true,
PADDLE_ENFORCE_EQ(holder_->Type(), std::type_index(typeid(T))); "The Channel hasn't been initialized");
PADDLE_ENFORCE_EQ(
holder_->Type(), std::type_index(typeid(T)),
"Channel type is not same as the type of the data being sent");
// Static cast should be safe because we have ensured that types are same // Static cast should be safe because we have ensured that types are same
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr()); Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
return channel != nullptr ? channel->Send(data) : false; PADDLE_ENFORCE_EQ(channel != nullptr, true, "Channel should not be null.");
channel->Send(data);
} }
template <typename T> template <typename T>
bool Receive(T* data) { bool Receive(T* data) {
if (!IsInitialized()) return false; PADDLE_ENFORCE_EQ(IsInitialized(), true,
PADDLE_ENFORCE_EQ(holder_->Type(), std::type_index(typeid(T))); "The Channel hasn't been initialized");
PADDLE_ENFORCE_EQ(
holder_->Type(), std::type_index(typeid(T)),
"Channel type is not same as the type of the data being sent");
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr()); Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
return channel != nullptr ? channel->Receive(data) : false; PADDLE_ENFORCE_EQ(channel != nullptr, true, "Channel should not be null.");
return channel->Receive(data);
} }
bool IsClosed() { bool IsClosed() {
if (IsInitialized()) { PADDLE_ENFORCE_EQ(IsInitialized(), true,
return holder_->IsClosed(); "The Channel hasn't been initialized");
} return holder_->IsClosed();
return false;
} }
bool CanSend() { bool CanSend() {
if (IsInitialized()) { PADDLE_ENFORCE_EQ(IsInitialized(), true,
return holder_->CanSend(); "The Channel hasn't been initialized");
} return holder_->CanSend();
return false;
} }
bool CanReceive() { bool CanReceive() {
if (IsInitialized()) { PADDLE_ENFORCE_EQ(IsInitialized(), true,
return holder_->CanReceive(); "The Channel hasn't been initialized");
} return holder_->CanReceive();
return false;
} }
void close() { void close() {
if (IsInitialized()) holder_->Close(); PADDLE_ENFORCE_EQ(IsInitialized(), true,
"The Channel hasn't been initialized");
holder_->Close();
} }
size_t Cap() { size_t Cap() {
if (IsInitialized()) return holder_->Cap(); PADDLE_ENFORCE_EQ(IsInitialized(), true,
return -1; "The Channel hasn't been initialized");
return holder_->Cap();
} }
void Lock() { void Lock() {
if (IsInitialized()) holder_->Lock(); PADDLE_ENFORCE_EQ(IsInitialized(), true,
"The Channel hasn't been initialized");
holder_->Lock();
} }
void Unlock() { void Unlock() {
if (IsInitialized()) holder_->Unlock(); PADDLE_ENFORCE_EQ(IsInitialized(), true,
"The Channel hasn't been initialized");
holder_->Unlock();
} }
template <typename T> template <typename T>
void AddToSendQ(const void* referrer, T* data, void AddToSendQ(const void* referrer, T* data,
std::shared_ptr<std::condition_variable_any> cond, std::shared_ptr<std::condition_variable_any> cond,
std::function<bool(ChannelAction)> cb) { std::function<bool(ChannelAction)> cb) {
if (IsInitialized()) { PADDLE_ENFORCE_EQ(IsInitialized(), true,
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr()); "The Channel hasn't been initialized");
if (channel != nullptr) { Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
channel->AddToSendQ(referrer, data, cond, cb); if (channel != nullptr) {
} channel->AddToSendQ(referrer, data, cond, cb);
} }
} }
...@@ -154,26 +166,31 @@ class ChannelHolder { ...@@ -154,26 +166,31 @@ class ChannelHolder {
void AddToReceiveQ(const void* referrer, T* data, void AddToReceiveQ(const void* referrer, T* data,
std::shared_ptr<std::condition_variable_any> cond, std::shared_ptr<std::condition_variable_any> cond,
std::function<bool(ChannelAction)> cb) { std::function<bool(ChannelAction)> cb) {
if (IsInitialized()) { PADDLE_ENFORCE_EQ(IsInitialized(), true,
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr()); "The Channel hasn't been initialized");
if (channel != nullptr) { Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
channel->AddToReceiveQ(referrer, data, cond, cb); if (channel != nullptr) {
} channel->AddToReceiveQ(referrer, data, cond, cb);
} }
} }
void RemoveFromSendQ(const void* referrer) { void RemoveFromSendQ(const void* referrer) {
if (IsInitialized()) holder_->RemoveFromSendQ(referrer); PADDLE_ENFORCE_EQ(IsInitialized(), true,
"The Channel hasn't been initialized");
holder_->RemoveFromSendQ(referrer);
} }
void RemoveFromReceiveQ(const void* referrer) { void RemoveFromReceiveQ(const void* referrer) {
if (IsInitialized()) holder_->RemoveFromReceiveQ(referrer); PADDLE_ENFORCE_EQ(IsInitialized(), true,
"The Channel hasn't been initialized");
holder_->RemoveFromReceiveQ(referrer);
} }
inline bool IsInitialized() const { return holder_ != nullptr; } inline bool IsInitialized() const { return holder_ != nullptr; }
inline const std::type_index Type() { inline const std::type_index Type() {
PADDLE_ENFORCE_EQ(IsInitialized(), true); PADDLE_ENFORCE_EQ(IsInitialized(), true,
"The Channel hasn't been initialized");
return holder_->Type(); return holder_->Type();
} }
......
...@@ -31,7 +31,7 @@ class ChannelImpl : public paddle::framework::Channel<T> { ...@@ -31,7 +31,7 @@ class ChannelImpl : public paddle::framework::Channel<T> {
public: public:
virtual bool CanSend(); virtual bool CanSend();
virtual bool CanReceive(); virtual bool CanReceive();
virtual bool Send(T *); virtual void Send(T *);
virtual bool Receive(T *); virtual bool Receive(T *);
virtual size_t Cap() { return cap_; } virtual size_t Cap() { return cap_; }
virtual void Lock(); virtual void Lock();
...@@ -76,10 +76,9 @@ class ChannelImpl : public paddle::framework::Channel<T> { ...@@ -76,10 +76,9 @@ class ChannelImpl : public paddle::framework::Channel<T> {
} }
}; };
bool send_return(bool value) { void send_return() {
send_ctr--; send_ctr--;
destructor_cond_.notify_all(); destructor_cond_.notify_all();
return value;
} }
bool recv_return(bool value) { bool recv_return(bool value) {
...@@ -118,15 +117,15 @@ bool ChannelImpl<T>::CanReceive() { ...@@ -118,15 +117,15 @@ bool ChannelImpl<T>::CanReceive() {
} }
template <typename T> template <typename T>
bool ChannelImpl<T>::Send(T *item) { void ChannelImpl<T>::Send(T *item) {
send_ctr++; send_ctr++;
std::unique_lock<std::recursive_mutex> lock{mu_}; std::unique_lock<std::recursive_mutex> lock{mu_};
// If channel is closed, do nothing // If channel is closed, throw exception
if (closed_) { if (closed_) {
lock.unlock(); lock.unlock();
// TODO(abhinavarora) Should panic on closed channel send_return();
return send_return(false); PADDLE_THROW("Cannot send on closed channel");
} }
// If there is a receiver, directly pass the value we want // If there is a receiver, directly pass the value we want
...@@ -143,7 +142,7 @@ bool ChannelImpl<T>::Send(T *item) { ...@@ -143,7 +142,7 @@ bool ChannelImpl<T>::Send(T *item) {
if (m->callback != nullptr) do_send = m->callback(ChannelAction::SEND); if (m->callback != nullptr) do_send = m->callback(ChannelAction::SEND);
if (do_send) if (do_send)
*(m->data) = std::move(*item); *(m->data) = std::move(*item);
else else {
// We cannot do the data transfer because // We cannot do the data transfer because
// this QueueMessage was added by Select // this QueueMessage was added by Select
// and some other case was executed. // and some other case was executed.
...@@ -151,12 +150,17 @@ bool ChannelImpl<T>::Send(T *item) { ...@@ -151,12 +150,17 @@ bool ChannelImpl<T>::Send(T *item) {
// We do not care about notifying other // We do not care about notifying other
// because they would have been notified // because they would have been notified
// by the executed select case. // by the executed select case.
return send_return(Send(item)); lock.unlock();
Send(item);
send_return();
return;
}
// Wake up the blocked process and unlock // Wake up the blocked process and unlock
m->Notify(); m->Notify();
lock.unlock(); lock.unlock();
return send_return(true); send_return();
return;
} }
// Unbuffered channel will always bypass this // Unbuffered channel will always bypass this
...@@ -167,7 +171,8 @@ bool ChannelImpl<T>::Send(T *item) { ...@@ -167,7 +171,8 @@ bool ChannelImpl<T>::Send(T *item) {
buf_.push_back(std::move(*item)); buf_.push_back(std::move(*item));
// Release lock and return true // Release lock and return true
lock.unlock(); lock.unlock();
return send_return(true); send_return();
return;
} }
// Block on channel, because some receiver will complete // Block on channel, because some receiver will complete
...@@ -175,8 +180,12 @@ bool ChannelImpl<T>::Send(T *item) { ...@@ -175,8 +180,12 @@ bool ChannelImpl<T>::Send(T *item) {
auto m = std::make_shared<QueueMessage>(item); auto m = std::make_shared<QueueMessage>(item);
sendq.push_back(m); sendq.push_back(m);
m->Wait(lock); m->Wait(lock);
// TODO(abhinavarora) Should panic on closed channel if (m->chan_closed) {
return send_return(!m->chan_closed); lock.unlock();
send_return();
PADDLE_THROW("Cannot send on closed channel");
}
send_return();
} }
template <typename T> template <typename T>
......
...@@ -16,7 +16,6 @@ limitations under the License. */ ...@@ -16,7 +16,6 @@ limitations under the License. */
#include <chrono> #include <chrono>
#include <thread> #include <thread>
#include "gtest/gtest.h" #include "gtest/gtest.h"
using paddle::framework::Channel; using paddle::framework::Channel;
...@@ -41,7 +40,7 @@ void RecevingOrderEqualToSendingOrder(Channel<int> *ch) { ...@@ -41,7 +40,7 @@ void RecevingOrderEqualToSendingOrder(Channel<int> *ch) {
unsigned sum_send = 0; unsigned sum_send = 0;
std::thread t([&]() { std::thread t([&]() {
for (int i = 0; i < 5; i++) { for (int i = 0; i < 5; i++) {
EXPECT_EQ(ch->Send(&i), true); ch->Send(&i);
sum_send += i; sum_send += i;
} }
}); });
...@@ -61,7 +60,7 @@ TEST(Channel, SufficientBufferSizeDoesntBlock) { ...@@ -61,7 +60,7 @@ TEST(Channel, SufficientBufferSizeDoesntBlock) {
const size_t buffer_size = 10; const size_t buffer_size = 10;
auto ch = MakeChannel<size_t>(buffer_size); auto ch = MakeChannel<size_t>(buffer_size);
for (size_t i = 0; i < buffer_size; ++i) { for (size_t i = 0; i < buffer_size; ++i) {
EXPECT_EQ(ch->Send(&i), true); // should not block ch->Send(&i);
} }
size_t out; size_t out;
...@@ -82,7 +81,7 @@ void SendReceiveWithACloseChannelShouldPanic(Channel<size_t> *ch) { ...@@ -82,7 +81,7 @@ void SendReceiveWithACloseChannelShouldPanic(Channel<size_t> *ch) {
const size_t data = 5; const size_t data = 5;
std::thread send_thread{[&]() { std::thread send_thread{[&]() {
size_t i = data; size_t i = data;
EXPECT_EQ(ch->Send(&i), true); // should not block ch->Send(&i); // should not block
}}; }};
std::thread recv_thread{[&]() { std::thread recv_thread{[&]() {
...@@ -94,12 +93,18 @@ void SendReceiveWithACloseChannelShouldPanic(Channel<size_t> *ch) { ...@@ -94,12 +93,18 @@ void SendReceiveWithACloseChannelShouldPanic(Channel<size_t> *ch) {
send_thread.join(); send_thread.join();
recv_thread.join(); recv_thread.join();
// After closing send should return false. Receive should // After closing send should panic. Receive should
// also return false as there is no data in queue. // also false as there is no data in queue.
CloseChannel(ch); CloseChannel(ch);
send_thread = std::thread{[&]() { send_thread = std::thread{[&]() {
size_t i = data; size_t i = data;
EXPECT_EQ(ch->Send(&i), false); // should return false bool is_exception = false;
try {
ch->Send(&i);
} catch (paddle::platform::EnforceNotMet e) {
is_exception = true;
}
EXPECT_EQ(is_exception, true);
}}; }};
recv_thread = std::thread{[&]() { recv_thread = std::thread{[&]() {
size_t i; size_t i;
...@@ -129,7 +134,7 @@ TEST(Channel, ReceiveFromBufferedChannelReturnResidualValuesTest) { ...@@ -129,7 +134,7 @@ TEST(Channel, ReceiveFromBufferedChannelReturnResidualValuesTest) {
auto ch = MakeChannel<size_t>(buffer_size); auto ch = MakeChannel<size_t>(buffer_size);
for (size_t i = 0; i < buffer_size; ++i) { for (size_t i = 0; i < buffer_size; ++i) {
EXPECT_EQ(ch->Send(&i), true); // sending should not block ch->Send(&i); // sending should not block
} }
size_t out; size_t out;
...@@ -160,9 +165,16 @@ TEST(Channel, ConcurrentSendNonConcurrentReceiveWithSufficientBufferSize) { ...@@ -160,9 +165,16 @@ TEST(Channel, ConcurrentSendNonConcurrentReceiveWithSufficientBufferSize) {
// Try to write more than buffer size. // Try to write more than buffer size.
for (size_t i = 0; i < 2 * buffer_size; ++i) { for (size_t i = 0; i < 2 * buffer_size; ++i) {
if (i < buffer_size) if (i < buffer_size)
EXPECT_EQ(ch->Send(&i), true); // should block after 10 iterations ch->Send(&i); // should block after 10 iterations
else else {
EXPECT_EQ(ch->Send(&i), false); bool is_exception = false;
try {
ch->Send(&i);
} catch (paddle::platform::EnforceNotMet e) {
is_exception = true;
}
EXPECT_EQ(is_exception, true);
}
} }
}); });
std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait 0.2 sec std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait 0.2 sec
...@@ -231,7 +243,13 @@ void ChannelCloseUnblocksSendersTest(Channel<int> *ch, bool isBuffered) { ...@@ -231,7 +243,13 @@ void ChannelCloseUnblocksSendersTest(Channel<int> *ch, bool isBuffered) {
t[i] = std::thread( t[i] = std::thread(
[&](bool *ended, bool *success) { [&](bool *ended, bool *success) {
int data = 10; int data = 10;
*success = ch->Send(&data); bool is_exception = false;
try {
ch->Send(&data);
} catch (paddle::platform::EnforceNotMet e) {
is_exception = true;
}
*success = !is_exception;
*ended = true; *ended = true;
}, },
&thread_ended[i], &send_success[i]); &thread_ended[i], &send_success[i]);
...@@ -316,8 +334,11 @@ TEST(Channel, UnbufferedLessReceiveMoreSendTest) { ...@@ -316,8 +334,11 @@ TEST(Channel, UnbufferedLessReceiveMoreSendTest) {
// Try to send more number of times // Try to send more number of times
// than receivers // than receivers
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
ch->Send(&i); try {
sum_send += i; ch->Send(&i);
sum_send += i;
} catch (paddle::platform::EnforceNotMet e) {
}
} }
}); });
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
...@@ -382,7 +403,13 @@ void ChannelDestroyUnblockSenders(Channel<int> *ch, bool isBuffered) { ...@@ -382,7 +403,13 @@ void ChannelDestroyUnblockSenders(Channel<int> *ch, bool isBuffered) {
t[i] = std::thread( t[i] = std::thread(
[&](bool *ended, bool *success) { [&](bool *ended, bool *success) {
int data = 10; int data = 10;
*success = ch->Send(&data); bool is_exception = false;
try {
ch->Send(&data);
} catch (paddle::platform::EnforceNotMet e) {
is_exception = true;
}
*success = !is_exception;
*ended = true; *ended = true;
}, },
&thread_ended[i], &send_success[i]); &thread_ended[i], &send_success[i]);
...@@ -508,7 +535,7 @@ void ChannelHolderSendReceive(ChannelHolder *ch) { ...@@ -508,7 +535,7 @@ void ChannelHolderSendReceive(ChannelHolder *ch) {
unsigned sum_send = 0; unsigned sum_send = 0;
std::thread t([&]() { std::thread t([&]() {
for (int i = 0; i < 5; i++) { for (int i = 0; i < 5; i++) {
EXPECT_EQ(ch->Send(&i), true); ch->Send(&i);
sum_send += i; sum_send += i;
} }
}); });
...@@ -541,8 +568,22 @@ TEST(ChannelHolder, ChannelUninitializedTest) { ...@@ -541,8 +568,22 @@ TEST(ChannelHolder, ChannelUninitializedTest) {
ChannelHolder *ch = new ChannelHolder(); ChannelHolder *ch = new ChannelHolder();
EXPECT_EQ(ch->IsInitialized(), false); EXPECT_EQ(ch->IsInitialized(), false);
int i = 10; int i = 10;
EXPECT_EQ(ch->Send(&i), false); bool send_exception = false;
EXPECT_EQ(ch->Receive(&i), false); try {
ch->Send(&i);
} catch (paddle::platform::EnforceNotMet e) {
send_exception = true;
}
EXPECT_EQ(send_exception, true);
bool recv_exception = false;
try {
ch->Receive(&i);
} catch (paddle::platform::EnforceNotMet e) {
recv_exception = true;
}
EXPECT_EQ(recv_exception, true);
bool is_exception = false; bool is_exception = false;
try { try {
ch->Type(); ch->Type();
...@@ -669,7 +710,13 @@ void ChannelHolderCloseUnblocksSendersTest(ChannelHolder *ch, bool isBuffered) { ...@@ -669,7 +710,13 @@ void ChannelHolderCloseUnblocksSendersTest(ChannelHolder *ch, bool isBuffered) {
t[i] = std::thread( t[i] = std::thread(
[&](bool *ended, bool *success) { [&](bool *ended, bool *success) {
int data = 10; int data = 10;
*success = ch->Send(&data); bool is_exception = false;
try {
ch->Send(&data);
} catch (paddle::platform::EnforceNotMet e) {
is_exception = true;
}
*success = !is_exception;
*ended = true; *ended = true;
}, },
&thread_ended[i], &send_success[i]); &thread_ended[i], &send_success[i]);
...@@ -760,7 +807,13 @@ void ChannelHolderDestroyUnblockSenders(ChannelHolder *ch, bool isBuffered) { ...@@ -760,7 +807,13 @@ void ChannelHolderDestroyUnblockSenders(ChannelHolder *ch, bool isBuffered) {
t[i] = std::thread( t[i] = std::thread(
[&](bool *ended, bool *success) { [&](bool *ended, bool *success) {
int data = 10; int data = 10;
*success = ch->Send(&data); bool is_exception = false;
try {
ch->Send(&data);
} catch (paddle::platform::EnforceNotMet e) {
is_exception = true;
}
*success = !is_exception;
*ended = true; *ended = true;
}, },
&thread_ended[i], &send_success[i]); &thread_ended[i], &send_success[i]);
......
...@@ -23,21 +23,10 @@ limitations under the License. */ ...@@ -23,21 +23,10 @@ limitations under the License. */
static constexpr char Channel[] = "Channel"; static constexpr char Channel[] = "Channel";
static constexpr char X[] = "X"; static constexpr char X[] = "X";
static constexpr char Status[] = "Status";
static constexpr char copy[] = "copy";
namespace paddle { namespace paddle {
namespace operators { namespace operators {
void SetSendStatus(const platform::Place &dev_place,
framework::Variable &status_var, bool status) {
auto cpu = platform::CPUPlace();
auto status_tensor =
status_var.GetMutable<framework::LoDTensor>()->mutable_data<bool>({1},
cpu);
status_tensor[0] = status;
}
class ChannelSendOp : public framework::OperatorBase { class ChannelSendOp : public framework::OperatorBase {
public: public:
ChannelSendOp(const std::string &type, ChannelSendOp(const std::string &type,
...@@ -51,9 +40,6 @@ class ChannelSendOp : public framework::OperatorBase { ...@@ -51,9 +40,6 @@ class ChannelSendOp : public framework::OperatorBase {
"Input(Channel) of ChannelSendOp should not be null."); "Input(Channel) of ChannelSendOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput(X), PADDLE_ENFORCE(ctx->HasInput(X),
"Input(X) of ChannelSendOp should not be null."); "Input(X) of ChannelSendOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput(Status),
"Output(Status) of ChannelSendOp should not be null.");
ctx->SetOutputDim("Status", {1});
} }
private: private:
...@@ -65,10 +51,7 @@ class ChannelSendOp : public framework::OperatorBase { ...@@ -65,10 +51,7 @@ class ChannelSendOp : public framework::OperatorBase {
auto input_var = scope.FindVar(Input(X)); auto input_var = scope.FindVar(Input(X));
// Send the input data through the channel. // Send the input data through the channel.
bool ok = concurrency::ChannelSend(ch, input_var); concurrency::ChannelSend(ch, input_var);
// Set the status output of the `ChannelSend` call.
SetSendStatus(dev_place, *scope.FindVar(Output(Status)), ok);
} }
}; };
...@@ -82,12 +65,6 @@ class ChannelSendOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -82,12 +65,6 @@ class ChannelSendOpMaker : public framework::OpProtoAndCheckerMaker {
.AsDuplicable(); .AsDuplicable();
AddInput(X, "(Variable) The value which gets sent by the channel.") AddInput(X, "(Variable) The value which gets sent by the channel.")
.AsDuplicable(); .AsDuplicable();
AddOutput(Status,
"(Tensor) An LoD Tensor that returns a boolean status of the"
"result of the send operation.")
.AsDuplicable();
AddAttr<bool>(copy, "(bool, default false) Should copy before send")
.SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
)DOC"); )DOC");
} }
......
...@@ -17,20 +17,20 @@ limitations under the License. */ ...@@ -17,20 +17,20 @@ limitations under the License. */
namespace poc = paddle::operators::concurrency; namespace poc = paddle::operators::concurrency;
bool poc::ChannelSend(framework::ChannelHolder *ch, framework::Variable *var) { void poc::ChannelSend(framework::ChannelHolder *ch, framework::Variable *var) {
auto type = framework::ToVarType(var->Type()); auto type = framework::ToVarType(var->Type());
if (type == framework::proto::VarType_Type_LOD_TENSOR) if (type == framework::proto::VarType_Type_LOD_TENSOR)
return ch->Send(var->GetMutable<framework::LoDTensor>()); ch->Send(var->GetMutable<framework::LoDTensor>());
else if (type == framework::proto::VarType_Type_LOD_RANK_TABLE) else if (type == framework::proto::VarType_Type_LOD_RANK_TABLE)
return ch->Send(var->GetMutable<framework::LoDRankTable>()); ch->Send(var->GetMutable<framework::LoDRankTable>());
else if (type == framework::proto::VarType_Type_LOD_TENSOR_ARRAY) else if (type == framework::proto::VarType_Type_LOD_TENSOR_ARRAY)
return ch->Send(var->GetMutable<framework::LoDTensorArray>()); ch->Send(var->GetMutable<framework::LoDTensorArray>());
else if (type == framework::proto::VarType_Type_SELECTED_ROWS) else if (type == framework::proto::VarType_Type_SELECTED_ROWS)
return ch->Send(var->GetMutable<framework::SelectedRows>()); ch->Send(var->GetMutable<framework::SelectedRows>());
else if (type == framework::proto::VarType_Type_READER) else if (type == framework::proto::VarType_Type_READER)
return ch->Send(var->GetMutable<framework::ReaderHolder>()); ch->Send(var->GetMutable<framework::ReaderHolder>());
else if (type == framework::proto::VarType_Type_CHANNEL) else if (type == framework::proto::VarType_Type_CHANNEL)
return ch->Send(var->GetMutable<framework::ChannelHolder>()); ch->Send(var->GetMutable<framework::ChannelHolder>());
else else
PADDLE_THROW("ChannelSend:Unsupported type"); PADDLE_THROW("ChannelSend:Unsupported type");
} }
......
...@@ -21,7 +21,7 @@ namespace paddle { ...@@ -21,7 +21,7 @@ namespace paddle {
namespace operators { namespace operators {
namespace concurrency { namespace concurrency {
bool ChannelSend(framework::ChannelHolder *ch, framework::Variable *var); void ChannelSend(framework::ChannelHolder *ch, framework::Variable *var);
bool ChannelReceive(framework::ChannelHolder *ch, framework::Variable *var); bool ChannelReceive(framework::ChannelHolder *ch, framework::Variable *var);
void ChannelAddToSendQ(framework::ChannelHolder *ch, const void *referrer, void ChannelAddToSendQ(framework::ChannelHolder *ch, const void *referrer,
......
...@@ -166,7 +166,9 @@ void DoubleBufferReader::PrefetchThreadFunc() { ...@@ -166,7 +166,9 @@ void DoubleBufferReader::PrefetchThreadFunc() {
std::swap(gpu_batch, batch.payloads_); std::swap(gpu_batch, batch.payloads_);
} }
if (!buffer_->Send(&batch)) { try {
buffer_->Send(&batch);
} catch (paddle::platform::EnforceNotMet e) {
VLOG(5) << "WARNING: The double buffer channel has been closed. The " VLOG(5) << "WARNING: The double buffer channel has been closed. The "
"prefetch thread will terminate."; "prefetch thread will terminate.";
break; break;
......
...@@ -146,14 +146,19 @@ void MultipleReader::PrefetchThreadFunc(std::string file_name, ...@@ -146,14 +146,19 @@ void MultipleReader::PrefetchThreadFunc(std::string file_name,
while (reader->HasNext()) { while (reader->HasNext()) {
std::vector<framework::LoDTensor> ins; std::vector<framework::LoDTensor> ins;
reader->ReadNext(&ins); reader->ReadNext(&ins);
if (!buffer_->Send(&ins)) { try {
buffer_->Send(&ins);
} catch (paddle::platform::EnforceNotMet e) {
VLOG(5) << "WARNING: The buffer channel has been closed. The prefetch " VLOG(5) << "WARNING: The buffer channel has been closed. The prefetch "
"thread of file '" "thread of file '"
<< file_name << "' will terminate."; << file_name << "' will terminate.";
break; break;
} }
} }
if (!available_thread_idx_->Send(&thread_idx)) {
try {
available_thread_idx_->Send(&thread_idx);
} catch (paddle::platform::EnforceNotMet e) {
VLOG(5) << "WARNING: The available_thread_idx_ channel has been closed. " VLOG(5) << "WARNING: The available_thread_idx_ channel has been closed. "
"Fail to send thread_idx."; "Fail to send thread_idx.";
} }
......
...@@ -339,11 +339,6 @@ def channel_send(channel, value, is_copy=False): ...@@ -339,11 +339,6 @@ def channel_send(channel, value, is_copy=False):
main_program = helper.main_program main_program = helper.main_program
channel_send_block = main_program.current_block() channel_send_block = main_program.current_block()
status = helper.create_variable(
name=unique_name.generate('status'),
type=core.VarDesc.VarType.LOD_TENSOR,
dtype=core.VarDesc.VarType.BOOL)
X = value X = value
if is_copy is True: if is_copy is True:
...@@ -359,15 +354,11 @@ def channel_send(channel, value, is_copy=False): ...@@ -359,15 +354,11 @@ def channel_send(channel, value, is_copy=False):
type="assign_op", inputs={"X": value}, outputs={"Out": copied_X}) type="assign_op", inputs={"X": value}, outputs={"Out": copied_X})
X = copied_X X = copied_X
channel_send_op = channel_send_block.append_op( channel_send_block.append_op(
type="channel_send", type="channel_send", inputs={
inputs={
"Channel": channel, "Channel": channel,
"X": X, "X": X,
}, })
outputs={"Status": status})
return status
def channel_recv(channel, return_value): def channel_recv(channel, return_value):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册