未验证 提交 89cfa491 编写于 作者: Z Zhen Wang 提交者: GitHub

Unmerged fetch list (#22635)

* update ScopeBufferedSSAGraphExecutor&AsyncSSAGraphExecutor&ThreadedSSAGraphExecutor&FastThreadedSSAGraphExecutor&ParallelSSAGraphExecutor&ParallelExecutor for fetching unmerged results.

* add the unit test for fetch_unmerged.

* update ut for multi-card and multi-cpu.

* add the error message and the user suggestion in FetchOpHandle. test=develop
上级 f05c213f
...@@ -132,14 +132,14 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor( ...@@ -132,14 +132,14 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
ProcessGraph(graphs_, local_scopes_[0]); ProcessGraph(graphs_, local_scopes_[0]);
} }
void AsyncSSAGraphExecutor::StartOffPythonTrainLoop() { void AsyncSSAGraphExecutor::StartOffPythonTrainLoop(bool return_merged) {
VLOG(3) << "StartOffPythonTrainLoop size = " << places_.size(); VLOG(3) << "StartOffPythonTrainLoop size = " << places_.size();
for (size_t i = 1; i < places_.size(); ++i) { for (size_t i = 1; i < places_.size(); ++i) {
auto call = [this, i]() -> void { auto call = [this, i, return_merged]() -> void {
VLOG(3) << "start off python thread " << i; VLOG(3) << "start off python thread " << i;
try { try {
while (true) { while (true) {
executors_[i]->Run({}); executors_[i]->Run({}, return_merged);
} }
} catch (...) { } catch (...) {
exception_holder_.Catch(std::current_exception()); exception_holder_.Catch(std::current_exception());
...@@ -164,8 +164,12 @@ void AsyncSSAGraphExecutor::HandleException() { ...@@ -164,8 +164,12 @@ void AsyncSSAGraphExecutor::HandleException() {
} }
} }
FeedFetchList AsyncSSAGraphExecutor::Run( FetchResultType AsyncSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) { const std::vector<std::string> &fetch_tensors, bool return_merged) {
PADDLE_ENFORCE_EQ(return_merged, true,
platform::errors::InvalidArgument(
"AsyncSSAGraphExecutor does not support unmerged "
"results to be fetched!"));
// init once // init once
if (run_futures_.size() == 0 && places_.size() > 1) { if (run_futures_.size() == 0 && places_.size() > 1) {
if (strategy_.thread_barrier_) { if (strategy_.thread_barrier_) {
...@@ -175,18 +179,17 @@ FeedFetchList AsyncSSAGraphExecutor::Run( ...@@ -175,18 +179,17 @@ FeedFetchList AsyncSSAGraphExecutor::Run(
#endif #endif
} }
exception_holder_.Clear(); exception_holder_.Clear();
StartOffPythonTrainLoop(); StartOffPythonTrainLoop(return_merged);
} }
if (places_.size() == 1) { if (places_.size() == 1) {
exception_holder_.Clear(); exception_holder_.Clear();
} }
FeedFetchList fetch_data; FetchResultType fetch_data;
fetch_data.reserve(fetch_tensors.size());
try { try {
fetch_data = executors_[0]->Run(fetch_tensors); fetch_data = executors_[0]->Run(fetch_tensors, return_merged);
} catch (...) { } catch (...) {
exception_holder_.Catch(std::current_exception()); exception_holder_.Catch(std::current_exception());
} }
...@@ -194,9 +197,10 @@ FeedFetchList AsyncSSAGraphExecutor::Run( ...@@ -194,9 +197,10 @@ FeedFetchList AsyncSSAGraphExecutor::Run(
HandleException(); HandleException();
FeedFetchList ret; FeedFetchList ret;
auto &val = boost::get<FeedFetchList>(fetch_data);
for (size_t fetch_idx = 0; fetch_idx < fetch_tensors.size(); ++fetch_idx) { for (size_t fetch_idx = 0; fetch_idx < fetch_tensors.size(); ++fetch_idx) {
std::vector<const LoDTensor *> lodtensor_ptrs; std::vector<const LoDTensor *> lodtensor_ptrs;
lodtensor_ptrs.push_back(&fetch_data.at(fetch_idx)); lodtensor_ptrs.push_back(&val.at(fetch_idx));
ret.emplace_back(); ret.emplace_back();
ret.back().MergeLoDTensor(lodtensor_ptrs, platform::CPUPlace()); ret.back().MergeLoDTensor(lodtensor_ptrs, platform::CPUPlace());
} }
......
...@@ -42,10 +42,11 @@ class AsyncSSAGraphExecutor : public SSAGraphExecutor { ...@@ -42,10 +42,11 @@ class AsyncSSAGraphExecutor : public SSAGraphExecutor {
~AsyncSSAGraphExecutor() final = default; ~AsyncSSAGraphExecutor() final = default;
const ir::Graph &Graph() const override { return *graphs_[0]; } const ir::Graph &Graph() const override { return *graphs_[0]; }
FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override; FetchResultType Run(const std::vector<std::string> &fetch_tensors,
bool return_merged) override;
private: private:
void StartOffPythonTrainLoop(); void StartOffPythonTrainLoop(bool return_merged);
void HandleException(); void HandleException();
private: private:
......
...@@ -51,8 +51,8 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor( ...@@ -51,8 +51,8 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
PrepareAtomicOpDeps(); PrepareAtomicOpDeps();
} }
FeedFetchList FastThreadedSSAGraphExecutor::Run( FetchResultType FastThreadedSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) { const std::vector<std::string> &fetch_tensors, bool return_merged) {
VLOG(3) << "enter FastThreadedSSAGraphExecutor Run"; VLOG(3) << "enter FastThreadedSSAGraphExecutor Run";
std::unique_ptr<platform::RecordEvent> event( std::unique_ptr<platform::RecordEvent> event(
new platform::RecordEvent("FastThreadedSSAGraphExecutorPrepare")); new platform::RecordEvent("FastThreadedSSAGraphExecutorPrepare"));
...@@ -61,15 +61,19 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( ...@@ -61,15 +61,19 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
PrepareAtomicOpDeps(); PrepareAtomicOpDeps();
size_t num_ops = op_deps->size(); size_t num_ops = op_deps->size();
paddle::framework::FeedFetchList fetches; FetchResultType fetches;
fetches.resize(fetch_tensors.size()); if (return_merged) {
fetches = FeedFetchList(fetch_tensors.size());
} else {
fetches = FetchUnmergedList(fetch_tensors.size());
}
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars; std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
std::vector<OpHandleBase *> fetch_ops; std::vector<OpHandleBase *> fetch_ops;
std::vector<OpHandleBase *> ready_fetch_ops; std::vector<OpHandleBase *> ready_fetch_ops;
exception_.Clear(); exception_.Clear();
InsertFetchOps(fetch_tensors, &fetches, &fetched_vars, op_deps.get(), InsertFetchOps(fetch_tensors, &fetches, &fetched_vars, op_deps.get(),
&fetch_ops, &ready_fetch_ops); &fetch_ops, &ready_fetch_ops, return_merged);
event.reset(nullptr); event.reset(nullptr);
if (strategy_.num_threads_ == 1 && traced_ops_.size() == num_ops) { if (strategy_.num_threads_ == 1 && traced_ops_.size() == num_ops) {
// If the num_threads is 1, we can record the order of operator's // If the num_threads is 1, we can record the order of operator's
...@@ -120,11 +124,11 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( ...@@ -120,11 +124,11 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
} }
void FastThreadedSSAGraphExecutor::InsertFetchOps( void FastThreadedSSAGraphExecutor::InsertFetchOps(
const std::vector<std::string> &fetch_tensors, FeedFetchList *fetches, const std::vector<std::string> &fetch_tensors, FetchResultType *fetches,
std::unordered_map<std::string, std::vector<VarHandleBase *>> *fetched_vars, std::unordered_map<std::string, std::vector<VarHandleBase *>> *fetched_vars,
std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps, std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps,
std::vector<OpHandleBase *> *fetch_ops, std::vector<OpHandleBase *> *fetch_ops,
std::vector<OpHandleBase *> *ready_fetch_ops) { std::vector<OpHandleBase *> *ready_fetch_ops, bool return_merged) {
std::unordered_set<std::string> fetch_tensor_set(fetch_tensors.begin(), std::unordered_set<std::string> fetch_tensor_set(fetch_tensors.begin(),
fetch_tensors.end()); fetch_tensors.end());
for (auto &fetch_var_name : fetch_tensor_set) { for (auto &fetch_var_name : fetch_tensor_set) {
...@@ -154,7 +158,7 @@ void FastThreadedSSAGraphExecutor::InsertFetchOps( ...@@ -154,7 +158,7 @@ void FastThreadedSSAGraphExecutor::InsertFetchOps(
ir::Node *fetch_node = ir::Node *fetch_node =
graph_->CreateEmptyNode("fetch", ir::Node::Type::kOperation); graph_->CreateEmptyNode("fetch", ir::Node::Type::kOperation);
auto *op = new FetchOpHandle(fetch_node, fetches, i, &local_scopes_, auto *op = new FetchOpHandle(fetch_node, fetches, i, &local_scopes_,
&local_exec_scopes_); &local_exec_scopes_, return_merged);
fetch_ops->emplace_back(op); fetch_ops->emplace_back(op);
for (auto &p : places_) { for (auto &p : places_) {
......
...@@ -36,7 +36,8 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -36,7 +36,8 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
const std::vector<Scope *> &local_exec_scopes, const std::vector<Scope *> &local_exec_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
ir::Graph *graph); ir::Graph *graph);
FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override; FetchResultType Run(const std::vector<std::string> &fetch_tensors,
bool return_merged) override;
const ir::Graph &Graph() const override; const ir::Graph &Graph() const override;
private: private:
...@@ -83,12 +84,12 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -83,12 +84,12 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
bool RunTracedOps(const std::vector<OpHandleBase *> &traced_ops); bool RunTracedOps(const std::vector<OpHandleBase *> &traced_ops);
void InsertFetchOps( void InsertFetchOps(
const std::vector<std::string> &fetch_tensors, FeedFetchList *fetches, const std::vector<std::string> &fetch_tensors, FetchResultType *fetches,
std::unordered_map<std::string, std::vector<VarHandleBase *>> std::unordered_map<std::string, std::vector<VarHandleBase *>>
*fetched_vars, *fetched_vars,
std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps, std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps,
std::vector<OpHandleBase *> *fetch_ops, std::vector<OpHandleBase *> *fetch_ops,
std::vector<OpHandleBase *> *ready_fetch_ops); std::vector<OpHandleBase *> *ready_fetch_ops, bool return_merged);
}; };
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/details/fetch_op_handle.h" #include "paddle/fluid/framework/details/fetch_op_handle.h"
#include <string> #include <string>
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
...@@ -21,14 +22,16 @@ namespace paddle { ...@@ -21,14 +22,16 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
FetchOpHandle::FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset, FetchOpHandle::FetchOpHandle(ir::Node *node, FetchResultType *data,
std::vector<Scope *> *local_scopes, size_t offset, std::vector<Scope *> *local_scopes,
std::vector<Scope *> *local_exec_scopes) std::vector<Scope *> *local_exec_scopes,
bool return_merged)
: OpHandleBase(node), : OpHandleBase(node),
data_(data), data_(data),
offset_(offset), offset_(offset),
local_scopes_(local_scopes), local_scopes_(local_scopes),
local_exec_scopes_(local_exec_scopes) {} local_exec_scopes_(local_exec_scopes),
return_merged_(return_merged) {}
FetchOpHandle::~FetchOpHandle() {} FetchOpHandle::~FetchOpHandle() {}
...@@ -37,12 +40,42 @@ void FetchOpHandle::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) { ...@@ -37,12 +40,42 @@ void FetchOpHandle::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) {
} }
void FetchOpHandle::WaitAndMergeCPUTensors() const { void FetchOpHandle::WaitAndMergeCPUTensors() const {
if (return_merged_) {
const auto &tensor_dims = tensors_[0].dims();
for (size_t i = 1; i < tensors_.size(); i++) {
const auto &ele_dims = tensors_[i].dims();
PADDLE_ENFORCE_EQ(
tensor_dims.size(), ele_dims.size(),
platform::errors::Fatal("The dimension sizes of fetched Tensors are "
"different from each other on different "
"devices. And the error is caused by the %zu "
"(th) fetched variable. Please set the "
"parameter `return_merged = False` when you "
"call the `Executor.run()` method.",
offset_));
for (int j = 1; j < tensor_dims.size(); j++) {
PADDLE_ENFORCE_EQ(
tensor_dims[j], ele_dims[j],
platform::errors::Fatal("The dimensions of fetched Tensors are "
"different from each other on different "
"devices. And the error is caused by the "
"%zu (th) fetched variable. Please set the "
"parameter `return_merged = False` when "
"you call the `Executor.run()` method.",
offset_));
}
}
std::vector<const LoDTensor *> tensors_ptr; std::vector<const LoDTensor *> tensors_ptr;
tensors_ptr.reserve(tensors_.size()); tensors_ptr.reserve(tensors_.size());
for (auto &t : tensors_) { for (auto &t : tensors_) {
tensors_ptr.emplace_back(&t); tensors_ptr.emplace_back(&t);
} }
data_->at(offset_).MergeLoDTensor(tensors_ptr, platform::CPUPlace()); auto &val = boost::get<FeedFetchList>(*data_);
val.at(offset_).MergeLoDTensor(tensors_ptr, platform::CPUPlace());
} else {
auto &val = boost::get<FetchUnmergedList>(*data_);
val.at(offset_) = std::move(tensors_);
}
} }
void FetchOpHandle::RunImpl() { void FetchOpHandle::RunImpl() {
......
...@@ -28,9 +28,9 @@ namespace details { ...@@ -28,9 +28,9 @@ namespace details {
struct FetchOpHandle : public OpHandleBase { struct FetchOpHandle : public OpHandleBase {
public: public:
FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset, FetchOpHandle(ir::Node *node, FetchResultType *data, size_t offset,
std::vector<Scope *> *local_scopes, std::vector<Scope *> *local_scopes,
std::vector<Scope *> *local_exec_scopes); std::vector<Scope *> *local_exec_scopes, bool return_merged);
~FetchOpHandle(); ~FetchOpHandle();
...@@ -50,11 +50,12 @@ struct FetchOpHandle : public OpHandleBase { ...@@ -50,11 +50,12 @@ struct FetchOpHandle : public OpHandleBase {
void WaitInputVarGenerated(const platform::Place &place) override; void WaitInputVarGenerated(const platform::Place &place) override;
private: private:
FeedFetchList *data_; FetchResultType *data_;
size_t offset_; size_t offset_;
std::vector<Scope *> *local_scopes_; std::vector<Scope *> *local_scopes_;
std::vector<Scope *> *local_exec_scopes_; std::vector<Scope *> *local_exec_scopes_;
std::vector<LoDTensor> tensors_; std::vector<LoDTensor> tensors_;
bool return_merged_;
}; };
} // namespace details } // namespace details
......
...@@ -123,25 +123,33 @@ std::vector<ir::Graph *> ParallelSSAGraphExecutor::Graphs() { ...@@ -123,25 +123,33 @@ std::vector<ir::Graph *> ParallelSSAGraphExecutor::Graphs() {
return result; return result;
} }
FeedFetchList ParallelSSAGraphExecutor::Run( FetchResultType ParallelSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) { const std::vector<std::string> &fetch_tensors, bool return_merged) {
std::vector<std::future<FeedFetchList>> run_futures; std::vector<std::future<FetchResultType>> run_futures;
std::vector<FeedFetchList> fetch_data;
FeedFetchList ret;
std::vector<FetchResultType> fetch_data;
FetchResultType ret;
fetch_data.reserve(places_.size()); fetch_data.reserve(places_.size());
ret.reserve(fetch_tensors.size()); if (return_merged) {
ret = FeedFetchList();
} else {
ret = FetchUnmergedList();
}
exception_holder_.Clear(); exception_holder_.Clear();
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
auto call = [this, i, &fetch_tensors]() -> FeedFetchList { auto call = [this, i, return_merged, &fetch_tensors]() -> FetchResultType {
try { try {
return executors_[i]->Run(fetch_tensors); return executors_[i]->Run(fetch_tensors, return_merged);
} catch (...) { } catch (...) {
exception_holder_.Catch(std::current_exception()); exception_holder_.Catch(std::current_exception());
} }
if (return_merged) {
return FeedFetchList(); return FeedFetchList();
} else {
return FetchUnmergedList();
}
}; };
if (pool_) { if (pool_) {
...@@ -164,14 +172,33 @@ FeedFetchList ParallelSSAGraphExecutor::Run( ...@@ -164,14 +172,33 @@ FeedFetchList ParallelSSAGraphExecutor::Run(
exception_holder_.ReThrow(); exception_holder_.ReThrow();
} }
if (return_merged) {
auto &ret_val = boost::get<FeedFetchList>(ret);
for (size_t fetch_idx = 0; fetch_idx < fetch_tensors.size(); ++fetch_idx) { for (size_t fetch_idx = 0; fetch_idx < fetch_tensors.size(); ++fetch_idx) {
std::vector<const LoDTensor *> lodtensor_ptrs; std::vector<const LoDTensor *> lodtensor_ptrs;
lodtensor_ptrs.reserve(local_scopes_.size()); lodtensor_ptrs.reserve(local_scopes_.size());
for (size_t scope_idx = 0; scope_idx < local_scopes_.size(); ++scope_idx) { for (size_t scope_idx = 0; scope_idx < local_scopes_.size();
lodtensor_ptrs.push_back(&fetch_data.at(scope_idx).at(fetch_idx)); ++scope_idx) {
auto &val = boost::get<FeedFetchList>(fetch_data.at(scope_idx));
lodtensor_ptrs.push_back(&val.at(fetch_idx));
}
ret_val.emplace_back();
ret_val.back().MergeLoDTensor(lodtensor_ptrs, platform::CPUPlace());
}
} else {
auto &ret_val = boost::get<FetchUnmergedList>(ret);
for (size_t fetch_idx = 0; fetch_idx < fetch_tensors.size(); ++fetch_idx) {
ret_val.emplace_back();
for (size_t scope_idx = 0; scope_idx < local_scopes_.size();
++scope_idx) {
auto &val = boost::get<FetchUnmergedList>(fetch_data.at(scope_idx));
PADDLE_ENFORCE_EQ(
val.at(fetch_idx).size(), 1,
platform::errors::Fatal(
"Each place must have only one fetched LoDTensor!"));
ret_val.back().emplace_back(val.at(fetch_idx)[0]);
}
} }
ret.emplace_back();
ret.back().MergeLoDTensor(lodtensor_ptrs, platform::CPUPlace());
} }
return ret; return ret;
} }
......
...@@ -39,7 +39,8 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor { ...@@ -39,7 +39,8 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor {
std::vector<ir::Graph *> Graphs(); std::vector<ir::Graph *> Graphs();
FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override; FetchResultType Run(const std::vector<std::string> &fetch_tensors,
bool return_merged) override;
private: private:
std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph( std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph(
......
...@@ -41,19 +41,19 @@ ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor( ...@@ -41,19 +41,19 @@ ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor(
PrepareLocalExeScopes(); PrepareLocalExeScopes();
} }
FeedFetchList ScopeBufferedSSAGraphExecutor::Run( FetchResultType ScopeBufferedSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) { const std::vector<std::string> &fetch_tensors, bool return_merged) {
if (drop_scope_counter_ == 0) { if (drop_scope_counter_ == 0) {
platform::RecordEvent e("InitLocalVars"); platform::RecordEvent e("InitLocalVars");
InitVariables(); InitVariables();
} }
std::vector<framework::LoDTensor> fetch_data; FetchResultType fetch_data;
std::exception_ptr eptr = nullptr; std::exception_ptr eptr = nullptr;
auto exe_run_func = [&]() { auto exe_run_func = [&]() {
try { try {
fetch_data = underlying_executor_->Run(fetch_tensors); fetch_data = underlying_executor_->Run(fetch_tensors, return_merged);
} catch (...) { } catch (...) {
eptr = std::current_exception(); eptr = std::current_exception();
} }
......
...@@ -50,7 +50,8 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -50,7 +50,8 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
return underlying_executor_->Graph(); return underlying_executor_->Graph();
} }
FeedFetchList Run(const std::vector<std::string>& fetch_tensors) override; FetchResultType Run(const std::vector<std::string>& fetch_tensors,
bool return_merged) override;
void DropLocalExeScopes(); void DropLocalExeScopes();
......
...@@ -35,7 +35,8 @@ class SSAGraphExecutor { ...@@ -35,7 +35,8 @@ class SSAGraphExecutor {
virtual const ir::Graph& Graph() const = 0; virtual const ir::Graph& Graph() const = 0;
virtual FeedFetchList Run(const std::vector<std::string>& fetch_tensors) = 0; virtual FetchResultType Run(const std::vector<std::string>& fetch_tensors,
bool return_merged = true) = 0;
}; };
void ClearFetchOp(ir::Graph* graph, std::vector<OpHandleBase*>* fetch_ops); void ClearFetchOp(ir::Graph* graph, std::vector<OpHandleBase*>* fetch_ops);
......
...@@ -52,8 +52,8 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor( ...@@ -52,8 +52,8 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
CopyOpDeps(); CopyOpDeps();
} }
inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl( inline FetchResultType ThreadedSSAGraphExecutor::RunImpl(
const std::vector<std::string> &fetch_tensors) { const std::vector<std::string> &fetch_tensors, bool return_merged) {
std::unique_ptr<platform::RecordEvent> event( std::unique_ptr<platform::RecordEvent> event(
new platform::RecordEvent("ThreadedSSAGraphExecutorPrepare")); new platform::RecordEvent("ThreadedSSAGraphExecutorPrepare"));
std::unique_ptr<OpDependentData> op_deps = op_deps_futures_.get(); std::unique_ptr<OpDependentData> op_deps = op_deps_futures_.get();
...@@ -70,10 +70,15 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl( ...@@ -70,10 +70,15 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl(
// Step 2. Insert FetchOps // Step 2. Insert FetchOps
std::vector<OpHandleBase *> fetch_ops; std::vector<OpHandleBase *> fetch_ops;
std::unordered_set<VarHandleBase *> fetch_dependencies; std::unordered_set<VarHandleBase *> fetch_dependencies;
FeedFetchList fetch_data(fetch_tensors.size()); FetchResultType fetch_data;
if (return_merged) {
fetch_data = FeedFetchList(fetch_tensors.size());
} else {
fetch_data = FetchUnmergedList(fetch_tensors.size());
}
InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &ready_ops, InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &ready_ops,
&pending_ops, &pending_vars, &fetch_data); &pending_ops, &pending_vars, &fetch_data, return_merged);
exception_holder_.Clear(); exception_holder_.Clear();
event.reset(nullptr); event.reset(nullptr);
...@@ -142,12 +147,12 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl( ...@@ -142,12 +147,12 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl(
return fetch_data; return fetch_data;
} }
FeedFetchList ThreadedSSAGraphExecutor::Run( FetchResultType ThreadedSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) { const std::vector<std::string> &fetch_tensors, bool return_merged) {
for (size_t j = 0; j < strategy_.num_iteration_per_run_ - 1; ++j) { for (size_t j = 0; j < strategy_.num_iteration_per_run_ - 1; ++j) {
RunImpl({}); RunImpl({}, return_merged);
} }
return RunImpl(fetch_tensors); return RunImpl(fetch_tensors, return_merged);
} }
void ThreadedSSAGraphExecutor::InsertFetchOps( void ThreadedSSAGraphExecutor::InsertFetchOps(
...@@ -157,7 +162,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( ...@@ -157,7 +162,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
std::unordered_set<OpHandleBase *> *ready_ops, std::unordered_set<OpHandleBase *> *ready_ops,
std::unordered_map<OpHandleBase *, size_t> *pending_ops, std::unordered_map<OpHandleBase *, size_t> *pending_ops,
std::unordered_set<VarHandleBase *> *pending_vars, std::unordered_set<VarHandleBase *> *pending_vars,
FeedFetchList *fetch_data) { FetchResultType *fetch_data, bool return_merged) {
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars; std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
std::unordered_set<VarHandleBase *> local_ready_vars; std::unordered_set<VarHandleBase *> local_ready_vars;
std::unordered_set<std::string> fetch_tensor_set(fetch_tensors.begin(), std::unordered_set<std::string> fetch_tensor_set(fetch_tensors.begin(),
...@@ -189,7 +194,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( ...@@ -189,7 +194,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
ir::Node *fetch_node = ir::Node *fetch_node =
graph_->CreateEmptyNode("fetch", ir::Node::Type::kOperation); graph_->CreateEmptyNode("fetch", ir::Node::Type::kOperation);
auto *op = new FetchOpHandle(fetch_node, fetch_data, i, &local_scopes_, auto *op = new FetchOpHandle(fetch_node, fetch_data, i, &local_scopes_,
&local_exec_scopes_); &local_exec_scopes_, return_merged);
fetch_ops->emplace_back(op); fetch_ops->emplace_back(op);
for (auto &p : places_) { for (auto &p : places_) {
......
...@@ -58,12 +58,14 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -58,12 +58,14 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
const ir::Graph &Graph() const override { return *graph_; } const ir::Graph &Graph() const override { return *graph_; }
// Run a SSAGraph by a thread pool // Run a SSAGraph by a thread pool
// Use topological sort algorithm // Use topological sort algorithm
FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override; FetchResultType Run(const std::vector<std::string> &fetch_tensors,
bool return_merged) override;
~ThreadedSSAGraphExecutor() final = default; ~ThreadedSSAGraphExecutor() final = default;
private: private:
inline FeedFetchList RunImpl(const std::vector<std::string> &fetch_tensors); inline FetchResultType RunImpl(const std::vector<std::string> &fetch_tensors,
bool return_merged);
void RunOp(const std::shared_ptr<BlockingQueue<VarHandleBase *>> &ready_var_q, void RunOp(const std::shared_ptr<BlockingQueue<VarHandleBase *>> &ready_var_q,
details::OpHandleBase *op); details::OpHandleBase *op);
...@@ -99,7 +101,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -99,7 +101,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
std::unordered_set<OpHandleBase *> *ready_ops, std::unordered_set<OpHandleBase *> *ready_ops,
std::unordered_map<OpHandleBase *, size_t> *pending_ops, std::unordered_map<OpHandleBase *, size_t> *pending_ops,
std::unordered_set<VarHandleBase *> *pending_vars, std::unordered_set<VarHandleBase *> *pending_vars,
FeedFetchList *fetch_data); FetchResultType *fetch_data, bool return_merged);
void PrepareOpDeps(); void PrepareOpDeps();
......
...@@ -15,11 +15,14 @@ limitations under the License. */ ...@@ -15,11 +15,14 @@ limitations under the License. */
#pragma once #pragma once
#include <vector> #include <vector>
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/platform/variant.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
using FeedFetchType = LoDTensor; using FeedFetchType = LoDTensor;
using FeedFetchList = std::vector<FeedFetchType>; using FeedFetchList = std::vector<FeedFetchType>;
using FetchUnmergedList = std::vector<std::vector<FeedFetchType>>;
using FetchResultType = boost::variant<FeedFetchList, FetchUnmergedList>;
static const char kFeedOpType[] = "feed"; static const char kFeedOpType[] = "feed";
static const char kFetchOpType[] = "fetch"; static const char kFetchOpType[] = "fetch";
......
...@@ -20,6 +20,7 @@ namespace paddle { ...@@ -20,6 +20,7 @@ namespace paddle {
namespace framework { namespace framework {
using LoDTensorArray = std::vector<LoDTensor>; using LoDTensorArray = std::vector<LoDTensor>;
using LoDTensor2DArray = std::vector<std::vector<LoDTensor>>;
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -723,8 +723,8 @@ void ParallelExecutor::BCastParamsToDevices( ...@@ -723,8 +723,8 @@ void ParallelExecutor::BCastParamsToDevices(
} }
} }
FeedFetchList ParallelExecutor::Run( FetchResultType ParallelExecutor::Run(
const std::vector<std::string> &fetch_tensors) { const std::vector<std::string> &fetch_tensors, bool return_merged) {
VLOG(3) << "enter ParallelExecutor Run"; VLOG(3) << "enter ParallelExecutor Run";
#ifdef WITH_GPERFTOOLS #ifdef WITH_GPERFTOOLS
if (gProfileStarted) { if (gProfileStarted) {
...@@ -738,7 +738,7 @@ FeedFetchList ParallelExecutor::Run( ...@@ -738,7 +738,7 @@ FeedFetchList ParallelExecutor::Run(
member_->HasGarbageCollectors()); member_->HasGarbageCollectors());
VLOG(3) << "ParallelExecutor begin to run member_->executor_->Run"; VLOG(3) << "ParallelExecutor begin to run member_->executor_->Run";
auto fetch_data = member_->executor_->Run(fetch_tensors); auto fetch_data = member_->executor_->Run(fetch_tensors, return_merged);
return fetch_data; return fetch_data;
} }
......
...@@ -77,7 +77,8 @@ class ParallelExecutor { ...@@ -77,7 +77,8 @@ class ParallelExecutor {
void FeedAndSplitTensorIntoLocalScopes( void FeedAndSplitTensorIntoLocalScopes(
const std::unordered_map<std::string, LoDTensor> &tensors); const std::unordered_map<std::string, LoDTensor> &tensors);
FeedFetchList Run(const std::vector<std::string> &fetch_tensors); FetchResultType Run(const std::vector<std::string> &fetch_tensors,
bool return_merged = true);
private: private:
// broadcast the parameters from the 0th device. // broadcast the parameters from the 0th device.
......
...@@ -32,6 +32,7 @@ DECLARE_bool(use_ngraph); ...@@ -32,6 +32,7 @@ DECLARE_bool(use_ngraph);
DECLARE_bool(use_system_allocator); DECLARE_bool(use_system_allocator);
DECLARE_bool(free_idle_chunk); DECLARE_bool(free_idle_chunk);
DECLARE_bool(free_when_no_cache_hit); DECLARE_bool(free_when_no_cache_hit);
DECLARE_bool(enable_parallel_graph);
namespace paddle { namespace paddle {
namespace pybind { namespace pybind {
...@@ -169,6 +170,7 @@ static void RegisterGlobalVarGetterSetter() { ...@@ -169,6 +170,7 @@ static void RegisterGlobalVarGetterSetter() {
REGISTER_GLOBAL_VAR_GETTER_ONLY(FLAGS_use_ngraph); REGISTER_GLOBAL_VAR_GETTER_ONLY(FLAGS_use_ngraph);
REGISTER_GLOBAL_VAR_GETTER_SETTER(FLAGS_eager_delete_tensor_gb); REGISTER_GLOBAL_VAR_GETTER_SETTER(FLAGS_eager_delete_tensor_gb);
REGISTER_GLOBAL_VAR_GETTER_SETTER(FLAGS_use_system_allocator); REGISTER_GLOBAL_VAR_GETTER_SETTER(FLAGS_use_system_allocator);
REGISTER_GLOBAL_VAR_GETTER_SETTER(FLAGS_enable_parallel_graph);
REGISTER_GLOBAL_VAR_GETTER_ONLY(FLAGS_free_idle_chunk); REGISTER_GLOBAL_VAR_GETTER_ONLY(FLAGS_free_idle_chunk);
REGISTER_GLOBAL_VAR_GETTER_ONLY(FLAGS_free_when_no_cache_hit); REGISTER_GLOBAL_VAR_GETTER_ONLY(FLAGS_free_when_no_cache_hit);
} }
......
...@@ -24,6 +24,7 @@ limitations under the License. */ ...@@ -24,6 +24,7 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/garbage_collector.h" #include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/io/fs.h" #include "paddle/fluid/framework/io/fs.h"
...@@ -103,6 +104,7 @@ DECLARE_bool(use_ngraph); ...@@ -103,6 +104,7 @@ DECLARE_bool(use_ngraph);
// disable auto conversion to list in Python // disable auto conversion to list in Python
PYBIND11_MAKE_OPAQUE(paddle::framework::LoDTensorArray); PYBIND11_MAKE_OPAQUE(paddle::framework::LoDTensorArray);
PYBIND11_MAKE_OPAQUE(paddle::framework::LoDTensor2DArray);
namespace paddle { namespace paddle {
namespace pybind { namespace pybind {
...@@ -1614,6 +1616,25 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1614,6 +1616,25 @@ All parameter, weight, gradient are variables in Paddle.
}, },
py::return_value_policy::take_ownership); py::return_value_policy::take_ownership);
py::class_<LoDTensor2DArray>(m, "LoDTensor2DArray", R"DOC(
LoDTensor2DArray is 2-D array of LoDTensor.
)DOC")
.def("_move_to_list",
[](LoDTensor2DArray &self) -> py::list {
py::list res(self.size());
for (size_t i = 0; i < self.size(); ++i) {
py::list tmp(self[i].size());
for (size_t j = 0; j < self[i].size(); ++j) {
tmp[j] = py::cast(std::move(self[i][j]));
}
res[i] = std::move(tmp);
self[i].clear();
}
self.clear();
return res;
},
py::return_value_policy::take_ownership);
m.def("op_support_gpu", OpSupportGPU); m.def("op_support_gpu", OpSupportGPU);
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
m.def("get_cuda_device_count", platform::GetCUDADeviceCount); m.def("get_cuda_device_count", platform::GetCUDADeviceCount);
...@@ -2306,9 +2327,20 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -2306,9 +2327,20 @@ All parameter, weight, gradient are variables in Paddle.
&ParallelExecutor::FeedAndSplitTensorIntoLocalScopes) &ParallelExecutor::FeedAndSplitTensorIntoLocalScopes)
.def("run", .def("run",
[](ParallelExecutor &self, [](ParallelExecutor &self,
const std::vector<std::string> &fetch_tensors) { const std::vector<std::string> &fetch_tensors,
bool return_merged) -> py::object {
paddle::framework::FetchResultType ret;
{
pybind11::gil_scoped_release release; pybind11::gil_scoped_release release;
return self.Run(fetch_tensors); ret = self.Run(fetch_tensors, return_merged);
}
if (return_merged) {
return py::cast(std::move(
boost::get<paddle::framework::FeedFetchList>(ret)));
} else {
return py::cast(std::move(
boost::get<paddle::framework::FetchUnmergedList>(ret)));
}
}) })
.def("device_count", &ParallelExecutor::DeviceCount); .def("device_count", &ParallelExecutor::DeviceCount);
......
...@@ -620,7 +620,7 @@ class Executor(object): ...@@ -620,7 +620,7 @@ class Executor(object):
self._closed = True self._closed = True
def _run_parallel(self, program, scope, feed, fetch_list, fetch_var_name, def _run_parallel(self, program, scope, feed, fetch_list, fetch_var_name,
return_numpy): return_numpy, return_merged):
exe = program._executor exe = program._executor
# TODO(zhenghuihuang): quantization uses Graph in CompiledProgram # TODO(zhenghuihuang): quantization uses Graph in CompiledProgram
# instead of program. We will add support for checking Vars in Graph # instead of program. We will add support for checking Vars in Graph
...@@ -674,7 +674,7 @@ class Executor(object): ...@@ -674,7 +674,7 @@ class Executor(object):
exe.feed_tensors_into_local_scopes(res) exe.feed_tensors_into_local_scopes(res)
fetch_var_names = list(map(_to_name_str, fetch_list)) fetch_var_names = list(map(_to_name_str, fetch_list))
tensors = exe.run(fetch_var_names)._move_to_list() tensors = exe.run(fetch_var_names, return_merged)._move_to_list()
return as_numpy(tensors) if return_numpy else tensors return as_numpy(tensors) if return_numpy else tensors
def run(self, def run(self,
...@@ -685,7 +685,8 @@ class Executor(object): ...@@ -685,7 +685,8 @@ class Executor(object):
fetch_var_name='fetch', fetch_var_name='fetch',
scope=None, scope=None,
return_numpy=True, return_numpy=True,
use_program_cache=False): use_program_cache=False,
return_merged=True):
""" """
Run the specified :code:`Program` or :code:`CompiledProgram`. It should be noted that the executor Run the specified :code:`Program` or :code:`CompiledProgram`. It should be noted that the executor
will execute all the operators in :code:`Program` or :code:`CompiledProgram` without pruning some will execute all the operators in :code:`Program` or :code:`CompiledProgram` without pruning some
...@@ -724,6 +725,17 @@ class Executor(object): ...@@ -724,6 +725,17 @@ class Executor(object):
the input program is :code:`fluid.Program`, and the parameters(program, feed variable name the input program is :code:`fluid.Program`, and the parameters(program, feed variable name
and fetch_list variable) of this interface remains unchanged during running. and fetch_list variable) of this interface remains unchanged during running.
The default is False. The default is False.
return_merged(bool): This parameter indicates whether fetched variables (the variables
specified in the fetch list) should be merged according to the execution device dimension.
If :code:`return_merged` is False, the type of the return value is a two-dimensional list
of :code:`Tensor` ( :code:`return_numpy` is False) or a two-dimensional list of
:code:`numpy.ndarray` ( :code:`return_numpy` is True). If :code:`return_merged` is True,
the type of the return value is an one-dimensional list of :code:`Tensor` ( :code:`return_numpy`
is False) or an one-dimensional list of :code:`numpy.ndarray` ( :code:`return_numpy` is True).
Please see Examples 2 for more details. If the lengths of fetched results are variant, please
set :code:`return_merged` as False, which denotes that the fetched results will not be merged.
The default is True, but it is just for the compatibility, and may use False as default value
in the future version.
Returns: Returns:
...@@ -743,7 +755,7 @@ class Executor(object): ...@@ -743,7 +755,7 @@ class Executor(object):
results are spliced together in dimension 0 for the same variable values results are spliced together in dimension 0 for the same variable values
(variables in fetch_list) on different devices. (variables in fetch_list) on different devices.
Examples: Examples 1:
.. code-block:: python .. code-block:: python
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -765,6 +777,66 @@ class Executor(object): ...@@ -765,6 +777,66 @@ class Executor(object):
x = numpy.random.random(size=(10, 1)).astype('float32') x = numpy.random.random(size=(10, 1)).astype('float32')
outs = exe.run(feed={'X': x}, outs = exe.run(feed={'X': x},
fetch_list=[loss.name]) fetch_list=[loss.name])
Examples 2:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
# First create the Executor.
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
data = fluid.data(name='X', shape=[None, 1], dtype='float32')
class_dim = 2
prediction = fluid.layers.fc(input=data, size=class_dim)
loss = fluid.layers.mean(prediction)
adam = fluid.optimizer.Adam()
adam.minimize(loss)
# Run the startup program once and only once.
exe.run(fluid.default_startup_program())
build_strategy = fluid.BuildStrategy()
binary = fluid.CompiledProgram(fluid.default_main_program()).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy)
batch_size = 6
x = np.random.random(size=(batch_size, 1)).astype('float32')
# Set return_merged as False to fetch unmerged results:
unmerged_prediction, = exe.run(binary, feed={'X': x},
fetch_list=[prediction.name],
return_merged=False)
# If the user uses two GPU cards to run this python code, the printed result will be
# (2, 3, class_dim). The first dimension value of the printed result is the number of used
# GPU cards, and the second dimension value is the quotient of batch_size and the
# number of used GPU cards.
print("The unmerged prediction shape: {}".format(np.array(unmerged_prediction).shape))
print(unmerged_prediction)
# Set return_merged as True to fetch merged results:
merged_prediction, = exe.run(binary, feed={'X': x},
fetch_list=[prediction.name],
return_merged=True)
# If the user uses two GPU cards to run this python code, the printed result will be
# (6, class_dim). The first dimension value of the printed result is the batch_size.
print("The merged prediction shape: {}".format(np.array(merged_prediction).shape))
print(merged_prediction)
# Out:
# The unmerged prediction shape: (2, 3, 2)
# [array([[-0.37620035, -0.19752218],
# [-0.3561043 , -0.18697084],
# [-0.24129935, -0.12669306]], dtype=float32), array([[-0.24489994, -0.12858354],
# [-0.49041364, -0.25748932],
# [-0.44331917, -0.23276259]], dtype=float32)]
# The merged prediction shape: (6, 2)
# [[-0.37789783 -0.19921964]
# [-0.3577645 -0.18863106]
# [-0.24274671 -0.12814042]
# [-0.24635398 -0.13003758]
# [-0.49232286 -0.25939852]
# [-0.44514108 -0.2345845 ]]
""" """
try: try:
return self._run_impl( return self._run_impl(
...@@ -775,7 +847,8 @@ class Executor(object): ...@@ -775,7 +847,8 @@ class Executor(object):
fetch_var_name=fetch_var_name, fetch_var_name=fetch_var_name,
scope=scope, scope=scope,
return_numpy=return_numpy, return_numpy=return_numpy,
use_program_cache=use_program_cache) use_program_cache=use_program_cache,
return_merged=return_merged)
except Exception as e: except Exception as e:
if not isinstance(e, core.EOFException): if not isinstance(e, core.EOFException):
warnings.warn( warnings.warn(
...@@ -783,7 +856,8 @@ class Executor(object): ...@@ -783,7 +856,8 @@ class Executor(object):
six.reraise(*sys.exc_info()) six.reraise(*sys.exc_info())
def _run_impl(self, program, feed, fetch_list, feed_var_name, def _run_impl(self, program, feed, fetch_list, feed_var_name,
fetch_var_name, scope, return_numpy, use_program_cache): fetch_var_name, scope, return_numpy, use_program_cache,
return_merged):
if self._closed: if self._closed:
raise RuntimeError("Attempted to use a closed Executor") raise RuntimeError("Attempted to use a closed Executor")
...@@ -840,7 +914,8 @@ class Executor(object): ...@@ -840,7 +914,8 @@ class Executor(object):
feed=feed, feed=feed,
fetch_list=fetch_list, fetch_list=fetch_list,
fetch_var_name=fetch_var_name, fetch_var_name=fetch_var_name,
return_numpy=return_numpy) return_numpy=return_numpy,
return_merged=return_merged)
def _run_program(self, program, feed, fetch_list, feed_var_name, def _run_program(self, program, feed, fetch_list, feed_var_name,
fetch_var_name, scope, return_numpy, use_program_cache): fetch_var_name, scope, return_numpy, use_program_cache):
......
...@@ -362,5 +362,5 @@ set_tests_properties(test_parallel_executor_test_while_train test_parallel_execu ...@@ -362,5 +362,5 @@ set_tests_properties(test_parallel_executor_test_while_train test_parallel_execu
test_parallel_executor_feed_persistable_var test_parallel_executor_feed_persistable_var
test_parallel_executor_crf_auto_growth test_buffer_shared_memory_reuse_pass_and_fuse_optimization_op_pass test_parallel_executor_crf_auto_growth test_buffer_shared_memory_reuse_pass_and_fuse_optimization_op_pass
test_data_norm_op test_imperative_using_non_zero_gpu test_fuse_bn_act_pass test_data_norm_op test_imperative_using_non_zero_gpu test_fuse_bn_act_pass
test_optimizer_in_control_flow test_optimizer_in_control_flow test_fetch_unmerged
test_buffer_shared_memory_reuse_pass PROPERTIES LABELS "RUN_TYPE=DIST") test_buffer_shared_memory_reuse_pass PROPERTIES LABELS "RUN_TYPE=DIST")
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
import random
import numpy as np
import paddle.fluid as fluid
import six
import paddle
os.environ["CPU_NUM"] = "2"
class TestFetchUnmerged(unittest.TestCase):
def conv_net(self, img, label):
conv_pool_1 = fluid.nets.simple_img_conv_pool(
input=img,
filter_size=5,
num_filters=20,
pool_size=2,
pool_stride=2,
pool_type='max',
act="relu")
conv_pool_1 = fluid.layers.batch_norm(conv_pool_1)
conv_pool_2 = fluid.nets.simple_img_conv_pool(
input=conv_pool_1,
filter_size=5,
num_filters=50,
pool_size=2,
pool_stride=2,
pool_type='avg',
act="relu")
hidden = fluid.layers.fc(input=conv_pool_2, size=100, act='relu')
prediction = fluid.layers.fc(input=hidden, size=10, act='softmax')
loss = fluid.layers.cross_entropy(input=prediction, label=label)
avg_loss = fluid.layers.mean(loss)
return avg_loss, prediction
def build_program(self, main, startup, is_test):
with fluid.unique_name.guard():
with fluid.program_guard(main, startup):
img = fluid.layers.data(
name='image', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(
name='label', shape=[1], dtype='int64')
loss, prediction = self.conv_net(img, label)
if not is_test:
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
return [img, label], loss, prediction
def fetch_unmerged(self, use_cuda=True):
main_program = fluid.Program()
startup_program = fluid.Program()
feeds, loss, prediction = self.build_program(main_program,
startup_program, False)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_program)
build_strategy = fluid.BuildStrategy()
binary = fluid.CompiledProgram(main_program).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy)
iters = 3
batch_size = 64
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=500),
batch_size=batch_size)
feeder = fluid.DataFeeder(feed_list=feeds, place=place)
device_num = fluid.core.get_cuda_device_count() if use_cuda else 2
for _ in range(iters):
data = next(train_reader())
loss_v, prediction_v = exe.run(binary,
feed=feeder.feed(data),
fetch_list=[loss, prediction],
return_merged=False)
self.assertEqual(np.array(loss_v).shape, (device_num, 1))
self.assertEqual(
np.array(prediction_v).shape,
(device_num, batch_size / device_num, 10))
for _ in range(iters):
data = next(train_reader())
loss_v, prediction_v = exe.run(binary,
feed=feeder.feed(data),
fetch_list=[loss, prediction],
return_merged=True)
self.assertEqual(np.array(loss_v).shape, (device_num, ))
self.assertEqual(np.array(prediction_v).shape, (batch_size, 10))
def test_fetch_unmerged(self):
if fluid.core.is_compiled_with_cuda():
self.fetch_unmerged(use_cuda=True)
self.fetch_unmerged(use_cuda=False)
def test_fetch_unmerged_parallel_graph(self):
fluid.core.globals()['FLAGS_enable_parallel_graph'] = True
if fluid.core.is_compiled_with_cuda():
self.fetch_unmerged(use_cuda=True)
self.fetch_unmerged(use_cuda=False)
fluid.core.globals()['FLAGS_enable_parallel_graph'] = False
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册