diff --git a/paddle/fluid/operators/select_op.cc b/paddle/fluid/operators/select_op.cc index c0bf0ff927481bc4da9cd6c4bb9b0c4a6841c891..876d8acf0d880a7ef806514014d297f98e04c53d 100644 --- a/paddle/fluid/operators/select_op.cc +++ b/paddle/fluid/operators/select_op.cc @@ -12,9 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include #include -#include +#include // NOLINT #include #include "paddle/fluid/framework/channel.h" #include "paddle/fluid/framework/executor.h" @@ -22,6 +21,8 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/concurrency/channel_util.h" +#include + namespace paddle { namespace operators { @@ -254,8 +255,8 @@ class SelectOp : public framework::OperatorBase { auto selectCond = std::make_shared(); std::recursive_mutex callbackMutex; - pushThreadOnChannelQueues(scope, cases, selectCond, caseToExecute, - completed, callbackMutex); + pushThreadOnChannelQueues(scope, cases, selectCond, &caseToExecute, + &completed, &callbackMutex); // TODO(thuan): Atomically unlock all channels and sleep current thread unlockChannels(channels); @@ -302,8 +303,8 @@ class SelectOp : public framework::OperatorBase { const framework::Scope *scope, std::vector> *cases, std::shared_ptr rCond, - std::atomic &caseToExecute, std::atomic &completed, - std::recursive_mutex &callbackMutex) const { + std::atomic *caseToExecute, std::atomic *completed, + std::recursive_mutex *callbackMutex) const { std::vector>::iterator it = cases->begin(); while (it != cases->end()) { std::shared_ptr c = *it; @@ -315,17 +316,17 @@ class SelectOp : public framework::OperatorBase { std::function cb = [&caseToExecute, &completed, &callbackMutex, c](framework::ChannelAction channelAction) { - std::lock_guard lock{callbackMutex}; + std::lock_guard lock{*callbackMutex}; bool canProcess = false; - if (!completed) { + if (!(*completed)) { // If the channel wasn't closed, we set the caseToExecute index // as this current case if (channelAction != framework::ChannelAction::CLOSE) { - caseToExecute = c->caseIndex; + *caseToExecute = c->caseIndex; } // This will allow our conditional variable to break out of wait - completed = true; + *completed = true; canProcess = true; }