/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include #include #include "paddle/fluid/framework/channel.h" #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/concurrency/channel_util.h" namespace paddle { namespace operators { static constexpr char kX[] = "X"; static constexpr char kCaseToExecute[] = "case_to_execute"; static constexpr char kCases[] = "cases"; static constexpr char kCasesBlock[] = "sub_block"; class SelectOp : public framework::OperatorBase { public: SelectOp(const std::string &type, const framework::VariableNameMap &inputs, const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : framework::OperatorBase(type, inputs, outputs, attrs) {} private: enum class SelectOpCaseType { DEFAULT = 0, SEND = 1, RECEIVE = 2, }; struct SelectOpCase { int caseIndex; SelectOpCaseType caseType; std::string channelName; std::string varName; SelectOpCase() {} SelectOpCase(int caseIndex, SelectOpCaseType caseType, std::string channelName, std::string varName) : caseIndex(caseIndex), caseType(caseType), channelName(channelName), varName(varName) {} }; void RunImpl(const framework::Scope &scope, const platform::Place &dev_place) const override { std::vector casesConfigs = Attr>(kCases); framework::BlockDesc *casesBlock = Attr(kCasesBlock); framework::Scope &casesBlockScope = scope.NewScope(); std::string caseToExecuteVarName = Input(kCaseToExecute); framework::Variable *caseToExecuteVar = casesBlockScope.FindVar(caseToExecuteVarName); // Construct cases from "conditional_block_op"(s) in the casesBlock std::vector> cases = ParseAndShuffleCases(&casesConfigs); // Get all unique channels involved in select std::set channelsSet; for (auto c : cases) { if (!c->channelName.empty()) { auto channelVar = scope.FindVar(c->channelName); framework::ChannelHolder *ch = channelVar->GetMutable(); if (channelsSet.find(ch) == channelsSet.end()) { channelsSet.insert(ch); } } } // Order all channels by their pointer address std::vector channels(channelsSet.begin(), channelsSet.end()); std::sort(channels.begin(), channels.end()); // Poll all cases int32_t caseToExecute = pollCases(&scope, &cases, channels); // At this point, the case to execute has already been determined, // so we can proceed with executing the cases block framework::LoDTensor *caseToExecuteTensor = caseToExecuteVar->GetMutable(); caseToExecuteTensor->data()[0] = caseToExecute; // Execute the cases block, only one case will be executed since we set the // case_to_execute value to the index of the case we want to execute framework::Executor executor(dev_place); framework::ProgramDesc *program = casesBlock->Program(); executor.Run(*program, &casesBlockScope, casesBlock->ID(), false /*create_local_scope*/); } /** * Goes through all operators in the casesConfigs and processes * "conditional_block" operators. These operators are mapped to our * SelectOpCase objects. We randomize the case orders, and set the * default case (if any exists) as the last case) * @param casesBlock * @return */ std::vector> ParseAndShuffleCases( std::vector *casesConfigs) const { std::vector> cases; std::shared_ptr defaultCase; if (casesConfigs != nullptr) { boost::char_delimiters_separator sep(false, ",", ""); for (std::vector::iterator itr = casesConfigs->begin(); itr < casesConfigs->end(); ++itr) { std::string caseConfig = *itr; boost::tokenizer<> tokens(caseConfig, sep); boost::tokenizer<>::iterator tok_iter = tokens.begin(); PADDLE_ENFORCE(tok_iter != tokens.end(), "Cannot get case index"); std::string caseIndexString = *tok_iter; int caseIndex = std::stoi(caseIndexString); ++tok_iter; PADDLE_ENFORCE(tok_iter != tokens.end(), "Cannot get case type"); std::string caseTypeString = *tok_iter; SelectOpCaseType caseType = (SelectOpCaseType)std::stoi(caseTypeString); std::string caseChannel; std::string caseChannelVar; ++tok_iter; if (caseType != SelectOpCaseType::DEFAULT) { PADDLE_ENFORCE(tok_iter != tokens.end(), "Cannot get case channel"); caseChannel = *tok_iter; ++tok_iter; PADDLE_ENFORCE(tok_iter != tokens.end(), "Cannot get case channel variable"); caseChannelVar = *tok_iter; } auto c = std::make_shared(caseIndex, caseType, caseChannel, caseChannelVar); if (caseType == SelectOpCaseType::DEFAULT) { PADDLE_ENFORCE(defaultCase == nullptr, "Select can only contain one default case."); defaultCase = c; } else { cases.push_back(c); } } } // Randomly sort cases, with default case being last std::random_shuffle(cases.begin(), cases.end()); if (defaultCase != nullptr) { cases.push_back(defaultCase); } return cases; } /** * This method will recursively poll the cases and determines if any case * condition is true. * If none of the cases conditions are true (and there is no default case), * then block * the thread. The thread may be woken up by a channel operation, at which * point we * execute the case. * @param scope * @param cases * @param channels * @return */ int32_t pollCases(const framework::Scope *scope, std::vector> *cases, std::vector channels) const { // Lock all involved channels lockChannels(channels); std::atomic caseToExecute(-1); std::vector>::iterator it = cases->begin(); while (it != cases->end()) { std::shared_ptr c = *it; auto chVar = scope->FindVar(c->channelName); framework::ChannelHolder *ch = chVar->GetMutable(); switch (c->caseType) { case SelectOpCaseType::SEND: PADDLE_ENFORCE(!ch->IsClosed(), "Cannot send to a closed channel"); if (ch->CanSend()) { // We can send to channel directly, send the data to channel // and execute case auto chVar = scope->FindVar(c->varName); concurrency::ChannelSend(ch, chVar); caseToExecute = c->caseIndex; } break; case SelectOpCaseType::RECEIVE: if (ch->CanReceive()) { // We can receive from channel directly, send the data to channel // and execute case auto chVar = scope->FindVar(c->varName); concurrency::ChannelReceive(ch, chVar); caseToExecute = c->caseIndex; } break; case SelectOpCaseType::DEFAULT: caseToExecute = c->caseIndex; break; } if (caseToExecute != -1) { // We found a case to execute, stop looking at other case statements break; } ++it; } if (caseToExecute == -1) { // None of the cases are eligible to execute, enqueue current thread // into all the sending/receiving queue of each involved channel std::atomic completed(false); std::recursive_mutex mutex; std::unique_lock lock{mutex}; // std::condition_variable_any selectCond; auto selectCond = std::make_shared(); std::recursive_mutex callbackMutex; pushThreadOnChannelQueues(scope, cases, selectCond, caseToExecute, completed, callbackMutex); // TODO(thuan): Atomically unlock all channels and sleep current thread unlockChannels(channels); selectCond->wait(lock, [&completed]() { return completed.load(); }); // Select has been woken up by case operation lockChannels(channels); removeThreadOnChannelQueues(scope, cases); if (caseToExecute == -1) { // Recursively poll cases, since we were woken up by a channel close // TODO(thuan): Need to test if this is a valid case unlockChannels(channels); return pollCases(scope, cases, channels); } } // At this point, caseToExecute != -1, and we can proceed with executing // the case block unlockChannels(channels); return caseToExecute; } void lockChannels(std::vector chs) const { std::vector::iterator it = chs.begin(); while (it != chs.end()) { framework::ChannelHolder *ch = *it; ch->Lock(); ++it; } } void unlockChannels(std::vector chs) const { std::vector::reverse_iterator it = chs.rbegin(); while (it != chs.rend()) { framework::ChannelHolder *ch = *it; ch->Unlock(); ++it; } } void pushThreadOnChannelQueues( const framework::Scope *scope, std::vector> *cases, std::shared_ptr rCond, std::atomic &caseToExecute, std::atomic &completed, std::recursive_mutex &callbackMutex) const { std::vector>::iterator it = cases->begin(); while (it != cases->end()) { std::shared_ptr c = *it; auto chVar = scope->FindVar(c->channelName); framework::ChannelHolder *ch = chVar->GetMutable(); std::function cb = [&caseToExecute, &completed, &callbackMutex, c](framework::ChannelAction channelAction) { std::lock_guard lock{callbackMutex}; bool canProcess = false; if (!completed) { // If the channel wasn't closed, we set the caseToExecute index // as this current case if (channelAction != framework::ChannelAction::CLOSE) { caseToExecute = c->caseIndex; } // This will allow our conditional variable to break out of wait completed = true; canProcess = true; } return canProcess; }; switch (c->caseType) { case SelectOpCaseType::SEND: { auto chOutputVar = scope->FindVar(c->varName); concurrency::ChannelAddToSendQ(ch, this, chOutputVar, rCond, cb); break; } case SelectOpCaseType::RECEIVE: { auto chOutputVar = scope->FindVar(c->varName); concurrency::ChannelAddToReceiveQ(ch, this, chOutputVar, rCond, cb); break; } default: break; } ++it; } } void removeThreadOnChannelQueues( const framework::Scope *scope, std::vector> *cases) const { std::vector>::iterator it = cases->begin(); while (it != cases->end()) { std::shared_ptr c = *it; auto chVar = scope->FindVar(c->channelName); framework::ChannelHolder *ch = chVar->GetMutable(); switch (c->caseType) { case SelectOpCaseType::SEND: { ch->RemoveFromSendQ(this); break; } case SelectOpCaseType::RECEIVE: { ch->RemoveFromReceiveQ(this); break; } default: break; } ++it; } } }; class SelectOpMaker : public framework::OpProtoAndCheckerMaker { public: SelectOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput(kX, "A set of variables, which are required by operators inside the " "cases of Select Op") .AsDuplicable(); AddInput(kCaseToExecute, "(Int) The variable the sets the index of the case to execute, " "after evaluating the channels being sent to and received from") .AsDuplicable(); AddAttr>(kCases, "(String vector) Serialized list of" "all cases in the select op. Each" "case is serialized as: " "',,,'" "where type is 0 for default, 1 for" "send, and 2 for receive" "No channel and values are needed for" "default cases."); AddAttr(kCasesBlock, "The cases block inside select_op"); AddComment(R"DOC( )DOC"); } }; // TODO(thuan): Implement Gradient Operator for SELECT_OP } // namespace operators } // namespace paddle REGISTER_OPERATOR(select, paddle::operators::SelectOp, paddle::framework::EmptyGradOpMaker, paddle::operators::SelectOpMaker);