提交 847e4f4e 编写于 作者: Q Qiao Longfei

pure async mode train

上级 f768fbf7
......@@ -14,10 +14,31 @@
#include "paddle/fluid/framework/details/async_ssa_graph_executor.h"
#include "paddle/fluid/framework/variable_helper.h"
namespace paddle {
namespace framework {
namespace details {
inline void NewTempScopeAndInitVars(const std::vector<VarInfo> &var_infos,
Scope *scope) {
Scope &local_scope = scope->NewScope();
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>() =
&local_scope;
for (auto &info : var_infos) {
if (scope->FindVar(info.name_) != nullptr) {
continue;
}
if (info.persistable_) { // Persistable
InitializeVariable(scope->Var(info.name_), info.type_);
} else {
InitializeVariable(local_scope.Var(info.name_), info.type_);
}
}
}
AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, std::vector<ir::Graph *> graphs)
......@@ -39,58 +60,81 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
executors_.emplace_back(new details::ThreadedSSAGraphExecutor(
strategy_, {local_scopes_[i]}, {places_[i]}, graphs_[i]));
}
}
FeedFetchList AsyncSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) {
std::vector<std::future<FeedFetchList>> run_futures;
std::vector<FeedFetchList> fetch_data;
FeedFetchList ret;
fetch_data.reserve(places_.size());
ret.reserve(fetch_tensors.size());
exception_holder_.Clear();
for (auto &node : graphs_[0]->Nodes()) {
if (node->IsVar() && !node->IsCtrlVar() && node->Var()) {
var_infos_.emplace_back();
var_infos_.back().name_ = node->Var()->Name();
var_infos_.back().type_ = node->Var()->GetType();
var_infos_.back().persistable_ = node->Var()->Persistable();
}
}
for (auto *scope : local_scopes_) {
NewTempScopeAndInitVars(var_infos_, scope);
}
}
for (size_t i = 0; i < places_.size(); ++i) {
auto call = [this, i, &fetch_tensors]() -> FeedFetchList {
void AsyncSSAGraphExecutor::StartOffPythonTrainLoop() {
VLOG(3) << "StartOffPythonTrainLoop size = " << places_.size();
for (size_t i = 1; i < places_.size(); ++i) {
auto call = [this, i]() -> void {
VLOG(3) << "start off python thread " << i;
try {
return executors_[i]->Run(fetch_tensors);
while (true) {
executors_[i]->Run({});
}
} catch (...) {
exception_holder_.Catch(std::current_exception());
VLOG(3) << "get exception type = " << exception_holder_.Type();
}
return FeedFetchList();
VLOG(3) << "thread " << i << " exited!";
};
if (pool_) {
run_futures.emplace_back(pool_->enqueue(std::move(call)));
} else {
fetch_data.emplace_back(std::move(call()));
}
run_futures_.emplace_back(pool_->enqueue(std::move(call)));
}
}
if (pool_) {
for (auto &f : run_futures) {
void AsyncSSAGraphExecutor::HandleException() {
if (exception_holder_.IsCaught()) {
for (auto &f : run_futures_) {
VLOG(3) << "wait future";
f.wait();
} else {
fetch_data.emplace_back(std::move(f.get()));
}
}
}
if (exception_holder_.IsCaught()) {
VLOG(3) << "caught exception " << exception_holder_.Type()
<< ", rethrow it";
run_futures_.clear();
exception_holder_.ReThrow();
}
}
FeedFetchList AsyncSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) {
// init once
if (run_futures_.size() == 0 && places_.size() > 1) {
exception_holder_.Clear();
StartOffPythonTrainLoop();
}
if (places_.size() == 1) {
exception_holder_.Clear();
} else {
HandleException();
}
FeedFetchList fetch_data;
fetch_data.reserve(fetch_tensors.size());
try {
fetch_data = executors_[0]->Run(fetch_tensors);
} catch (...) {
exception_holder_.Catch(std::current_exception());
}
HandleException();
FeedFetchList ret;
for (size_t fetch_idx = 0; fetch_idx < fetch_tensors.size(); ++fetch_idx) {
std::vector<const LoDTensor *> lodtensor_ptrs;
lodtensor_ptrs.reserve(local_scopes_.size());
for (size_t scope_idx = 0; scope_idx < local_scopes_.size(); ++scope_idx) {
lodtensor_ptrs.push_back(&fetch_data.at(scope_idx).at(fetch_idx));
}
lodtensor_ptrs.push_back(&fetch_data.at(fetch_idx));
ret.emplace_back();
ret.back().MergeLoDTensor(lodtensor_ptrs, platform::CPUPlace());
}
......
......@@ -24,6 +24,12 @@ namespace paddle {
namespace framework {
namespace details {
struct VarInfo {
std::string name_;
proto::VarType::Type type_;
bool persistable_;
};
class AsyncSSAGraphExecutor : public SSAGraphExecutor {
public:
AsyncSSAGraphExecutor(const ExecutionStrategy &strategy,
......@@ -35,6 +41,10 @@ class AsyncSSAGraphExecutor : public SSAGraphExecutor {
FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override;
private:
void StartOffPythonTrainLoop();
void HandleException();
private:
ExecutionStrategy strategy_;
std::vector<Scope *> local_scopes_;
......@@ -44,6 +54,8 @@ class AsyncSSAGraphExecutor : public SSAGraphExecutor {
std::vector<std::unique_ptr<details::ThreadedSSAGraphExecutor>> executors_;
ExceptionHolder exception_holder_;
std::vector<std::future<void>> run_futures_;
std::vector<VarInfo> var_infos_;
};
} // namespace details
......
......@@ -119,6 +119,8 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl(
if (timeout) {
if (exception_holder_.IsCaught()) {
VLOG(3) << "caught exception " << exception_holder_.Type()
<< ", rethrow it";
for (auto &run_op_future : run_op_futures_) {
run_op_future.wait();
}
......
......@@ -379,9 +379,11 @@ ParallelExecutor::ParallelExecutor(
}
VLOG(3) << "use ScopeBufferedSSAGraphExecutor";
if (!build_strategy.async_mode_) {
member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, std::move(var_infos),
member_->places_, std::move(member_->executor_)));
}
}
void ParallelExecutor::BCastParamsToDevices(
......
......@@ -69,6 +69,9 @@ void ReaderBase::Start() {
ReaderBase::~ReaderBase() {}
DecoratedReader::~DecoratedReader() { reader_->Shutdown(); }
DecoratedReader::~DecoratedReader() {
VLOG(1) << "~DecoratedReader";
reader_->Shutdown();
}
} // namespace framework
} // namespace paddle
......@@ -77,7 +77,10 @@ class DecoratedReader : public ReaderBase,
~DecoratedReader();
protected:
void ShutdownImpl() override { reader_->Shutdown(); }
void ShutdownImpl() override {
VLOG(1) << "ShutdownImpl";
reader_->Shutdown();
}
void StartImpl() override { reader_->Start(); }
......@@ -98,6 +101,8 @@ class ReaderHolder {
reader_ = reader_base;
}
~ReaderHolder() { VLOG(1) << "~ReaderHolder"; }
const std::shared_ptr<ReaderBase>& Get() const { return reader_; }
void ReadNext(std::vector<LoDTensor>* out) {
......@@ -106,6 +111,7 @@ class ReaderHolder {
}
void ResetAll() {
VLOG(1) << "ResetAll";
auto end_readers = reader_->GetEndPoints();
for (auto* reader : end_readers) {
reader->Shutdown();
......@@ -116,11 +122,13 @@ class ReaderHolder {
}
void Shutdown() {
VLOG(1) << "Shutdown";
PADDLE_ENFORCE_NOT_NULL(reader_);
reader_->Shutdown();
}
void Start() {
VLOG(1) << "start";
PADDLE_ENFORCE_NOT_NULL(reader_);
reader_->Start();
}
......
......@@ -86,6 +86,7 @@ class BlockingQueue {
void ReOpen() {
std::lock_guard<std::mutex> lock(mutex_);
VLOG(1) << "reopen queue";
closed_ = false;
std::deque<T> new_deque;
queue_.swap(new_deque);
......@@ -95,7 +96,7 @@ class BlockingQueue {
void Close() {
std::lock_guard<std::mutex> lock(mutex_);
VLOG(3) << "close queue";
VLOG(1) << "close queue";
closed_ = true;
send_cv_.notify_all();
receive_cv_.notify_all();
......
......@@ -20,6 +20,7 @@ namespace paddle {
namespace operators {
namespace reader {
BufferedReader::~BufferedReader() {
VLOG(1) << "~BufferedReader";
reader_->Shutdown();
while (!position_.empty()) {
position_.front().wait();
......@@ -41,6 +42,7 @@ BufferedReader::BufferedReader(
thread_pool_(1),
place_(place),
buffer_size_(buffer_size) {
VLOG(1) << "BufferedReader";
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place_)) {
platform::SetDeviceId(boost::get<platform::CUDAPlace>(place_).device);
......@@ -121,6 +123,7 @@ void BufferedReader::ReadAsync(size_t i) {
}
void BufferedReader::ShutdownImpl() {
VLOG(1) << "ShutdownImpl";
reader_->Shutdown();
while (!position_.empty()) {
position_.pop();
......
......@@ -33,10 +33,13 @@ class PyReader : public framework::FileReader {
if (!success) out->clear();
}
~PyReader() { queue_->Close(); }
~PyReader() {
VLOG(1) << "~PyReader";
queue_->Close();
}
void Shutdown() override {
VLOG(3) << "PyReader shutdown!";
VLOG(1) << "PyReader shutdown!";
queue_->Close();
}
......
......@@ -57,7 +57,10 @@ class LoDTensorBlockingQueue {
inline void ReOpen() { queue_.ReOpen(); }
inline void Close() { queue_.Close(); }
inline void Close() {
VLOG(1) << "LoDTensorBlockingQueue close";
queue_.Close();
}
inline bool IsClosed() const { return queue_.IsClosed(); }
......
......@@ -557,6 +557,7 @@ All parameter, weight, gradient are variables in Paddle.
m.def("init_lod_tensor_blocking_queue",
[](Variable &var,
size_t capacity) -> std::shared_ptr<LoDTensorBlockingQueue> {
VLOG(1) << "init_lod_tensor_blocking_queue";
auto *holder = var.GetMutable<LoDTensorBlockingQueueHolder>();
holder->InitOnce(capacity, FLAGS_reader_queue_speed_test_mode);
return holder->GetQueue();
......
......@@ -36,7 +36,7 @@ def convolutional_neural_network(use_py_reader):
capacity=64,
feed_list=[img, label],
name='py_reader',
use_double_buffer=True)
use_double_buffer=False)
img, label = fluid.layers.read_file(py_reader)
conv_pool_1 = fluid.nets.simple_img_conv_pool(
......@@ -139,19 +139,20 @@ def train(use_cuda, thread_num, cpu_num):
exec_strategy=exec_strategy)
py_reader.decorate_paddle_reader(train_reader)
py_reader.start()
for pass_id in range(2):
step = 0
py_reader.start()
try:
while True:
loss_val = pe.run(fetch_list=[avg_loss.name])
loss_val = numpy.mean(loss_val)
if step % 100 == 0:
print("Batch %d, Cost %f, queue size %d" %
(step, loss_val, py_reader.queue.size()))
if step % 10 == 0:
print("Pass %d, Batch %d, Cost %f, queue size %d" %
(pass_id, step, loss_val, py_reader.queue.size()))
step += 1
except fluid.core.EOFException:
print("train end")
print("train end pass = " + str(pass_id))
py_reader.reset()
return step
......@@ -161,10 +162,11 @@ class TestAsyncSSAGraphExecutor(unittest.TestCase):
def test_check_async_ssa_exe_train(self):
step_list = []
for cpu_num in [1, 2, 4]:
scope = fluid.core.Scope()
with fluid.scope_guard(scope):
print("run cpu_num -> " + str(cpu_num))
with fluid.scope_guard(fluid.core.Scope()):
with fluid.program_guard(
fluid.Program(), startup_program=fluid.Program()):
main_program=fluid.Program(),
startup_program=fluid.Program()):
start_time = time.time()
step = train(
use_cuda=False, thread_num=cpu_num, cpu_num=cpu_num)
......@@ -173,7 +175,8 @@ class TestAsyncSSAGraphExecutor(unittest.TestCase):
print("cpu_num -> " + str(cpu_num) + " step -> " + str(step) +
" time -> " + str(end_time - start_time))
with fluid.program_guard(
fluid.Program(), startup_program=fluid.Program()):
main_program=fluid.Program(),
startup_program=fluid.Program()):
test()
assert int(step_list[0] / 2) == int(step_list[1])
assert int(step_list[1] / 2) == int(step_list[2])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册