workqueue.cc 5.9 KB
Newer Older
1 2 3 4 5 6 7 8
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.

#include "paddle/fluid/framework/new_executor/workqueue.h"
#include "paddle/fluid/framework/new_executor/nonblocking_threadpool.h"
L
liutiexing 已提交
9 10
#include "paddle/fluid/framework/new_executor/workqueue_utils.h"
#include "paddle/fluid/platform/enforce.h"
11 12 13

namespace paddle {
namespace framework {
L
liutiexing 已提交
14
namespace {
15

L
liutiexing 已提交
16 17
using TaskTracker = TaskTracker<EventsWaiter::EventNotifier>;

L
liutiexing 已提交
18
class WorkQueueImpl : public WorkQueue {
19
 public:
L
liutiexing 已提交
20 21
  explicit WorkQueueImpl(const WorkQueueOptions& options) : WorkQueue(options) {
    if (options_.track_task && options.queue_empty_waiter != nullptr) {
L
liutiexing 已提交
22
      void* storage = AlignedMalloc(sizeof(TaskTracker), alignof(TaskTracker));
L
liutiexing 已提交
23 24 25 26 27
      TaskTracker* tracker = reinterpret_cast<TaskTracker*>(storage);
      auto notifier = options.queue_empty_waiter->RegisterEvent(
          kQueueEmptyEvent,
          [tracker]() { return tracker->PendingTaskNum() == 0; });
      tracker_ = new (storage) TaskTracker(*notifier.get());
L
liutiexing 已提交
28 29 30 31
    }
    queue_ = new NonblockingThreadPool(options_.num_threads,
                                       options_.allow_spinning);
  }
32

L
liutiexing 已提交
33
  virtual ~WorkQueueImpl() {
L
liutiexing 已提交
34 35 36 37
    if (tracker_ != nullptr) {
      tracker_->~TaskTracker();
      AlignedFree(tracker_);
    }
L
liutiexing 已提交
38 39
    delete queue_;
  }
40 41

  void AddTask(std::function<void()> fn) override {
L
liutiexing 已提交
42 43 44 45 46 47 48 49
    if (tracker_ != nullptr) {
      fn = [
        task = std::move(fn), raii = CounterGuard<TaskTracker>(tracker_)
      ]() mutable {
        task();
      };
    }
    queue_->AddTask(std::move(fn));
50 51
  }

52 53 54 55 56
  void Cancel() override {
    queue_->Cancel();
    queue_->WaitThreadsExit();
  }

L
liutiexing 已提交
57
  size_t NumThreads() const override { return queue_->NumThreads(); }
58 59

 private:
L
liutiexing 已提交
60 61
  NonblockingThreadPool* queue_{nullptr};
  TaskTracker* tracker_{nullptr};
62 63
};

L
liutiexing 已提交
64 65 66 67 68 69 70 71 72 73 74 75 76
class WorkQueueGroupImpl : public WorkQueueGroup {
 public:
  explicit WorkQueueGroupImpl(
      const std::vector<WorkQueueOptions>& queue_options);

  ~WorkQueueGroupImpl();

  void AddTask(size_t queue_idx, std::function<void()> fn) override;

  size_t QueueNumThreads(size_t queue_idx) const override;

  size_t QueueGroupNumThreads() const override;

77 78
  void Cancel() override;

L
liutiexing 已提交
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
 private:
  std::vector<NonblockingThreadPool*> queues_;
  NonblockingThreadPool* queues_storage_;
  TaskTracker* tracker_;
};

WorkQueueGroupImpl::WorkQueueGroupImpl(
    const std::vector<WorkQueueOptions>& queues_options)
    : WorkQueueGroup(queues_options),
      queues_storage_(nullptr),
      tracker_(nullptr) {
  size_t num_queues = queues_options_.size();
  queues_.resize(num_queues);
  void* buffer = malloc(sizeof(NonblockingThreadPool) * num_queues);
  queues_storage_ = reinterpret_cast<NonblockingThreadPool*>(buffer);
  for (size_t idx = 0; idx < num_queues; ++idx) {
    const auto& options = queues_options_[idx];
L
liutiexing 已提交
96 97
    if (options.track_task && tracker_ == nullptr &&
        options.queue_empty_waiter != nullptr) {
L
liutiexing 已提交
98
      void* storage = AlignedMalloc(sizeof(TaskTracker), alignof(TaskTracker));
L
liutiexing 已提交
99 100 101 102 103
      TaskTracker* tracker = reinterpret_cast<TaskTracker*>(storage);
      auto notifier = options.queue_empty_waiter->RegisterEvent(
          kQueueEmptyEvent,
          [tracker]() { return tracker->PendingTaskNum() == 0; });
      tracker_ = new (storage) TaskTracker(*notifier.get());
L
liutiexing 已提交
104 105 106 107
    }
    queues_[idx] = new (&queues_storage_[idx])
        NonblockingThreadPool(options.num_threads, options.allow_spinning);
  }
108 109
}

L
liutiexing 已提交
110 111 112
WorkQueueGroupImpl::~WorkQueueGroupImpl() {
  for (auto queue : queues_) {
    queue->~NonblockingThreadPool();
113
  }
L
liutiexing 已提交
114 115 116 117
  if (tracker_ != nullptr) {
    tracker_->~TaskTracker();
    AlignedFree(tracker_);
  }
L
liutiexing 已提交
118 119
  free(queues_storage_);
}
120

L
liutiexing 已提交
121 122 123 124 125 126 127 128 129 130 131
void WorkQueueGroupImpl::AddTask(size_t queue_idx, std::function<void()> fn) {
  assert(queue_idx < queues_.size());
  if (queues_options_.at(queue_idx).track_task) {
    fn = [
      task = std::move(fn), raii = CounterGuard<TaskTracker>(tracker_)
    ]() mutable {
      task();
    };
  }
  queues_[queue_idx]->AddTask(std::move(fn));
}
132

L
liutiexing 已提交
133 134 135 136
size_t WorkQueueGroupImpl::QueueNumThreads(size_t queue_idx) const {
  assert(queue_idx < queues_.size());
  return queues_.at(queue_idx)->NumThreads();
}
137

L
liutiexing 已提交
138 139 140 141
size_t WorkQueueGroupImpl::QueueGroupNumThreads() const {
  size_t total_num = 0;
  for (auto queue : queues_) {
    total_num += queue->NumThreads();
142
  }
L
liutiexing 已提交
143 144
  return total_num;
}
145

146 147 148 149 150 151 152 153 154
void WorkQueueGroupImpl::Cancel() {
  for (auto queue : queues_) {
    queue->Cancel();
  }
  for (auto queue : queues_) {
    queue->WaitThreadsExit();
  }
}

L
liutiexing 已提交
155
}  // namespace
156

L
liutiexing 已提交
157 158 159 160 161 162 163
std::unique_ptr<WorkQueue> CreateSingleThreadedWorkQueue(
    const WorkQueueOptions& options) {
  PADDLE_ENFORCE_EQ(options.num_threads, 1u,
                    platform::errors::InvalidArgument(
                        "For a SingleThreadedWorkQueue, "
                        "WorkQueueOptions.num_threads must equals to 1."));
  std::unique_ptr<WorkQueue> ptr(new WorkQueueImpl(options));
T
Tomasz Socha 已提交
164
  return ptr;
L
liutiexing 已提交
165
}
166

L
liutiexing 已提交
167 168 169 170 171 172 173 174
std::unique_ptr<WorkQueue> CreateMultiThreadedWorkQueue(
    const WorkQueueOptions& options) {
  PADDLE_ENFORCE_GT(
      options.num_threads, 1u,
      platform::errors::InvalidArgument("For a MultiThreadedWorkQueue, "
                                        "WorkQueueOptions.num_threads must be "
                                        "greater than 1."));
  std::unique_ptr<WorkQueue> ptr(new WorkQueueImpl(options));
L
liutiexing 已提交
175
  return std::move(ptr);
L
liutiexing 已提交
176
}
177

L
liutiexing 已提交
178 179 180 181 182 183 184
std::unique_ptr<WorkQueueGroup> CreateWorkQueueGroup(
    const std::vector<WorkQueueOptions>& queues_options) {
  PADDLE_ENFORCE_GT(queues_options.size(), 1u,
                    platform::errors::InvalidArgument(
                        "For a WorkQueueGroup, the number of WorkQueueOptions "
                        "must be greater than 1."));
  std::unique_ptr<WorkQueueGroup> ptr(new WorkQueueGroupImpl(queues_options));
L
liutiexing 已提交
185
  return std::move(ptr);
186 187
}

188 189
}  // namespace framework
}  // namespace paddle