未验证 提交 acfc9b8a 编写于 作者: Z Zeng Jinle 提交者: GitHub

Reader sequential and inference partial feed (#22699)

* sequential reader stage 1, test=develop

* fix ut, test=develop

* fix iterable=False reset bug, add some logs and polish code, test=develop

* inference feed partial data, test=develop

* Turn on keep_order=True for test, test=develop

* enhance ut to test more cases, test=develop

* test commit for reverting

* Revert "test commit for reverting", test=develop

This reverts commit 80aef42e.

* add ut of merged and unmerged results, test=develop

* add more uts for coverages and add en doc of api, test=develop

* follow comments, test=develop

* change note style, test=develop
上级 95b356a0
...@@ -9,7 +9,7 @@ cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_pr ...@@ -9,7 +9,7 @@ cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_pr
cc_library(share_tensor_buffer_op_handle SRCS share_tensor_buffer_op_handle.cc DEPS op_handle_base scope computation_op_handle share_tensor_buffer_functor) cc_library(share_tensor_buffer_op_handle SRCS share_tensor_buffer_op_handle.cc DEPS op_handle_base scope computation_op_handle share_tensor_buffer_functor)
cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(fetch_barrier_op_handle SRCS fetch_barrier_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(fetch_barrier_op_handle SRCS fetch_barrier_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(multi_devices_helper INTERFACE SRCS multi_devices_helper.cc DEPS graph graph_helper) cc_library(multi_devices_helper SRCS multi_devices_helper.cc DEPS graph graph_helper)
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows) cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)
...@@ -65,13 +65,15 @@ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope d ...@@ -65,13 +65,15 @@ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope d
cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper) cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper)
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto
multi_devices_helper
sequential_execution_pass sequential_execution_pass
modify_op_lock_and_record_event_pass modify_op_lock_and_record_event_pass
all_reduce_deps_pass all_reduce_deps_pass
reference_count_pass reference_count_pass
eager_deletion_pass eager_deletion_pass
buffer_shared_inplace_op_pass buffer_shared_inplace_op_pass
buffer_shared_cross_op_memory_reuse_pass) buffer_shared_cross_op_memory_reuse_pass
set_reader_device_info_pass)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS}) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS})
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
......
...@@ -66,6 +66,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -66,6 +66,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
AppendPrintGraphPass("graph_viz_pass", "_fused_graph"); AppendPrintGraphPass("graph_viz_pass", "_fused_graph");
AppendMultiDevPass(); AppendMultiDevPass();
AppendSetReaderDeviceIndexPass();
AppendMultiGraphOptPasses(); AppendMultiGraphOptPasses();
AppendPassToSetMkldnnAttr("mkldnn_placement_pass"); AppendPassToSetMkldnnAttr("mkldnn_placement_pass");
...@@ -227,6 +228,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -227,6 +228,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
&strategy_); &strategy_);
} }
void AppendSetReaderDeviceIndexPass() {
AppendPass("set_reader_device_index_pass");
}
void AppendPrintGraphPass(const std::string &pass_name, void AppendPrintGraphPass(const std::string &pass_name,
const std::string &debug_file_suffix) { const std::string &debug_file_suffix) {
if (!strategy_.debug_graphviz_path_.empty()) { if (!strategy_.debug_graphviz_path_.empty()) {
...@@ -397,6 +402,9 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, ...@@ -397,6 +402,9 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
"GPU, skipped."; "GPU, skipped.";
continue; continue;
} }
} else if (pass->Type() == "set_reader_device_index_pass") {
pass->Erase(kPlaces);
pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
} }
VLOG(1) << "Start Apply Pass " << pass->Type(); VLOG(1) << "Start Apply Pass " << pass->Type();
graph = pass->Apply(graph); graph = pass->Apply(graph);
...@@ -433,6 +441,7 @@ USE_PASS(fuse_sgd_op_pass); ...@@ -433,6 +441,7 @@ USE_PASS(fuse_sgd_op_pass);
USE_PASS(fuse_momentum_op_pass); USE_PASS(fuse_momentum_op_pass);
USE_PASS(fuse_all_reduce_op_pass); USE_PASS(fuse_all_reduce_op_pass);
USE_PASS(runtime_context_cache_pass); USE_PASS(runtime_context_cache_pass);
USE_PASS(set_reader_device_index_pass);
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
USE_PASS(mkldnn_placement_pass); USE_PASS(mkldnn_placement_pass);
#endif #endif
......
...@@ -34,6 +34,8 @@ class ComputationOpHandle : public OpHandleBase { ...@@ -34,6 +34,8 @@ class ComputationOpHandle : public OpHandleBase {
OperatorBase *GetOp() { return op_.get(); } OperatorBase *GetOp() { return op_.get(); }
const OperatorBase *GetOp() const { return op_.get(); }
std::string Name() const override; std::string Name() const override;
const Scope *GetScope() const { return scope_; } const Scope *GetScope() const { return scope_; }
......
...@@ -31,10 +31,12 @@ namespace framework { ...@@ -31,10 +31,12 @@ namespace framework {
namespace details { namespace details {
EagerDeletionOpHandle::EagerDeletionOpHandle( EagerDeletionOpHandle::EagerDeletionOpHandle(
ir::Node *node, Scope *scope, const platform::Place &place, ir::Node *node, Scope *scope, size_t scope_idx,
const platform::Place &place,
const std::unordered_set<ir::MemOptVarInfo *> &vars, GarbageCollector *gc) const std::unordered_set<ir::MemOptVarInfo *> &vars, GarbageCollector *gc)
: OpHandleBase(node), : OpHandleBase(node),
scope_(scope), scope_(scope),
scope_idx_(scope_idx),
place_(place), place_(place),
var_infos_(vars.begin(), vars.end()), var_infos_(vars.begin(), vars.end()),
gc_(gc) { gc_(gc) {
......
...@@ -34,7 +34,7 @@ namespace details { ...@@ -34,7 +34,7 @@ namespace details {
class EagerDeletionOpHandle : public OpHandleBase { class EagerDeletionOpHandle : public OpHandleBase {
public: public:
EagerDeletionOpHandle(ir::Node *node, Scope *scope, EagerDeletionOpHandle(ir::Node *node, Scope *scope, size_t scope_idx,
const platform::Place &place, const platform::Place &place,
const std::unordered_set<ir::MemOptVarInfo *> &vars, const std::unordered_set<ir::MemOptVarInfo *> &vars,
GarbageCollector *gc); GarbageCollector *gc);
...@@ -50,6 +50,8 @@ class EagerDeletionOpHandle : public OpHandleBase { ...@@ -50,6 +50,8 @@ class EagerDeletionOpHandle : public OpHandleBase {
*/ */
Priority GetPriority() const override { return kHighest; } Priority GetPriority() const override { return kHighest; }
size_t GetScopeIdx() const { return scope_idx_; }
protected: protected:
void RunImpl() override; void RunImpl() override;
...@@ -63,6 +65,7 @@ class EagerDeletionOpHandle : public OpHandleBase { ...@@ -63,6 +65,7 @@ class EagerDeletionOpHandle : public OpHandleBase {
void CallOnce(); void CallOnce();
Scope *scope_; Scope *scope_;
size_t scope_idx_;
platform::Place place_; platform::Place place_;
std::vector<ir::MemOptVarInfo *> var_infos_; // not own std::vector<ir::MemOptVarInfo *> var_infos_; // not own
GarbageCollector *gc_; // not own GarbageCollector *gc_; // not own
......
...@@ -12,9 +12,227 @@ ...@@ -12,9 +12,227 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
#include <algorithm>
#include <unordered_set>
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/details/share_tensor_buffer_op_handle.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details {} // namespace details namespace details {
static constexpr size_t kUndefinedDevIdx = -1UL;
// NOTE(paddle-dev): the following ops are related to multi-device
// communication. If the graph contains any of the following ops,
// it cannot separate into multiple graphs on each device.
static std::unordered_set<std::string> kMultiDeviceOps{
"sync_batch_norm",
"sync_batch_norm_grad",
"allreduce",
"c_allreduce_sum",
"c_allreduce_prod",
"c_allreduce_min",
"c_allreduce_max",
"c_allgather",
"c_reducescatter",
"c_broadcast",
"c_comm_init",
"c_comm_init_all",
"c_gen_nccl_id",
"c_sync_comm_stream",
"send",
"recv",
"send_barrier",
"fetch_barrier",
};
static size_t GetScopeIdxFromOp(const details::OpHandleBase &op) {
if (auto *compute_op =
dynamic_cast<const details::ComputationOpHandle *>(&op)) {
return kMultiDeviceOps.count(compute_op->GetOp()->Type()) == 0
? compute_op->GetScopeIdx()
: kUndefinedDevIdx;
} else if (auto *gc_op =
dynamic_cast<const details::EagerDeletionOpHandle *>(&op)) {
return gc_op->GetScopeIdx();
} else if (auto *share_op =
dynamic_cast<const details::ShareTensorBufferOpHandle *>(
&op)) {
return share_op->GetScopeIdx();
} else {
return kUndefinedDevIdx;
}
}
static bool ContainMultiDeviceOp(const ProgramDesc &program,
size_t begin_block_idx) {
for (size_t block_idx = begin_block_idx; block_idx < program.Size();
++block_idx) {
for (auto *op_desc : program.Block(block_idx).AllOps()) {
if (kMultiDeviceOps.count(op_desc->Type()) > 0) {
return true;
}
}
}
return false;
}
static size_t GetUniqueDeviceIdOfOp(const details::OpHandleBase &op) {
size_t dev_idx = GetScopeIdxFromOp(op);
if (dev_idx == kUndefinedDevIdx) {
return kUndefinedDevIdx;
}
const auto &ins = op.Inputs();
const auto &outs = op.Outputs();
auto in_outs = ins;
in_outs.insert(in_outs.end(), outs.begin(), outs.end());
for (auto *var : in_outs) {
auto *var_handle = dynamic_cast<details::VarHandle *>(var);
if (var_handle == nullptr) {
continue;
}
if (dev_idx != var_handle->scope_idx()) {
return kUndefinedDevIdx;
}
}
return dev_idx;
}
/**
* This function tries to separate the original graph into multiple graphs, in
* which each graph would only run on single device. This is usually used to
* separate a data-parallel inference graph to multiple graphs on each device.
*
* The graph can be separated into multiple single device graphs if and only if:
*
* - the graph does not contain any ops related to multi-devices communication,
* such as allreduce, send, recv, sync_batch_norm, etc.
*
* - ops on different devices do not depend on each other. That is to say, the
* graph has several disconnected sub-graphs.
*/
std::vector<std::unique_ptr<ir::Graph>> TrySeparateToMultipleSingleDeviceGraphs(
ir::Graph *graph) {
// If sub-block contains multi-devices ops, we cannot separate
if (ContainMultiDeviceOp(graph->OriginProgram(), 1)) {
return {};
}
size_t place_num = 0;
auto op_handles = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
if (op_handles.empty()) {
return {};
}
std::unordered_map<details::OpHandleBase *, size_t> op_to_dev_idx;
for (auto &op : op_handles) {
auto dev_idx = GetUniqueDeviceIdOfOp(*op);
if (dev_idx == kUndefinedDevIdx) {
VLOG(10) << "Op " << op->Name() << " is not determined";
return {};
}
place_num = std::max(place_num, dev_idx + 1);
op_to_dev_idx[op] = dev_idx;
}
for (auto &op : op_handles) {
auto dev_idx = op_to_dev_idx.at(op);
for (auto &in_var : op->Inputs()) {
if (in_var->GeneratedOp()) {
auto iter = op_to_dev_idx.find(in_var->GeneratedOp());
if (iter == op_to_dev_idx.end() || iter->second != dev_idx) {
return {};
}
}
}
for (auto &out_var : op->Outputs()) {
for (auto &pending_op : out_var->PendingOps()) {
auto iter = op_to_dev_idx.find(pending_op);
if (iter == op_to_dev_idx.end() || iter->second != dev_idx) {
return {};
}
}
}
}
PADDLE_ENFORCE_GE(
place_num, 1,
platform::errors::NotFound(
"No place found, this may be a bug.\nIt would be helpful if you "
"could inform us of how this conversion went by opening a github "
"issue at https://github.com/PaddlePaddle/Paddle/issues/new. And "
"we will resolve it with high priority."));
std::vector<std::unique_ptr<ir::Graph>> graphs(place_num);
for (auto &g : graphs) {
g.reset(new ir::Graph(ProgramDesc()));
g->Set(kGraphVars, new GraphVars(1UL));
g->Set(kGraphDepVars, new GraphDepVars());
}
for (auto &op : op_handles) {
auto dev_idx = op_to_dev_idx.at(op);
auto *ret_graph = graphs[dev_idx].get();
auto &ret_vars = ret_graph->Get<GraphVars>(kGraphVars)[0];
auto &ret_dummy_vars = ret_graph->Get<GraphDepVars>(kGraphDepVars);
auto &origin_vars = graph->Get<GraphVars>(kGraphVars)[dev_idx];
ret_graph->AddNode(graph->RemoveNode(op->Node()).release());
auto handler = [&](const std::vector<VarHandleBase *> &vars) {
for (auto *var : vars) {
if (graph->Nodes().count(var->Node()) > 0) {
ret_graph->AddNode(graph->RemoveNode(var->Node()).release());
auto *dummy_var = dynamic_cast<DummyVarHandle *>(var);
if (dummy_var == nullptr) {
ret_vars.emplace(var->Name(), origin_vars.at(var->Name()));
} else {
ret_dummy_vars.emplace(dummy_var);
}
}
}
};
handler(op->Inputs());
handler(op->Outputs());
}
graph->Erase(kGraphVars);
graph->Erase(kGraphDepVars);
return graphs;
}
static bool HasDropLastReadOpImpl(const ir::Graph &graph, bool drop_last) {
auto ops = ir::FilterByNodeWrapper<OpHandleBase>(graph);
for (auto *op : ops) {
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
if (compute_op && compute_op->GetOp()->Type() == "read" &&
compute_op->GetOp()->Attr<bool>("drop_last") == drop_last) {
VLOG(10) << "The graph has drop_last=" << drop_last << " read op";
return true;
}
}
VLOG(10) << "The graph does not have drop_last=" << drop_last << " read op";
return false;
}
bool HasDropLastReadOp(const ir::Graph &graph) {
return HasDropLastReadOpImpl(graph, true);
}
bool HasKeepLastReadOp(const ir::Graph &graph) {
return HasDropLastReadOpImpl(graph, false);
}
} // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -47,6 +47,7 @@ constexpr char kGraphVars[] = "vars"; ...@@ -47,6 +47,7 @@ constexpr char kGraphVars[] = "vars";
constexpr char kNRanks[] = "nranks"; constexpr char kNRanks[] = "nranks";
constexpr char kPlaces[] = "places"; constexpr char kPlaces[] = "places";
constexpr char kGlobalScope[] = "global_scope";
constexpr char kLocalScopes[] = "local_scopes"; constexpr char kLocalScopes[] = "local_scopes";
constexpr char kNCCLCtxs[] = "nccl_ctxs"; constexpr char kNCCLCtxs[] = "nccl_ctxs";
constexpr char kUseHierarchicalAllReduce[] = "use_hierarchical_allreduce"; constexpr char kUseHierarchicalAllReduce[] = "use_hierarchical_allreduce";
...@@ -100,6 +101,13 @@ inline std::vector<std::string> GetOpRoleVarsOrEmpty(const OpDesc &op) { ...@@ -100,6 +101,13 @@ inline std::vector<std::string> GetOpRoleVarsOrEmpty(const OpDesc &op) {
return boost::get<std::vector<std::string>>(iter->second); return boost::get<std::vector<std::string>>(iter->second);
} }
std::vector<std::unique_ptr<ir::Graph>> TrySeparateToMultipleSingleDeviceGraphs(
ir::Graph *graph);
bool HasDropLastReadOp(const ir::Graph &graph);
bool HasKeepLastReadOp(const ir::Graph &graph);
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h" #include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h"
#include <algorithm>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
...@@ -21,11 +22,11 @@ namespace paddle { ...@@ -21,11 +22,11 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
std::vector<std::unique_ptr<ir::Graph>> static std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph(
ParallelSSAGraphExecutor::SeparateMultiDevicesGraph(ir::Graph *graph) { ir::Graph *graph, size_t place_num) {
std::vector<std::unique_ptr<ir::Graph>> graphs; std::vector<std::unique_ptr<ir::Graph>> graphs;
graphs.reserve(places_.size()); graphs.reserve(place_num);
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < place_num; ++i) {
ProgramDesc empty; ProgramDesc empty;
graphs.emplace_back(std::unique_ptr<ir::Graph>(new ir::Graph(empty))); graphs.emplace_back(std::unique_ptr<ir::Graph>(new ir::Graph(empty)));
auto &g = graphs.back(); auto &g = graphs.back();
...@@ -64,7 +65,7 @@ ParallelSSAGraphExecutor::SeparateMultiDevicesGraph(ir::Graph *graph) { ...@@ -64,7 +65,7 @@ ParallelSSAGraphExecutor::SeparateMultiDevicesGraph(ir::Graph *graph) {
} }
} }
for (size_t dev_id = 0; dev_id < places_.size(); ++dev_id) { for (size_t dev_id = 0; dev_id < place_num; ++dev_id) {
auto &dev_vars = graphs[dev_id]->Get<GraphVars>(kGraphVars)[0]; auto &dev_vars = graphs[dev_id]->Get<GraphVars>(kGraphVars)[0];
auto &origin_vars = graph->Get<GraphVars>(kGraphVars)[dev_id]; auto &origin_vars = graph->Get<GraphVars>(kGraphVars)[dev_id];
for (auto &name_pair : origin_vars) { for (auto &name_pair : origin_vars) {
...@@ -85,15 +86,34 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor( ...@@ -85,15 +86,34 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes, const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<Scope *> &local_exec_scopes, const std::vector<Scope *> &local_exec_scopes,
const std::vector<platform::Place> &places, ir::Graph *graph) const std::vector<platform::Place> &places, ir::Graph *graph)
// TODO(Yancey1989): Copying graphs is not safely since it deleted the
// attrs.
: ParallelSSAGraphExecutor(strategy, local_scopes, local_exec_scopes,
places,
SeparateMultiDevicesGraph(graph,
places.size())) {}
ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<Scope *> &local_exec_scopes,
const std::vector<platform::Place> &places,
std::vector<std::unique_ptr<ir::Graph>> graphs)
: strategy_(std::move(strategy)), : strategy_(std::move(strategy)),
local_scopes_(std::move(local_scopes)), local_scopes_(std::move(local_scopes)),
pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr), pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr),
places_(std::move(places)), places_(places),
// TODO(Yancey1989): Copying graphs is not safely since it deleted the graphs_(std::move(graphs)),
// attrs. feed_status_(places.size(), FeedStatus::kNone) {
graphs_(SeparateMultiDevicesGraph(graph)) {
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size()); PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
PADDLE_ENFORCE_EQ(places_.size(), graphs_.size(),
platform::errors::InvalidArgument(
"Graph number does not match place number"));
PADDLE_ENFORCE_GT(
places_.size(), 0,
platform::errors::InvalidArgument("place number must be larger than 0"));
auto seq_allreduce_pass = auto seq_allreduce_pass =
ir::PassRegistry::Instance().Get("all_reduce_deps_pass"); ir::PassRegistry::Instance().Get("all_reduce_deps_pass");
seq_allreduce_pass->Set<bool>(kUseHierarchicalAllReduce, new bool(false)); seq_allreduce_pass->Set<bool>(kUseHierarchicalAllReduce, new bool(false));
...@@ -123,28 +143,41 @@ std::vector<ir::Graph *> ParallelSSAGraphExecutor::Graphs() { ...@@ -123,28 +143,41 @@ std::vector<ir::Graph *> ParallelSSAGraphExecutor::Graphs() {
return result; return result;
} }
enum ExceptionStatus { kSuccess = 0, kEOF, kOther };
FetchResultType ParallelSSAGraphExecutor::Run( FetchResultType ParallelSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors, bool return_merged) { const std::vector<std::string> &fetch_tensors, bool return_merged) {
size_t feed_num = std::count(feed_status_.begin(), feed_status_.end(),
FeedStatus::kHasFeed);
bool has_feed = (feed_num > 0);
VLOG(10) << "Feed num " << feed_num;
size_t place_num = places_.size();
std::vector<std::future<FetchResultType>> run_futures; std::vector<std::future<FetchResultType>> run_futures;
std::vector<ExceptionStatus> exception_status(place_num,
ExceptionStatus::kSuccess);
std::vector<FetchResultType> fetch_data; std::vector<FetchResultType> fetch_data;
FetchResultType ret; fetch_data.reserve(place_num);
fetch_data.reserve(places_.size());
if (return_merged) {
ret = FeedFetchList();
} else {
ret = FetchUnmergedList();
}
exception_holder_.Clear(); exception_holder_.Clear();
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < place_num; ++i) {
auto call = [this, i, return_merged, &fetch_tensors]() -> FetchResultType { auto call = [&, i]() -> FetchResultType {
try { try {
return executors_[i]->Run(fetch_tensors, return_merged); if (!support_partial_feed_ || !has_feed ||
feed_status_[i] == FeedStatus::kHasFeed) {
return executors_[i]->Run(fetch_tensors, return_merged);
}
} catch (platform::EOFException &) {
exception_status[i] = ExceptionStatus::kEOF;
exception_holder_.Catch(std::current_exception());
} catch (...) { } catch (...) {
exception_status[i] = ExceptionStatus::kOther;
exception_holder_.Catch(std::current_exception()); exception_holder_.Catch(std::current_exception());
} }
if (return_merged) { if (return_merged) {
return FeedFetchList(); return FeedFetchList();
} else { } else {
...@@ -161,46 +194,96 @@ FetchResultType ParallelSSAGraphExecutor::Run( ...@@ -161,46 +194,96 @@ FetchResultType ParallelSSAGraphExecutor::Run(
if (pool_) { if (pool_) {
for (auto &f : run_futures) { for (auto &f : run_futures) {
if (exception_holder_.IsCaught()) { fetch_data.emplace_back(f.get());
f.wait(); }
} else { }
fetch_data.emplace_back(f.get());
bool has_exception = exception_holder_.IsCaught();
if (!support_partial_feed_ && has_exception) {
VLOG(10) << "Exception rethrow because partial feed is not supported";
exception_holder_.ReThrow();
}
std::vector<bool> is_valid(place_num, true);
if (support_partial_feed_) {
if (has_feed) {
for (size_t i = 0; i < place_num; ++i) {
if (feed_status_[i] == FeedStatus::kNone) {
is_valid[i] = false;
} else if (exception_status[i] != ExceptionStatus::kSuccess) {
PADDLE_ENFORCE_EQ(has_exception, true,
platform::errors::InvalidArgument(
"Thread pool raises exception but not caught"));
VLOG(10) << "Exception rethrow because non-EOF exception raises when "
"feed is given";
exception_holder_.ReThrow();
}
}
} else {
for (size_t i = 0; i < place_num; ++i) {
if (exception_status[i] == ExceptionStatus::kOther) {
PADDLE_ENFORCE_EQ(has_exception, true,
platform::errors::InvalidArgument(
"Thread pool raises exception but not caught"));
VLOG(10) << "Exception rethrow because non-EOF exception raises when "
"feed is not given";
exception_holder_.ReThrow();
} else if (exception_status[i] != ExceptionStatus::kSuccess) {
is_valid[i] = false;
}
} }
} }
} }
if (exception_holder_.IsCaught()) {
if (std::count(is_valid.begin(), is_valid.end(), true) == 0) {
PADDLE_ENFORCE_EQ(has_exception, true,
platform::errors::InvalidArgument(
"Thread pool raises exception but not caught"));
VLOG(10) << "Raise exception because there is no success worker";
exception_holder_.ReThrow(); exception_holder_.ReThrow();
} }
if (return_merged) { if (return_merged) {
auto &ret_val = boost::get<FeedFetchList>(ret); FeedFetchList ret;
ret.reserve(fetch_tensors.size());
for (size_t fetch_idx = 0; fetch_idx < fetch_tensors.size(); ++fetch_idx) { for (size_t fetch_idx = 0; fetch_idx < fetch_tensors.size(); ++fetch_idx) {
std::vector<const LoDTensor *> lodtensor_ptrs; std::vector<const LoDTensor *> lodtensor_ptrs;
lodtensor_ptrs.reserve(local_scopes_.size()); lodtensor_ptrs.reserve(place_num);
for (size_t scope_idx = 0; scope_idx < local_scopes_.size(); for (size_t scope_idx = 0; scope_idx < place_num; ++scope_idx) {
++scope_idx) { if (!is_valid[scope_idx]) {
auto &val = boost::get<FeedFetchList>(fetch_data.at(scope_idx)); continue;
lodtensor_ptrs.push_back(&val.at(fetch_idx)); }
const auto &fetch_list =
boost::get<FeedFetchList>(fetch_data[scope_idx]);
lodtensor_ptrs.push_back(&fetch_list[fetch_idx]);
} }
ret_val.emplace_back(); ret.emplace_back();
ret_val.back().MergeLoDTensor(lodtensor_ptrs, platform::CPUPlace()); ret.back().MergeLoDTensor(lodtensor_ptrs, platform::CPUPlace());
} }
return ret;
} else { } else {
auto &ret_val = boost::get<FetchUnmergedList>(ret); FetchUnmergedList ret;
ret.reserve(fetch_tensors.size());
for (size_t fetch_idx = 0; fetch_idx < fetch_tensors.size(); ++fetch_idx) { for (size_t fetch_idx = 0; fetch_idx < fetch_tensors.size(); ++fetch_idx) {
ret_val.emplace_back(); ret.emplace_back();
for (size_t scope_idx = 0; scope_idx < local_scopes_.size(); for (size_t scope_idx = 0; scope_idx < local_scopes_.size();
++scope_idx) { ++scope_idx) {
auto &val = boost::get<FetchUnmergedList>(fetch_data.at(scope_idx)); if (!is_valid[scope_idx]) {
continue;
}
const auto &fetch_list =
boost::get<FetchUnmergedList>(fetch_data[scope_idx]);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
val.at(fetch_idx).size(), 1, fetch_list[fetch_idx].size(), 1,
platform::errors::Fatal( platform::errors::Fatal(
"Each place must have only one fetched LoDTensor!")); "Each place must have only one fetched LoDTensor!"));
ret_val.back().emplace_back(val.at(fetch_idx)[0]); ret.back().emplace_back(fetch_list[fetch_idx][0]);
} }
} }
return ret;
} }
return ret;
} }
} // namespace details } // namespace details
......
...@@ -27,12 +27,25 @@ namespace framework { ...@@ -27,12 +27,25 @@ namespace framework {
namespace details { namespace details {
class ParallelSSAGraphExecutor : public SSAGraphExecutor { class ParallelSSAGraphExecutor : public SSAGraphExecutor {
public:
enum FeedStatus {
kNone = 0, // No feed
kHasFeed = 1 // Has feed
};
public: public:
ParallelSSAGraphExecutor(const ExecutionStrategy &strategy, ParallelSSAGraphExecutor(const ExecutionStrategy &strategy,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const std::vector<Scope *> &local_exec_scopes, const std::vector<Scope *> &local_exec_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
ir::Graph *graph); ir::Graph *graph);
ParallelSSAGraphExecutor(const ExecutionStrategy &strategy,
const std::vector<Scope *> &local_scopes,
const std::vector<Scope *> &local_exec_scopes,
const std::vector<platform::Place> &places,
std::vector<std::unique_ptr<ir::Graph>> graphs);
~ParallelSSAGraphExecutor() final = default; ~ParallelSSAGraphExecutor() final = default;
const ir::Graph &Graph() const override { return *graphs_[0]; } const ir::Graph &Graph() const override { return *graphs_[0]; }
...@@ -42,10 +55,15 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor { ...@@ -42,10 +55,15 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor {
FetchResultType Run(const std::vector<std::string> &fetch_tensors, FetchResultType Run(const std::vector<std::string> &fetch_tensors,
bool return_merged) override; bool return_merged) override;
private: void SetHasFeed(size_t dev_idx, bool has_feed) {
std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph( feed_status_[dev_idx] = has_feed ? FeedStatus::kHasFeed : FeedStatus::kNone;
ir::Graph *graph); }
void EnablePartialFeedSupport() { support_partial_feed_ = true; }
bool SupportPartialFeed() const { return support_partial_feed_; }
private:
ExecutionStrategy strategy_; ExecutionStrategy strategy_;
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
std::unique_ptr<::ThreadPool> pool_{nullptr}; std::unique_ptr<::ThreadPool> pool_{nullptr};
...@@ -55,6 +73,9 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor { ...@@ -55,6 +73,9 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor {
std::vector<std::unique_ptr<details::FastThreadedSSAGraphExecutor>> std::vector<std::unique_ptr<details::FastThreadedSSAGraphExecutor>>
executors_; executors_;
ExceptionHolder exception_holder_; ExceptionHolder exception_holder_;
bool support_partial_feed_{false};
std::vector<FeedStatus> feed_status_;
}; };
} // namespace details } // namespace details
......
...@@ -228,7 +228,7 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const { ...@@ -228,7 +228,7 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const {
} }
auto *eager_deletion_op = new details::EagerDeletionOpHandle( auto *eager_deletion_op = new details::EagerDeletionOpHandle(
eager_deletion_node, op->GetScope(), op->GetPlace(), eager_deletion_node, op->GetScope(), op->GetScopeIdx(), op->GetPlace(),
std::move(var_info), gcs.at(places[op->GetScopeIdx()]).get()); std::move(var_info), gcs.at(places[op->GetScopeIdx()]).get());
auto it = std::find_if( auto it = std::find_if(
......
...@@ -98,7 +98,7 @@ class ReferenceCountPassTestHelper { ...@@ -98,7 +98,7 @@ class ReferenceCountPassTestHelper {
ir::PassRegistry::Instance().Get("reference_count_pass"); ir::PassRegistry::Instance().Get("reference_count_pass");
ref_cnt_pass->SetNotOwned(ir::kMemOptVarInfoMapList, &mem_opt_var_infos_); ref_cnt_pass->SetNotOwned(ir::kMemOptVarInfoMapList, &mem_opt_var_infos_);
ref_cnt_pass->SetNotOwned(ir::kLastLiveOpsOfVars, &last_live_ops_of_vars_); ref_cnt_pass->SetNotOwned(ir::kLastLiveOpsOfVars, &last_live_ops_of_vars_);
ref_cnt_pass->Apply(&graph_); ref_cnt_pass->Apply(&const_cast<ir::Graph &>(executor_->Graph()));
} }
bool IsLastLivedOps(const std::string &name, bool IsLastLivedOps(const std::string &name,
......
...@@ -11,6 +11,7 @@ endif() ...@@ -11,6 +11,7 @@ endif()
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle fetch_barrier_op_handle ${ALL_REDUCE_OP_HANDLES} reduce_op_handle broadcast_op_handle fused_broadcast_op_handle) scale_loss_grad_op_handle rpc_op_handle fetch_barrier_op_handle ${ALL_REDUCE_OP_HANDLES} reduce_op_handle broadcast_op_handle fused_broadcast_op_handle)
cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass) cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass)
cc_library(set_reader_device_info_pass SRCS set_reader_device_info_pass.cc DEPS graph graph_helper pass multi_devices_graph_pass)
cc_library(fuse_all_reduce_op_pass SRCS fuse_all_reduce_op_pass.cc DEPS graph graph_helper fused_all_reduce_op_handle) cc_library(fuse_all_reduce_op_pass SRCS fuse_all_reduce_op_pass.cc DEPS graph graph_helper fused_all_reduce_op_handle)
cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS all_reduce_op_handle graph graph_helper pass) cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS all_reduce_op_handle graph graph_helper pass)
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h"
namespace paddle {
namespace framework {
namespace ir {
static int GetDeviceCountFromPassAttr(const Pass &pass) {
return static_cast<int>(
pass.Get<const std::vector<platform::Place>>(details::kPlaces).size());
}
static std::unordered_set<std::string> ReaderOpSet() {
return {"create_py_reader"};
}
class InitReaderDeviceCountPass : public Pass {
protected:
void ApplyImpl(Graph *graph) const override {
using QueueHolder =
operators::reader::OrderedMultiDeviceLoDTensorBlockingQueueHolder;
auto reader_ops = ReaderOpSet();
auto dev_cnt = GetDeviceCountFromPassAttr(*this);
const auto &scope = Get<const Scope>(details::kGlobalScope);
for (auto &node : graph->Nodes()) {
if (node->IsOp() && node->Op() &&
reader_ops.count(node->Op()->Type()) != 0) {
auto queue_name = node->Op()->Input("blocking_queue")[0];
auto var = scope.FindVar(queue_name);
if (var && var->IsType<QueueHolder>()) {
VLOG(10) << "Set device count of " << queue_name << " to be "
<< dev_cnt;
var->GetMutable<QueueHolder>()->GetQueue()->SetDeviceCount(dev_cnt);
}
}
}
}
};
class SetReaderDeviceIndexPass : public Pass {
protected:
void ApplyImpl(Graph *graph) const override {
auto dev_cnt = GetDeviceCountFromPassAttr(*this);
auto reader_ops = ReaderOpSet();
size_t found_op_num = 0;
for (auto &node : graph->Nodes()) {
if (node->IsOp() && node->Op() &&
reader_ops.count(node->Op()->Type()) != 0) {
auto &op_handle = dynamic_cast<details::ComputationOpHandle &>(
node->Wrapper<details::OpHandleBase>());
auto *op_desc = node->Op();
auto &op_base_attrs =
const_cast<framework::AttributeMap &>(op_handle.GetOp()->Attrs());
int dev_idx = static_cast<int>(op_handle.GetScopeIdx());
op_desc->SetAttr("device_index", dev_idx);
op_desc->SetAttr("device_count", dev_cnt);
op_base_attrs["device_index"] = dev_idx;
op_base_attrs["device_count"] = dev_cnt;
++found_op_num;
VLOG(10) << "Found op " << op_desc->Type() << " on device " << dev_idx;
}
}
VLOG(10) << "Found op number " << found_op_num;
}
};
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(init_reader_device_count_pass,
paddle::framework::ir::InitReaderDeviceCountPass)
.RequirePassAttr(paddle::framework::details::kGlobalScope)
.RequirePassAttr(paddle::framework::details::kPlaces);
REGISTER_PASS(set_reader_device_index_pass,
paddle::framework::ir::SetReaderDeviceIndexPass)
.RequirePassAttr(paddle::framework::details::kPlaces);
...@@ -307,18 +307,18 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor, ...@@ -307,18 +307,18 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor,
std::vector<LoDTensor> LoDTensor::SplitLoDTensor( std::vector<LoDTensor> LoDTensor::SplitLoDTensor(
const std::vector<platform::Place> places) const { const std::vector<platform::Place> places) const {
PADDLE_ENFORCE_GT(places.size(), 0,
platform::errors::InvalidArgument(
"place number cannot be empty when splitting"));
check_memory_size(); check_memory_size();
int batch_size = size_t batch_size =
lod().empty() ? dims()[0] : static_cast<int>(lod()[0].size()) - 1; lod().empty() ? static_cast<size_t>(dims()[0]) : lod()[0].size() - 1;
size_t result_size = std::min(static_cast<size_t>(batch_size), places.size());
size_t remainder = batch_size % places.size();
std::vector<LoDTensor> results; // if batch_size is 0, just return #places.size() copys of empty
results.reserve(result_size);
// if result_size(batch_size) is 0, just return #places.size() copys of empty
// tensors. // tensors.
if (result_size == 0) { if (batch_size == 0) {
std::vector<LoDTensor> empty_results;
empty_results.reserve(places.size());
for (size_t i = 0; i < places.size(); ++i) { for (size_t i = 0; i < places.size(); ++i) {
LoDTensor dst; LoDTensor dst;
dst.Resize(dims()); dst.Resize(dims());
...@@ -326,18 +326,22 @@ std::vector<LoDTensor> LoDTensor::SplitLoDTensor( ...@@ -326,18 +326,22 @@ std::vector<LoDTensor> LoDTensor::SplitLoDTensor(
if (!lod().empty()) { if (!lod().empty()) {
dst.set_lod(lod()); dst.set_lod(lod());
} }
results.emplace_back(dst); empty_results.emplace_back(std::move(dst));
} }
return results; return empty_results;
} }
int step_width = static_cast<int>(batch_size / result_size); auto step_width = (batch_size + places.size() - 1) / places.size();
auto result_size = (batch_size + step_width - 1) / step_width;
std::vector<LoDTensor> results;
results.reserve(result_size);
for (size_t i = 0; i < result_size; ++i) { for (size_t i = 0; i < result_size; ++i) {
int begin = static_cast<int>(i * step_width); auto begin = i * step_width;
int end = static_cast<int>((i + 1) * step_width); auto end = std::min<size_t>((i + 1) * step_width, batch_size);
if (i + 1 == places.size()) { // last PADDLE_ENFORCE_LT(begin, end,
end += remainder; platform::errors::InvalidArgument(
} "begin must be less than end, this may be a bug"));
LoDTensor dst; LoDTensor dst;
if (lod().empty()) { if (lod().empty()) {
...@@ -362,7 +366,7 @@ std::vector<LoDTensor> LoDTensor::SplitLoDTensor( ...@@ -362,7 +366,7 @@ std::vector<LoDTensor> LoDTensor::SplitLoDTensor(
} }
dst.set_lod(my_lod); dst.set_lod(my_lod);
} }
results.emplace_back(dst); results.emplace_back(std::move(dst));
} }
return results; return results;
......
...@@ -55,8 +55,9 @@ static bool gProfileStarted = false; ...@@ -55,8 +55,9 @@ static bool gProfileStarted = false;
class ParallelExecutorPrivate { class ParallelExecutorPrivate {
public: public:
explicit ParallelExecutorPrivate(const std::vector<platform::Place> &places) ParallelExecutorPrivate(const std::vector<platform::Place> &places,
: places_(places) { Scope *global_scope)
: places_(places), global_scope_(global_scope) {
if (!FLAGS_pe_profile_fname.empty()) { if (!FLAGS_pe_profile_fname.empty()) {
std::call_once(gProfileOnce, [] { std::call_once(gProfileOnce, [] {
#ifdef WITH_GPERFTOOLS #ifdef WITH_GPERFTOOLS
...@@ -82,6 +83,19 @@ class ParallelExecutorPrivate { ...@@ -82,6 +83,19 @@ class ParallelExecutorPrivate {
} }
} }
void InitReaderDeviceCount(ir::Graph *graph) const {
auto pass =
ir::PassRegistry::Instance().Get("init_reader_device_count_pass");
pass->SetNotOwned<const Scope>(details::kGlobalScope, global_scope_);
pass->SetNotOwned<const std::vector<platform::Place>>(details::kPlaces,
&places_);
pass->Apply(graph);
}
void SetHasFeed(size_t dev_idx, bool has_feed = true);
bool AllowPartialFeed() const;
ir::Graph *ApplyMemoryOptimizePass(ir::Graph *graph); ir::Graph *ApplyMemoryOptimizePass(ir::Graph *graph);
inline bool HasGarbageCollectors() const { return !gcs_.empty(); } inline bool HasGarbageCollectors() const { return !gcs_.empty(); }
...@@ -257,8 +271,20 @@ class ParallelExecutorPrivate { ...@@ -257,8 +271,20 @@ class ParallelExecutorPrivate {
ir::MemOptVarInfoMapList mem_opt_var_infos_; ir::MemOptVarInfoMapList mem_opt_var_infos_;
ir::GarbageCollectorMap gcs_; ir::GarbageCollectorMap gcs_;
details::ParallelSSAGraphExecutor *inference_executor_{nullptr};
}; };
void ParallelExecutorPrivate::SetHasFeed(size_t dev_idx, bool has_feed) {
if (inference_executor_) {
inference_executor_->SetHasFeed(dev_idx, has_feed);
}
}
bool ParallelExecutorPrivate::AllowPartialFeed() const {
return inference_executor_ && inference_executor_->SupportPartialFeed();
}
ir::Graph *ParallelExecutorPrivate::ApplyMemoryOptimizePass(ir::Graph *graph) { ir::Graph *ParallelExecutorPrivate::ApplyMemoryOptimizePass(ir::Graph *graph) {
if (FLAGS_use_ngraph) { if (FLAGS_use_ngraph) {
LOG_FIRST_N(WARNING, 1) LOG_FIRST_N(WARNING, 1)
...@@ -379,6 +405,21 @@ ir::Graph *ParallelExecutorPrivate::ApplyMemoryOptimizePass(ir::Graph *graph) { ...@@ -379,6 +405,21 @@ ir::Graph *ParallelExecutorPrivate::ApplyMemoryOptimizePass(ir::Graph *graph) {
return graph; return graph;
} }
class ResetHasFeedGuard {
public:
explicit ResetHasFeedGuard(ParallelExecutorPrivate *pe_member)
: pe_member_(pe_member) {}
~ResetHasFeedGuard() {
for (size_t i = 0; i < pe_member_->places_.size(); ++i) {
pe_member_->SetHasFeed(i, false);
}
}
private:
ParallelExecutorPrivate *pe_member_;
};
size_t ParallelExecutor::DeviceCount() const { return member_->places_.size(); } size_t ParallelExecutor::DeviceCount() const { return member_->places_.size(); }
std::vector<Scope *> &ParallelExecutor::GetLocalScopes() { std::vector<Scope *> &ParallelExecutor::GetLocalScopes() {
...@@ -407,8 +448,8 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places, ...@@ -407,8 +448,8 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
const ExecutionStrategy &exec_strategy, const ExecutionStrategy &exec_strategy,
const BuildStrategy &build_strategy, const BuildStrategy &build_strategy,
ir::Graph *graph) ir::Graph *graph)
: member_(new ParallelExecutorPrivate(places)) { : member_(new ParallelExecutorPrivate(places, scope)) {
member_->global_scope_ = scope; member_->InitReaderDeviceCount(graph);
member_->use_cuda_ = exec_strategy.use_cuda_; member_->use_cuda_ = exec_strategy.use_cuda_;
member_->build_strategy_ = build_strategy; member_->build_strategy_ = build_strategy;
member_->use_all_reduce_ = member_->build_strategy_.reduce_ == member_->use_all_reduce_ = member_->build_strategy_.reduce_ ==
...@@ -616,18 +657,38 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places, ...@@ -616,18 +657,38 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
"Paddle should be compiled with CUDA for ParallelGraph Execution."); "Paddle should be compiled with CUDA for ParallelGraph Execution.");
#endif #endif
} else { } else {
if (exec_strategy.type_ == ExecutionStrategy::kDefault) { bool has_drop_last_read_op = details::HasDropLastReadOp(*graph);
VLOG(3) << "use ThreadedSSAGraphExecutor"; auto possible_inference_graphs =
member_->executor_.reset(new details::ThreadedSSAGraphExecutor( details::TrySeparateToMultipleSingleDeviceGraphs(graph);
if (!possible_inference_graphs.empty()) {
VLOG(5) << "Use ParallelSSAGraphExecutor in inference phase";
auto *pg_exe = new details::ParallelSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->local_exec_scopes_, exec_strategy, member_->local_scopes_, member_->local_exec_scopes_,
member_->places_, graph)); member_->places_, std::move(possible_inference_graphs));
if (!has_drop_last_read_op) {
VLOG(5) << "Enable partial feed support in inference phase";
pg_exe->EnablePartialFeedSupport();
}
final_graphs = pg_exe->Graphs();
member_->executor_.reset(pg_exe);
member_->inference_executor_ = pg_exe;
} else { } else {
VLOG(3) << "use FastThreadedSSAGraphExecutor"; LOG_IF(WARNING, details::HasKeepLastReadOp(*graph))
member_->executor_.reset(new details::FastThreadedSSAGraphExecutor( << "drop_last=False for DataLoader is not supported in training "
exec_strategy, member_->local_scopes_, member_->local_exec_scopes_, "network. It is automatically turned to drop_last=True.";
member_->places_, graph)); if (exec_strategy.type_ == ExecutionStrategy::kDefault) {
VLOG(3) << "use ThreadedSSAGraphExecutor";
member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->local_exec_scopes_,
member_->places_, graph));
} else {
VLOG(3) << "use FastThreadedSSAGraphExecutor";
member_->executor_.reset(new details::FastThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->local_exec_scopes_,
member_->places_, graph));
}
final_graphs.emplace_back(graph);
} }
final_graphs.emplace_back(graph);
} }
VLOG(3) << "use ScopeBufferedSSAGraphExecutor"; VLOG(3) << "use ScopeBufferedSSAGraphExecutor";
...@@ -735,6 +796,8 @@ FetchResultType ParallelExecutor::Run( ...@@ -735,6 +796,8 @@ FetchResultType ParallelExecutor::Run(
platform::RecordBlock b(0); platform::RecordBlock b(0);
ResetHasFeedGuard reset_has_feed_guard(member_);
ir::SkipMemOptVarsGuard guard(&(member_->mem_opt_var_infos_), fetch_tensors, ir::SkipMemOptVarsGuard guard(&(member_->mem_opt_var_infos_), fetch_tensors,
member_->HasGarbageCollectors()); member_->HasGarbageCollectors());
...@@ -745,10 +808,31 @@ FetchResultType ParallelExecutor::Run( ...@@ -745,10 +808,31 @@ FetchResultType ParallelExecutor::Run(
void ParallelExecutor::FeedTensorsIntoLocalScopes( void ParallelExecutor::FeedTensorsIntoLocalScopes(
const std::vector<std::unordered_map<std::string, LoDTensor>> &tensors) { const std::vector<std::unordered_map<std::string, LoDTensor>> &tensors) {
PADDLE_ENFORCE_EQ(member_->local_scopes_.size(), tensors.size()); if (!member_->AllowPartialFeed()) {
PADDLE_ENFORCE_EQ(tensors.size(), member_->local_scopes_.size(),
platform::errors::Unimplemented(
"The feed data number %d does not match the device "
"number %d. If you are using DataLoader to feed "
"data, this may be because you set drop_last=False "
"in training network. Currently, drop_last=False for "
"DataLoader is not supported for training network. "
"Please set drop_last=True when defining DataLoader.",
tensors.size(), member_->local_scopes_.size()));
} else {
PADDLE_ENFORCE_GE(member_->local_scopes_.size(), tensors.size(),
platform::errors::InvalidArgument(
"The feed tensor number exceeds the device number"));
}
size_t feed_num = 0;
for (size_t i = 0; i < tensors.size(); ++i) { for (size_t i = 0; i < tensors.size(); ++i) {
auto &map = tensors[i]; auto &map = tensors[i];
if (map.empty()) {
continue;
}
member_->SetHasFeed(i);
++feed_num;
for (auto &pair : map) { for (auto &pair : map) {
bool is_persistable = member_->IsPersistable(pair.first); bool is_persistable = member_->IsPersistable(pair.first);
if (!is_persistable) { if (!is_persistable) {
...@@ -763,11 +847,28 @@ void ParallelExecutor::FeedTensorsIntoLocalScopes( ...@@ -763,11 +847,28 @@ void ParallelExecutor::FeedTensorsIntoLocalScopes(
trg->set_lod(pair.second.lod()); trg->set_lod(pair.second.lod());
} }
} }
if (!member_->AllowPartialFeed()) {
PADDLE_ENFORCE_EQ(feed_num, member_->local_scopes_.size(),
platform::errors::Unimplemented(
"The feed data number %d does not match the device "
"number %d. If you are using DataLoader to feed "
"data, this may be because you set drop_last=False "
"in training network. Currently, drop_last=False for "
"DataLoader is not supported for training network. "
"Please set drop_last=True when defining DataLoader.",
feed_num, member_->local_scopes_.size()));
}
} }
void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes( void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
const std::unordered_map<std::string, LoDTensor> &tensors) { const std::unordered_map<std::string, LoDTensor> &tensors) {
size_t num_places = member_->places_.size(); size_t num_places = member_->places_.size();
bool allow_partial_feed = member_->AllowPartialFeed();
size_t persistable_feed_len = -1UL;
size_t non_persistable_feed_len = -1UL;
for (auto &pair : tensors) { for (auto &pair : tensors) {
bool is_persistable = member_->IsPersistable(pair.first); bool is_persistable = member_->IsPersistable(pair.first);
VLOG(3) << "Split " << (is_persistable ? "persistable" : "no persistable") VLOG(3) << "Split " << (is_persistable ? "persistable" : "no persistable")
...@@ -775,7 +876,8 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes( ...@@ -775,7 +876,8 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
<< ", place: " << pair.second.place(); << ", place: " << pair.second.place();
auto lod_tensors = pair.second.SplitLoDTensor(member_->places_); auto lod_tensors = pair.second.SplitLoDTensor(member_->places_);
bool is_cpu_place = platform::is_cpu_place(member_->places_.front()); bool is_cpu_place = platform::is_cpu_place(member_->places_.front());
if (!is_persistable && num_places != lod_tensors.size()) { if (!is_persistable && num_places != lod_tensors.size() &&
!allow_partial_feed) {
auto error_info = string::Sprintf( auto error_info = string::Sprintf(
"The number(%d) of samples[%s] of current batch is less than the " "The number(%d) of samples[%s] of current batch is less than the "
"count(%d) of devices(%s), currently, it is not allowed. ", "count(%d) of devices(%s), currently, it is not allowed. ",
...@@ -801,7 +903,7 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes( ...@@ -801,7 +903,7 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
framework::TensorCopy(pair.second, member_->places_.at(i), &tmp); framework::TensorCopy(pair.second, member_->places_.at(i), &tmp);
} }
} }
if (lod_tensors.size() != num_places) { if (lod_tensors.size() != num_places && !allow_partial_feed) {
auto error_info = string::Sprintf( auto error_info = string::Sprintf(
"The number(%d) of samples[%s] of the current batch does not match " "The number(%d) of samples[%s] of the current batch does not match "
"the count(%d) of devices(%s). Because that %s is a persistable " "the count(%d) of devices(%s). Because that %s is a persistable "
...@@ -815,7 +917,31 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes( ...@@ -815,7 +917,31 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
} }
} }
for (size_t j = 0; j < num_places; ++j) { if (allow_partial_feed) {
if (is_persistable) {
if (persistable_feed_len == -1UL) {
persistable_feed_len = lod_tensors.size();
} else {
PADDLE_ENFORCE_EQ(
persistable_feed_len, lod_tensors.size(),
platform::errors::InvalidArgument(
"The feeded number of different persistable variables "
"should be the same"));
}
} else {
if (non_persistable_feed_len == -1UL) {
non_persistable_feed_len = lod_tensors.size();
} else {
PADDLE_ENFORCE_EQ(
non_persistable_feed_len, lod_tensors.size(),
platform::errors::InvalidArgument(
"The feeded number of different non-persistable variables "
"should be the same"));
}
}
}
for (size_t j = 0; j < lod_tensors.size(); ++j) {
auto *feed_scope = is_persistable ? member_->local_scopes_[j] auto *feed_scope = is_persistable ? member_->local_scopes_[j]
: member_->local_exec_scopes_[j]; : member_->local_exec_scopes_[j];
auto *feed_var = feed_scope->Var(pair.first); auto *feed_var = feed_scope->Var(pair.first);
...@@ -825,6 +951,22 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes( ...@@ -825,6 +951,22 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
t->set_lod(lod_tensors[j].lod()); t->set_lod(lod_tensors[j].lod());
} }
} }
if (allow_partial_feed && persistable_feed_len != -1UL &&
non_persistable_feed_len != -1UL) {
VLOG(10) << "Persistable len " << persistable_feed_len;
VLOG(10) << "Non persistable len " << non_persistable_feed_len;
PADDLE_ENFORCE_GE(persistable_feed_len, non_persistable_feed_len,
platform::errors::InvalidArgument(
"The feeded number of persistable variables should "
"not be less than non-persistable variables"));
}
if (non_persistable_feed_len != -1UL) {
for (size_t i = 0; i < non_persistable_feed_len; ++i) {
member_->SetHasFeed(i);
}
}
} }
ParallelExecutor::~ParallelExecutor() { ParallelExecutor::~ParallelExecutor() {
...@@ -875,6 +1017,10 @@ bool ParallelExecutor::EnableParallelGraphExecution( ...@@ -875,6 +1017,10 @@ bool ParallelExecutor::EnableParallelGraphExecution(
return enable_parallel_graph; return enable_parallel_graph;
} }
const ir::Graph &ParallelExecutor::Graph() const {
return member_->executor_->Graph();
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -882,3 +1028,4 @@ USE_PASS(reference_count_pass); ...@@ -882,3 +1028,4 @@ USE_PASS(reference_count_pass);
USE_PASS(eager_deletion_pass); USE_PASS(eager_deletion_pass);
USE_PASS(buffer_shared_inplace_pass); USE_PASS(buffer_shared_inplace_pass);
USE_PASS(buffer_shared_cross_op_memory_reuse_pass); USE_PASS(buffer_shared_cross_op_memory_reuse_pass);
USE_PASS(init_reader_device_count_pass);
...@@ -80,6 +80,8 @@ class ParallelExecutor { ...@@ -80,6 +80,8 @@ class ParallelExecutor {
FetchResultType Run(const std::vector<std::string> &fetch_tensors, FetchResultType Run(const std::vector<std::string> &fetch_tensors,
bool return_merged = true); bool return_merged = true);
const ir::Graph &Graph() const;
private: private:
// broadcast the parameters from the 0th device. // broadcast the parameters from the 0th device.
// trainer_id the trainer index in nccl distributed training. // trainer_id the trainer index in nccl distributed training.
......
...@@ -117,6 +117,10 @@ class DecoratedReader : public ReaderBase, ...@@ -117,6 +117,10 @@ class DecoratedReader : public ReaderBase,
~DecoratedReader(); ~DecoratedReader();
const std::shared_ptr<ReaderBase>& UnderlyingReader() const {
return reader_;
}
protected: protected:
void ShutdownImpl() override { void ShutdownImpl() override {
VLOG(1) << "ShutdownImpl"; VLOG(1) << "ShutdownImpl";
...@@ -190,6 +194,8 @@ class ReaderHolder { ...@@ -190,6 +194,8 @@ class ReaderHolder {
return reader_->NeedCheckFeed(); return reader_->NeedCheckFeed();
} }
void Clear() { reader_.reset(); }
operator const std::shared_ptr<ReaderBase>&() const { return this->reader_; } operator const std::shared_ptr<ReaderBase>&() const { return this->reader_; }
private: private:
......
...@@ -56,6 +56,7 @@ class CudnnRNNCache; ...@@ -56,6 +56,7 @@ class CudnnRNNCache;
namespace reader { namespace reader {
class LoDTensorBlockingQueueHolder; class LoDTensorBlockingQueueHolder;
class OrderedMultiDeviceLoDTensorBlockingQueueHolder;
} // namespace reader } // namespace reader
} // namespace operators } // namespace operators
...@@ -139,6 +140,7 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl< ...@@ -139,6 +140,7 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl<
Tensor, LoDTensor, SelectedRows, std::vector<Scope *>, LoDRankTable, Tensor, LoDTensor, SelectedRows, std::vector<Scope *>, LoDRankTable,
LoDTensorArray, platform::PlaceList, ReaderHolder, std::string, Scope *, LoDTensorArray, platform::PlaceList, ReaderHolder, std::string, Scope *,
operators::reader::LoDTensorBlockingQueueHolder, operators::reader::LoDTensorBlockingQueueHolder,
operators::reader::OrderedMultiDeviceLoDTensorBlockingQueueHolder,
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL)
ncclUniqueId, platform::Communicator, platform::NCCLCommunicator, ncclUniqueId, platform::Communicator, platform::NCCLCommunicator,
......
...@@ -27,12 +27,20 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase { ...@@ -27,12 +27,20 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
const platform::Place& dev_place) const override { const platform::Place& dev_place) const override {
auto* out = scope.FindVar(Output("Out")) auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>(); ->template GetMutable<framework::ReaderHolder>();
if (out->Get() != nullptr) {
return;
}
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>(); ->Get<framework::ReaderHolder>();
if (out->Get() != nullptr) {
auto* decorated_reader =
dynamic_cast<framework::DecoratedReader*>(out->Get().get());
PADDLE_ENFORCE_NOT_NULL(
decorated_reader,
platform::errors::NotFound("Not inited with DecoratedReader"));
if (decorated_reader->UnderlyingReader() == underlying_reader.Get()) {
return;
}
}
auto place_str = Attr<std::string>("place"); auto place_str = Attr<std::string>("place");
platform::Place place; platform::Place place;
if (place_str == "AUTO") { if (place_str == "AUTO") {
...@@ -47,6 +55,8 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase { ...@@ -47,6 +55,8 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
place = platform::CUDAPlace(static_cast<int>(num)); place = platform::CUDAPlace(static_cast<int>(num));
} }
VLOG(10) << "Create new double buffer reader on " << place;
out->Reset(framework::MakeDecoratedReader<BufferedReader>(underlying_reader, out->Reset(framework::MakeDecoratedReader<BufferedReader>(underlying_reader,
place, 2)); place, 2));
} }
......
...@@ -38,8 +38,21 @@ class CreatePyReaderOp : public framework::OperatorBase { ...@@ -38,8 +38,21 @@ class CreatePyReaderOp : public framework::OperatorBase {
queue_holder_var, queue_holder_var,
"No LoDTensorBlockingQueueHolder variable with name %s found", "No LoDTensorBlockingQueueHolder variable with name %s found",
queue_name); queue_name);
auto* queue_holder = std::shared_ptr<LoDTensorBlockingQueue> queue;
queue_holder_var->template GetMutable<LoDTensorBlockingQueueHolder>(); std::shared_ptr<OrderedMultiDeviceLoDTensorBlockingQueue> ordered_queue;
int dev_idx = -1;
if (queue_holder_var->IsType<LoDTensorBlockingQueueHolder>()) {
queue = queue_holder_var->Get<LoDTensorBlockingQueueHolder>().GetQueue();
} else if (queue_holder_var
->IsType<OrderedMultiDeviceLoDTensorBlockingQueueHolder>()) {
auto* queue_holder =
queue_holder_var
->GetMutable<OrderedMultiDeviceLoDTensorBlockingQueueHolder>();
dev_idx = Attr<int>("device_index");
ordered_queue = queue_holder->GetQueue();
ordered_queue->SetDeviceCount(Attr<int>("device_count"));
queue = ordered_queue->GetQueue(dev_idx);
}
/* Coverting shape_concat and ranks into DDim of each data. /* Coverting shape_concat and ranks into DDim of each data.
shape_concat and ranks are shapes and shape ranks of each data.E.g. shape_concat and ranks are shapes and shape ranks of each data.E.g.
...@@ -71,8 +84,12 @@ class CreatePyReaderOp : public framework::OperatorBase { ...@@ -71,8 +84,12 @@ class CreatePyReaderOp : public framework::OperatorBase {
for (size_t i = 0; i < need_check_feed_int.size(); ++i) { for (size_t i = 0; i < need_check_feed_int.size(); ++i) {
need_check_feed.push_back(static_cast<bool>(need_check_feed_int[i])); need_check_feed.push_back(static_cast<bool>(need_check_feed_int[i]));
} }
out->Reset(std::make_shared<PyReader>(queue_holder->GetQueue(), dims, auto py_reader =
var_types, need_check_feed)); std::make_shared<PyReader>(queue, dims, var_types, need_check_feed);
if (ordered_queue) {
ordered_queue->SetResetMethod(dev_idx, [out] { out->Clear(); });
}
out->Reset(py_reader);
} }
}; };
...@@ -82,6 +99,13 @@ class CreatePyReaderOpMaker : public FileReaderMakerBase { ...@@ -82,6 +99,13 @@ class CreatePyReaderOpMaker : public FileReaderMakerBase {
AddInput("blocking_queue", AddInput("blocking_queue",
"Name of the `LoDTensorBlockingQueueHolder` variable"); "Name of the `LoDTensorBlockingQueueHolder` variable");
AddAttr<int>("device_index", "The device index this reader offers data")
.SetDefault(0);
AddAttr<int>("device_count",
"The total device number this reader offers data")
.SetDefault(1);
AddComment(R"DOC( AddComment(R"DOC(
Create PyReader to support LoDTensor data feeding in Python side. Create PyReader to support LoDTensor data feeding in Python side.
)DOC"); )DOC");
......
...@@ -27,16 +27,13 @@ namespace paddle { ...@@ -27,16 +27,13 @@ namespace paddle {
namespace operators { namespace operators {
namespace reader { namespace reader {
class LoDTensorBlockingQueueHolder;
class LoDTensorBlockingQueue { class LoDTensorBlockingQueue {
friend class LoDTensorBlockingQueueHolder; public:
private:
explicit LoDTensorBlockingQueue(size_t capacity, bool speed_test_mode = false) explicit LoDTensorBlockingQueue(size_t capacity, bool speed_test_mode = false)
: queue_(capacity, speed_test_mode) {} : queue_(capacity, speed_test_mode) {}
public: ~LoDTensorBlockingQueue() { VLOG(10) << "Destruct LoDTensorBlockingQueue"; }
bool Push(const std::vector<framework::LoDTensor>& lod_tensor_vec) { bool Push(const std::vector<framework::LoDTensor>& lod_tensor_vec) {
return queue_.Send(lod_tensor_vec); return queue_.Send(lod_tensor_vec);
} }
...@@ -67,10 +64,140 @@ class LoDTensorBlockingQueue { ...@@ -67,10 +64,140 @@ class LoDTensorBlockingQueue {
inline void Kill() { queue_.Kill(); } inline void Kill() { queue_.Kill(); }
inline bool WaitForInited(size_t) { return true; }
private: private:
BlockingQueue<std::vector<framework::LoDTensor>> queue_; BlockingQueue<std::vector<framework::LoDTensor>> queue_;
}; };
class OrderedMultiDeviceLoDTensorBlockingQueue {
public:
OrderedMultiDeviceLoDTensorBlockingQueue(size_t capacity,
bool speed_test_mode = false)
: capacity_(capacity), speed_test_mode_(speed_test_mode) {}
~OrderedMultiDeviceLoDTensorBlockingQueue() {
VLOG(10) << "Destruct OrderedMultiDeviceLoDTensorBlockingQueue";
}
bool WaitForInited(size_t milliseconds) {
std::unique_lock<std::mutex> lock(init_mutex_);
return cv_.wait_for(lock, std::chrono::milliseconds(milliseconds),
[this] { return !queues_.empty(); });
}
void SetDeviceCount(size_t dev_cnt) {
{
std::lock_guard<std::mutex> lock(init_mutex_);
PADDLE_ENFORCE_GE(dev_cnt, 1,
platform::errors::InvalidArgument(
"Device count to init "
"OrderedMultiDeviceLoDTensorBlockingQueue"
" must be larger than 1"));
if (!queues_.empty()) {
PADDLE_ENFORCE_EQ(queues_.size(), dev_cnt,
platform::errors::InvalidArgument(
"queues should be only inited once"));
return;
}
VLOG(1) << "Init queue with size " << dev_cnt;
queues_.resize(dev_cnt);
for (auto& item : queues_) {
auto cap = (capacity_ + dev_cnt - 1) / dev_cnt;
item.reset(new LoDTensorBlockingQueue(cap, speed_test_mode_));
}
}
cv_.notify_all();
}
const std::shared_ptr<LoDTensorBlockingQueue>& GetQueue(size_t idx) const {
EnforceIsInited();
PADDLE_ENFORCE_LT(
idx, queues_.size(),
platform::errors::OutOfRange("The queue index is out of range"));
return queues_[idx];
}
bool Push(const std::vector<framework::LoDTensor>& lod_tensor_vec) {
return CurQueue()->Push(lod_tensor_vec);
}
inline size_t Size() const {
size_t size = 0;
for (auto& item : queues_) {
size += item->Size();
}
return size;
}
inline void Close() {
for (auto& item : queues_) {
item->Close();
}
}
inline void Kill() {
for (auto& item : queues_) {
item->Kill();
}
}
inline void Reset() {
{
std::lock_guard<std::mutex> reset_lock(reset_mutex_);
for (auto& method : reset_methods_) {
if (method) method();
}
}
auto dev_cnt = queues_.size();
for (auto& item : queues_) {
auto cap = (capacity_ + dev_cnt - 1) / dev_cnt;
item.reset(new LoDTensorBlockingQueue(cap, speed_test_mode_));
}
data_index_ = 0;
}
inline void SetResetMethod(size_t idx,
const std::function<void()>& reset_method) {
std::lock_guard<std::mutex> reset_lock(reset_mutex_);
EnforceIsInited();
if (reset_methods_.size() <= idx) {
reset_methods_.resize(idx + 1);
}
reset_methods_[idx] = reset_method;
}
inline size_t Cap() const { return capacity_; }
private:
const std::shared_ptr<LoDTensorBlockingQueue>& CurQueue() {
return queues_[(data_index_++) % queues_.size()];
}
private:
void EnforceIsInited() const {
PADDLE_ENFORCE_EQ(queues_.empty(), false,
platform::errors::NotFound("queue has not been inited"));
}
private:
std::vector<std::shared_ptr<LoDTensorBlockingQueue>> queues_;
mutable uint64_t data_index_{0};
size_t dev_cnt_{0};
const size_t capacity_;
const bool speed_test_mode_;
bool is_closed_{false};
std::vector<std::function<void()>> reset_methods_;
mutable std::mutex reset_mutex_;
mutable std::mutex init_mutex_;
mutable std::condition_variable cv_;
};
class LoDTensorBlockingQueueHolder { class LoDTensorBlockingQueueHolder {
public: public:
void InitOnce(size_t capacity, bool speed_test_mode = false) { void InitOnce(size_t capacity, bool speed_test_mode = false) {
...@@ -88,6 +215,26 @@ class LoDTensorBlockingQueueHolder { ...@@ -88,6 +215,26 @@ class LoDTensorBlockingQueueHolder {
std::shared_ptr<LoDTensorBlockingQueue> queue_; std::shared_ptr<LoDTensorBlockingQueue> queue_;
}; };
class OrderedMultiDeviceLoDTensorBlockingQueueHolder {
public:
void InitOnce(size_t capacity, bool speed_test_mode = false) {
PADDLE_ENFORCE_EQ(queue_, nullptr,
platform::errors::AlreadyExists(
"OrderedMultiDeviceLoDTensorBlockingQueueHolder::"
"InitOnce() can only be called once"));
queue_.reset(new OrderedMultiDeviceLoDTensorBlockingQueue(capacity,
speed_test_mode));
}
inline const std::shared_ptr<OrderedMultiDeviceLoDTensorBlockingQueue>&
GetQueue() const {
return queue_;
}
private:
std::shared_ptr<OrderedMultiDeviceLoDTensorBlockingQueue> queue_;
};
} // namespace reader } // namespace reader
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -156,6 +156,10 @@ class ReadOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -156,6 +156,10 @@ class ReadOpMaker : public framework::OpProtoAndCheckerMaker {
" and it is set by ParallelExecutor instance, not users.") " and it is set by ParallelExecutor instance, not users.")
.SetDefault(true); .SetDefault(true);
AddAttr<bool>("infer_out", "").SetDefault(true); AddAttr<bool>("infer_out", "").SetDefault(true);
AddAttr<bool>("drop_last",
"Whether to drop last batches whose number is less than CPU "
"cores/GPU cards number")
.SetDefault(true);
AddComment(R"DOC( AddComment(R"DOC(
Read Operator Read Operator
......
...@@ -51,7 +51,6 @@ limitations under the License. */ ...@@ -51,7 +51,6 @@ limitations under the License. */
#include "paddle/fluid/memory/allocation/mmap_allocator.h" #include "paddle/fluid/memory/allocation/mmap_allocator.h"
#include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/py_func_op.h" #include "paddle/fluid/operators/py_func_op.h"
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h"
#include "paddle/fluid/platform/cpu_helper.h" #include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/dynload/dynamic_loader.h" #include "paddle/fluid/platform/dynload/dynamic_loader.h"
...@@ -94,9 +93,6 @@ limitations under the License. */ ...@@ -94,9 +93,6 @@ limitations under the License. */
#include "pybind11/stl.h" #include "pybind11/stl.h"
DEFINE_bool(reader_queue_speed_test_mode, false,
"If set true, the queue.pop will only get data from queue but not "
"remove the data from queue for speed testing");
DECLARE_bool(use_mkldnn); DECLARE_bool(use_mkldnn);
#ifdef PADDLE_WITH_NGRAPH #ifdef PADDLE_WITH_NGRAPH
DECLARE_bool(use_ngraph); DECLARE_bool(use_ngraph);
...@@ -997,35 +993,6 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -997,35 +993,6 @@ All parameter, weight, gradient are variables in Paddle.
BindReader(&m); BindReader(&m);
using LoDTensorBlockingQueue =
::paddle::operators::reader::LoDTensorBlockingQueue;
using LoDTensorBlockingQueueHolder =
::paddle::operators::reader::LoDTensorBlockingQueueHolder;
py::class_<LoDTensorBlockingQueue, std::shared_ptr<LoDTensorBlockingQueue>>(
m, "LoDTensorBlockingQueue", "")
.def("push",
[](LoDTensorBlockingQueue &self,
const std::vector<framework::LoDTensor> &lod_tensor_vec) {
pybind11::gil_scoped_release release;
return self.Push(lod_tensor_vec);
})
.def("size", &LoDTensorBlockingQueue::Size)
.def("capacity", &LoDTensorBlockingQueue::Cap)
.def("close", &LoDTensorBlockingQueue::Close)
.def("kill", &LoDTensorBlockingQueue::Kill)
.def("is_closed", &LoDTensorBlockingQueue::IsClosed);
m.def("init_lod_tensor_blocking_queue",
[](Variable &var,
size_t capacity) -> std::shared_ptr<LoDTensorBlockingQueue> {
VLOG(1) << "init_lod_tensor_blocking_queue";
auto *holder = var.GetMutable<LoDTensorBlockingQueueHolder>();
holder->InitOnce(capacity, FLAGS_reader_queue_speed_test_mode);
return holder->GetQueue();
},
py::return_value_policy::copy);
py::class_<Scope>(m, "_Scope", R"DOC( py::class_<Scope>(m, "_Scope", R"DOC(
Scope is an association of a name to Variable. All variables belong to Scope. Scope is an association of a name to Variable. All variables belong to Scope.
......
...@@ -20,47 +20,141 @@ ...@@ -20,47 +20,141 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "Python.h" #include "Python.h"
#include "boost/optional.hpp"
#include "gflags/gflags.h"
#include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/imperative/tracer.h" #include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/operators/reader/buffered_reader.h" #include "paddle/fluid/operators/reader/buffered_reader.h"
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h"
#include "paddle/fluid/operators/reader/py_reader.h" #include "paddle/fluid/operators/reader/py_reader.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "pybind11/stl.h" #include "pybind11/stl.h"
DEFINE_bool(reader_queue_speed_test_mode, false,
"If set true, the queue.pop will only get data from queue but not "
"remove the data from queue for speed testing");
namespace paddle { namespace paddle {
namespace pybind { namespace pybind {
namespace py = pybind11; namespace py = pybind11;
namespace reader = operators::reader;
// Check whether the tensor shape matches the VarDesc shape
// Return the different shape if exists
static boost::optional<std::vector<int64_t>> DiffTensorShapeWithVarDesc(
const framework::LoDTensor &tensor, const framework::VarDesc &var_desc,
size_t num_places) {
auto tensor_shape = tensor.dims();
auto desc_shape = var_desc.GetShape();
int64_t rank = tensor_shape.size();
if (UNLIKELY(rank == 0)) {
if (desc_shape.size() != 0) { // Tensor rank = 0 but desc does not match
return framework::vectorize<int64_t>(tensor_shape);
} else {
return boost::none;
}
}
PADDLE_ENFORCE_GE(tensor_shape[0], 0,
platform::errors::InvalidArgument(
"Tensor shape at dim 0 must not be less than 0"));
if (!tensor.lod().empty()) {
tensor_shape[0] = -1; // unknown shape
} else {
int64_t split_size = (tensor_shape[0] + num_places - 1) / num_places;
int64_t remainder = (split_size == 0 ? 0 : tensor_shape[0] % split_size);
tensor_shape[0] = split_size;
if (desc_shape[0] >= 0) { // need check dim 0
if (tensor_shape[0] != desc_shape[0]) {
return framework::vectorize<int64_t>(tensor_shape);
}
if (remainder > 0) {
tensor_shape[0] = remainder;
return framework::vectorize<int64_t>(tensor_shape);
}
}
}
for (int64_t idx = 1; idx < rank; ++idx) {
PADDLE_ENFORCE_GE(
tensor_shape[idx], 0,
platform::errors::InvalidArgument(
"Tensor shape at dim %d must not be less than 0", idx));
if (desc_shape[idx] >= 0 && tensor_shape[idx] != desc_shape[idx]) {
return framework::vectorize<int64_t>(tensor_shape);
}
}
return boost::none;
}
static const std::shared_ptr<reader::LoDTensorBlockingQueue> &GetQueue(
const std::shared_ptr<reader::LoDTensorBlockingQueue> &queue, size_t idx) {
return queue;
}
static const std::shared_ptr<reader::LoDTensorBlockingQueue> &GetQueue(
const std::shared_ptr<reader::OrderedMultiDeviceLoDTensorBlockingQueue>
&queue,
size_t idx) {
return queue->GetQueue(idx);
}
template <typename QueueType>
class MultiDeviceFeedReader { class MultiDeviceFeedReader {
public: public:
using ResultDictList = using ResultDictList =
std::vector<std::unordered_map<std::string, framework::LoDTensor>>; std::vector<std::unordered_map<std::string, framework::LoDTensor>>;
using ResultList = std::vector<std::vector<framework::LoDTensor>>; using ResultList = std::vector<std::vector<framework::LoDTensor>>;
static constexpr bool kKeepOrder =
std::is_same<QueueType,
reader::OrderedMultiDeviceLoDTensorBlockingQueue>::value;
MultiDeviceFeedReader( MultiDeviceFeedReader(
const std::shared_ptr<operators::reader::LoDTensorBlockingQueue> &queue, const std::shared_ptr<QueueType> &queue,
const std::vector<std::string> &names, const std::vector<std::string> &names,
const std::vector<std::vector<int>> &shapes, const std::vector<std::vector<int>> &shapes,
const std::vector<framework::proto::VarType::Type> &dtypes, const std::vector<framework::proto::VarType::Type> &dtypes,
const std::vector<bool> &need_check_feed, const std::vector<bool> &need_check_feed,
const std::vector<platform::Place> &dst_places, bool use_double_buffer) const std::vector<platform::Place> &dst_places, bool use_double_buffer,
bool drop_last)
: queue_(queue), : queue_(queue),
names_(names), names_(names),
pool_(new ::ThreadPool(dst_places.size())) { pool_(new ::ThreadPool(dst_places.size())),
drop_last_(drop_last) {
std::vector<framework::DDim> dims; std::vector<framework::DDim> dims;
for (auto &shape : shapes) { for (auto &shape : shapes) {
dims.push_back(framework::make_ddim(shape)); dims.push_back(framework::make_ddim(shape));
} }
std::shared_ptr<framework::ReaderBase> reader(
new operators::reader::PyReader(queue, dims, dtypes, need_check_feed)); auto first_reader = std::make_shared<reader::PyReader>(
GetQueue(queue, 0), dims, dtypes, need_check_feed);
auto create_or_get_reader = [&](size_t idx) {
if (idx == 0 ||
std::is_same<QueueType, reader::LoDTensorBlockingQueue>::value) {
return first_reader;
} else {
return std::make_shared<reader::PyReader>(GetQueue(queue, idx), dims,
dtypes, need_check_feed);
}
};
readers_.reserve(dst_places.size()); readers_.reserve(dst_places.size());
for (auto &p : dst_places) { for (size_t i = 0; i < dst_places.size(); ++i) {
auto &p = dst_places[i];
auto *holder = new framework::ReaderHolder(); auto *holder = new framework::ReaderHolder();
auto reader = create_or_get_reader(i);
if (use_double_buffer) { if (use_double_buffer) {
VLOG(10) << "Creating " << i << "-th BufferedReader";
holder->Reset( holder->Reset(
framework::MakeDecoratedReader<operators::reader::BufferedReader>( framework::MakeDecoratedReader<operators::reader::BufferedReader>(
reader, p, 2)); reader, p, 2));
...@@ -80,12 +174,22 @@ class MultiDeviceFeedReader { ...@@ -80,12 +174,22 @@ class MultiDeviceFeedReader {
ReadAsync(); ReadAsync();
} }
bool DropLast() const { return drop_last_; }
ResultDictList ReadNext() { ResultDictList ReadNext() {
CheckNextStatus(); CheckNextStatus();
ResultDictList result(ret_.size()); ResultDictList result;
result.reserve(ret_.size());
for (size_t i = 0; i < ret_.size(); ++i) { for (size_t i = 0; i < ret_.size(); ++i) {
if (ret_[i].empty()) {
if (!kKeepOrder) result.emplace_back();
continue;
}
result.emplace_back();
auto &ret = result.back();
for (size_t j = 0; j < names_.size(); ++j) { for (size_t j = 0; j < names_.size(); ++j) {
result[i].emplace(names_[j], std::move(ret_[i][j])); ret.emplace(names_[j], std::move(ret_[i][j]));
} }
} }
ReadAsync(); ReadAsync();
...@@ -97,6 +201,7 @@ class MultiDeviceFeedReader { ...@@ -97,6 +201,7 @@ class MultiDeviceFeedReader {
ResultList result; ResultList result;
result.reserve(ret_.size()); result.reserve(ret_.size());
for (size_t i = 0; i < ret_.size(); ++i) { for (size_t i = 0; i < ret_.size(); ++i) {
if (kKeepOrder && ret_[i].empty()) continue;
result.emplace_back(std::move(ret_[i])); result.emplace_back(std::move(ret_[i]));
} }
ReadAsync(); ReadAsync();
...@@ -122,24 +227,29 @@ class MultiDeviceFeedReader { ...@@ -122,24 +227,29 @@ class MultiDeviceFeedReader {
}; };
Status WaitFutures(std::exception_ptr *excep) { Status WaitFutures(std::exception_ptr *excep) {
bool is_success = true;
*excep = nullptr; *excep = nullptr;
size_t success_num = 0;
for (size_t i = 0; i < futures_.size(); ++i) { for (size_t i = 0; i < futures_.size(); ++i) {
auto each_status = futures_[i].get(); auto each_status = futures_[i].get();
if (UNLIKELY(each_status != Status::kSuccess)) { if (UNLIKELY(each_status != Status::kSuccess)) {
is_success = false;
if (UNLIKELY(each_status == Status::kException)) { if (UNLIKELY(each_status == Status::kException)) {
PADDLE_ENFORCE_NOT_NULL(exceptions_[i]); PADDLE_ENFORCE_NOT_NULL(exceptions_[i]);
*excep = exceptions_[i]; *excep = exceptions_[i];
exceptions_[i] = nullptr; exceptions_[i] = nullptr;
} }
} else {
++success_num;
} }
} }
if (UNLIKELY(*excep)) { if (UNLIKELY(*excep)) {
return Status::kException; return Status::kException;
}
if (drop_last_) {
return success_num == futures_.size() ? Status::kSuccess : Status::kEOF;
} else { } else {
return is_success ? Status::kSuccess : Status::kEOF; return success_num > 0 ? Status::kSuccess : Status::kEOF;
} }
} }
...@@ -183,7 +293,7 @@ class MultiDeviceFeedReader { ...@@ -183,7 +293,7 @@ class MultiDeviceFeedReader {
PADDLE_ENFORCE_EQ(status, Status::kSuccess); PADDLE_ENFORCE_EQ(status, Status::kSuccess);
} }
std::shared_ptr<operators::reader::LoDTensorBlockingQueue> queue_; std::shared_ptr<QueueType> queue_;
std::vector<std::string> names_; std::vector<std::string> names_;
std::unique_ptr<::ThreadPool> pool_; std::unique_ptr<::ThreadPool> pool_;
...@@ -193,24 +303,21 @@ class MultiDeviceFeedReader { ...@@ -193,24 +303,21 @@ class MultiDeviceFeedReader {
std::vector<std::exception_ptr> exceptions_; std::vector<std::exception_ptr> exceptions_;
std::vector<std::vector<framework::LoDTensor>> ret_; std::vector<std::vector<framework::LoDTensor>> ret_;
bool drop_last_;
}; };
void BindReader(py::module *module) { template <typename QueueType>
void BindMultiDeviceReader(py::module *module, const char *reader_name) {
auto &m = *module; auto &m = *module;
namespace reader = ::paddle::operators::reader; using ReaderType = MultiDeviceFeedReader<QueueType>;
py::class_<ReaderType>(m, reader_name, "")
py::class_<framework::ReaderHolder>(m, "Reader", "") .def("read_next", &ReaderType::ReadNext,
.def("start", &framework::ReaderHolder::Start)
.def("reset", &framework::ReaderHolder::ResetAll);
py::class_<MultiDeviceFeedReader>(m, "MultiDeviceFeedReader", "")
.def("read_next", &MultiDeviceFeedReader::ReadNext,
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("read_next_list", &MultiDeviceFeedReader::ReadNextList, .def("read_next_list", &ReaderType::ReadNextList,
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("read_next_var_list", .def("read_next_var_list",
[](MultiDeviceFeedReader &self) { [](ReaderType &self) {
auto result_list = self.ReadNextList(); auto result_list = self.ReadNextList();
auto &tensor_list = result_list[0]; auto &tensor_list = result_list[0];
std::vector<std::shared_ptr<imperative::VarBase>> var_list; std::vector<std::shared_ptr<imperative::VarBase>> var_list;
...@@ -234,23 +341,116 @@ void BindReader(py::module *module) { ...@@ -234,23 +341,116 @@ void BindReader(py::module *module) {
return var_list; return var_list;
}, },
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("reset", &MultiDeviceFeedReader::Reset, .def("reset", &ReaderType::Reset,
py::call_guard<py::gil_scoped_release>());
}
void BindReader(py::module *module) {
auto &m = *module;
m.def("diff_tensor_shape", [](const framework::LoDTensor &tensor,
const framework::VarDesc &var_desc,
size_t num_places) -> py::object {
auto diff = DiffTensorShapeWithVarDesc(tensor, var_desc, num_places);
if (diff) {
return py::cast(std::move(diff.get()));
} else {
return py::cast(nullptr);
}
});
m.def("init_lod_tensor_blocking_queue",
[](framework::Variable &var, size_t capacity,
bool is_ordered) -> py::object {
VLOG(1) << "init_lod_tensor_blocking_queue";
if (is_ordered) {
auto *holder = var.GetMutable<
reader::OrderedMultiDeviceLoDTensorBlockingQueueHolder>();
holder->InitOnce(capacity, FLAGS_reader_queue_speed_test_mode);
return py::cast(holder->GetQueue());
} else {
auto *holder =
var.GetMutable<reader::LoDTensorBlockingQueueHolder>();
holder->InitOnce(capacity, FLAGS_reader_queue_speed_test_mode);
return py::cast(holder->GetQueue());
}
},
py::return_value_policy::copy);
py::class_<framework::ReaderHolder>(m, "Reader", "")
.def("start", &framework::ReaderHolder::Start)
.def("reset", &framework::ReaderHolder::ResetAll);
py::class_<reader::LoDTensorBlockingQueue,
std::shared_ptr<reader::LoDTensorBlockingQueue>>(
m, "LoDTensorBlockingQueue", "")
.def("push",
[](reader::LoDTensorBlockingQueue &self,
const std::vector<framework::LoDTensor> &lod_tensor_vec) {
return self.Push(lod_tensor_vec);
},
py::call_guard<py::gil_scoped_release>())
.def("size", &reader::LoDTensorBlockingQueue::Size)
.def("capacity", &reader::LoDTensorBlockingQueue::Cap)
.def("close", &reader::LoDTensorBlockingQueue::Close)
.def("kill", &reader::LoDTensorBlockingQueue::Kill)
.def("wait_for_inited", &reader::LoDTensorBlockingQueue::WaitForInited,
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
py::class_<reader::OrderedMultiDeviceLoDTensorBlockingQueue,
std::shared_ptr<reader::OrderedMultiDeviceLoDTensorBlockingQueue>>(
m, "OrderedMultiDeviceLoDTensorBlockingQueue", "")
.def("push",
[](reader::OrderedMultiDeviceLoDTensorBlockingQueue &self,
const std::vector<framework::LoDTensor> &lod_tensor_vec) {
return self.Push(lod_tensor_vec);
},
py::call_guard<py::gil_scoped_release>())
.def("size", &reader::OrderedMultiDeviceLoDTensorBlockingQueue::Size)
.def("capacity", &reader::OrderedMultiDeviceLoDTensorBlockingQueue::Cap)
.def("close", &reader::OrderedMultiDeviceLoDTensorBlockingQueue::Close)
.def("kill", &reader::OrderedMultiDeviceLoDTensorBlockingQueue::Kill)
.def("wait_for_inited",
&reader::OrderedMultiDeviceLoDTensorBlockingQueue::WaitForInited,
py::call_guard<py::gil_scoped_release>())
.def("reset", &reader::OrderedMultiDeviceLoDTensorBlockingQueue::Reset);
BindMultiDeviceReader<reader::LoDTensorBlockingQueue>(
module, "MultiDeviceFeedReader");
BindMultiDeviceReader<reader::OrderedMultiDeviceLoDTensorBlockingQueue>(
module, "OrderedMultiDeviceFeedReader");
m.def("create_py_reader", m.def("create_py_reader",
[](const std::shared_ptr<operators::reader::LoDTensorBlockingQueue> [](const std::shared_ptr<reader::LoDTensorBlockingQueue> &queue,
&queue,
const std::vector<std::string> &names, const std::vector<std::string> &names,
const std::vector<std::vector<int>> &shapes, const std::vector<std::vector<int>> &shapes,
const std::vector<framework::proto::VarType::Type> &dtypes, const std::vector<framework::proto::VarType::Type> &dtypes,
const std::vector<bool> &need_check_feed, const std::vector<bool> &need_check_feed,
const std::vector<platform::Place> &dst_places, const std::vector<platform::Place> &dst_places,
bool use_double_buffer) { bool use_double_buffer, bool drop_last) {
return new MultiDeviceFeedReader(queue, names, shapes, dtypes, return new MultiDeviceFeedReader<reader::LoDTensorBlockingQueue>(
need_check_feed, dst_places, queue, names, shapes, dtypes, need_check_feed, dst_places,
use_double_buffer); use_double_buffer, drop_last);
}, },
py::return_value_policy::take_ownership); py::return_value_policy::take_ownership);
m.def(
"create_py_reader",
[](const std::shared_ptr<reader::OrderedMultiDeviceLoDTensorBlockingQueue>
&queue,
const std::vector<std::string> &names,
const std::vector<std::vector<int>> &shapes,
const std::vector<framework::proto::VarType::Type> &dtypes,
const std::vector<bool> &need_check_feed,
const std::vector<platform::Place> &dst_places, bool use_double_buffer,
bool drop_last) {
queue->SetDeviceCount(dst_places.size());
return new MultiDeviceFeedReader<
reader::OrderedMultiDeviceLoDTensorBlockingQueue>(
queue, names, shapes, dtypes, need_check_feed, dst_places,
use_double_buffer, drop_last);
},
py::return_value_policy::take_ownership);
} }
} // namespace pybind } // namespace pybind
......
...@@ -216,18 +216,12 @@ def check_feed_shape_type(var, feed, num_places=1): ...@@ -216,18 +216,12 @@ def check_feed_shape_type(var, feed, num_places=1):
the feed value the feed value
""" """
if var.desc.need_check_feed(): if var.desc.need_check_feed():
feed_shape = feed.shape() diff_shape = core.diff_tensor_shape(feed, var.desc, num_places)
if six.PY2: if diff_shape is not None:
feed_shape[0] = long(feed_shape[0] /
num_places) if len(feed.lod()) == 0 else -1
else:
feed_shape[0] = int(feed_shape[0] /
num_places) if len(feed.lod()) == 0 else -1
if not dimension_is_compatible_with(feed_shape, var.shape):
raise ValueError( raise ValueError(
'The fed Variable %r should have dimensions = %d, shape = ' 'The fed Variable %r should have dimensions = %d, shape = '
'%r, but received fed shape %r on each device' % '%r, but received fed shape %r on each device' %
(var.name, len(var.shape), var.shape, feed_shape)) (var.name, len(var.shape), var.shape, diff_shape))
if not dtype_is_compatible_with(feed._dtype(), var.dtype): if not dtype_is_compatible_with(feed._dtype(), var.dtype):
var_dtype_format = convert_dtype(var.dtype) if isinstance( var_dtype_format = convert_dtype(var.dtype) if isinstance(
var.dtype, core.VarDesc.VarType) else var.dtype var.dtype, core.VarDesc.VarType) else var.dtype
...@@ -646,11 +640,6 @@ class Executor(object): ...@@ -646,11 +640,6 @@ class Executor(object):
exe.feed_and_split_tensor_into_local_scopes(feed_tensor_dict) exe.feed_and_split_tensor_into_local_scopes(feed_tensor_dict)
elif isinstance(feed, list) or isinstance(feed, tuple): elif isinstance(feed, list) or isinstance(feed, tuple):
if len(feed) != len(program._places):
raise ValueError(
"Feed a list of tensor, the list should be the same size as places"
)
res = list() res = list()
for i, each in enumerate(feed): for i, each in enumerate(feed):
if not isinstance(each, dict): if not isinstance(each, dict):
......
...@@ -429,7 +429,7 @@ def _py_reader(capacity, ...@@ -429,7 +429,7 @@ def _py_reader(capacity,
double_buffer_name = "_".join([name, "double_buffer"]) double_buffer_name = "_".join([name, "double_buffer"])
var = global_scope().var(queue_name) var = global_scope().var(queue_name)
feed_queue = core.init_lod_tensor_blocking_queue(var, capacity) feed_queue = core.init_lod_tensor_blocking_queue(var, capacity, False)
startup_blk = default_startup_program().current_block() startup_blk = default_startup_program().current_block()
startup_var = startup_blk.create_var(name=reader_name) startup_var = startup_blk.create_var(name=reader_name)
......
...@@ -48,6 +48,17 @@ __all__ = ['PyReader', 'DataLoader'] ...@@ -48,6 +48,17 @@ __all__ = ['PyReader', 'DataLoader']
data_loader_unique_name_generator = UniqueNameGenerator() data_loader_unique_name_generator = UniqueNameGenerator()
KEEP_DATA_LOADER_ORDER = True
def keep_data_loader_order(*args):
global KEEP_DATA_LOADER_ORDER
if len(args) == 0:
return KEEP_DATA_LOADER_ORDER
else:
assert len(args) == 1 and isinstance(args[0], bool)
KEEP_DATA_LOADER_ORDER = args[0]
def _convert_places(places): def _convert_places(places):
if not isinstance(places, (list, tuple)): if not isinstance(places, (list, tuple)):
...@@ -172,8 +183,12 @@ class DataLoader(object): ...@@ -172,8 +183,12 @@ class DataLoader(object):
use_double_buffer=True, use_double_buffer=True,
iterable=True, iterable=True,
return_list=False, return_list=False,
use_multiprocess=False): use_multiprocess=False,
drop_last=True):
""" """
.. note::
**The framework ensures that the data loading order of DataLoader is exactly the same as the user-defined data source.**
Create a DataLoader object for loading data from Python generator. Create a DataLoader object for loading data from Python generator.
Data would be prefetched using Python thread and be pushed Data would be prefetched using Python thread and be pushed
into a queue asynchronously. into a queue asynchronously.
...@@ -182,7 +197,7 @@ class DataLoader(object): ...@@ -182,7 +197,7 @@ class DataLoader(object):
:code:`set_sample_generator` , :code:`set_sample_list_generator` and :code:`set_sample_generator` , :code:`set_sample_list_generator` and
:code:`set_batch_generator` . Please see the following example codes :code:`set_batch_generator` . Please see the following example codes
to know their usages. to know their usages.
If iterable = True, the created DataLoader object is a Python generator If iterable = True, the created DataLoader object is a Python generator
object, which is iterable using for-range loop. object, which is iterable using for-range loop.
...@@ -218,11 +233,18 @@ class DataLoader(object): ...@@ -218,11 +233,18 @@ class DataLoader(object):
can be used in the dygraph mode. In the static graph mode, can be used in the dygraph mode. In the static graph mode,
whether this parameter is set or not has no effect. whether this parameter is set or not has no effect.
The Default value is False. The Default value is False.
drop_last (bool): whether to drop the last batches whose number is
less than the CPU core/GPU card number. The default value is
True. In training phase, users should not set drop_last=False,
because all CPU cores/GPU cards must read data from DataLoader.
In inference phase, users can set drop_last=False, so that the
last batches whose number is less than the CPU core/GPU card
number can be tested.
Returns: Returns:
loader (DataLoader): the created DataLoader object. loader (DataLoader): the created DataLoader object.
Examples: Examples 1:
.. code-block:: python .. code-block:: python
...@@ -354,6 +376,49 @@ class DataLoader(object): ...@@ -354,6 +376,49 @@ class DataLoader(object):
assert image.shape == [BATCH_SIZE, 784] assert image.shape == [BATCH_SIZE, 784]
assert label.shape == [BATCH_SIZE, 1] assert label.shape == [BATCH_SIZE, 1]
assert relu.shape == [BATCH_SIZE, 784] assert relu.shape == [BATCH_SIZE, 784]
Examples 2:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
import os
# We use 2 CPU cores to run inference network
os.environ['CPU_NUM'] = '2'
# The data source has only 3 batches, which can not be
# divided evenly to each CPU core
def batch_generator():
for i in range(3):
yield np.array([i+1]).astype('float32'),
x = fluid.data(name='x', shape=[None], dtype='float32')
y = x * x
def run_inference(drop_last):
loader = fluid.io.DataLoader.from_generator(feed_list=[x],
capacity=8, drop_last=drop_last)
loader.set_batch_generator(batch_generator, fluid.cpu_places())
exe = fluid.Executor(fluid.CPUPlace())
prog = fluid.CompiledProgram(fluid.default_main_program())
prog = prog.with_data_parallel()
result = []
for data in loader():
each_ret, = exe.run(prog, feed=data, fetch_list=[y])
result.extend(each_ret)
return result
# Set drop_last to True, so that the last batch whose
# number is less than CPU core number would be discarded.
print(run_inference(drop_last=True)) # [1.0, 4.0]
# Set drop_last to False, so that the last batch whose
# number is less than CPU core number can be tested.
print(run_inference(drop_last=False)) # [1.0, 4.0, 9.0]
""" """
if in_dygraph_mode(): if in_dygraph_mode():
return DygraphGeneratorLoader(feed_list, capacity, return DygraphGeneratorLoader(feed_list, capacity,
...@@ -361,7 +426,7 @@ class DataLoader(object): ...@@ -361,7 +426,7 @@ class DataLoader(object):
return_list, use_multiprocess) return_list, use_multiprocess)
else: else:
return GeneratorLoader(feed_list, capacity, use_double_buffer, return GeneratorLoader(feed_list, capacity, use_double_buffer,
iterable, return_list) iterable, return_list, drop_last)
@staticmethod @staticmethod
def from_dataset(dataset, places, drop_last=True): def from_dataset(dataset, places, drop_last=True):
...@@ -514,10 +579,10 @@ class DygraphGeneratorLoader(DataLoaderBase): ...@@ -514,10 +579,10 @@ class DygraphGeneratorLoader(DataLoaderBase):
self._dtypes = [] self._dtypes = []
self._need_check_feed = [] self._need_check_feed = []
self._blocking_queue = core.init_lod_tensor_blocking_queue( self._blocking_queue = core.init_lod_tensor_blocking_queue(
core.Variable(), self._capacity) core.Variable(), self._capacity, False)
self._reader = core.create_py_reader( self._reader = core.create_py_reader(
self.queue, self._var_names, self._shapes, self._dtypes, self.queue, self._var_names, self._shapes, self._dtypes,
self._need_check_feed, self._places, self._use_double_buffer) self._need_check_feed, self._places, self._use_double_buffer, True)
def _start(self): def _start(self):
if self._use_multiprocess: if self._use_multiprocess:
...@@ -728,12 +793,16 @@ class GeneratorLoader(DataLoaderBase): ...@@ -728,12 +793,16 @@ class GeneratorLoader(DataLoaderBase):
capacity=None, capacity=None,
use_double_buffer=True, use_double_buffer=True,
iterable=True, iterable=True,
return_list=False): return_list=False,
drop_last=True):
self._tensor_reader = None self._tensor_reader = None
self._places = None self._places = None
self._thread = None self._thread = None
self._queue = None self._queue = None
self._feed_list = feed_list self._feed_list = feed_list
self._exited = False
self._drop_last = drop_last
self._keep_order = keep_data_loader_order()
if not capacity: if not capacity:
raise ValueError("Please give value to capacity.") raise ValueError("Please give value to capacity.")
self._iterable = iterable self._iterable = iterable
...@@ -761,11 +830,12 @@ class GeneratorLoader(DataLoaderBase): ...@@ -761,11 +830,12 @@ class GeneratorLoader(DataLoaderBase):
self._need_check_feed = [ self._need_check_feed = [
v.desc.need_check_feed() for v in self._feed_list v.desc.need_check_feed() for v in self._feed_list
] ]
self._queue = core.init_lod_tensor_blocking_queue(core.Variable(), self._queue = core.init_lod_tensor_blocking_queue(
self._capacity) core.Variable(), self._capacity, self._keep_order)
self._reader = core.create_py_reader( self._reader = core.create_py_reader(
self.queue, self._var_names, self._shapes, self._dtypes, self.queue, self._var_names, self._shapes, self._dtypes,
self._need_check_feed, self._places, self._use_double_buffer) self._need_check_feed, self._places, self._use_double_buffer,
self._drop_last)
def _init_non_iterable(self): def _init_non_iterable(self):
lod_levels = [] lod_levels = []
...@@ -789,16 +859,21 @@ class GeneratorLoader(DataLoaderBase): ...@@ -789,16 +859,21 @@ class GeneratorLoader(DataLoaderBase):
double_buffer_name = data_loader_unique_name_generator('double_buffer') double_buffer_name = data_loader_unique_name_generator('double_buffer')
var = global_scope().var(queue_name) var = global_scope().var(queue_name)
self._queue = core.init_lod_tensor_blocking_queue(var, self._capacity) self._queue = core.init_lod_tensor_blocking_queue(var, self._capacity,
self._keep_order)
if self._keep_order:
block = default_main_program().current_block()
else:
block = default_startup_program().current_block()
startup_blk = default_startup_program().current_block() reader_var = block.create_var(name=reader_name)
startup_var = startup_blk.create_var(name=reader_name)
dtype_int = [int(t) for t in dtypes] dtype_int = [int(t) for t in dtypes]
startup_blk.append_op( block.append_op(
type='create_py_reader', type='create_py_reader',
inputs={'blocking_queue': [queue_name]}, inputs={'blocking_queue': [queue_name]},
outputs={'Out': [startup_var]}, outputs={'Out': [reader_var]},
attrs={ attrs={
'shape_concat': shape_concat, 'shape_concat': shape_concat,
'lod_levels': lod_levels, 'lod_levels': lod_levels,
...@@ -807,16 +882,23 @@ class GeneratorLoader(DataLoaderBase): ...@@ -807,16 +882,23 @@ class GeneratorLoader(DataLoaderBase):
'ranks': ranks 'ranks': ranks
}) })
startup_var.desc.set_dtypes(dtypes) reader_var.desc.set_dtypes(dtypes)
startup_var.persistable = True reader_var.persistable = True
reader_var.stop_gradient = True
if self._keep_order:
main_prog_var = reader_var
reader = main_prog_var
reader.reset = self._queue.reset
else:
main_prog_var = _copy_reader_var_(
default_main_program().current_block(), reader_var)
main_prog_var = _copy_reader_var_( main_prog_var.stop_gradient = True
default_main_program().current_block(), startup_var) main_prog_var.persistable = True
main_prog_var.stop_gradient = True reader = monkey_patch_reader_methods(main_prog_var)
main_prog_var.persistable = True
reader = monkey_patch_reader_methods(main_prog_var)
if self._use_double_buffer: if self._use_double_buffer:
double_buffer_reader = double_buffer( double_buffer_reader = double_buffer(
reader, name=double_buffer_name) reader, name=double_buffer_name)
...@@ -830,7 +912,8 @@ class GeneratorLoader(DataLoaderBase): ...@@ -830,7 +912,8 @@ class GeneratorLoader(DataLoaderBase):
default_main_program().current_block().append_op( default_main_program().current_block().append_op(
type='read', type='read',
inputs={'Reader': [self._reader]}, inputs={'Reader': [self._reader]},
outputs={'Out': self._feed_list}) outputs={'Out': self._feed_list},
attrs={'drop_last': self._drop_last})
@property @property
def queue(self): def queue(self):
...@@ -879,14 +962,20 @@ class GeneratorLoader(DataLoaderBase): ...@@ -879,14 +962,20 @@ class GeneratorLoader(DataLoaderBase):
" to locate the data causes this issue.\n\t* Please consider using " " to locate the data causes this issue.\n\t* Please consider using "
"'fluid.create_lod_tensor' to convert it to a LoD-Tensor.")) "'fluid.create_lod_tensor' to convert it to a LoD-Tensor."))
return arr
def _start(self): def _start(self):
def __thread_main__(): def __thread_main__():
try: try:
while not self._queue.wait_for_inited(1):
if self._exited:
return
for tensors in self._tensor_reader(): for tensors in self._tensor_reader():
array = core.LoDTensorArray() array = core.LoDTensorArray()
for item in tensors: for item in tensors:
if not isinstance(item, core.LoDTensor): if not isinstance(item, core.LoDTensor):
self._check_input_array(item) item = self._check_input_array(item)
tmp = core.LoDTensor() tmp = core.LoDTensor()
tmp.set(item, core.CPUPlace()) tmp.set(item, core.CPUPlace())
item = tmp item = tmp
...@@ -910,10 +999,12 @@ class GeneratorLoader(DataLoaderBase): ...@@ -910,10 +999,12 @@ class GeneratorLoader(DataLoaderBase):
def _reset(self): def _reset(self):
self._queue.close() self._queue.close()
self._exited = True
thread = self._thread thread = self._thread
if thread is not None: if thread is not None:
thread.join() thread.join()
self._exited = False
self._reader.reset() self._reader.reset()
def set_sample_generator(self, def set_sample_generator(self,
......
...@@ -359,7 +359,10 @@ set_tests_properties(test_parallel_executor_test_while_train test_parallel_execu ...@@ -359,7 +359,10 @@ set_tests_properties(test_parallel_executor_test_while_train test_parallel_execu
test_parallel_executor_feed_persistable_var test_parallel_executor_feed_persistable_var
test_parallel_executor_crf_auto_growth test_buffer_shared_memory_reuse_pass_and_fuse_optimization_op_pass test_parallel_executor_crf_auto_growth test_buffer_shared_memory_reuse_pass_and_fuse_optimization_op_pass
test_data_norm_op test_imperative_using_non_zero_gpu test_fuse_bn_act_pass test_data_norm_op test_imperative_using_non_zero_gpu test_fuse_bn_act_pass
test_optimizer_in_control_flow test_fetch_unmerged test_optimizer_in_control_flow test_dataloader_keep_order
test_dataloader_unkeep_order
test_parallel_executor_inference_feed_partial_data
test_fetch_unmerged
test_buffer_shared_memory_reuse_pass PROPERTIES LABELS "RUN_TYPE=DIST") test_buffer_shared_memory_reuse_pass PROPERTIES LABELS "RUN_TYPE=DIST")
if(NOT WIN32 AND NOT APPLE) if(NOT WIN32 AND NOT APPLE)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import unittest
import numpy as np
import os
import six
def create_reader(shape, batch_number):
def __impl__():
idx = 0
for _ in six.moves.range(batch_number):
yield np.ones(shape).astype('float32') * idx,
idx += 1
return __impl__
class DataLoaderKeepOrderTestBase(unittest.TestCase):
def initParameters(self):
self.iterable = False
self.break_num = 10000
def setUp(self):
self.epoch_num = 3
self.batch_num = 40
self.shape = [3, 4, 5]
self.initParameters()
def build_network(self, places):
input_data = fluid.data(shape=self.shape, dtype='float32', name="input")
loader = fluid.io.DataLoader.from_generator(
capacity=16, feed_list=[input_data], iterable=self.iterable)
fc = fluid.layers.fc(input_data, size=10)
loss = fluid.layers.reduce_mean(fc)
loader.set_batch_generator(
create_reader(self.shape, self.batch_num),
places=places if loader.iterable else None)
return input_data, loss, loader
def assertInputData(self, batch_id, input_data, dev_cnt):
if isinstance(input_data, list):
self.assertTrue(len(input_data), dev_cnt)
start_val = dev_cnt * batch_id
for each_input_dict in input_data:
input_tensor = np.array(each_input_dict["input"])
self.assertEqual(self.shape, list(input_tensor.shape))
self.assertTrue((input_tensor == start_val).all())
start_val += 1
else:
self.assertEqual(
list(input_data.shape),
[self.shape[0] * dev_cnt] + self.shape[1:])
start_val = dev_cnt * batch_id
for idx in six.moves.range(dev_cnt):
data_part = input_data[idx * self.shape[0]:(idx + 1) *
self.shape[0], :]
self.assertTrue((data_part == start_val).all())
start_val += 1
def get_places(self):
place_list = [fluid.cpu_places(1), fluid.cpu_places(4)]
if fluid.is_compiled_with_cuda():
place_list.extend([fluid.cuda_places(0), fluid.cuda_places([0, 1])])
return place_list
def test_main(self):
for p in self.get_places():
use_compiled_program_list = [True] if len(p) > 1 else [False, True]
for use_compiled_program in use_compiled_program_list:
self.run_main_with_place(p, use_compiled_program)
def run_main_with_place(self, places, use_compiled_program=True):
with fluid.scope_guard(fluid.Scope()):
with fluid.program_guard(fluid.Program(), fluid.Program()):
input_data, loss, loader = self.build_network(places)
fetch_list = [input_data]
exe = fluid.Executor(places[0])
exe.run(fluid.default_startup_program())
dev_cnt = len(places)
if dev_cnt > 1:
self.assertTrue(use_compiled_program)
main_program = fluid.default_main_program()
if use_compiled_program:
main_program = fluid.CompiledProgram(
main_program).with_data_parallel(
loss_name=loss.name, places=places)
max_batch_num = min(self.break_num,
int(self.batch_num / dev_cnt))
if loader.iterable:
early_break = False
for epoch_id in six.moves.range(self.epoch_num):
early_break = False
batch_id = 0
for data in loader():
if batch_id >= self.break_num:
early_break = True
break
self.assertInputData(batch_id, data, dev_cnt)
fetch_val, = exe.run(program=main_program,
feed=data,
fetch_list=fetch_list)
self.assertInputData(batch_id, fetch_val, dev_cnt)
batch_id += 1
self.assertEqual(batch_id, max_batch_num)
if early_break:
loader._reset()
else:
for epoch_id in six.moves.range(self.epoch_num):
batch_id = 0
loader.start()
try:
while True:
if batch_id >= self.break_num:
loader.reset()
break
fetch_val, = exe.run(program=main_program,
fetch_list=fetch_list)
self.assertInputData(batch_id, fetch_val,
dev_cnt)
batch_id += 1
except fluid.core.EOFException:
loader.reset()
self.assertEqual(batch_id, max_batch_num)
class IterableDataLoaderKeepOrderTest2(DataLoaderKeepOrderTestBase):
def initParameters(self):
self.iterable = True
self.break_num = 10000
class IterableDataLoaderKeepOrderTest3(DataLoaderKeepOrderTestBase):
def initParameters(self):
self.iterable = False
self.break_num = 2
class IterableDataLoaderKeepOrderTest4(DataLoaderKeepOrderTestBase):
def initParameters(self):
self.iterable = True
self.break_num = 2
class IterableDataLoaderKeepOrderTest5(DataLoaderKeepOrderTestBase):
def initParameters(self):
self.iterable = False
self.break_num = 0
class IterableDataLoaderKeepOrderTest6(DataLoaderKeepOrderTestBase):
def initParameters(self):
self.iterable = True
self.break_num = 0
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import unittest
import numpy as np
import os
import six
from paddle.fluid.reader import keep_data_loader_order
keep_data_loader_order(False)
def create_reader(shape, batch_number):
def __impl__():
idx = 0
for _ in six.moves.range(batch_number):
yield np.ones(shape).astype('float32') * idx,
idx += 1
return __impl__
class DataLoaderKeepOrderTestBase(unittest.TestCase):
def initParameters(self):
self.iterable = False
self.break_num = 10000
def setUp(self):
self.epoch_num = 3
self.batch_num = 40
self.shape = [3, 4, 5]
self.initParameters()
def clear_visited(self):
self.visited = set()
def build_network(self, places):
input_data = fluid.data(shape=self.shape, dtype='float32', name="input")
loader = fluid.io.DataLoader.from_generator(
capacity=16, feed_list=[input_data], iterable=self.iterable)
fc = fluid.layers.fc(input_data, size=10)
loss = fluid.layers.reduce_mean(fc)
loader.set_batch_generator(
create_reader(self.shape, self.batch_num),
places=places if loader.iterable else None)
return input_data, loss, loader
def assertInputData(self, batch_id, input_data, dev_cnt,
check_visited=True):
if isinstance(input_data, list):
self.assertTrue(len(input_data), dev_cnt)
start_val = dev_cnt * batch_id
for each_input_dict in input_data:
input_tensor = np.array(each_input_dict["input"])
self.assertEqual(self.shape, list(input_tensor.shape))
num = input_tensor.flatten()[0]
equal = (input_tensor == num).all()
self.assertTrue(equal)
if check_visited:
self.assertTrue(num not in self.visited)
self.visited.add(num)
start_val += 1
else:
self.assertEqual(
list(input_data.shape),
[self.shape[0] * dev_cnt] + self.shape[1:])
start_val = dev_cnt * batch_id
for idx in six.moves.range(dev_cnt):
data_part = input_data[idx * self.shape[0]:(idx + 1) *
self.shape[0], :]
num = data_part.flatten()[0]
self.assertTrue((data_part == num).all())
if check_visited:
self.assertTrue(num not in self.visited)
self.visited.add(num)
start_val += 1
def get_places(self):
place_list = [fluid.cpu_places(1), fluid.cpu_places(4)]
if fluid.is_compiled_with_cuda():
place_list.extend([fluid.cuda_places(0), fluid.cuda_places([0, 1])])
return place_list
def test_main(self):
for p in self.get_places():
use_compiled_program_list = [True] if len(p) > 1 else [False, True]
for use_compiled_program in use_compiled_program_list:
self.run_main_with_place(p, use_compiled_program)
def run_main_with_place(self, places, use_compiled_program=True):
with fluid.scope_guard(fluid.Scope()):
with fluid.program_guard(fluid.Program(), fluid.Program()):
input_data, loss, loader = self.build_network(places)
fetch_list = [input_data]
exe = fluid.Executor(places[0])
exe.run(fluid.default_startup_program())
dev_cnt = len(places)
if dev_cnt > 1:
self.assertTrue(use_compiled_program)
main_program = fluid.default_main_program()
if use_compiled_program:
main_program = fluid.CompiledProgram(
main_program).with_data_parallel(
loss_name=loss.name, places=places)
max_batch_num = min(self.break_num,
int(self.batch_num / dev_cnt))
if loader.iterable:
early_break = False
for epoch_id in six.moves.range(self.epoch_num):
early_break = False
self.clear_visited()
batch_id = 0
for data in loader():
if batch_id >= self.break_num:
early_break = True
break
self.assertInputData(
batch_id, data, dev_cnt, check_visited=False)
fetch_val, = exe.run(program=main_program,
feed=data,
fetch_list=fetch_list)
self.assertInputData(batch_id, fetch_val, dev_cnt)
batch_id += 1
if dev_cnt == 1:
self.assertEqual(batch_id, max_batch_num)
else:
self.assertLessEqual(batch_id, max_batch_num)
if early_break:
loader._reset()
else:
for epoch_id in six.moves.range(self.epoch_num):
batch_id = 0
self.clear_visited()
loader.start()
try:
while True:
if batch_id >= self.break_num:
loader.reset()
break
fetch_val, = exe.run(program=main_program,
fetch_list=fetch_list)
self.assertInputData(batch_id, fetch_val,
dev_cnt)
batch_id += 1
except fluid.core.EOFException:
loader.reset()
if dev_cnt == 1:
self.assertEqual(batch_id, max_batch_num)
else:
self.assertLessEqual(batch_id, max_batch_num)
class IterableDataLoaderKeepOrderTest2(DataLoaderKeepOrderTestBase):
def initParameters(self):
self.iterable = True
self.break_num = 10000
class IterableDataLoaderKeepOrderTest3(DataLoaderKeepOrderTestBase):
def initParameters(self):
self.iterable = False
self.break_num = 2
class IterableDataLoaderKeepOrderTest4(DataLoaderKeepOrderTestBase):
def initParameters(self):
self.iterable = True
self.break_num = 2
class IterableDataLoaderKeepOrderTest5(DataLoaderKeepOrderTestBase):
def initParameters(self):
self.iterable = False
self.break_num = 0
class IterableDataLoaderKeepOrderTest6(DataLoaderKeepOrderTestBase):
def initParameters(self):
self.iterable = True
self.break_num = 0
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import numpy as np
import unittest
import six
class TestInferencePartialFeed(unittest.TestCase):
def setUp(self):
self.iterations = 10
self.size = 10
def run_network(self, places, use_split, has_persistable):
startup_prog = fluid.Program()
main_prog = fluid.Program()
with fluid.program_guard(main_prog, startup_prog):
x = fluid.data(name='x', shape=[None, self.size], dtype='float32')
y = fluid.data(name='y', shape=[None, self.size], dtype='float32')
if has_persistable:
lr = fluid.data(name='lr', shape=[1], dtype='float32')
lr.persistable = True
else:
lr = fluid.data(name='lr', shape=[None], dtype='float32')
relu_x = fluid.layers.relu(x)
relu_y = fluid.layers.relu(y)
relu_lr = fluid.layers.relu(lr)
exe = fluid.Executor(places[0])
exe.run(startup_prog)
prog = fluid.CompiledProgram(main_prog).with_data_parallel(
places=places)
gen_random = lambda shape:np.random.uniform(low=-1.0, high=1.0, size=shape).astype('float32')
assert_result = lambda feed, result: self.assertTrue(np.array_equal(np.maximum(0, feed), result))
def assert_merged_unmerged(merged, unmerged):
unmerged = np.concatenate(unmerged, axis=0)
self.assertTrue(np.array_equal(merged, unmerged))
def feed_split_test():
for place_num in six.moves.range(1, len(places) * 3):
x_np = gen_random([place_num, self.size])
y_np = gen_random([place_num, self.size])
if not lr.persistable or place_num <= len(places):
lr_np = gen_random([place_num])
else:
lr_np = gen_random([1])
feed = {x.name: x_np, y.name: y_np, lr.name: lr_np}
fetch_list = [relu_x, relu_y, relu_lr]
relu_x_np, relu_y_np, relu_lr_np = exe.run(
prog, feed=feed, fetch_list=fetch_list, return_merged=True)
relu_x_np_unmerged, relu_y_np_unmerged, relu_lr_np_unmerged = exe.run(
prog, feed=feed, fetch_list=fetch_list, return_merged=False)
assert_merged_unmerged(relu_x_np, relu_x_np_unmerged)
assert_merged_unmerged(relu_y_np, relu_y_np_unmerged)
assert_merged_unmerged(relu_lr_np, relu_lr_np_unmerged)
assert_result(x_np, relu_x_np)
assert_result(y_np, relu_y_np)
if not lr.persistable or place_num <= len(places):
assert_result(lr_np, relu_lr_np)
else:
expected_relu_lr_np = max(lr_np[0], 0)
self.assertTrue(np.all(expected_relu_lr_np == relu_lr_np))
def feed_list_test():
for place_num in six.moves.range(1, len(places) + 1):
x_np_list = []
y_np_list = []
lr_np_list = []
feed_list = []
for _ in six.moves.range(place_num):
x_np = gen_random([1, self.size])
y_np = gen_random([1, self.size])
lr_np = gen_random([1])
x_np_list.append(x_np)
y_np_list.append(y_np)
lr_np_list.append(lr_np)
feed_list.append({
x.name: x_np,
y.name: y_np,
lr.name: lr_np
})
fetch_list = [relu_x, relu_y, relu_lr]
relu_x_np, relu_y_np, relu_lr_np = exe.run(
prog,
feed=feed_list,
fetch_list=fetch_list,
return_merged=True)
relu_x_np_unmerged, relu_y_np_unmerged, relu_lr_np_unmerged = exe.run(
prog,
feed=feed_list,
fetch_list=fetch_list,
return_merged=False)
assert_merged_unmerged(relu_x_np, relu_x_np_unmerged)
assert_merged_unmerged(relu_y_np, relu_y_np_unmerged)
assert_merged_unmerged(relu_lr_np, relu_lr_np_unmerged)
x_np = np.concatenate(x_np_list)
y_np = np.concatenate(y_np_list)
lr_np = np.concatenate(lr_np_list)
assert_result(x_np, relu_x_np)
assert_result(y_np, relu_y_np)
assert_result(lr_np, relu_lr_np)
for _ in six.moves.range(self.iterations):
if use_split:
feed_split_test()
else:
feed_list_test()
def test_main(self):
places = [fluid.cpu_places(4)]
if fluid.is_compiled_with_cuda():
places.append(fluid.cuda_places())
for p in places:
for has_persistable in [False, True]:
for use_split in [False, True]:
self.run_network(
p, use_split=use_split, has_persistable=has_persistable)
class TestInferencePartialFeedUsingDataLoader(unittest.TestCase):
def setUp(self):
self.epoch_num = 3
self.batch_num = 101 # a prime number
self.batch_size = 32
def create_reader(self):
def __impl__():
for _ in six.moves.range(self.batch_num):
yield np.random.random([self.batch_size, 1]).astype('float32'),
return __impl__
def run_network(self, iterable, use_cuda, drop_last):
x = fluid.data(shape=[None, 1], name='x', dtype='float32')
places = fluid.cuda_places() if use_cuda else fluid.cpu_places(4)
loader = fluid.io.DataLoader.from_generator(
feed_list=[x], capacity=16, iterable=iterable, drop_last=drop_last)
y = fluid.layers.fc(x, size=10)
loss = fluid.layers.reduce_mean(y)
exe = fluid.Executor(places[0])
exe.run(fluid.default_startup_program())
prog = fluid.CompiledProgram(fluid.default_main_program(
)).with_data_parallel(
places=places, loss_name=loss.name)
loader.set_batch_generator(
self.create_reader(), places=places if iterable else None)
for _ in six.moves.range(self.epoch_num):
actual_batch_num = 0
if loader.iterable:
for feed_data in loader():
x_data, = exe.run(prog, feed=feed_data, fetch_list=[x])
self.assertEqual(x_data.shape[0] % self.batch_size, 0)
self.assertTrue(x_data.shape[0] != 0)
actual_batch_num += int(x_data.shape[0] / self.batch_size)
else:
loader.start()
try:
while True:
x_data, = exe.run(prog, fetch_list=[x])
self.assertEqual(x_data.shape[0] % self.batch_size, 0)
self.assertTrue(x_data.shape[0] != 0)
actual_batch_num += int(x_data.shape[0] /
self.batch_size)
except fluid.core.EOFException:
loader.reset()
if not drop_last or len(places) == 1:
self.assertEqual(self.batch_num, actual_batch_num)
else:
self.assertGreater(self.batch_num, actual_batch_num)
def test_main(self):
use_cuda_list = [False, True] if fluid.is_compiled_with_cuda(
) else [False]
iterable_list = [False, True]
drop_last_list = [False, True]
for iterable in iterable_list:
for use_cuda in use_cuda_list:
for drop_last in drop_last_list:
with fluid.program_guard(fluid.Program(), fluid.Program()):
with fluid.scope_guard(fluid.Scope()):
self.run_network(iterable, use_cuda, drop_last)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册