提交 847e4f4e 编写于 作者: Q Qiao Longfei

pure async mode train

上级 f768fbf7
...@@ -14,10 +14,31 @@ ...@@ -14,10 +14,31 @@
#include "paddle/fluid/framework/details/async_ssa_graph_executor.h" #include "paddle/fluid/framework/details/async_ssa_graph_executor.h"
#include "paddle/fluid/framework/variable_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
inline void NewTempScopeAndInitVars(const std::vector<VarInfo> &var_infos,
Scope *scope) {
Scope &local_scope = scope->NewScope();
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>() =
&local_scope;
for (auto &info : var_infos) {
if (scope->FindVar(info.name_) != nullptr) {
continue;
}
if (info.persistable_) { // Persistable
InitializeVariable(scope->Var(info.name_), info.type_);
} else {
InitializeVariable(local_scope.Var(info.name_), info.type_);
}
}
}
AsyncSSAGraphExecutor::AsyncSSAGraphExecutor( AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes, const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, std::vector<ir::Graph *> graphs) const std::vector<platform::Place> &places, std::vector<ir::Graph *> graphs)
...@@ -39,58 +60,81 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor( ...@@ -39,58 +60,81 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
executors_.emplace_back(new details::ThreadedSSAGraphExecutor( executors_.emplace_back(new details::ThreadedSSAGraphExecutor(
strategy_, {local_scopes_[i]}, {places_[i]}, graphs_[i])); strategy_, {local_scopes_[i]}, {places_[i]}, graphs_[i]));
} }
}
FeedFetchList AsyncSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) {
std::vector<std::future<FeedFetchList>> run_futures;
std::vector<FeedFetchList> fetch_data; for (auto &node : graphs_[0]->Nodes()) {
FeedFetchList ret; if (node->IsVar() && !node->IsCtrlVar() && node->Var()) {
var_infos_.emplace_back();
fetch_data.reserve(places_.size()); var_infos_.back().name_ = node->Var()->Name();
ret.reserve(fetch_tensors.size()); var_infos_.back().type_ = node->Var()->GetType();
exception_holder_.Clear(); var_infos_.back().persistable_ = node->Var()->Persistable();
}
}
for (auto *scope : local_scopes_) {
NewTempScopeAndInitVars(var_infos_, scope);
}
}
for (size_t i = 0; i < places_.size(); ++i) { void AsyncSSAGraphExecutor::StartOffPythonTrainLoop() {
auto call = [this, i, &fetch_tensors]() -> FeedFetchList { VLOG(3) << "StartOffPythonTrainLoop size = " << places_.size();
for (size_t i = 1; i < places_.size(); ++i) {
auto call = [this, i]() -> void {
VLOG(3) << "start off python thread " << i;
try { try {
return executors_[i]->Run(fetch_tensors); while (true) {
executors_[i]->Run({});
}
} catch (...) { } catch (...) {
exception_holder_.Catch(std::current_exception()); exception_holder_.Catch(std::current_exception());
VLOG(3) << "get exception type = " << exception_holder_.Type();
} }
return FeedFetchList(); VLOG(3) << "thread " << i << " exited!";
}; };
run_futures_.emplace_back(pool_->enqueue(std::move(call)));
if (pool_) {
run_futures.emplace_back(pool_->enqueue(std::move(call)));
} else {
fetch_data.emplace_back(std::move(call()));
}
} }
}
if (pool_) { void AsyncSSAGraphExecutor::HandleException() {
for (auto &f : run_futures) {
if (exception_holder_.IsCaught()) { if (exception_holder_.IsCaught()) {
for (auto &f : run_futures_) {
VLOG(3) << "wait future";
f.wait(); f.wait();
} else {
fetch_data.emplace_back(std::move(f.get()));
}
}
} }
if (exception_holder_.IsCaught()) {
VLOG(3) << "caught exception " << exception_holder_.Type() VLOG(3) << "caught exception " << exception_holder_.Type()
<< ", rethrow it"; << ", rethrow it";
run_futures_.clear();
exception_holder_.ReThrow(); exception_holder_.ReThrow();
} }
}
FeedFetchList AsyncSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) {
// init once
if (run_futures_.size() == 0 && places_.size() > 1) {
exception_holder_.Clear();
StartOffPythonTrainLoop();
}
if (places_.size() == 1) {
exception_holder_.Clear();
} else {
HandleException();
}
FeedFetchList fetch_data;
fetch_data.reserve(fetch_tensors.size());
try {
fetch_data = executors_[0]->Run(fetch_tensors);
} catch (...) {
exception_holder_.Catch(std::current_exception());
}
HandleException();
FeedFetchList ret;
for (size_t fetch_idx = 0; fetch_idx < fetch_tensors.size(); ++fetch_idx) { for (size_t fetch_idx = 0; fetch_idx < fetch_tensors.size(); ++fetch_idx) {
std::vector<const LoDTensor *> lodtensor_ptrs; std::vector<const LoDTensor *> lodtensor_ptrs;
lodtensor_ptrs.reserve(local_scopes_.size()); lodtensor_ptrs.push_back(&fetch_data.at(fetch_idx));
for (size_t scope_idx = 0; scope_idx < local_scopes_.size(); ++scope_idx) {
lodtensor_ptrs.push_back(&fetch_data.at(scope_idx).at(fetch_idx));
}
ret.emplace_back(); ret.emplace_back();
ret.back().MergeLoDTensor(lodtensor_ptrs, platform::CPUPlace()); ret.back().MergeLoDTensor(lodtensor_ptrs, platform::CPUPlace());
} }
......
...@@ -24,6 +24,12 @@ namespace paddle { ...@@ -24,6 +24,12 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
struct VarInfo {
std::string name_;
proto::VarType::Type type_;
bool persistable_;
};
class AsyncSSAGraphExecutor : public SSAGraphExecutor { class AsyncSSAGraphExecutor : public SSAGraphExecutor {
public: public:
AsyncSSAGraphExecutor(const ExecutionStrategy &strategy, AsyncSSAGraphExecutor(const ExecutionStrategy &strategy,
...@@ -35,6 +41,10 @@ class AsyncSSAGraphExecutor : public SSAGraphExecutor { ...@@ -35,6 +41,10 @@ class AsyncSSAGraphExecutor : public SSAGraphExecutor {
FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override; FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override;
private:
void StartOffPythonTrainLoop();
void HandleException();
private: private:
ExecutionStrategy strategy_; ExecutionStrategy strategy_;
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
...@@ -44,6 +54,8 @@ class AsyncSSAGraphExecutor : public SSAGraphExecutor { ...@@ -44,6 +54,8 @@ class AsyncSSAGraphExecutor : public SSAGraphExecutor {
std::vector<std::unique_ptr<details::ThreadedSSAGraphExecutor>> executors_; std::vector<std::unique_ptr<details::ThreadedSSAGraphExecutor>> executors_;
ExceptionHolder exception_holder_; ExceptionHolder exception_holder_;
std::vector<std::future<void>> run_futures_;
std::vector<VarInfo> var_infos_;
}; };
} // namespace details } // namespace details
......
...@@ -119,6 +119,8 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl( ...@@ -119,6 +119,8 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl(
if (timeout) { if (timeout) {
if (exception_holder_.IsCaught()) { if (exception_holder_.IsCaught()) {
VLOG(3) << "caught exception " << exception_holder_.Type()
<< ", rethrow it";
for (auto &run_op_future : run_op_futures_) { for (auto &run_op_future : run_op_futures_) {
run_op_future.wait(); run_op_future.wait();
} }
......
...@@ -379,9 +379,11 @@ ParallelExecutor::ParallelExecutor( ...@@ -379,9 +379,11 @@ ParallelExecutor::ParallelExecutor(
} }
VLOG(3) << "use ScopeBufferedSSAGraphExecutor"; VLOG(3) << "use ScopeBufferedSSAGraphExecutor";
if (!build_strategy.async_mode_) {
member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor( member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, std::move(var_infos), exec_strategy, member_->local_scopes_, std::move(var_infos),
member_->places_, std::move(member_->executor_))); member_->places_, std::move(member_->executor_)));
}
} }
void ParallelExecutor::BCastParamsToDevices( void ParallelExecutor::BCastParamsToDevices(
......
...@@ -69,6 +69,9 @@ void ReaderBase::Start() { ...@@ -69,6 +69,9 @@ void ReaderBase::Start() {
ReaderBase::~ReaderBase() {} ReaderBase::~ReaderBase() {}
DecoratedReader::~DecoratedReader() { reader_->Shutdown(); } DecoratedReader::~DecoratedReader() {
VLOG(1) << "~DecoratedReader";
reader_->Shutdown();
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -77,7 +77,10 @@ class DecoratedReader : public ReaderBase, ...@@ -77,7 +77,10 @@ class DecoratedReader : public ReaderBase,
~DecoratedReader(); ~DecoratedReader();
protected: protected:
void ShutdownImpl() override { reader_->Shutdown(); } void ShutdownImpl() override {
VLOG(1) << "ShutdownImpl";
reader_->Shutdown();
}
void StartImpl() override { reader_->Start(); } void StartImpl() override { reader_->Start(); }
...@@ -98,6 +101,8 @@ class ReaderHolder { ...@@ -98,6 +101,8 @@ class ReaderHolder {
reader_ = reader_base; reader_ = reader_base;
} }
~ReaderHolder() { VLOG(1) << "~ReaderHolder"; }
const std::shared_ptr<ReaderBase>& Get() const { return reader_; } const std::shared_ptr<ReaderBase>& Get() const { return reader_; }
void ReadNext(std::vector<LoDTensor>* out) { void ReadNext(std::vector<LoDTensor>* out) {
...@@ -106,6 +111,7 @@ class ReaderHolder { ...@@ -106,6 +111,7 @@ class ReaderHolder {
} }
void ResetAll() { void ResetAll() {
VLOG(1) << "ResetAll";
auto end_readers = reader_->GetEndPoints(); auto end_readers = reader_->GetEndPoints();
for (auto* reader : end_readers) { for (auto* reader : end_readers) {
reader->Shutdown(); reader->Shutdown();
...@@ -116,11 +122,13 @@ class ReaderHolder { ...@@ -116,11 +122,13 @@ class ReaderHolder {
} }
void Shutdown() { void Shutdown() {
VLOG(1) << "Shutdown";
PADDLE_ENFORCE_NOT_NULL(reader_); PADDLE_ENFORCE_NOT_NULL(reader_);
reader_->Shutdown(); reader_->Shutdown();
} }
void Start() { void Start() {
VLOG(1) << "start";
PADDLE_ENFORCE_NOT_NULL(reader_); PADDLE_ENFORCE_NOT_NULL(reader_);
reader_->Start(); reader_->Start();
} }
......
...@@ -86,6 +86,7 @@ class BlockingQueue { ...@@ -86,6 +86,7 @@ class BlockingQueue {
void ReOpen() { void ReOpen() {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
VLOG(1) << "reopen queue";
closed_ = false; closed_ = false;
std::deque<T> new_deque; std::deque<T> new_deque;
queue_.swap(new_deque); queue_.swap(new_deque);
...@@ -95,7 +96,7 @@ class BlockingQueue { ...@@ -95,7 +96,7 @@ class BlockingQueue {
void Close() { void Close() {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
VLOG(3) << "close queue"; VLOG(1) << "close queue";
closed_ = true; closed_ = true;
send_cv_.notify_all(); send_cv_.notify_all();
receive_cv_.notify_all(); receive_cv_.notify_all();
......
...@@ -20,6 +20,7 @@ namespace paddle { ...@@ -20,6 +20,7 @@ namespace paddle {
namespace operators { namespace operators {
namespace reader { namespace reader {
BufferedReader::~BufferedReader() { BufferedReader::~BufferedReader() {
VLOG(1) << "~BufferedReader";
reader_->Shutdown(); reader_->Shutdown();
while (!position_.empty()) { while (!position_.empty()) {
position_.front().wait(); position_.front().wait();
...@@ -41,6 +42,7 @@ BufferedReader::BufferedReader( ...@@ -41,6 +42,7 @@ BufferedReader::BufferedReader(
thread_pool_(1), thread_pool_(1),
place_(place), place_(place),
buffer_size_(buffer_size) { buffer_size_(buffer_size) {
VLOG(1) << "BufferedReader";
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place_)) { if (platform::is_gpu_place(place_)) {
platform::SetDeviceId(boost::get<platform::CUDAPlace>(place_).device); platform::SetDeviceId(boost::get<platform::CUDAPlace>(place_).device);
...@@ -121,6 +123,7 @@ void BufferedReader::ReadAsync(size_t i) { ...@@ -121,6 +123,7 @@ void BufferedReader::ReadAsync(size_t i) {
} }
void BufferedReader::ShutdownImpl() { void BufferedReader::ShutdownImpl() {
VLOG(1) << "ShutdownImpl";
reader_->Shutdown(); reader_->Shutdown();
while (!position_.empty()) { while (!position_.empty()) {
position_.pop(); position_.pop();
......
...@@ -33,10 +33,13 @@ class PyReader : public framework::FileReader { ...@@ -33,10 +33,13 @@ class PyReader : public framework::FileReader {
if (!success) out->clear(); if (!success) out->clear();
} }
~PyReader() { queue_->Close(); } ~PyReader() {
VLOG(1) << "~PyReader";
queue_->Close();
}
void Shutdown() override { void Shutdown() override {
VLOG(3) << "PyReader shutdown!"; VLOG(1) << "PyReader shutdown!";
queue_->Close(); queue_->Close();
} }
......
...@@ -57,7 +57,10 @@ class LoDTensorBlockingQueue { ...@@ -57,7 +57,10 @@ class LoDTensorBlockingQueue {
inline void ReOpen() { queue_.ReOpen(); } inline void ReOpen() { queue_.ReOpen(); }
inline void Close() { queue_.Close(); } inline void Close() {
VLOG(1) << "LoDTensorBlockingQueue close";
queue_.Close();
}
inline bool IsClosed() const { return queue_.IsClosed(); } inline bool IsClosed() const { return queue_.IsClosed(); }
......
...@@ -557,6 +557,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -557,6 +557,7 @@ All parameter, weight, gradient are variables in Paddle.
m.def("init_lod_tensor_blocking_queue", m.def("init_lod_tensor_blocking_queue",
[](Variable &var, [](Variable &var,
size_t capacity) -> std::shared_ptr<LoDTensorBlockingQueue> { size_t capacity) -> std::shared_ptr<LoDTensorBlockingQueue> {
VLOG(1) << "init_lod_tensor_blocking_queue";
auto *holder = var.GetMutable<LoDTensorBlockingQueueHolder>(); auto *holder = var.GetMutable<LoDTensorBlockingQueueHolder>();
holder->InitOnce(capacity, FLAGS_reader_queue_speed_test_mode); holder->InitOnce(capacity, FLAGS_reader_queue_speed_test_mode);
return holder->GetQueue(); return holder->GetQueue();
......
...@@ -36,7 +36,7 @@ def convolutional_neural_network(use_py_reader): ...@@ -36,7 +36,7 @@ def convolutional_neural_network(use_py_reader):
capacity=64, capacity=64,
feed_list=[img, label], feed_list=[img, label],
name='py_reader', name='py_reader',
use_double_buffer=True) use_double_buffer=False)
img, label = fluid.layers.read_file(py_reader) img, label = fluid.layers.read_file(py_reader)
conv_pool_1 = fluid.nets.simple_img_conv_pool( conv_pool_1 = fluid.nets.simple_img_conv_pool(
...@@ -139,19 +139,20 @@ def train(use_cuda, thread_num, cpu_num): ...@@ -139,19 +139,20 @@ def train(use_cuda, thread_num, cpu_num):
exec_strategy=exec_strategy) exec_strategy=exec_strategy)
py_reader.decorate_paddle_reader(train_reader) py_reader.decorate_paddle_reader(train_reader)
py_reader.start()
for pass_id in range(2):
step = 0 step = 0
py_reader.start()
try: try:
while True: while True:
loss_val = pe.run(fetch_list=[avg_loss.name]) loss_val = pe.run(fetch_list=[avg_loss.name])
loss_val = numpy.mean(loss_val) loss_val = numpy.mean(loss_val)
if step % 100 == 0: if step % 10 == 0:
print("Batch %d, Cost %f, queue size %d" % print("Pass %d, Batch %d, Cost %f, queue size %d" %
(step, loss_val, py_reader.queue.size())) (pass_id, step, loss_val, py_reader.queue.size()))
step += 1 step += 1
except fluid.core.EOFException: except fluid.core.EOFException:
print("train end") print("train end pass = " + str(pass_id))
py_reader.reset() py_reader.reset()
return step return step
...@@ -161,10 +162,11 @@ class TestAsyncSSAGraphExecutor(unittest.TestCase): ...@@ -161,10 +162,11 @@ class TestAsyncSSAGraphExecutor(unittest.TestCase):
def test_check_async_ssa_exe_train(self): def test_check_async_ssa_exe_train(self):
step_list = [] step_list = []
for cpu_num in [1, 2, 4]: for cpu_num in [1, 2, 4]:
scope = fluid.core.Scope() print("run cpu_num -> " + str(cpu_num))
with fluid.scope_guard(scope): with fluid.scope_guard(fluid.core.Scope()):
with fluid.program_guard( with fluid.program_guard(
fluid.Program(), startup_program=fluid.Program()): main_program=fluid.Program(),
startup_program=fluid.Program()):
start_time = time.time() start_time = time.time()
step = train( step = train(
use_cuda=False, thread_num=cpu_num, cpu_num=cpu_num) use_cuda=False, thread_num=cpu_num, cpu_num=cpu_num)
...@@ -173,7 +175,8 @@ class TestAsyncSSAGraphExecutor(unittest.TestCase): ...@@ -173,7 +175,8 @@ class TestAsyncSSAGraphExecutor(unittest.TestCase):
print("cpu_num -> " + str(cpu_num) + " step -> " + str(step) + print("cpu_num -> " + str(cpu_num) + " step -> " + str(step) +
" time -> " + str(end_time - start_time)) " time -> " + str(end_time - start_time))
with fluid.program_guard( with fluid.program_guard(
fluid.Program(), startup_program=fluid.Program()): main_program=fluid.Program(),
startup_program=fluid.Program()):
test() test()
assert int(step_list[0] / 2) == int(step_list[1]) assert int(step_list[0] / 2) == int(step_list[1])
assert int(step_list[1] / 2) == int(step_list[2]) assert int(step_list[1] / 2) == int(step_list[2])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册