未验证 提交 65534c47 编写于 作者: A Abhinav Arora 提交者: GitHub

Fluid channels should match the semantics of Go Channels (#9265)

* Fluid Channel should match Go Channel in Semantics

* Fix Python channel_send

* Address code rveiew feedback

* Fix open_files_op.cc

* Add description to Channel Asserts
上级 ab5a3560
......@@ -34,7 +34,7 @@ class Channel {
public:
virtual bool CanSend() = 0;
virtual bool CanReceive() = 0;
virtual bool Send(T*) = 0;
virtual void Send(T*) = 0;
virtual bool Receive(T*) = 0;
virtual size_t Cap() = 0;
virtual void Lock() = 0;
......@@ -84,69 +84,81 @@ class ChannelHolder {
}
template <typename T>
bool Send(T* data) {
if (!IsInitialized()) return false;
PADDLE_ENFORCE_EQ(holder_->Type(), std::type_index(typeid(T)));
void Send(T* data) {
PADDLE_ENFORCE_EQ(IsInitialized(), true,
"The Channel hasn't been initialized");
PADDLE_ENFORCE_EQ(
holder_->Type(), std::type_index(typeid(T)),
"Channel type is not same as the type of the data being sent");
// Static cast should be safe because we have ensured that types are same
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
return channel != nullptr ? channel->Send(data) : false;
PADDLE_ENFORCE_EQ(channel != nullptr, true, "Channel should not be null.");
channel->Send(data);
}
template <typename T>
bool Receive(T* data) {
if (!IsInitialized()) return false;
PADDLE_ENFORCE_EQ(holder_->Type(), std::type_index(typeid(T)));
PADDLE_ENFORCE_EQ(IsInitialized(), true,
"The Channel hasn't been initialized");
PADDLE_ENFORCE_EQ(
holder_->Type(), std::type_index(typeid(T)),
"Channel type is not same as the type of the data being sent");
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
return channel != nullptr ? channel->Receive(data) : false;
PADDLE_ENFORCE_EQ(channel != nullptr, true, "Channel should not be null.");
return channel->Receive(data);
}
bool IsClosed() {
if (IsInitialized()) {
return holder_->IsClosed();
}
return false;
PADDLE_ENFORCE_EQ(IsInitialized(), true,
"The Channel hasn't been initialized");
return holder_->IsClosed();
}
bool CanSend() {
if (IsInitialized()) {
return holder_->CanSend();
}
return false;
PADDLE_ENFORCE_EQ(IsInitialized(), true,
"The Channel hasn't been initialized");
return holder_->CanSend();
}
bool CanReceive() {
if (IsInitialized()) {
return holder_->CanReceive();
}
return false;
PADDLE_ENFORCE_EQ(IsInitialized(), true,
"The Channel hasn't been initialized");
return holder_->CanReceive();
}
void close() {
if (IsInitialized()) holder_->Close();
PADDLE_ENFORCE_EQ(IsInitialized(), true,
"The Channel hasn't been initialized");
holder_->Close();
}
size_t Cap() {
if (IsInitialized()) return holder_->Cap();
return -1;
PADDLE_ENFORCE_EQ(IsInitialized(), true,
"The Channel hasn't been initialized");
return holder_->Cap();
}
void Lock() {
if (IsInitialized()) holder_->Lock();
PADDLE_ENFORCE_EQ(IsInitialized(), true,
"The Channel hasn't been initialized");
holder_->Lock();
}
void Unlock() {
if (IsInitialized()) holder_->Unlock();
PADDLE_ENFORCE_EQ(IsInitialized(), true,
"The Channel hasn't been initialized");
holder_->Unlock();
}
template <typename T>
void AddToSendQ(const void* referrer, T* data,
std::shared_ptr<std::condition_variable_any> cond,
std::function<bool(ChannelAction)> cb) {
if (IsInitialized()) {
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
if (channel != nullptr) {
channel->AddToSendQ(referrer, data, cond, cb);
}
PADDLE_ENFORCE_EQ(IsInitialized(), true,
"The Channel hasn't been initialized");
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
if (channel != nullptr) {
channel->AddToSendQ(referrer, data, cond, cb);
}
}
......@@ -154,26 +166,31 @@ class ChannelHolder {
void AddToReceiveQ(const void* referrer, T* data,
std::shared_ptr<std::condition_variable_any> cond,
std::function<bool(ChannelAction)> cb) {
if (IsInitialized()) {
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
if (channel != nullptr) {
channel->AddToReceiveQ(referrer, data, cond, cb);
}
PADDLE_ENFORCE_EQ(IsInitialized(), true,
"The Channel hasn't been initialized");
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
if (channel != nullptr) {
channel->AddToReceiveQ(referrer, data, cond, cb);
}
}
void RemoveFromSendQ(const void* referrer) {
if (IsInitialized()) holder_->RemoveFromSendQ(referrer);
PADDLE_ENFORCE_EQ(IsInitialized(), true,
"The Channel hasn't been initialized");
holder_->RemoveFromSendQ(referrer);
}
void RemoveFromReceiveQ(const void* referrer) {
if (IsInitialized()) holder_->RemoveFromReceiveQ(referrer);
PADDLE_ENFORCE_EQ(IsInitialized(), true,
"The Channel hasn't been initialized");
holder_->RemoveFromReceiveQ(referrer);
}
inline bool IsInitialized() const { return holder_ != nullptr; }
inline const std::type_index Type() {
PADDLE_ENFORCE_EQ(IsInitialized(), true);
PADDLE_ENFORCE_EQ(IsInitialized(), true,
"The Channel hasn't been initialized");
return holder_->Type();
}
......
......@@ -31,7 +31,7 @@ class ChannelImpl : public paddle::framework::Channel<T> {
public:
virtual bool CanSend();
virtual bool CanReceive();
virtual bool Send(T *);
virtual void Send(T *);
virtual bool Receive(T *);
virtual size_t Cap() { return cap_; }
virtual void Lock();
......@@ -76,10 +76,9 @@ class ChannelImpl : public paddle::framework::Channel<T> {
}
};
bool send_return(bool value) {
void send_return() {
send_ctr--;
destructor_cond_.notify_all();
return value;
}
bool recv_return(bool value) {
......@@ -118,15 +117,15 @@ bool ChannelImpl<T>::CanReceive() {
}
template <typename T>
bool ChannelImpl<T>::Send(T *item) {
void ChannelImpl<T>::Send(T *item) {
send_ctr++;
std::unique_lock<std::recursive_mutex> lock{mu_};
// If channel is closed, do nothing
// If channel is closed, throw exception
if (closed_) {
lock.unlock();
// TODO(abhinavarora) Should panic on closed channel
return send_return(false);
send_return();
PADDLE_THROW("Cannot send on closed channel");
}
// If there is a receiver, directly pass the value we want
......@@ -143,7 +142,7 @@ bool ChannelImpl<T>::Send(T *item) {
if (m->callback != nullptr) do_send = m->callback(ChannelAction::SEND);
if (do_send)
*(m->data) = std::move(*item);
else
else {
// We cannot do the data transfer because
// this QueueMessage was added by Select
// and some other case was executed.
......@@ -151,12 +150,17 @@ bool ChannelImpl<T>::Send(T *item) {
// We do not care about notifying other
// because they would have been notified
// by the executed select case.
return send_return(Send(item));
lock.unlock();
Send(item);
send_return();
return;
}
// Wake up the blocked process and unlock
m->Notify();
lock.unlock();
return send_return(true);
send_return();
return;
}
// Unbuffered channel will always bypass this
......@@ -167,7 +171,8 @@ bool ChannelImpl<T>::Send(T *item) {
buf_.push_back(std::move(*item));
// Release lock and return true
lock.unlock();
return send_return(true);
send_return();
return;
}
// Block on channel, because some receiver will complete
......@@ -175,8 +180,12 @@ bool ChannelImpl<T>::Send(T *item) {
auto m = std::make_shared<QueueMessage>(item);
sendq.push_back(m);
m->Wait(lock);
// TODO(abhinavarora) Should panic on closed channel
return send_return(!m->chan_closed);
if (m->chan_closed) {
lock.unlock();
send_return();
PADDLE_THROW("Cannot send on closed channel");
}
send_return();
}
template <typename T>
......
......@@ -16,7 +16,6 @@ limitations under the License. */
#include <chrono>
#include <thread>
#include "gtest/gtest.h"
using paddle::framework::Channel;
......@@ -41,7 +40,7 @@ void RecevingOrderEqualToSendingOrder(Channel<int> *ch) {
unsigned sum_send = 0;
std::thread t([&]() {
for (int i = 0; i < 5; i++) {
EXPECT_EQ(ch->Send(&i), true);
ch->Send(&i);
sum_send += i;
}
});
......@@ -61,7 +60,7 @@ TEST(Channel, SufficientBufferSizeDoesntBlock) {
const size_t buffer_size = 10;
auto ch = MakeChannel<size_t>(buffer_size);
for (size_t i = 0; i < buffer_size; ++i) {
EXPECT_EQ(ch->Send(&i), true); // should not block
ch->Send(&i);
}
size_t out;
......@@ -82,7 +81,7 @@ void SendReceiveWithACloseChannelShouldPanic(Channel<size_t> *ch) {
const size_t data = 5;
std::thread send_thread{[&]() {
size_t i = data;
EXPECT_EQ(ch->Send(&i), true); // should not block
ch->Send(&i); // should not block
}};
std::thread recv_thread{[&]() {
......@@ -94,12 +93,18 @@ void SendReceiveWithACloseChannelShouldPanic(Channel<size_t> *ch) {
send_thread.join();
recv_thread.join();
// After closing send should return false. Receive should
// also return false as there is no data in queue.
// After closing send should panic. Receive should
// also false as there is no data in queue.
CloseChannel(ch);
send_thread = std::thread{[&]() {
size_t i = data;
EXPECT_EQ(ch->Send(&i), false); // should return false
bool is_exception = false;
try {
ch->Send(&i);
} catch (paddle::platform::EnforceNotMet e) {
is_exception = true;
}
EXPECT_EQ(is_exception, true);
}};
recv_thread = std::thread{[&]() {
size_t i;
......@@ -129,7 +134,7 @@ TEST(Channel, ReceiveFromBufferedChannelReturnResidualValuesTest) {
auto ch = MakeChannel<size_t>(buffer_size);
for (size_t i = 0; i < buffer_size; ++i) {
EXPECT_EQ(ch->Send(&i), true); // sending should not block
ch->Send(&i); // sending should not block
}
size_t out;
......@@ -160,9 +165,16 @@ TEST(Channel, ConcurrentSendNonConcurrentReceiveWithSufficientBufferSize) {
// Try to write more than buffer size.
for (size_t i = 0; i < 2 * buffer_size; ++i) {
if (i < buffer_size)
EXPECT_EQ(ch->Send(&i), true); // should block after 10 iterations
else
EXPECT_EQ(ch->Send(&i), false);
ch->Send(&i); // should block after 10 iterations
else {
bool is_exception = false;
try {
ch->Send(&i);
} catch (paddle::platform::EnforceNotMet e) {
is_exception = true;
}
EXPECT_EQ(is_exception, true);
}
}
});
std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait 0.2 sec
......@@ -231,7 +243,13 @@ void ChannelCloseUnblocksSendersTest(Channel<int> *ch, bool isBuffered) {
t[i] = std::thread(
[&](bool *ended, bool *success) {
int data = 10;
*success = ch->Send(&data);
bool is_exception = false;
try {
ch->Send(&data);
} catch (paddle::platform::EnforceNotMet e) {
is_exception = true;
}
*success = !is_exception;
*ended = true;
},
&thread_ended[i], &send_success[i]);
......@@ -316,8 +334,11 @@ TEST(Channel, UnbufferedLessReceiveMoreSendTest) {
// Try to send more number of times
// than receivers
for (int i = 0; i < 4; i++) {
ch->Send(&i);
sum_send += i;
try {
ch->Send(&i);
sum_send += i;
} catch (paddle::platform::EnforceNotMet e) {
}
}
});
for (int i = 0; i < 3; i++) {
......@@ -382,7 +403,13 @@ void ChannelDestroyUnblockSenders(Channel<int> *ch, bool isBuffered) {
t[i] = std::thread(
[&](bool *ended, bool *success) {
int data = 10;
*success = ch->Send(&data);
bool is_exception = false;
try {
ch->Send(&data);
} catch (paddle::platform::EnforceNotMet e) {
is_exception = true;
}
*success = !is_exception;
*ended = true;
},
&thread_ended[i], &send_success[i]);
......@@ -508,7 +535,7 @@ void ChannelHolderSendReceive(ChannelHolder *ch) {
unsigned sum_send = 0;
std::thread t([&]() {
for (int i = 0; i < 5; i++) {
EXPECT_EQ(ch->Send(&i), true);
ch->Send(&i);
sum_send += i;
}
});
......@@ -541,8 +568,22 @@ TEST(ChannelHolder, ChannelUninitializedTest) {
ChannelHolder *ch = new ChannelHolder();
EXPECT_EQ(ch->IsInitialized(), false);
int i = 10;
EXPECT_EQ(ch->Send(&i), false);
EXPECT_EQ(ch->Receive(&i), false);
bool send_exception = false;
try {
ch->Send(&i);
} catch (paddle::platform::EnforceNotMet e) {
send_exception = true;
}
EXPECT_EQ(send_exception, true);
bool recv_exception = false;
try {
ch->Receive(&i);
} catch (paddle::platform::EnforceNotMet e) {
recv_exception = true;
}
EXPECT_EQ(recv_exception, true);
bool is_exception = false;
try {
ch->Type();
......@@ -669,7 +710,13 @@ void ChannelHolderCloseUnblocksSendersTest(ChannelHolder *ch, bool isBuffered) {
t[i] = std::thread(
[&](bool *ended, bool *success) {
int data = 10;
*success = ch->Send(&data);
bool is_exception = false;
try {
ch->Send(&data);
} catch (paddle::platform::EnforceNotMet e) {
is_exception = true;
}
*success = !is_exception;
*ended = true;
},
&thread_ended[i], &send_success[i]);
......@@ -760,7 +807,13 @@ void ChannelHolderDestroyUnblockSenders(ChannelHolder *ch, bool isBuffered) {
t[i] = std::thread(
[&](bool *ended, bool *success) {
int data = 10;
*success = ch->Send(&data);
bool is_exception = false;
try {
ch->Send(&data);
} catch (paddle::platform::EnforceNotMet e) {
is_exception = true;
}
*success = !is_exception;
*ended = true;
},
&thread_ended[i], &send_success[i]);
......
......@@ -23,21 +23,10 @@ limitations under the License. */
static constexpr char Channel[] = "Channel";
static constexpr char X[] = "X";
static constexpr char Status[] = "Status";
static constexpr char copy[] = "copy";
namespace paddle {
namespace operators {
void SetSendStatus(const platform::Place &dev_place,
framework::Variable &status_var, bool status) {
auto cpu = platform::CPUPlace();
auto status_tensor =
status_var.GetMutable<framework::LoDTensor>()->mutable_data<bool>({1},
cpu);
status_tensor[0] = status;
}
class ChannelSendOp : public framework::OperatorBase {
public:
ChannelSendOp(const std::string &type,
......@@ -51,9 +40,6 @@ class ChannelSendOp : public framework::OperatorBase {
"Input(Channel) of ChannelSendOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput(X),
"Input(X) of ChannelSendOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput(Status),
"Output(Status) of ChannelSendOp should not be null.");
ctx->SetOutputDim("Status", {1});
}
private:
......@@ -65,10 +51,7 @@ class ChannelSendOp : public framework::OperatorBase {
auto input_var = scope.FindVar(Input(X));
// Send the input data through the channel.
bool ok = concurrency::ChannelSend(ch, input_var);
// Set the status output of the `ChannelSend` call.
SetSendStatus(dev_place, *scope.FindVar(Output(Status)), ok);
concurrency::ChannelSend(ch, input_var);
}
};
......@@ -82,12 +65,6 @@ class ChannelSendOpMaker : public framework::OpProtoAndCheckerMaker {
.AsDuplicable();
AddInput(X, "(Variable) The value which gets sent by the channel.")
.AsDuplicable();
AddOutput(Status,
"(Tensor) An LoD Tensor that returns a boolean status of the"
"result of the send operation.")
.AsDuplicable();
AddAttr<bool>(copy, "(bool, default false) Should copy before send")
.SetDefault(false);
AddComment(R"DOC(
)DOC");
}
......
......@@ -17,20 +17,20 @@ limitations under the License. */
namespace poc = paddle::operators::concurrency;
bool poc::ChannelSend(framework::ChannelHolder *ch, framework::Variable *var) {
void poc::ChannelSend(framework::ChannelHolder *ch, framework::Variable *var) {
auto type = framework::ToVarType(var->Type());
if (type == framework::proto::VarType_Type_LOD_TENSOR)
return ch->Send(var->GetMutable<framework::LoDTensor>());
ch->Send(var->GetMutable<framework::LoDTensor>());
else if (type == framework::proto::VarType_Type_LOD_RANK_TABLE)
return ch->Send(var->GetMutable<framework::LoDRankTable>());
ch->Send(var->GetMutable<framework::LoDRankTable>());
else if (type == framework::proto::VarType_Type_LOD_TENSOR_ARRAY)
return ch->Send(var->GetMutable<framework::LoDTensorArray>());
ch->Send(var->GetMutable<framework::LoDTensorArray>());
else if (type == framework::proto::VarType_Type_SELECTED_ROWS)
return ch->Send(var->GetMutable<framework::SelectedRows>());
ch->Send(var->GetMutable<framework::SelectedRows>());
else if (type == framework::proto::VarType_Type_READER)
return ch->Send(var->GetMutable<framework::ReaderHolder>());
ch->Send(var->GetMutable<framework::ReaderHolder>());
else if (type == framework::proto::VarType_Type_CHANNEL)
return ch->Send(var->GetMutable<framework::ChannelHolder>());
ch->Send(var->GetMutable<framework::ChannelHolder>());
else
PADDLE_THROW("ChannelSend:Unsupported type");
}
......
......@@ -21,7 +21,7 @@ namespace paddle {
namespace operators {
namespace concurrency {
bool ChannelSend(framework::ChannelHolder *ch, framework::Variable *var);
void ChannelSend(framework::ChannelHolder *ch, framework::Variable *var);
bool ChannelReceive(framework::ChannelHolder *ch, framework::Variable *var);
void ChannelAddToSendQ(framework::ChannelHolder *ch, const void *referrer,
......
......@@ -166,7 +166,9 @@ void DoubleBufferReader::PrefetchThreadFunc() {
std::swap(gpu_batch, batch.payloads_);
}
if (!buffer_->Send(&batch)) {
try {
buffer_->Send(&batch);
} catch (paddle::platform::EnforceNotMet e) {
VLOG(5) << "WARNING: The double buffer channel has been closed. The "
"prefetch thread will terminate.";
break;
......
......@@ -146,14 +146,19 @@ void MultipleReader::PrefetchThreadFunc(std::string file_name,
while (reader->HasNext()) {
std::vector<framework::LoDTensor> ins;
reader->ReadNext(&ins);
if (!buffer_->Send(&ins)) {
try {
buffer_->Send(&ins);
} catch (paddle::platform::EnforceNotMet e) {
VLOG(5) << "WARNING: The buffer channel has been closed. The prefetch "
"thread of file '"
<< file_name << "' will terminate.";
break;
}
}
if (!available_thread_idx_->Send(&thread_idx)) {
try {
available_thread_idx_->Send(&thread_idx);
} catch (paddle::platform::EnforceNotMet e) {
VLOG(5) << "WARNING: The available_thread_idx_ channel has been closed. "
"Fail to send thread_idx.";
}
......
......@@ -339,11 +339,6 @@ def channel_send(channel, value, is_copy=False):
main_program = helper.main_program
channel_send_block = main_program.current_block()
status = helper.create_variable(
name=unique_name.generate('status'),
type=core.VarDesc.VarType.LOD_TENSOR,
dtype=core.VarDesc.VarType.BOOL)
X = value
if is_copy is True:
......@@ -359,15 +354,11 @@ def channel_send(channel, value, is_copy=False):
type="assign_op", inputs={"X": value}, outputs={"Out": copied_X})
X = copied_X
channel_send_op = channel_send_block.append_op(
type="channel_send",
inputs={
channel_send_block.append_op(
type="channel_send", inputs={
"Channel": channel,
"X": X,
},
outputs={"Status": status})
return status
})
def channel_recv(channel, return_value):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册