未验证 提交 acfc9b8a 编写于 作者: Z Zeng Jinle 提交者: GitHub

Reader sequential and inference partial feed (#22699)

* sequential reader stage 1, test=develop

* fix ut, test=develop

* fix iterable=False reset bug, add some logs and polish code, test=develop

* inference feed partial data, test=develop

* Turn on keep_order=True for test, test=develop

* enhance ut to test more cases, test=develop

* test commit for reverting

* Revert "test commit for reverting", test=develop

This reverts commit 80aef42e.

* add ut of merged and unmerged results, test=develop

* add more uts for coverages and add en doc of api, test=develop

* follow comments, test=develop

* change note style, test=develop
上级 95b356a0
......@@ -9,7 +9,7 @@ cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_pr
cc_library(share_tensor_buffer_op_handle SRCS share_tensor_buffer_op_handle.cc DEPS op_handle_base scope computation_op_handle share_tensor_buffer_functor)
cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(fetch_barrier_op_handle SRCS fetch_barrier_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(multi_devices_helper INTERFACE SRCS multi_devices_helper.cc DEPS graph graph_helper)
cc_library(multi_devices_helper SRCS multi_devices_helper.cc DEPS graph graph_helper)
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)
......@@ -65,13 +65,15 @@ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope d
cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper)
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto
multi_devices_helper
sequential_execution_pass
modify_op_lock_and_record_event_pass
all_reduce_deps_pass
reference_count_pass
eager_deletion_pass
buffer_shared_inplace_op_pass
buffer_shared_cross_op_memory_reuse_pass)
buffer_shared_cross_op_memory_reuse_pass
set_reader_device_info_pass)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS})
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
......
......@@ -66,6 +66,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
AppendPrintGraphPass("graph_viz_pass", "_fused_graph");
AppendMultiDevPass();
AppendSetReaderDeviceIndexPass();
AppendMultiGraphOptPasses();
AppendPassToSetMkldnnAttr("mkldnn_placement_pass");
......@@ -227,6 +228,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
&strategy_);
}
void AppendSetReaderDeviceIndexPass() {
AppendPass("set_reader_device_index_pass");
}
void AppendPrintGraphPass(const std::string &pass_name,
const std::string &debug_file_suffix) {
if (!strategy_.debug_graphviz_path_.empty()) {
......@@ -397,6 +402,9 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
"GPU, skipped.";
continue;
}
} else if (pass->Type() == "set_reader_device_index_pass") {
pass->Erase(kPlaces);
pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
}
VLOG(1) << "Start Apply Pass " << pass->Type();
graph = pass->Apply(graph);
......@@ -433,6 +441,7 @@ USE_PASS(fuse_sgd_op_pass);
USE_PASS(fuse_momentum_op_pass);
USE_PASS(fuse_all_reduce_op_pass);
USE_PASS(runtime_context_cache_pass);
USE_PASS(set_reader_device_index_pass);
#ifdef PADDLE_WITH_MKLDNN
USE_PASS(mkldnn_placement_pass);
#endif
......
......@@ -34,6 +34,8 @@ class ComputationOpHandle : public OpHandleBase {
OperatorBase *GetOp() { return op_.get(); }
const OperatorBase *GetOp() const { return op_.get(); }
std::string Name() const override;
const Scope *GetScope() const { return scope_; }
......
......@@ -31,10 +31,12 @@ namespace framework {
namespace details {
EagerDeletionOpHandle::EagerDeletionOpHandle(
ir::Node *node, Scope *scope, const platform::Place &place,
ir::Node *node, Scope *scope, size_t scope_idx,
const platform::Place &place,
const std::unordered_set<ir::MemOptVarInfo *> &vars, GarbageCollector *gc)
: OpHandleBase(node),
scope_(scope),
scope_idx_(scope_idx),
place_(place),
var_infos_(vars.begin(), vars.end()),
gc_(gc) {
......
......@@ -34,7 +34,7 @@ namespace details {
class EagerDeletionOpHandle : public OpHandleBase {
public:
EagerDeletionOpHandle(ir::Node *node, Scope *scope,
EagerDeletionOpHandle(ir::Node *node, Scope *scope, size_t scope_idx,
const platform::Place &place,
const std::unordered_set<ir::MemOptVarInfo *> &vars,
GarbageCollector *gc);
......@@ -50,6 +50,8 @@ class EagerDeletionOpHandle : public OpHandleBase {
*/
Priority GetPriority() const override { return kHighest; }
size_t GetScopeIdx() const { return scope_idx_; }
protected:
void RunImpl() override;
......@@ -63,6 +65,7 @@ class EagerDeletionOpHandle : public OpHandleBase {
void CallOnce();
Scope *scope_;
size_t scope_idx_;
platform::Place place_;
std::vector<ir::MemOptVarInfo *> var_infos_; // not own
GarbageCollector *gc_; // not own
......
......@@ -12,9 +12,227 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include <algorithm>
#include <unordered_set>
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/details/share_tensor_buffer_op_handle.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle {
namespace framework {
namespace details {} // namespace details
namespace details {
static constexpr size_t kUndefinedDevIdx = -1UL;
// NOTE(paddle-dev): the following ops are related to multi-device
// communication. If the graph contains any of the following ops,
// it cannot separate into multiple graphs on each device.
static std::unordered_set<std::string> kMultiDeviceOps{
"sync_batch_norm",
"sync_batch_norm_grad",
"allreduce",
"c_allreduce_sum",
"c_allreduce_prod",
"c_allreduce_min",
"c_allreduce_max",
"c_allgather",
"c_reducescatter",
"c_broadcast",
"c_comm_init",
"c_comm_init_all",
"c_gen_nccl_id",
"c_sync_comm_stream",
"send",
"recv",
"send_barrier",
"fetch_barrier",
};
static size_t GetScopeIdxFromOp(const details::OpHandleBase &op) {
if (auto *compute_op =
dynamic_cast<const details::ComputationOpHandle *>(&op)) {
return kMultiDeviceOps.count(compute_op->GetOp()->Type()) == 0
? compute_op->GetScopeIdx()
: kUndefinedDevIdx;
} else if (auto *gc_op =
dynamic_cast<const details::EagerDeletionOpHandle *>(&op)) {
return gc_op->GetScopeIdx();
} else if (auto *share_op =
dynamic_cast<const details::ShareTensorBufferOpHandle *>(
&op)) {
return share_op->GetScopeIdx();
} else {
return kUndefinedDevIdx;
}
}
static bool ContainMultiDeviceOp(const ProgramDesc &program,
size_t begin_block_idx) {
for (size_t block_idx = begin_block_idx; block_idx < program.Size();
++block_idx) {
for (auto *op_desc : program.Block(block_idx).AllOps()) {
if (kMultiDeviceOps.count(op_desc->Type()) > 0) {
return true;
}
}
}
return false;
}
static size_t GetUniqueDeviceIdOfOp(const details::OpHandleBase &op) {
size_t dev_idx = GetScopeIdxFromOp(op);
if (dev_idx == kUndefinedDevIdx) {
return kUndefinedDevIdx;
}
const auto &ins = op.Inputs();
const auto &outs = op.Outputs();
auto in_outs = ins;
in_outs.insert(in_outs.end(), outs.begin(), outs.end());
for (auto *var : in_outs) {
auto *var_handle = dynamic_cast<details::VarHandle *>(var);
if (var_handle == nullptr) {
continue;
}
if (dev_idx != var_handle->scope_idx()) {
return kUndefinedDevIdx;
}
}
return dev_idx;
}
/**
* This function tries to separate the original graph into multiple graphs, in
* which each graph would only run on single device. This is usually used to
* separate a data-parallel inference graph to multiple graphs on each device.
*
* The graph can be separated into multiple single device graphs if and only if:
*
* - the graph does not contain any ops related to multi-devices communication,
* such as allreduce, send, recv, sync_batch_norm, etc.
*
* - ops on different devices do not depend on each other. That is to say, the
* graph has several disconnected sub-graphs.
*/
std::vector<std::unique_ptr<ir::Graph>> TrySeparateToMultipleSingleDeviceGraphs(
ir::Graph *graph) {
// If sub-block contains multi-devices ops, we cannot separate
if (ContainMultiDeviceOp(graph->OriginProgram(), 1)) {
return {};
}
size_t place_num = 0;
auto op_handles = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
if (op_handles.empty()) {
return {};
}
std::unordered_map<details::OpHandleBase *, size_t> op_to_dev_idx;
for (auto &op : op_handles) {
auto dev_idx = GetUniqueDeviceIdOfOp(*op);
if (dev_idx == kUndefinedDevIdx) {
VLOG(10) << "Op " << op->Name() << " is not determined";
return {};
}
place_num = std::max(place_num, dev_idx + 1);
op_to_dev_idx[op] = dev_idx;
}
for (auto &op : op_handles) {
auto dev_idx = op_to_dev_idx.at(op);
for (auto &in_var : op->Inputs()) {
if (in_var->GeneratedOp()) {
auto iter = op_to_dev_idx.find(in_var->GeneratedOp());
if (iter == op_to_dev_idx.end() || iter->second != dev_idx) {
return {};
}
}
}
for (auto &out_var : op->Outputs()) {
for (auto &pending_op : out_var->PendingOps()) {
auto iter = op_to_dev_idx.find(pending_op);
if (iter == op_to_dev_idx.end() || iter->second != dev_idx) {
return {};
}
}
}
}
PADDLE_ENFORCE_GE(
place_num, 1,
platform::errors::NotFound(
"No place found, this may be a bug.\nIt would be helpful if you "
"could inform us of how this conversion went by opening a github "
"issue at https://github.com/PaddlePaddle/Paddle/issues/new. And "
"we will resolve it with high priority."));
std::vector<std::unique_ptr<ir::Graph>> graphs(place_num);
for (auto &g : graphs) {
g.reset(new ir::Graph(ProgramDesc()));
g->Set(kGraphVars, new GraphVars(1UL));
g->Set(kGraphDepVars, new GraphDepVars());
}
for (auto &op : op_handles) {
auto dev_idx = op_to_dev_idx.at(op);
auto *ret_graph = graphs[dev_idx].get();
auto &ret_vars = ret_graph->Get<GraphVars>(kGraphVars)[0];
auto &ret_dummy_vars = ret_graph->Get<GraphDepVars>(kGraphDepVars);
auto &origin_vars = graph->Get<GraphVars>(kGraphVars)[dev_idx];
ret_graph->AddNode(graph->RemoveNode(op->Node()).release());
auto handler = [&](const std::vector<VarHandleBase *> &vars) {
for (auto *var : vars) {
if (graph->Nodes().count(var->Node()) > 0) {
ret_graph->AddNode(graph->RemoveNode(var->Node()).release());
auto *dummy_var = dynamic_cast<DummyVarHandle *>(var);
if (dummy_var == nullptr) {
ret_vars.emplace(var->Name(), origin_vars.at(var->Name()));
} else {
ret_dummy_vars.emplace(dummy_var);
}
}
}
};
handler(op->Inputs());
handler(op->Outputs());
}
graph->Erase(kGraphVars);
graph->Erase(kGraphDepVars);
return graphs;
}
static bool HasDropLastReadOpImpl(const ir::Graph &graph, bool drop_last) {
auto ops = ir::FilterByNodeWrapper<OpHandleBase>(graph);
for (auto *op : ops) {
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
if (compute_op && compute_op->GetOp()->Type() == "read" &&
compute_op->GetOp()->Attr<bool>("drop_last") == drop_last) {
VLOG(10) << "The graph has drop_last=" << drop_last << " read op";
return true;
}
}
VLOG(10) << "The graph does not have drop_last=" << drop_last << " read op";
return false;
}
bool HasDropLastReadOp(const ir::Graph &graph) {
return HasDropLastReadOpImpl(graph, true);
}
bool HasKeepLastReadOp(const ir::Graph &graph) {
return HasDropLastReadOpImpl(graph, false);
}
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -47,6 +47,7 @@ constexpr char kGraphVars[] = "vars";
constexpr char kNRanks[] = "nranks";
constexpr char kPlaces[] = "places";
constexpr char kGlobalScope[] = "global_scope";
constexpr char kLocalScopes[] = "local_scopes";
constexpr char kNCCLCtxs[] = "nccl_ctxs";
constexpr char kUseHierarchicalAllReduce[] = "use_hierarchical_allreduce";
......@@ -100,6 +101,13 @@ inline std::vector<std::string> GetOpRoleVarsOrEmpty(const OpDesc &op) {
return boost::get<std::vector<std::string>>(iter->second);
}
std::vector<std::unique_ptr<ir::Graph>> TrySeparateToMultipleSingleDeviceGraphs(
ir::Graph *graph);
bool HasDropLastReadOp(const ir::Graph &graph);
bool HasKeepLastReadOp(const ir::Graph &graph);
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h"
#include <algorithm>
#include <memory>
#include <utility>
#include "paddle/fluid/framework/ir/graph_helper.h"
......@@ -21,11 +22,11 @@ namespace paddle {
namespace framework {
namespace details {
std::vector<std::unique_ptr<ir::Graph>>
ParallelSSAGraphExecutor::SeparateMultiDevicesGraph(ir::Graph *graph) {
static std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph(
ir::Graph *graph, size_t place_num) {
std::vector<std::unique_ptr<ir::Graph>> graphs;
graphs.reserve(places_.size());
for (size_t i = 0; i < places_.size(); ++i) {
graphs.reserve(place_num);
for (size_t i = 0; i < place_num; ++i) {
ProgramDesc empty;
graphs.emplace_back(std::unique_ptr<ir::Graph>(new ir::Graph(empty)));
auto &g = graphs.back();
......@@ -64,7 +65,7 @@ ParallelSSAGraphExecutor::SeparateMultiDevicesGraph(ir::Graph *graph) {
}
}
for (size_t dev_id = 0; dev_id < places_.size(); ++dev_id) {
for (size_t dev_id = 0; dev_id < place_num; ++dev_id) {
auto &dev_vars = graphs[dev_id]->Get<GraphVars>(kGraphVars)[0];
auto &origin_vars = graph->Get<GraphVars>(kGraphVars)[dev_id];
for (auto &name_pair : origin_vars) {
......@@ -85,15 +86,34 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<Scope *> &local_exec_scopes,
const std::vector<platform::Place> &places, ir::Graph *graph)
// TODO(Yancey1989): Copying graphs is not safely since it deleted the
// attrs.
: ParallelSSAGraphExecutor(strategy, local_scopes, local_exec_scopes,
places,
SeparateMultiDevicesGraph(graph,
places.size())) {}
ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<Scope *> &local_exec_scopes,
const std::vector<platform::Place> &places,
std::vector<std::unique_ptr<ir::Graph>> graphs)
: strategy_(std::move(strategy)),
local_scopes_(std::move(local_scopes)),
pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr),
places_(std::move(places)),
// TODO(Yancey1989): Copying graphs is not safely since it deleted the
// attrs.
graphs_(SeparateMultiDevicesGraph(graph)) {
places_(places),
graphs_(std::move(graphs)),
feed_status_(places.size(), FeedStatus::kNone) {
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
PADDLE_ENFORCE_EQ(places_.size(), graphs_.size(),
platform::errors::InvalidArgument(
"Graph number does not match place number"));
PADDLE_ENFORCE_GT(
places_.size(), 0,
platform::errors::InvalidArgument("place number must be larger than 0"));
auto seq_allreduce_pass =
ir::PassRegistry::Instance().Get("all_reduce_deps_pass");
seq_allreduce_pass->Set<bool>(kUseHierarchicalAllReduce, new bool(false));
......@@ -123,28 +143,41 @@ std::vector<ir::Graph *> ParallelSSAGraphExecutor::Graphs() {
return result;
}
enum ExceptionStatus { kSuccess = 0, kEOF, kOther };
FetchResultType ParallelSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors, bool return_merged) {
size_t feed_num = std::count(feed_status_.begin(), feed_status_.end(),
FeedStatus::kHasFeed);
bool has_feed = (feed_num > 0);
VLOG(10) << "Feed num " << feed_num;
size_t place_num = places_.size();
std::vector<std::future<FetchResultType>> run_futures;
std::vector<ExceptionStatus> exception_status(place_num,
ExceptionStatus::kSuccess);
std::vector<FetchResultType> fetch_data;
FetchResultType ret;
fetch_data.reserve(places_.size());
if (return_merged) {
ret = FeedFetchList();
} else {
ret = FetchUnmergedList();
}
fetch_data.reserve(place_num);
exception_holder_.Clear();
for (size_t i = 0; i < places_.size(); ++i) {
auto call = [this, i, return_merged, &fetch_tensors]() -> FetchResultType {
for (size_t i = 0; i < place_num; ++i) {
auto call = [&, i]() -> FetchResultType {
try {
return executors_[i]->Run(fetch_tensors, return_merged);
if (!support_partial_feed_ || !has_feed ||
feed_status_[i] == FeedStatus::kHasFeed) {
return executors_[i]->Run(fetch_tensors, return_merged);
}
} catch (platform::EOFException &) {
exception_status[i] = ExceptionStatus::kEOF;
exception_holder_.Catch(std::current_exception());
} catch (...) {
exception_status[i] = ExceptionStatus::kOther;
exception_holder_.Catch(std::current_exception());
}
if (return_merged) {
return FeedFetchList();
} else {
......@@ -161,46 +194,96 @@ FetchResultType ParallelSSAGraphExecutor::Run(
if (pool_) {
for (auto &f : run_futures) {
if (exception_holder_.IsCaught()) {
f.wait();
} else {
fetch_data.emplace_back(f.get());
fetch_data.emplace_back(f.get());
}
}
bool has_exception = exception_holder_.IsCaught();
if (!support_partial_feed_ && has_exception) {
VLOG(10) << "Exception rethrow because partial feed is not supported";
exception_holder_.ReThrow();
}
std::vector<bool> is_valid(place_num, true);
if (support_partial_feed_) {
if (has_feed) {
for (size_t i = 0; i < place_num; ++i) {
if (feed_status_[i] == FeedStatus::kNone) {
is_valid[i] = false;
} else if (exception_status[i] != ExceptionStatus::kSuccess) {
PADDLE_ENFORCE_EQ(has_exception, true,
platform::errors::InvalidArgument(
"Thread pool raises exception but not caught"));
VLOG(10) << "Exception rethrow because non-EOF exception raises when "
"feed is given";
exception_holder_.ReThrow();
}
}
} else {
for (size_t i = 0; i < place_num; ++i) {
if (exception_status[i] == ExceptionStatus::kOther) {
PADDLE_ENFORCE_EQ(has_exception, true,
platform::errors::InvalidArgument(
"Thread pool raises exception but not caught"));
VLOG(10) << "Exception rethrow because non-EOF exception raises when "
"feed is not given";
exception_holder_.ReThrow();
} else if (exception_status[i] != ExceptionStatus::kSuccess) {
is_valid[i] = false;
}
}
}
}
if (exception_holder_.IsCaught()) {
if (std::count(is_valid.begin(), is_valid.end(), true) == 0) {
PADDLE_ENFORCE_EQ(has_exception, true,
platform::errors::InvalidArgument(
"Thread pool raises exception but not caught"));
VLOG(10) << "Raise exception because there is no success worker";
exception_holder_.ReThrow();
}
if (return_merged) {
auto &ret_val = boost::get<FeedFetchList>(ret);
FeedFetchList ret;
ret.reserve(fetch_tensors.size());
for (size_t fetch_idx = 0; fetch_idx < fetch_tensors.size(); ++fetch_idx) {
std::vector<const LoDTensor *> lodtensor_ptrs;
lodtensor_ptrs.reserve(local_scopes_.size());
for (size_t scope_idx = 0; scope_idx < local_scopes_.size();
++scope_idx) {
auto &val = boost::get<FeedFetchList>(fetch_data.at(scope_idx));
lodtensor_ptrs.push_back(&val.at(fetch_idx));
lodtensor_ptrs.reserve(place_num);
for (size_t scope_idx = 0; scope_idx < place_num; ++scope_idx) {
if (!is_valid[scope_idx]) {
continue;
}
const auto &fetch_list =
boost::get<FeedFetchList>(fetch_data[scope_idx]);
lodtensor_ptrs.push_back(&fetch_list[fetch_idx]);
}
ret_val.emplace_back();
ret_val.back().MergeLoDTensor(lodtensor_ptrs, platform::CPUPlace());
ret.emplace_back();
ret.back().MergeLoDTensor(lodtensor_ptrs, platform::CPUPlace());
}
return ret;
} else {
auto &ret_val = boost::get<FetchUnmergedList>(ret);
FetchUnmergedList ret;
ret.reserve(fetch_tensors.size());
for (size_t fetch_idx = 0; fetch_idx < fetch_tensors.size(); ++fetch_idx) {
ret_val.emplace_back();
ret.emplace_back();
for (size_t scope_idx = 0; scope_idx < local_scopes_.size();
++scope_idx) {
auto &val = boost::get<FetchUnmergedList>(fetch_data.at(scope_idx));
if (!is_valid[scope_idx]) {
continue;
}
const auto &fetch_list =
boost::get<FetchUnmergedList>(fetch_data[scope_idx]);
PADDLE_ENFORCE_EQ(
val.at(fetch_idx).size(), 1,
fetch_list[fetch_idx].size(), 1,
platform::errors::Fatal(
"Each place must have only one fetched LoDTensor!"));
ret_val.back().emplace_back(val.at(fetch_idx)[0]);
ret.back().emplace_back(fetch_list[fetch_idx][0]);
}
}
return ret;
}
return ret;
}
} // namespace details
......
......@@ -27,12 +27,25 @@ namespace framework {
namespace details {
class ParallelSSAGraphExecutor : public SSAGraphExecutor {
public:
enum FeedStatus {
kNone = 0, // No feed
kHasFeed = 1 // Has feed
};
public:
ParallelSSAGraphExecutor(const ExecutionStrategy &strategy,
const std::vector<Scope *> &local_scopes,
const std::vector<Scope *> &local_exec_scopes,
const std::vector<platform::Place> &places,
ir::Graph *graph);
ParallelSSAGraphExecutor(const ExecutionStrategy &strategy,
const std::vector<Scope *> &local_scopes,
const std::vector<Scope *> &local_exec_scopes,
const std::vector<platform::Place> &places,
std::vector<std::unique_ptr<ir::Graph>> graphs);
~ParallelSSAGraphExecutor() final = default;
const ir::Graph &Graph() const override { return *graphs_[0]; }
......@@ -42,10 +55,15 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor {
FetchResultType Run(const std::vector<std::string> &fetch_tensors,
bool return_merged) override;
private:
std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph(
ir::Graph *graph);
void SetHasFeed(size_t dev_idx, bool has_feed) {
feed_status_[dev_idx] = has_feed ? FeedStatus::kHasFeed : FeedStatus::kNone;
}
void EnablePartialFeedSupport() { support_partial_feed_ = true; }
bool SupportPartialFeed() const { return support_partial_feed_; }
private:
ExecutionStrategy strategy_;
std::vector<Scope *> local_scopes_;
std::unique_ptr<::ThreadPool> pool_{nullptr};
......@@ -55,6 +73,9 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor {
std::vector<std::unique_ptr<details::FastThreadedSSAGraphExecutor>>
executors_;
ExceptionHolder exception_holder_;
bool support_partial_feed_{false};
std::vector<FeedStatus> feed_status_;
};
} // namespace details
......
......@@ -228,7 +228,7 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const {
}
auto *eager_deletion_op = new details::EagerDeletionOpHandle(
eager_deletion_node, op->GetScope(), op->GetPlace(),
eager_deletion_node, op->GetScope(), op->GetScopeIdx(), op->GetPlace(),
std::move(var_info), gcs.at(places[op->GetScopeIdx()]).get());
auto it = std::find_if(
......
......@@ -98,7 +98,7 @@ class ReferenceCountPassTestHelper {
ir::PassRegistry::Instance().Get("reference_count_pass");
ref_cnt_pass->SetNotOwned(ir::kMemOptVarInfoMapList, &mem_opt_var_infos_);
ref_cnt_pass->SetNotOwned(ir::kLastLiveOpsOfVars, &last_live_ops_of_vars_);
ref_cnt_pass->Apply(&graph_);
ref_cnt_pass->Apply(&const_cast<ir::Graph &>(executor_->Graph()));
}
bool IsLastLivedOps(const std::string &name,
......
......@@ -11,6 +11,7 @@ endif()
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle fetch_barrier_op_handle ${ALL_REDUCE_OP_HANDLES} reduce_op_handle broadcast_op_handle fused_broadcast_op_handle)
cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass)
cc_library(set_reader_device_info_pass SRCS set_reader_device_info_pass.cc DEPS graph graph_helper pass multi_devices_graph_pass)
cc_library(fuse_all_reduce_op_pass SRCS fuse_all_reduce_op_pass.cc DEPS graph graph_helper fused_all_reduce_op_handle)
cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS all_reduce_op_handle graph graph_helper pass)
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h"
namespace paddle {
namespace framework {
namespace ir {
static int GetDeviceCountFromPassAttr(const Pass &pass) {
return static_cast<int>(
pass.Get<const std::vector<platform::Place>>(details::kPlaces).size());
}
static std::unordered_set<std::string> ReaderOpSet() {
return {"create_py_reader"};
}
class InitReaderDeviceCountPass : public Pass {
protected:
void ApplyImpl(Graph *graph) const override {
using QueueHolder =
operators::reader::OrderedMultiDeviceLoDTensorBlockingQueueHolder;
auto reader_ops = ReaderOpSet();
auto dev_cnt = GetDeviceCountFromPassAttr(*this);
const auto &scope = Get<const Scope>(details::kGlobalScope);
for (auto &node : graph->Nodes()) {
if (node->IsOp() && node->Op() &&
reader_ops.count(node->Op()->Type()) != 0) {
auto queue_name = node->Op()->Input("blocking_queue")[0];
auto var = scope.FindVar(queue_name);
if (var && var->IsType<QueueHolder>()) {
VLOG(10) << "Set device count of " << queue_name << " to be "
<< dev_cnt;
var->GetMutable<QueueHolder>()->GetQueue()->SetDeviceCount(dev_cnt);
}
}
}
}
};
class SetReaderDeviceIndexPass : public Pass {
protected:
void ApplyImpl(Graph *graph) const override {
auto dev_cnt = GetDeviceCountFromPassAttr(*this);
auto reader_ops = ReaderOpSet();
size_t found_op_num = 0;
for (auto &node : graph->Nodes()) {
if (node->IsOp() && node->Op() &&
reader_ops.count(node->Op()->Type()) != 0) {
auto &op_handle = dynamic_cast<details::ComputationOpHandle &>(
node->Wrapper<details::OpHandleBase>());
auto *op_desc = node->Op();
auto &op_base_attrs =
const_cast<framework::AttributeMap &>(op_handle.GetOp()->Attrs());
int dev_idx = static_cast<int>(op_handle.GetScopeIdx());
op_desc->SetAttr("device_index", dev_idx);
op_desc->SetAttr("device_count", dev_cnt);
op_base_attrs["device_index"] = dev_idx;
op_base_attrs["device_count"] = dev_cnt;
++found_op_num;
VLOG(10) << "Found op " << op_desc->Type() << " on device " << dev_idx;
}
}
VLOG(10) << "Found op number " << found_op_num;
}
};
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(init_reader_device_count_pass,
paddle::framework::ir::InitReaderDeviceCountPass)
.RequirePassAttr(paddle::framework::details::kGlobalScope)
.RequirePassAttr(paddle::framework::details::kPlaces);
REGISTER_PASS(set_reader_device_index_pass,
paddle::framework::ir::SetReaderDeviceIndexPass)
.RequirePassAttr(paddle::framework::details::kPlaces);
......@@ -307,18 +307,18 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor,
std::vector<LoDTensor> LoDTensor::SplitLoDTensor(
const std::vector<platform::Place> places) const {
PADDLE_ENFORCE_GT(places.size(), 0,
platform::errors::InvalidArgument(
"place number cannot be empty when splitting"));
check_memory_size();
int batch_size =
lod().empty() ? dims()[0] : static_cast<int>(lod()[0].size()) - 1;
size_t result_size = std::min(static_cast<size_t>(batch_size), places.size());
size_t remainder = batch_size % places.size();
size_t batch_size =
lod().empty() ? static_cast<size_t>(dims()[0]) : lod()[0].size() - 1;
std::vector<LoDTensor> results;
results.reserve(result_size);
// if result_size(batch_size) is 0, just return #places.size() copys of empty
// if batch_size is 0, just return #places.size() copys of empty
// tensors.
if (result_size == 0) {
if (batch_size == 0) {
std::vector<LoDTensor> empty_results;
empty_results.reserve(places.size());
for (size_t i = 0; i < places.size(); ++i) {
LoDTensor dst;
dst.Resize(dims());
......@@ -326,18 +326,22 @@ std::vector<LoDTensor> LoDTensor::SplitLoDTensor(
if (!lod().empty()) {
dst.set_lod(lod());
}
results.emplace_back(dst);
empty_results.emplace_back(std::move(dst));
}
return results;
return empty_results;
}
int step_width = static_cast<int>(batch_size / result_size);
auto step_width = (batch_size + places.size() - 1) / places.size();
auto result_size = (batch_size + step_width - 1) / step_width;
std::vector<LoDTensor> results;
results.reserve(result_size);
for (size_t i = 0; i < result_size; ++i) {
int begin = static_cast<int>(i * step_width);
int end = static_cast<int>((i + 1) * step_width);
if (i + 1 == places.size()) { // last
end += remainder;
}
auto begin = i * step_width;
auto end = std::min<size_t>((i + 1) * step_width, batch_size);
PADDLE_ENFORCE_LT(begin, end,
platform::errors::InvalidArgument(
"begin must be less than end, this may be a bug"));
LoDTensor dst;
if (lod().empty()) {
......@@ -362,7 +366,7 @@ std::vector<LoDTensor> LoDTensor::SplitLoDTensor(
}
dst.set_lod(my_lod);
}
results.emplace_back(dst);
results.emplace_back(std::move(dst));
}
return results;
......
......@@ -55,8 +55,9 @@ static bool gProfileStarted = false;
class ParallelExecutorPrivate {
public:
explicit ParallelExecutorPrivate(const std::vector<platform::Place> &places)
: places_(places) {
ParallelExecutorPrivate(const std::vector<platform::Place> &places,
Scope *global_scope)
: places_(places), global_scope_(global_scope) {
if (!FLAGS_pe_profile_fname.empty()) {
std::call_once(gProfileOnce, [] {
#ifdef WITH_GPERFTOOLS
......@@ -82,6 +83,19 @@ class ParallelExecutorPrivate {
}
}
void InitReaderDeviceCount(ir::Graph *graph) const {
auto pass =
ir::PassRegistry::Instance().Get("init_reader_device_count_pass");
pass->SetNotOwned<const Scope>(details::kGlobalScope, global_scope_);
pass->SetNotOwned<const std::vector<platform::Place>>(details::kPlaces,
&places_);
pass->Apply(graph);
}
void SetHasFeed(size_t dev_idx, bool has_feed = true);
bool AllowPartialFeed() const;
ir::Graph *ApplyMemoryOptimizePass(ir::Graph *graph);
inline bool HasGarbageCollectors() const { return !gcs_.empty(); }
......@@ -257,8 +271,20 @@ class ParallelExecutorPrivate {
ir::MemOptVarInfoMapList mem_opt_var_infos_;
ir::GarbageCollectorMap gcs_;
details::ParallelSSAGraphExecutor *inference_executor_{nullptr};
};
void ParallelExecutorPrivate::SetHasFeed(size_t dev_idx, bool has_feed) {
if (inference_executor_) {
inference_executor_->SetHasFeed(dev_idx, has_feed);
}
}
bool ParallelExecutorPrivate::AllowPartialFeed() const {
return inference_executor_ && inference_executor_->SupportPartialFeed();
}
ir::Graph *ParallelExecutorPrivate::ApplyMemoryOptimizePass(ir::Graph *graph) {
if (FLAGS_use_ngraph) {
LOG_FIRST_N(WARNING, 1)
......@@ -379,6 +405,21 @@ ir::Graph *ParallelExecutorPrivate::ApplyMemoryOptimizePass(ir::Graph *graph) {
return graph;
}
class ResetHasFeedGuard {
public:
explicit ResetHasFeedGuard(ParallelExecutorPrivate *pe_member)
: pe_member_(pe_member) {}
~ResetHasFeedGuard() {
for (size_t i = 0; i < pe_member_->places_.size(); ++i) {
pe_member_->SetHasFeed(i, false);
}
}
private:
ParallelExecutorPrivate *pe_member_;
};
size_t ParallelExecutor::DeviceCount() const { return member_->places_.size(); }
std::vector<Scope *> &ParallelExecutor::GetLocalScopes() {
......@@ -407,8 +448,8 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
const ExecutionStrategy &exec_strategy,
const BuildStrategy &build_strategy,
ir::Graph *graph)
: member_(new ParallelExecutorPrivate(places)) {
member_->global_scope_ = scope;
: member_(new ParallelExecutorPrivate(places, scope)) {
member_->InitReaderDeviceCount(graph);
member_->use_cuda_ = exec_strategy.use_cuda_;
member_->build_strategy_ = build_strategy;
member_->use_all_reduce_ = member_->build_strategy_.reduce_ ==
......@@ -616,18 +657,38 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
"Paddle should be compiled with CUDA for ParallelGraph Execution.");
#endif
} else {
if (exec_strategy.type_ == ExecutionStrategy::kDefault) {
VLOG(3) << "use ThreadedSSAGraphExecutor";
member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
bool has_drop_last_read_op = details::HasDropLastReadOp(*graph);
auto possible_inference_graphs =
details::TrySeparateToMultipleSingleDeviceGraphs(graph);
if (!possible_inference_graphs.empty()) {
VLOG(5) << "Use ParallelSSAGraphExecutor in inference phase";
auto *pg_exe = new details::ParallelSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->local_exec_scopes_,
member_->places_, graph));
member_->places_, std::move(possible_inference_graphs));
if (!has_drop_last_read_op) {
VLOG(5) << "Enable partial feed support in inference phase";
pg_exe->EnablePartialFeedSupport();
}
final_graphs = pg_exe->Graphs();
member_->executor_.reset(pg_exe);
member_->inference_executor_ = pg_exe;
} else {
VLOG(3) << "use FastThreadedSSAGraphExecutor";
member_->executor_.reset(new details::FastThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->local_exec_scopes_,
member_->places_, graph));
LOG_IF(WARNING, details::HasKeepLastReadOp(*graph))
<< "drop_last=False for DataLoader is not supported in training "
"network. It is automatically turned to drop_last=True.";
if (exec_strategy.type_ == ExecutionStrategy::kDefault) {
VLOG(3) << "use ThreadedSSAGraphExecutor";
member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->local_exec_scopes_,
member_->places_, graph));
} else {
VLOG(3) << "use FastThreadedSSAGraphExecutor";
member_->executor_.reset(new details::FastThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->local_exec_scopes_,
member_->places_, graph));
}
final_graphs.emplace_back(graph);
}
final_graphs.emplace_back(graph);
}
VLOG(3) << "use ScopeBufferedSSAGraphExecutor";
......@@ -735,6 +796,8 @@ FetchResultType ParallelExecutor::Run(
platform::RecordBlock b(0);
ResetHasFeedGuard reset_has_feed_guard(member_);
ir::SkipMemOptVarsGuard guard(&(member_->mem_opt_var_infos_), fetch_tensors,
member_->HasGarbageCollectors());
......@@ -745,10 +808,31 @@ FetchResultType ParallelExecutor::Run(
void ParallelExecutor::FeedTensorsIntoLocalScopes(
const std::vector<std::unordered_map<std::string, LoDTensor>> &tensors) {
PADDLE_ENFORCE_EQ(member_->local_scopes_.size(), tensors.size());
if (!member_->AllowPartialFeed()) {
PADDLE_ENFORCE_EQ(tensors.size(), member_->local_scopes_.size(),
platform::errors::Unimplemented(
"The feed data number %d does not match the device "
"number %d. If you are using DataLoader to feed "
"data, this may be because you set drop_last=False "
"in training network. Currently, drop_last=False for "
"DataLoader is not supported for training network. "
"Please set drop_last=True when defining DataLoader.",
tensors.size(), member_->local_scopes_.size()));
} else {
PADDLE_ENFORCE_GE(member_->local_scopes_.size(), tensors.size(),
platform::errors::InvalidArgument(
"The feed tensor number exceeds the device number"));
}
size_t feed_num = 0;
for (size_t i = 0; i < tensors.size(); ++i) {
auto &map = tensors[i];
if (map.empty()) {
continue;
}
member_->SetHasFeed(i);
++feed_num;
for (auto &pair : map) {
bool is_persistable = member_->IsPersistable(pair.first);
if (!is_persistable) {
......@@ -763,11 +847,28 @@ void ParallelExecutor::FeedTensorsIntoLocalScopes(
trg->set_lod(pair.second.lod());
}
}
if (!member_->AllowPartialFeed()) {
PADDLE_ENFORCE_EQ(feed_num, member_->local_scopes_.size(),
platform::errors::Unimplemented(
"The feed data number %d does not match the device "
"number %d. If you are using DataLoader to feed "
"data, this may be because you set drop_last=False "
"in training network. Currently, drop_last=False for "
"DataLoader is not supported for training network. "
"Please set drop_last=True when defining DataLoader.",
feed_num, member_->local_scopes_.size()));
}
}
void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
const std::unordered_map<std::string, LoDTensor> &tensors) {
size_t num_places = member_->places_.size();
bool allow_partial_feed = member_->AllowPartialFeed();
size_t persistable_feed_len = -1UL;
size_t non_persistable_feed_len = -1UL;
for (auto &pair : tensors) {
bool is_persistable = member_->IsPersistable(pair.first);
VLOG(3) << "Split " << (is_persistable ? "persistable" : "no persistable")
......@@ -775,7 +876,8 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
<< ", place: " << pair.second.place();
auto lod_tensors = pair.second.SplitLoDTensor(member_->places_);
bool is_cpu_place = platform::is_cpu_place(member_->places_.front());
if (!is_persistable && num_places != lod_tensors.size()) {
if (!is_persistable && num_places != lod_tensors.size() &&
!allow_partial_feed) {
auto error_info = string::Sprintf(
"The number(%d) of samples[%s] of current batch is less than the "
"count(%d) of devices(%s), currently, it is not allowed. ",
......@@ -801,7 +903,7 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
framework::TensorCopy(pair.second, member_->places_.at(i), &tmp);
}
}
if (lod_tensors.size() != num_places) {
if (lod_tensors.size() != num_places && !allow_partial_feed) {
auto error_info = string::Sprintf(
"The number(%d) of samples[%s] of the current batch does not match "
"the count(%d) of devices(%s). Because that %s is a persistable "
......@@ -815,7 +917,31 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
}
}
for (size_t j = 0; j < num_places; ++j) {
if (allow_partial_feed) {
if (is_persistable) {
if (persistable_feed_len == -1UL) {
persistable_feed_len = lod_tensors.size();
} else {
PADDLE_ENFORCE_EQ(
persistable_feed_len, lod_tensors.size(),
platform::errors::InvalidArgument(
"The feeded number of different persistable variables "
"should be the same"));
}
} else {
if (non_persistable_feed_len == -1UL) {
non_persistable_feed_len = lod_tensors.size();
} else {
PADDLE_ENFORCE_EQ(
non_persistable_feed_len, lod_tensors.size(),
platform::errors::InvalidArgument(
"The feeded number of different non-persistable variables "
"should be the same"));
}
}
}
for (size_t j = 0; j < lod_tensors.size(); ++j) {
auto *feed_scope = is_persistable ? member_->local_scopes_[j]
: member_->local_exec_scopes_[j];
auto *feed_var = feed_scope->Var(pair.first);
......@@ -825,6 +951,22 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
t->set_lod(lod_tensors[j].lod());
}
}
if (allow_partial_feed && persistable_feed_len != -1UL &&
non_persistable_feed_len != -1UL) {
VLOG(10) << "Persistable len " << persistable_feed_len;
VLOG(10) << "Non persistable len " << non_persistable_feed_len;
PADDLE_ENFORCE_GE(persistable_feed_len, non_persistable_feed_len,
platform::errors::InvalidArgument(
"The feeded number of persistable variables should "
"not be less than non-persistable variables"));
}
if (non_persistable_feed_len != -1UL) {
for (size_t i = 0; i < non_persistable_feed_len; ++i) {
member_->SetHasFeed(i);
}
}
}
ParallelExecutor::~ParallelExecutor() {
......@@ -875,6 +1017,10 @@ bool ParallelExecutor::EnableParallelGraphExecution(
return enable_parallel_graph;
}
const ir::Graph &ParallelExecutor::Graph() const {
return member_->executor_->Graph();
}
} // namespace framework
} // namespace paddle
......@@ -882,3 +1028,4 @@ USE_PASS(reference_count_pass);
USE_PASS(eager_deletion_pass);
USE_PASS(buffer_shared_inplace_pass);
USE_PASS(buffer_shared_cross_op_memory_reuse_pass);
USE_PASS(init_reader_device_count_pass);
......@@ -80,6 +80,8 @@ class ParallelExecutor {
FetchResultType Run(const std::vector<std::string> &fetch_tensors,
bool return_merged = true);
const ir::Graph &Graph() const;
private:
// broadcast the parameters from the 0th device.
// trainer_id the trainer index in nccl distributed training.
......
......@@ -117,6 +117,10 @@ class DecoratedReader : public ReaderBase,
~DecoratedReader();
const std::shared_ptr<ReaderBase>& UnderlyingReader() const {
return reader_;
}
protected:
void ShutdownImpl() override {
VLOG(1) << "ShutdownImpl";
......@@ -190,6 +194,8 @@ class ReaderHolder {
return reader_->NeedCheckFeed();
}
void Clear() { reader_.reset(); }
operator const std::shared_ptr<ReaderBase>&() const { return this->reader_; }
private:
......
......@@ -56,6 +56,7 @@ class CudnnRNNCache;
namespace reader {
class LoDTensorBlockingQueueHolder;
class OrderedMultiDeviceLoDTensorBlockingQueueHolder;
} // namespace reader
} // namespace operators
......@@ -139,6 +140,7 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl<
Tensor, LoDTensor, SelectedRows, std::vector<Scope *>, LoDRankTable,
LoDTensorArray, platform::PlaceList, ReaderHolder, std::string, Scope *,
operators::reader::LoDTensorBlockingQueueHolder,
operators::reader::OrderedMultiDeviceLoDTensorBlockingQueueHolder,
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_NCCL)
ncclUniqueId, platform::Communicator, platform::NCCLCommunicator,
......
......@@ -27,12 +27,20 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
const platform::Place& dev_place) const override {
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
if (out->Get() != nullptr) {
return;
}
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
if (out->Get() != nullptr) {
auto* decorated_reader =
dynamic_cast<framework::DecoratedReader*>(out->Get().get());
PADDLE_ENFORCE_NOT_NULL(
decorated_reader,
platform::errors::NotFound("Not inited with DecoratedReader"));
if (decorated_reader->UnderlyingReader() == underlying_reader.Get()) {
return;
}
}
auto place_str = Attr<std::string>("place");
platform::Place place;
if (place_str == "AUTO") {
......@@ -47,6 +55,8 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
place = platform::CUDAPlace(static_cast<int>(num));
}
VLOG(10) << "Create new double buffer reader on " << place;
out->Reset(framework::MakeDecoratedReader<BufferedReader>(underlying_reader,
place, 2));
}
......
......@@ -38,8 +38,21 @@ class CreatePyReaderOp : public framework::OperatorBase {
queue_holder_var,
"No LoDTensorBlockingQueueHolder variable with name %s found",
queue_name);
auto* queue_holder =
queue_holder_var->template GetMutable<LoDTensorBlockingQueueHolder>();
std::shared_ptr<LoDTensorBlockingQueue> queue;
std::shared_ptr<OrderedMultiDeviceLoDTensorBlockingQueue> ordered_queue;
int dev_idx = -1;
if (queue_holder_var->IsType<LoDTensorBlockingQueueHolder>()) {
queue = queue_holder_var->Get<LoDTensorBlockingQueueHolder>().GetQueue();
} else if (queue_holder_var
->IsType<OrderedMultiDeviceLoDTensorBlockingQueueHolder>()) {
auto* queue_holder =
queue_holder_var
->GetMutable<OrderedMultiDeviceLoDTensorBlockingQueueHolder>();
dev_idx = Attr<int>("device_index");
ordered_queue = queue_holder->GetQueue();
ordered_queue->SetDeviceCount(Attr<int>("device_count"));
queue = ordered_queue->GetQueue(dev_idx);
}
/* Coverting shape_concat and ranks into DDim of each data.
shape_concat and ranks are shapes and shape ranks of each data.E.g.
......@@ -71,8 +84,12 @@ class CreatePyReaderOp : public framework::OperatorBase {
for (size_t i = 0; i < need_check_feed_int.size(); ++i) {
need_check_feed.push_back(static_cast<bool>(need_check_feed_int[i]));
}
out->Reset(std::make_shared<PyReader>(queue_holder->GetQueue(), dims,
var_types, need_check_feed));
auto py_reader =
std::make_shared<PyReader>(queue, dims, var_types, need_check_feed);
if (ordered_queue) {
ordered_queue->SetResetMethod(dev_idx, [out] { out->Clear(); });
}
out->Reset(py_reader);
}
};
......@@ -82,6 +99,13 @@ class CreatePyReaderOpMaker : public FileReaderMakerBase {
AddInput("blocking_queue",
"Name of the `LoDTensorBlockingQueueHolder` variable");
AddAttr<int>("device_index", "The device index this reader offers data")
.SetDefault(0);
AddAttr<int>("device_count",
"The total device number this reader offers data")
.SetDefault(1);
AddComment(R"DOC(
Create PyReader to support LoDTensor data feeding in Python side.
)DOC");
......
......@@ -27,16 +27,13 @@ namespace paddle {
namespace operators {
namespace reader {
class LoDTensorBlockingQueueHolder;
class LoDTensorBlockingQueue {
friend class LoDTensorBlockingQueueHolder;
private:
public:
explicit LoDTensorBlockingQueue(size_t capacity, bool speed_test_mode = false)
: queue_(capacity, speed_test_mode) {}
public:
~LoDTensorBlockingQueue() { VLOG(10) << "Destruct LoDTensorBlockingQueue"; }
bool Push(const std::vector<framework::LoDTensor>& lod_tensor_vec) {
return queue_.Send(lod_tensor_vec);
}
......@@ -67,10 +64,140 @@ class LoDTensorBlockingQueue {
inline void Kill() { queue_.Kill(); }
inline bool WaitForInited(size_t) { return true; }
private:
BlockingQueue<std::vector<framework::LoDTensor>> queue_;
};
class OrderedMultiDeviceLoDTensorBlockingQueue {
public:
OrderedMultiDeviceLoDTensorBlockingQueue(size_t capacity,
bool speed_test_mode = false)
: capacity_(capacity), speed_test_mode_(speed_test_mode) {}
~OrderedMultiDeviceLoDTensorBlockingQueue() {
VLOG(10) << "Destruct OrderedMultiDeviceLoDTensorBlockingQueue";
}
bool WaitForInited(size_t milliseconds) {
std::unique_lock<std::mutex> lock(init_mutex_);
return cv_.wait_for(lock, std::chrono::milliseconds(milliseconds),
[this] { return !queues_.empty(); });
}
void SetDeviceCount(size_t dev_cnt) {
{
std::lock_guard<std::mutex> lock(init_mutex_);
PADDLE_ENFORCE_GE(dev_cnt, 1,
platform::errors::InvalidArgument(
"Device count to init "
"OrderedMultiDeviceLoDTensorBlockingQueue"
" must be larger than 1"));
if (!queues_.empty()) {
PADDLE_ENFORCE_EQ(queues_.size(), dev_cnt,
platform::errors::InvalidArgument(
"queues should be only inited once"));
return;
}
VLOG(1) << "Init queue with size " << dev_cnt;
queues_.resize(dev_cnt);
for (auto& item : queues_) {
auto cap = (capacity_ + dev_cnt - 1) / dev_cnt;
item.reset(new LoDTensorBlockingQueue(cap, speed_test_mode_));
}
}
cv_.notify_all();
}
const std::shared_ptr<LoDTensorBlockingQueue>& GetQueue(size_t idx) const {
EnforceIsInited();
PADDLE_ENFORCE_LT(
idx, queues_.size(),
platform::errors::OutOfRange("The queue index is out of range"));
return queues_[idx];
}
bool Push(const std::vector<framework::LoDTensor>& lod_tensor_vec) {
return CurQueue()->Push(lod_tensor_vec);
}
inline size_t Size() const {
size_t size = 0;
for (auto& item : queues_) {
size += item->Size();
}
return size;
}
inline void Close() {
for (auto& item : queues_) {
item->Close();
}
}
inline void Kill() {
for (auto& item : queues_) {
item->Kill();
}
}
inline void Reset() {
{
std::lock_guard<std::mutex> reset_lock(reset_mutex_);
for (auto& method : reset_methods_) {
if (method) method();
}
}
auto dev_cnt = queues_.size();
for (auto& item : queues_) {
auto cap = (capacity_ + dev_cnt - 1) / dev_cnt;
item.reset(new LoDTensorBlockingQueue(cap, speed_test_mode_));
}
data_index_ = 0;
}
inline void SetResetMethod(size_t idx,
const std::function<void()>& reset_method) {
std::lock_guard<std::mutex> reset_lock(reset_mutex_);
EnforceIsInited();
if (reset_methods_.size() <= idx) {
reset_methods_.resize(idx + 1);
}
reset_methods_[idx] = reset_method;
}
inline size_t Cap() const { return capacity_; }
private:
const std::shared_ptr<LoDTensorBlockingQueue>& CurQueue() {
return queues_[(data_index_++) % queues_.size()];
}
private:
void EnforceIsInited() const {
PADDLE_ENFORCE_EQ(queues_.empty(), false,
platform::errors::NotFound("queue has not been inited"));
}
private:
std::vector<std::shared_ptr<LoDTensorBlockingQueue>> queues_;
mutable uint64_t data_index_{0};
size_t dev_cnt_{0};
const size_t capacity_;
const bool speed_test_mode_;
bool is_closed_{false};
std::vector<std::function<void()>> reset_methods_;
mutable std::mutex reset_mutex_;
mutable std::mutex init_mutex_;
mutable std::condition_variable cv_;
};
class LoDTensorBlockingQueueHolder {
public:
void InitOnce(size_t capacity, bool speed_test_mode = false) {
......@@ -88,6 +215,26 @@ class LoDTensorBlockingQueueHolder {
std::shared_ptr<LoDTensorBlockingQueue> queue_;
};
class OrderedMultiDeviceLoDTensorBlockingQueueHolder {
public:
void InitOnce(size_t capacity, bool speed_test_mode = false) {
PADDLE_ENFORCE_EQ(queue_, nullptr,
platform::errors::AlreadyExists(
"OrderedMultiDeviceLoDTensorBlockingQueueHolder::"
"InitOnce() can only be called once"));
queue_.reset(new OrderedMultiDeviceLoDTensorBlockingQueue(capacity,
speed_test_mode));
}
inline const std::shared_ptr<OrderedMultiDeviceLoDTensorBlockingQueue>&
GetQueue() const {
return queue_;
}
private:
std::shared_ptr<OrderedMultiDeviceLoDTensorBlockingQueue> queue_;
};
} // namespace reader
} // namespace operators
} // namespace paddle
......@@ -156,6 +156,10 @@ class ReadOpMaker : public framework::OpProtoAndCheckerMaker {
" and it is set by ParallelExecutor instance, not users.")
.SetDefault(true);
AddAttr<bool>("infer_out", "").SetDefault(true);
AddAttr<bool>("drop_last",
"Whether to drop last batches whose number is less than CPU "
"cores/GPU cards number")
.SetDefault(true);
AddComment(R"DOC(
Read Operator
......
......@@ -51,7 +51,6 @@ limitations under the License. */
#include "paddle/fluid/memory/allocation/mmap_allocator.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/py_func_op.h"
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h"
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/dynload/dynamic_loader.h"
......@@ -94,9 +93,6 @@ limitations under the License. */
#include "pybind11/stl.h"
DEFINE_bool(reader_queue_speed_test_mode, false,
"If set true, the queue.pop will only get data from queue but not "
"remove the data from queue for speed testing");
DECLARE_bool(use_mkldnn);
#ifdef PADDLE_WITH_NGRAPH
DECLARE_bool(use_ngraph);
......@@ -997,35 +993,6 @@ All parameter, weight, gradient are variables in Paddle.
BindReader(&m);
using LoDTensorBlockingQueue =
::paddle::operators::reader::LoDTensorBlockingQueue;
using LoDTensorBlockingQueueHolder =
::paddle::operators::reader::LoDTensorBlockingQueueHolder;
py::class_<LoDTensorBlockingQueue, std::shared_ptr<LoDTensorBlockingQueue>>(
m, "LoDTensorBlockingQueue", "")
.def("push",
[](LoDTensorBlockingQueue &self,
const std::vector<framework::LoDTensor> &lod_tensor_vec) {
pybind11::gil_scoped_release release;
return self.Push(lod_tensor_vec);
})
.def("size", &LoDTensorBlockingQueue::Size)
.def("capacity", &LoDTensorBlockingQueue::Cap)
.def("close", &LoDTensorBlockingQueue::Close)
.def("kill", &LoDTensorBlockingQueue::Kill)
.def("is_closed", &LoDTensorBlockingQueue::IsClosed);
m.def("init_lod_tensor_blocking_queue",
[](Variable &var,
size_t capacity) -> std::shared_ptr<LoDTensorBlockingQueue> {
VLOG(1) << "init_lod_tensor_blocking_queue";
auto *holder = var.GetMutable<LoDTensorBlockingQueueHolder>();
holder->InitOnce(capacity, FLAGS_reader_queue_speed_test_mode);
return holder->GetQueue();
},
py::return_value_policy::copy);
py::class_<Scope>(m, "_Scope", R"DOC(
Scope is an association of a name to Variable. All variables belong to Scope.
......
......@@ -20,47 +20,141 @@
#include <utility>
#include <vector>
#include "Python.h"
#include "boost/optional.hpp"
#include "gflags/gflags.h"
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/operators/reader/buffered_reader.h"
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h"
#include "paddle/fluid/operators/reader/py_reader.h"
#include "paddle/fluid/platform/place.h"
#include "pybind11/stl.h"
DEFINE_bool(reader_queue_speed_test_mode, false,
"If set true, the queue.pop will only get data from queue but not "
"remove the data from queue for speed testing");
namespace paddle {
namespace pybind {
namespace py = pybind11;
namespace reader = operators::reader;
// Check whether the tensor shape matches the VarDesc shape
// Return the different shape if exists
static boost::optional<std::vector<int64_t>> DiffTensorShapeWithVarDesc(
const framework::LoDTensor &tensor, const framework::VarDesc &var_desc,
size_t num_places) {
auto tensor_shape = tensor.dims();
auto desc_shape = var_desc.GetShape();
int64_t rank = tensor_shape.size();
if (UNLIKELY(rank == 0)) {
if (desc_shape.size() != 0) { // Tensor rank = 0 but desc does not match
return framework::vectorize<int64_t>(tensor_shape);
} else {
return boost::none;
}
}
PADDLE_ENFORCE_GE(tensor_shape[0], 0,
platform::errors::InvalidArgument(
"Tensor shape at dim 0 must not be less than 0"));
if (!tensor.lod().empty()) {
tensor_shape[0] = -1; // unknown shape
} else {
int64_t split_size = (tensor_shape[0] + num_places - 1) / num_places;
int64_t remainder = (split_size == 0 ? 0 : tensor_shape[0] % split_size);
tensor_shape[0] = split_size;
if (desc_shape[0] >= 0) { // need check dim 0
if (tensor_shape[0] != desc_shape[0]) {
return framework::vectorize<int64_t>(tensor_shape);
}
if (remainder > 0) {
tensor_shape[0] = remainder;
return framework::vectorize<int64_t>(tensor_shape);
}
}
}
for (int64_t idx = 1; idx < rank; ++idx) {
PADDLE_ENFORCE_GE(
tensor_shape[idx], 0,
platform::errors::InvalidArgument(
"Tensor shape at dim %d must not be less than 0", idx));
if (desc_shape[idx] >= 0 && tensor_shape[idx] != desc_shape[idx]) {
return framework::vectorize<int64_t>(tensor_shape);
}
}
return boost::none;
}
static const std::shared_ptr<reader::LoDTensorBlockingQueue> &GetQueue(
const std::shared_ptr<reader::LoDTensorBlockingQueue> &queue, size_t idx) {
return queue;
}
static const std::shared_ptr<reader::LoDTensorBlockingQueue> &GetQueue(
const std::shared_ptr<reader::OrderedMultiDeviceLoDTensorBlockingQueue>
&queue,
size_t idx) {
return queue->GetQueue(idx);
}
template <typename QueueType>
class MultiDeviceFeedReader {
public:
using ResultDictList =
std::vector<std::unordered_map<std::string, framework::LoDTensor>>;
using ResultList = std::vector<std::vector<framework::LoDTensor>>;
static constexpr bool kKeepOrder =
std::is_same<QueueType,
reader::OrderedMultiDeviceLoDTensorBlockingQueue>::value;
MultiDeviceFeedReader(
const std::shared_ptr<operators::reader::LoDTensorBlockingQueue> &queue,
const std::shared_ptr<QueueType> &queue,
const std::vector<std::string> &names,
const std::vector<std::vector<int>> &shapes,
const std::vector<framework::proto::VarType::Type> &dtypes,
const std::vector<bool> &need_check_feed,
const std::vector<platform::Place> &dst_places, bool use_double_buffer)
const std::vector<platform::Place> &dst_places, bool use_double_buffer,
bool drop_last)
: queue_(queue),
names_(names),
pool_(new ::ThreadPool(dst_places.size())) {
pool_(new ::ThreadPool(dst_places.size())),
drop_last_(drop_last) {
std::vector<framework::DDim> dims;
for (auto &shape : shapes) {
dims.push_back(framework::make_ddim(shape));
}
std::shared_ptr<framework::ReaderBase> reader(
new operators::reader::PyReader(queue, dims, dtypes, need_check_feed));
auto first_reader = std::make_shared<reader::PyReader>(
GetQueue(queue, 0), dims, dtypes, need_check_feed);
auto create_or_get_reader = [&](size_t idx) {
if (idx == 0 ||
std::is_same<QueueType, reader::LoDTensorBlockingQueue>::value) {
return first_reader;
} else {
return std::make_shared<reader::PyReader>(GetQueue(queue, idx), dims,
dtypes, need_check_feed);
}
};
readers_.reserve(dst_places.size());
for (auto &p : dst_places) {
for (size_t i = 0; i < dst_places.size(); ++i) {
auto &p = dst_places[i];
auto *holder = new framework::ReaderHolder();
auto reader = create_or_get_reader(i);
if (use_double_buffer) {
VLOG(10) << "Creating " << i << "-th BufferedReader";
holder->Reset(
framework::MakeDecoratedReader<operators::reader::BufferedReader>(
reader, p, 2));
......@@ -80,12 +174,22 @@ class MultiDeviceFeedReader {
ReadAsync();
}
bool DropLast() const { return drop_last_; }
ResultDictList ReadNext() {
CheckNextStatus();
ResultDictList result(ret_.size());
ResultDictList result;
result.reserve(ret_.size());
for (size_t i = 0; i < ret_.size(); ++i) {
if (ret_[i].empty()) {
if (!kKeepOrder) result.emplace_back();
continue;
}
result.emplace_back();
auto &ret = result.back();
for (size_t j = 0; j < names_.size(); ++j) {
result[i].emplace(names_[j], std::move(ret_[i][j]));
ret.emplace(names_[j], std::move(ret_[i][j]));
}
}
ReadAsync();
......@@ -97,6 +201,7 @@ class MultiDeviceFeedReader {
ResultList result;
result.reserve(ret_.size());
for (size_t i = 0; i < ret_.size(); ++i) {
if (kKeepOrder && ret_[i].empty()) continue;
result.emplace_back(std::move(ret_[i]));
}
ReadAsync();
......@@ -122,24 +227,29 @@ class MultiDeviceFeedReader {
};
Status WaitFutures(std::exception_ptr *excep) {
bool is_success = true;
*excep = nullptr;
size_t success_num = 0;
for (size_t i = 0; i < futures_.size(); ++i) {
auto each_status = futures_[i].get();
if (UNLIKELY(each_status != Status::kSuccess)) {
is_success = false;
if (UNLIKELY(each_status == Status::kException)) {
PADDLE_ENFORCE_NOT_NULL(exceptions_[i]);
*excep = exceptions_[i];
exceptions_[i] = nullptr;
}
} else {
++success_num;
}
}
if (UNLIKELY(*excep)) {
return Status::kException;
}
if (drop_last_) {
return success_num == futures_.size() ? Status::kSuccess : Status::kEOF;
} else {
return is_success ? Status::kSuccess : Status::kEOF;
return success_num > 0 ? Status::kSuccess : Status::kEOF;
}
}
......@@ -183,7 +293,7 @@ class MultiDeviceFeedReader {
PADDLE_ENFORCE_EQ(status, Status::kSuccess);
}
std::shared_ptr<operators::reader::LoDTensorBlockingQueue> queue_;
std::shared_ptr<QueueType> queue_;
std::vector<std::string> names_;
std::unique_ptr<::ThreadPool> pool_;
......@@ -193,24 +303,21 @@ class MultiDeviceFeedReader {
std::vector<std::exception_ptr> exceptions_;
std::vector<std::vector<framework::LoDTensor>> ret_;
bool drop_last_;
};
void BindReader(py::module *module) {
template <typename QueueType>
void BindMultiDeviceReader(py::module *module, const char *reader_name) {
auto &m = *module;
namespace reader = ::paddle::operators::reader;
py::class_<framework::ReaderHolder>(m, "Reader", "")
.def("start", &framework::ReaderHolder::Start)
.def("reset", &framework::ReaderHolder::ResetAll);
py::class_<MultiDeviceFeedReader>(m, "MultiDeviceFeedReader", "")
.def("read_next", &MultiDeviceFeedReader::ReadNext,
using ReaderType = MultiDeviceFeedReader<QueueType>;
py::class_<ReaderType>(m, reader_name, "")
.def("read_next", &ReaderType::ReadNext,
py::call_guard<py::gil_scoped_release>())
.def("read_next_list", &MultiDeviceFeedReader::ReadNextList,
.def("read_next_list", &ReaderType::ReadNextList,
py::call_guard<py::gil_scoped_release>())
.def("read_next_var_list",
[](MultiDeviceFeedReader &self) {
[](ReaderType &self) {
auto result_list = self.ReadNextList();
auto &tensor_list = result_list[0];
std::vector<std::shared_ptr<imperative::VarBase>> var_list;
......@@ -234,23 +341,116 @@ void BindReader(py::module *module) {
return var_list;
},
py::call_guard<py::gil_scoped_release>())
.def("reset", &MultiDeviceFeedReader::Reset,
.def("reset", &ReaderType::Reset,
py::call_guard<py::gil_scoped_release>());
}
void BindReader(py::module *module) {
auto &m = *module;
m.def("diff_tensor_shape", [](const framework::LoDTensor &tensor,
const framework::VarDesc &var_desc,
size_t num_places) -> py::object {
auto diff = DiffTensorShapeWithVarDesc(tensor, var_desc, num_places);
if (diff) {
return py::cast(std::move(diff.get()));
} else {
return py::cast(nullptr);
}
});
m.def("init_lod_tensor_blocking_queue",
[](framework::Variable &var, size_t capacity,
bool is_ordered) -> py::object {
VLOG(1) << "init_lod_tensor_blocking_queue";
if (is_ordered) {
auto *holder = var.GetMutable<
reader::OrderedMultiDeviceLoDTensorBlockingQueueHolder>();
holder->InitOnce(capacity, FLAGS_reader_queue_speed_test_mode);
return py::cast(holder->GetQueue());
} else {
auto *holder =
var.GetMutable<reader::LoDTensorBlockingQueueHolder>();
holder->InitOnce(capacity, FLAGS_reader_queue_speed_test_mode);
return py::cast(holder->GetQueue());
}
},
py::return_value_policy::copy);
py::class_<framework::ReaderHolder>(m, "Reader", "")
.def("start", &framework::ReaderHolder::Start)
.def("reset", &framework::ReaderHolder::ResetAll);
py::class_<reader::LoDTensorBlockingQueue,
std::shared_ptr<reader::LoDTensorBlockingQueue>>(
m, "LoDTensorBlockingQueue", "")
.def("push",
[](reader::LoDTensorBlockingQueue &self,
const std::vector<framework::LoDTensor> &lod_tensor_vec) {
return self.Push(lod_tensor_vec);
},
py::call_guard<py::gil_scoped_release>())
.def("size", &reader::LoDTensorBlockingQueue::Size)
.def("capacity", &reader::LoDTensorBlockingQueue::Cap)
.def("close", &reader::LoDTensorBlockingQueue::Close)
.def("kill", &reader::LoDTensorBlockingQueue::Kill)
.def("wait_for_inited", &reader::LoDTensorBlockingQueue::WaitForInited,
py::call_guard<py::gil_scoped_release>());
py::class_<reader::OrderedMultiDeviceLoDTensorBlockingQueue,
std::shared_ptr<reader::OrderedMultiDeviceLoDTensorBlockingQueue>>(
m, "OrderedMultiDeviceLoDTensorBlockingQueue", "")
.def("push",
[](reader::OrderedMultiDeviceLoDTensorBlockingQueue &self,
const std::vector<framework::LoDTensor> &lod_tensor_vec) {
return self.Push(lod_tensor_vec);
},
py::call_guard<py::gil_scoped_release>())
.def("size", &reader::OrderedMultiDeviceLoDTensorBlockingQueue::Size)
.def("capacity", &reader::OrderedMultiDeviceLoDTensorBlockingQueue::Cap)
.def("close", &reader::OrderedMultiDeviceLoDTensorBlockingQueue::Close)
.def("kill", &reader::OrderedMultiDeviceLoDTensorBlockingQueue::Kill)
.def("wait_for_inited",
&reader::OrderedMultiDeviceLoDTensorBlockingQueue::WaitForInited,
py::call_guard<py::gil_scoped_release>())
.def("reset", &reader::OrderedMultiDeviceLoDTensorBlockingQueue::Reset);
BindMultiDeviceReader<reader::LoDTensorBlockingQueue>(
module, "MultiDeviceFeedReader");
BindMultiDeviceReader<reader::OrderedMultiDeviceLoDTensorBlockingQueue>(
module, "OrderedMultiDeviceFeedReader");
m.def("create_py_reader",
[](const std::shared_ptr<operators::reader::LoDTensorBlockingQueue>
&queue,
[](const std::shared_ptr<reader::LoDTensorBlockingQueue> &queue,
const std::vector<std::string> &names,
const std::vector<std::vector<int>> &shapes,
const std::vector<framework::proto::VarType::Type> &dtypes,
const std::vector<bool> &need_check_feed,
const std::vector<platform::Place> &dst_places,
bool use_double_buffer) {
return new MultiDeviceFeedReader(queue, names, shapes, dtypes,
need_check_feed, dst_places,
use_double_buffer);
bool use_double_buffer, bool drop_last) {
return new MultiDeviceFeedReader<reader::LoDTensorBlockingQueue>(
queue, names, shapes, dtypes, need_check_feed, dst_places,
use_double_buffer, drop_last);
},
py::return_value_policy::take_ownership);
m.def(
"create_py_reader",
[](const std::shared_ptr<reader::OrderedMultiDeviceLoDTensorBlockingQueue>
&queue,
const std::vector<std::string> &names,
const std::vector<std::vector<int>> &shapes,
const std::vector<framework::proto::VarType::Type> &dtypes,
const std::vector<bool> &need_check_feed,
const std::vector<platform::Place> &dst_places, bool use_double_buffer,
bool drop_last) {
queue->SetDeviceCount(dst_places.size());
return new MultiDeviceFeedReader<
reader::OrderedMultiDeviceLoDTensorBlockingQueue>(
queue, names, shapes, dtypes, need_check_feed, dst_places,
use_double_buffer, drop_last);
},
py::return_value_policy::take_ownership);
}
} // namespace pybind
......
......@@ -216,18 +216,12 @@ def check_feed_shape_type(var, feed, num_places=1):
the feed value
"""
if var.desc.need_check_feed():
feed_shape = feed.shape()
if six.PY2:
feed_shape[0] = long(feed_shape[0] /
num_places) if len(feed.lod()) == 0 else -1
else:
feed_shape[0] = int(feed_shape[0] /
num_places) if len(feed.lod()) == 0 else -1
if not dimension_is_compatible_with(feed_shape, var.shape):
diff_shape = core.diff_tensor_shape(feed, var.desc, num_places)
if diff_shape is not None:
raise ValueError(
'The fed Variable %r should have dimensions = %d, shape = '
'%r, but received fed shape %r on each device' %
(var.name, len(var.shape), var.shape, feed_shape))
(var.name, len(var.shape), var.shape, diff_shape))
if not dtype_is_compatible_with(feed._dtype(), var.dtype):
var_dtype_format = convert_dtype(var.dtype) if isinstance(
var.dtype, core.VarDesc.VarType) else var.dtype
......@@ -646,11 +640,6 @@ class Executor(object):
exe.feed_and_split_tensor_into_local_scopes(feed_tensor_dict)
elif isinstance(feed, list) or isinstance(feed, tuple):
if len(feed) != len(program._places):
raise ValueError(
"Feed a list of tensor, the list should be the same size as places"
)
res = list()
for i, each in enumerate(feed):
if not isinstance(each, dict):
......
......@@ -429,7 +429,7 @@ def _py_reader(capacity,
double_buffer_name = "_".join([name, "double_buffer"])
var = global_scope().var(queue_name)
feed_queue = core.init_lod_tensor_blocking_queue(var, capacity)
feed_queue = core.init_lod_tensor_blocking_queue(var, capacity, False)
startup_blk = default_startup_program().current_block()
startup_var = startup_blk.create_var(name=reader_name)
......
......@@ -48,6 +48,17 @@ __all__ = ['PyReader', 'DataLoader']
data_loader_unique_name_generator = UniqueNameGenerator()
KEEP_DATA_LOADER_ORDER = True
def keep_data_loader_order(*args):
global KEEP_DATA_LOADER_ORDER
if len(args) == 0:
return KEEP_DATA_LOADER_ORDER
else:
assert len(args) == 1 and isinstance(args[0], bool)
KEEP_DATA_LOADER_ORDER = args[0]
def _convert_places(places):
if not isinstance(places, (list, tuple)):
......@@ -172,8 +183,12 @@ class DataLoader(object):
use_double_buffer=True,
iterable=True,
return_list=False,
use_multiprocess=False):
use_multiprocess=False,
drop_last=True):
"""
.. note::
**The framework ensures that the data loading order of DataLoader is exactly the same as the user-defined data source.**
Create a DataLoader object for loading data from Python generator.
Data would be prefetched using Python thread and be pushed
into a queue asynchronously.
......@@ -182,7 +197,7 @@ class DataLoader(object):
:code:`set_sample_generator` , :code:`set_sample_list_generator` and
:code:`set_batch_generator` . Please see the following example codes
to know their usages.
If iterable = True, the created DataLoader object is a Python generator
object, which is iterable using for-range loop.
......@@ -218,11 +233,18 @@ class DataLoader(object):
can be used in the dygraph mode. In the static graph mode,
whether this parameter is set or not has no effect.
The Default value is False.
drop_last (bool): whether to drop the last batches whose number is
less than the CPU core/GPU card number. The default value is
True. In training phase, users should not set drop_last=False,
because all CPU cores/GPU cards must read data from DataLoader.
In inference phase, users can set drop_last=False, so that the
last batches whose number is less than the CPU core/GPU card
number can be tested.
Returns:
loader (DataLoader): the created DataLoader object.
Examples:
Examples 1:
.. code-block:: python
......@@ -354,6 +376,49 @@ class DataLoader(object):
assert image.shape == [BATCH_SIZE, 784]
assert label.shape == [BATCH_SIZE, 1]
assert relu.shape == [BATCH_SIZE, 784]
Examples 2:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
import os
# We use 2 CPU cores to run inference network
os.environ['CPU_NUM'] = '2'
# The data source has only 3 batches, which can not be
# divided evenly to each CPU core
def batch_generator():
for i in range(3):
yield np.array([i+1]).astype('float32'),
x = fluid.data(name='x', shape=[None], dtype='float32')
y = x * x
def run_inference(drop_last):
loader = fluid.io.DataLoader.from_generator(feed_list=[x],
capacity=8, drop_last=drop_last)
loader.set_batch_generator(batch_generator, fluid.cpu_places())
exe = fluid.Executor(fluid.CPUPlace())
prog = fluid.CompiledProgram(fluid.default_main_program())
prog = prog.with_data_parallel()
result = []
for data in loader():
each_ret, = exe.run(prog, feed=data, fetch_list=[y])
result.extend(each_ret)
return result
# Set drop_last to True, so that the last batch whose
# number is less than CPU core number would be discarded.
print(run_inference(drop_last=True)) # [1.0, 4.0]
# Set drop_last to False, so that the last batch whose
# number is less than CPU core number can be tested.
print(run_inference(drop_last=False)) # [1.0, 4.0, 9.0]
"""
if in_dygraph_mode():
return DygraphGeneratorLoader(feed_list, capacity,
......@@ -361,7 +426,7 @@ class DataLoader(object):
return_list, use_multiprocess)
else:
return GeneratorLoader(feed_list, capacity, use_double_buffer,
iterable, return_list)
iterable, return_list, drop_last)
@staticmethod
def from_dataset(dataset, places, drop_last=True):
......@@ -514,10 +579,10 @@ class DygraphGeneratorLoader(DataLoaderBase):
self._dtypes = []
self._need_check_feed = []
self._blocking_queue = core.init_lod_tensor_blocking_queue(
core.Variable(), self._capacity)
core.Variable(), self._capacity, False)
self._reader = core.create_py_reader(
self.queue, self._var_names, self._shapes, self._dtypes,
self._need_check_feed, self._places, self._use_double_buffer)
self._need_check_feed, self._places, self._use_double_buffer, True)
def _start(self):
if self._use_multiprocess:
......@@ -728,12 +793,16 @@ class GeneratorLoader(DataLoaderBase):
capacity=None,
use_double_buffer=True,
iterable=True,
return_list=False):
return_list=False,
drop_last=True):
self._tensor_reader = None
self._places = None
self._thread = None
self._queue = None
self._feed_list = feed_list
self._exited = False
self._drop_last = drop_last
self._keep_order = keep_data_loader_order()
if not capacity:
raise ValueError("Please give value to capacity.")
self._iterable = iterable
......@@ -761,11 +830,12 @@ class GeneratorLoader(DataLoaderBase):
self._need_check_feed = [
v.desc.need_check_feed() for v in self._feed_list
]
self._queue = core.init_lod_tensor_blocking_queue(core.Variable(),
self._capacity)
self._queue = core.init_lod_tensor_blocking_queue(
core.Variable(), self._capacity, self._keep_order)
self._reader = core.create_py_reader(
self.queue, self._var_names, self._shapes, self._dtypes,
self._need_check_feed, self._places, self._use_double_buffer)
self._need_check_feed, self._places, self._use_double_buffer,
self._drop_last)
def _init_non_iterable(self):
lod_levels = []
......@@ -789,16 +859,21 @@ class GeneratorLoader(DataLoaderBase):
double_buffer_name = data_loader_unique_name_generator('double_buffer')
var = global_scope().var(queue_name)
self._queue = core.init_lod_tensor_blocking_queue(var, self._capacity)
self._queue = core.init_lod_tensor_blocking_queue(var, self._capacity,
self._keep_order)
if self._keep_order:
block = default_main_program().current_block()
else:
block = default_startup_program().current_block()
startup_blk = default_startup_program().current_block()
startup_var = startup_blk.create_var(name=reader_name)
reader_var = block.create_var(name=reader_name)
dtype_int = [int(t) for t in dtypes]
startup_blk.append_op(
block.append_op(
type='create_py_reader',
inputs={'blocking_queue': [queue_name]},
outputs={'Out': [startup_var]},
outputs={'Out': [reader_var]},
attrs={
'shape_concat': shape_concat,
'lod_levels': lod_levels,
......@@ -807,16 +882,23 @@ class GeneratorLoader(DataLoaderBase):
'ranks': ranks
})
startup_var.desc.set_dtypes(dtypes)
startup_var.persistable = True
reader_var.desc.set_dtypes(dtypes)
reader_var.persistable = True
reader_var.stop_gradient = True
if self._keep_order:
main_prog_var = reader_var
reader = main_prog_var
reader.reset = self._queue.reset
else:
main_prog_var = _copy_reader_var_(
default_main_program().current_block(), reader_var)
main_prog_var = _copy_reader_var_(
default_main_program().current_block(), startup_var)
main_prog_var.stop_gradient = True
main_prog_var.persistable = True
main_prog_var.stop_gradient = True
main_prog_var.persistable = True
reader = monkey_patch_reader_methods(main_prog_var)
reader = monkey_patch_reader_methods(main_prog_var)
if self._use_double_buffer:
double_buffer_reader = double_buffer(
reader, name=double_buffer_name)
......@@ -830,7 +912,8 @@ class GeneratorLoader(DataLoaderBase):
default_main_program().current_block().append_op(
type='read',
inputs={'Reader': [self._reader]},
outputs={'Out': self._feed_list})
outputs={'Out': self._feed_list},
attrs={'drop_last': self._drop_last})
@property
def queue(self):
......@@ -879,14 +962,20 @@ class GeneratorLoader(DataLoaderBase):
" to locate the data causes this issue.\n\t* Please consider using "
"'fluid.create_lod_tensor' to convert it to a LoD-Tensor."))
return arr
def _start(self):
def __thread_main__():
try:
while not self._queue.wait_for_inited(1):
if self._exited:
return
for tensors in self._tensor_reader():
array = core.LoDTensorArray()
for item in tensors:
if not isinstance(item, core.LoDTensor):
self._check_input_array(item)
item = self._check_input_array(item)
tmp = core.LoDTensor()
tmp.set(item, core.CPUPlace())
item = tmp
......@@ -910,10 +999,12 @@ class GeneratorLoader(DataLoaderBase):
def _reset(self):
self._queue.close()
self._exited = True
thread = self._thread
if thread is not None:
thread.join()
self._exited = False
self._reader.reset()
def set_sample_generator(self,
......
......@@ -359,7 +359,10 @@ set_tests_properties(test_parallel_executor_test_while_train test_parallel_execu
test_parallel_executor_feed_persistable_var
test_parallel_executor_crf_auto_growth test_buffer_shared_memory_reuse_pass_and_fuse_optimization_op_pass
test_data_norm_op test_imperative_using_non_zero_gpu test_fuse_bn_act_pass
test_optimizer_in_control_flow test_fetch_unmerged
test_optimizer_in_control_flow test_dataloader_keep_order
test_dataloader_unkeep_order
test_parallel_executor_inference_feed_partial_data
test_fetch_unmerged
test_buffer_shared_memory_reuse_pass PROPERTIES LABELS "RUN_TYPE=DIST")
if(NOT WIN32 AND NOT APPLE)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import unittest
import numpy as np
import os
import six
def create_reader(shape, batch_number):
def __impl__():
idx = 0
for _ in six.moves.range(batch_number):
yield np.ones(shape).astype('float32') * idx,
idx += 1
return __impl__
class DataLoaderKeepOrderTestBase(unittest.TestCase):
def initParameters(self):
self.iterable = False
self.break_num = 10000
def setUp(self):
self.epoch_num = 3
self.batch_num = 40
self.shape = [3, 4, 5]
self.initParameters()
def build_network(self, places):
input_data = fluid.data(shape=self.shape, dtype='float32', name="input")
loader = fluid.io.DataLoader.from_generator(
capacity=16, feed_list=[input_data], iterable=self.iterable)
fc = fluid.layers.fc(input_data, size=10)
loss = fluid.layers.reduce_mean(fc)
loader.set_batch_generator(
create_reader(self.shape, self.batch_num),
places=places if loader.iterable else None)
return input_data, loss, loader
def assertInputData(self, batch_id, input_data, dev_cnt):
if isinstance(input_data, list):
self.assertTrue(len(input_data), dev_cnt)
start_val = dev_cnt * batch_id
for each_input_dict in input_data:
input_tensor = np.array(each_input_dict["input"])
self.assertEqual(self.shape, list(input_tensor.shape))
self.assertTrue((input_tensor == start_val).all())
start_val += 1
else:
self.assertEqual(
list(input_data.shape),
[self.shape[0] * dev_cnt] + self.shape[1:])
start_val = dev_cnt * batch_id
for idx in six.moves.range(dev_cnt):
data_part = input_data[idx * self.shape[0]:(idx + 1) *
self.shape[0], :]
self.assertTrue((data_part == start_val).all())
start_val += 1
def get_places(self):
place_list = [fluid.cpu_places(1), fluid.cpu_places(4)]
if fluid.is_compiled_with_cuda():
place_list.extend([fluid.cuda_places(0), fluid.cuda_places([0, 1])])
return place_list
def test_main(self):
for p in self.get_places():
use_compiled_program_list = [True] if len(p) > 1 else [False, True]
for use_compiled_program in use_compiled_program_list:
self.run_main_with_place(p, use_compiled_program)
def run_main_with_place(self, places, use_compiled_program=True):
with fluid.scope_guard(fluid.Scope()):
with fluid.program_guard(fluid.Program(), fluid.Program()):
input_data, loss, loader = self.build_network(places)
fetch_list = [input_data]
exe = fluid.Executor(places[0])
exe.run(fluid.default_startup_program())
dev_cnt = len(places)
if dev_cnt > 1:
self.assertTrue(use_compiled_program)
main_program = fluid.default_main_program()
if use_compiled_program:
main_program = fluid.CompiledProgram(
main_program).with_data_parallel(
loss_name=loss.name, places=places)
max_batch_num = min(self.break_num,
int(self.batch_num / dev_cnt))
if loader.iterable:
early_break = False
for epoch_id in six.moves.range(self.epoch_num):
early_break = False
batch_id = 0
for data in loader():
if batch_id >= self.break_num:
early_break = True
break
self.assertInputData(batch_id, data, dev_cnt)
fetch_val, = exe.run(program=main_program,
feed=data,
fetch_list=fetch_list)
self.assertInputData(batch_id, fetch_val, dev_cnt)
batch_id += 1
self.assertEqual(batch_id, max_batch_num)
if early_break:
loader._reset()
else:
for epoch_id in six.moves.range(self.epoch_num):
batch_id = 0
loader.start()
try:
while True:
if batch_id >= self.break_num:
loader.reset()
break
fetch_val, = exe.run(program=main_program,
fetch_list=fetch_list)
self.assertInputData(batch_id, fetch_val,
dev_cnt)
batch_id += 1
except fluid.core.EOFException:
loader.reset()
self.assertEqual(batch_id, max_batch_num)
class IterableDataLoaderKeepOrderTest2(DataLoaderKeepOrderTestBase):
def initParameters(self):
self.iterable = True
self.break_num = 10000
class IterableDataLoaderKeepOrderTest3(DataLoaderKeepOrderTestBase):
def initParameters(self):
self.iterable = False
self.break_num = 2
class IterableDataLoaderKeepOrderTest4(DataLoaderKeepOrderTestBase):
def initParameters(self):
self.iterable = True
self.break_num = 2
class IterableDataLoaderKeepOrderTest5(DataLoaderKeepOrderTestBase):
def initParameters(self):
self.iterable = False
self.break_num = 0
class IterableDataLoaderKeepOrderTest6(DataLoaderKeepOrderTestBase):
def initParameters(self):
self.iterable = True
self.break_num = 0
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import unittest
import numpy as np
import os
import six
from paddle.fluid.reader import keep_data_loader_order
keep_data_loader_order(False)
def create_reader(shape, batch_number):
def __impl__():
idx = 0
for _ in six.moves.range(batch_number):
yield np.ones(shape).astype('float32') * idx,
idx += 1
return __impl__
class DataLoaderKeepOrderTestBase(unittest.TestCase):
def initParameters(self):
self.iterable = False
self.break_num = 10000
def setUp(self):
self.epoch_num = 3
self.batch_num = 40
self.shape = [3, 4, 5]
self.initParameters()
def clear_visited(self):
self.visited = set()
def build_network(self, places):
input_data = fluid.data(shape=self.shape, dtype='float32', name="input")
loader = fluid.io.DataLoader.from_generator(
capacity=16, feed_list=[input_data], iterable=self.iterable)
fc = fluid.layers.fc(input_data, size=10)
loss = fluid.layers.reduce_mean(fc)
loader.set_batch_generator(
create_reader(self.shape, self.batch_num),
places=places if loader.iterable else None)
return input_data, loss, loader
def assertInputData(self, batch_id, input_data, dev_cnt,
check_visited=True):
if isinstance(input_data, list):
self.assertTrue(len(input_data), dev_cnt)
start_val = dev_cnt * batch_id
for each_input_dict in input_data:
input_tensor = np.array(each_input_dict["input"])
self.assertEqual(self.shape, list(input_tensor.shape))
num = input_tensor.flatten()[0]
equal = (input_tensor == num).all()
self.assertTrue(equal)
if check_visited:
self.assertTrue(num not in self.visited)
self.visited.add(num)
start_val += 1
else:
self.assertEqual(
list(input_data.shape),
[self.shape[0] * dev_cnt] + self.shape[1:])
start_val = dev_cnt * batch_id
for idx in six.moves.range(dev_cnt):
data_part = input_data[idx * self.shape[0]:(idx + 1) *
self.shape[0], :]
num = data_part.flatten()[0]
self.assertTrue((data_part == num).all())
if check_visited:
self.assertTrue(num not in self.visited)
self.visited.add(num)
start_val += 1
def get_places(self):
place_list = [fluid.cpu_places(1), fluid.cpu_places(4)]
if fluid.is_compiled_with_cuda():
place_list.extend([fluid.cuda_places(0), fluid.cuda_places([0, 1])])
return place_list
def test_main(self):
for p in self.get_places():
use_compiled_program_list = [True] if len(p) > 1 else [False, True]
for use_compiled_program in use_compiled_program_list:
self.run_main_with_place(p, use_compiled_program)
def run_main_with_place(self, places, use_compiled_program=True):
with fluid.scope_guard(fluid.Scope()):
with fluid.program_guard(fluid.Program(), fluid.Program()):
input_data, loss, loader = self.build_network(places)
fetch_list = [input_data]
exe = fluid.Executor(places[0])
exe.run(fluid.default_startup_program())
dev_cnt = len(places)
if dev_cnt > 1:
self.assertTrue(use_compiled_program)
main_program = fluid.default_main_program()
if use_compiled_program:
main_program = fluid.CompiledProgram(
main_program).with_data_parallel(
loss_name=loss.name, places=places)
max_batch_num = min(self.break_num,
int(self.batch_num / dev_cnt))
if loader.iterable:
early_break = False
for epoch_id in six.moves.range(self.epoch_num):
early_break = False
self.clear_visited()
batch_id = 0
for data in loader():
if batch_id >= self.break_num:
early_break = True
break
self.assertInputData(
batch_id, data, dev_cnt, check_visited=False)
fetch_val, = exe.run(program=main_program,
feed=data,
fetch_list=fetch_list)
self.assertInputData(batch_id, fetch_val, dev_cnt)
batch_id += 1
if dev_cnt == 1:
self.assertEqual(batch_id, max_batch_num)
else:
self.assertLessEqual(batch_id, max_batch_num)
if early_break:
loader._reset()
else:
for epoch_id in six.moves.range(self.epoch_num):
batch_id = 0
self.clear_visited()
loader.start()
try:
while True:
if batch_id >= self.break_num:
loader.reset()
break
fetch_val, = exe.run(program=main_program,
fetch_list=fetch_list)
self.assertInputData(batch_id, fetch_val,
dev_cnt)
batch_id += 1
except fluid.core.EOFException:
loader.reset()
if dev_cnt == 1:
self.assertEqual(batch_id, max_batch_num)
else:
self.assertLessEqual(batch_id, max_batch_num)
class IterableDataLoaderKeepOrderTest2(DataLoaderKeepOrderTestBase):
def initParameters(self):
self.iterable = True
self.break_num = 10000
class IterableDataLoaderKeepOrderTest3(DataLoaderKeepOrderTestBase):
def initParameters(self):
self.iterable = False
self.break_num = 2
class IterableDataLoaderKeepOrderTest4(DataLoaderKeepOrderTestBase):
def initParameters(self):
self.iterable = True
self.break_num = 2
class IterableDataLoaderKeepOrderTest5(DataLoaderKeepOrderTestBase):
def initParameters(self):
self.iterable = False
self.break_num = 0
class IterableDataLoaderKeepOrderTest6(DataLoaderKeepOrderTestBase):
def initParameters(self):
self.iterable = True
self.break_num = 0
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import numpy as np
import unittest
import six
class TestInferencePartialFeed(unittest.TestCase):
def setUp(self):
self.iterations = 10
self.size = 10
def run_network(self, places, use_split, has_persistable):
startup_prog = fluid.Program()
main_prog = fluid.Program()
with fluid.program_guard(main_prog, startup_prog):
x = fluid.data(name='x', shape=[None, self.size], dtype='float32')
y = fluid.data(name='y', shape=[None, self.size], dtype='float32')
if has_persistable:
lr = fluid.data(name='lr', shape=[1], dtype='float32')
lr.persistable = True
else:
lr = fluid.data(name='lr', shape=[None], dtype='float32')
relu_x = fluid.layers.relu(x)
relu_y = fluid.layers.relu(y)
relu_lr = fluid.layers.relu(lr)
exe = fluid.Executor(places[0])
exe.run(startup_prog)
prog = fluid.CompiledProgram(main_prog).with_data_parallel(
places=places)
gen_random = lambda shape:np.random.uniform(low=-1.0, high=1.0, size=shape).astype('float32')
assert_result = lambda feed, result: self.assertTrue(np.array_equal(np.maximum(0, feed), result))
def assert_merged_unmerged(merged, unmerged):
unmerged = np.concatenate(unmerged, axis=0)
self.assertTrue(np.array_equal(merged, unmerged))
def feed_split_test():
for place_num in six.moves.range(1, len(places) * 3):
x_np = gen_random([place_num, self.size])
y_np = gen_random([place_num, self.size])
if not lr.persistable or place_num <= len(places):
lr_np = gen_random([place_num])
else:
lr_np = gen_random([1])
feed = {x.name: x_np, y.name: y_np, lr.name: lr_np}
fetch_list = [relu_x, relu_y, relu_lr]
relu_x_np, relu_y_np, relu_lr_np = exe.run(
prog, feed=feed, fetch_list=fetch_list, return_merged=True)
relu_x_np_unmerged, relu_y_np_unmerged, relu_lr_np_unmerged = exe.run(
prog, feed=feed, fetch_list=fetch_list, return_merged=False)
assert_merged_unmerged(relu_x_np, relu_x_np_unmerged)
assert_merged_unmerged(relu_y_np, relu_y_np_unmerged)
assert_merged_unmerged(relu_lr_np, relu_lr_np_unmerged)
assert_result(x_np, relu_x_np)
assert_result(y_np, relu_y_np)
if not lr.persistable or place_num <= len(places):
assert_result(lr_np, relu_lr_np)
else:
expected_relu_lr_np = max(lr_np[0], 0)
self.assertTrue(np.all(expected_relu_lr_np == relu_lr_np))
def feed_list_test():
for place_num in six.moves.range(1, len(places) + 1):
x_np_list = []
y_np_list = []
lr_np_list = []
feed_list = []
for _ in six.moves.range(place_num):
x_np = gen_random([1, self.size])
y_np = gen_random([1, self.size])
lr_np = gen_random([1])
x_np_list.append(x_np)
y_np_list.append(y_np)
lr_np_list.append(lr_np)
feed_list.append({
x.name: x_np,
y.name: y_np,
lr.name: lr_np
})
fetch_list = [relu_x, relu_y, relu_lr]
relu_x_np, relu_y_np, relu_lr_np = exe.run(
prog,
feed=feed_list,
fetch_list=fetch_list,
return_merged=True)
relu_x_np_unmerged, relu_y_np_unmerged, relu_lr_np_unmerged = exe.run(
prog,
feed=feed_list,
fetch_list=fetch_list,
return_merged=False)
assert_merged_unmerged(relu_x_np, relu_x_np_unmerged)
assert_merged_unmerged(relu_y_np, relu_y_np_unmerged)
assert_merged_unmerged(relu_lr_np, relu_lr_np_unmerged)
x_np = np.concatenate(x_np_list)
y_np = np.concatenate(y_np_list)
lr_np = np.concatenate(lr_np_list)
assert_result(x_np, relu_x_np)
assert_result(y_np, relu_y_np)
assert_result(lr_np, relu_lr_np)
for _ in six.moves.range(self.iterations):
if use_split:
feed_split_test()
else:
feed_list_test()
def test_main(self):
places = [fluid.cpu_places(4)]
if fluid.is_compiled_with_cuda():
places.append(fluid.cuda_places())
for p in places:
for has_persistable in [False, True]:
for use_split in [False, True]:
self.run_network(
p, use_split=use_split, has_persistable=has_persistable)
class TestInferencePartialFeedUsingDataLoader(unittest.TestCase):
def setUp(self):
self.epoch_num = 3
self.batch_num = 101 # a prime number
self.batch_size = 32
def create_reader(self):
def __impl__():
for _ in six.moves.range(self.batch_num):
yield np.random.random([self.batch_size, 1]).astype('float32'),
return __impl__
def run_network(self, iterable, use_cuda, drop_last):
x = fluid.data(shape=[None, 1], name='x', dtype='float32')
places = fluid.cuda_places() if use_cuda else fluid.cpu_places(4)
loader = fluid.io.DataLoader.from_generator(
feed_list=[x], capacity=16, iterable=iterable, drop_last=drop_last)
y = fluid.layers.fc(x, size=10)
loss = fluid.layers.reduce_mean(y)
exe = fluid.Executor(places[0])
exe.run(fluid.default_startup_program())
prog = fluid.CompiledProgram(fluid.default_main_program(
)).with_data_parallel(
places=places, loss_name=loss.name)
loader.set_batch_generator(
self.create_reader(), places=places if iterable else None)
for _ in six.moves.range(self.epoch_num):
actual_batch_num = 0
if loader.iterable:
for feed_data in loader():
x_data, = exe.run(prog, feed=feed_data, fetch_list=[x])
self.assertEqual(x_data.shape[0] % self.batch_size, 0)
self.assertTrue(x_data.shape[0] != 0)
actual_batch_num += int(x_data.shape[0] / self.batch_size)
else:
loader.start()
try:
while True:
x_data, = exe.run(prog, fetch_list=[x])
self.assertEqual(x_data.shape[0] % self.batch_size, 0)
self.assertTrue(x_data.shape[0] != 0)
actual_batch_num += int(x_data.shape[0] /
self.batch_size)
except fluid.core.EOFException:
loader.reset()
if not drop_last or len(places) == 1:
self.assertEqual(self.batch_num, actual_batch_num)
else:
self.assertGreater(self.batch_num, actual_batch_num)
def test_main(self):
use_cuda_list = [False, True] if fluid.is_compiled_with_cuda(
) else [False]
iterable_list = [False, True]
drop_last_list = [False, True]
for iterable in iterable_list:
for use_cuda in use_cuda_list:
for drop_last in drop_last_list:
with fluid.program_guard(fluid.Program(), fluid.Program()):
with fluid.scope_guard(fluid.Scope()):
self.run_network(iterable, use_cuda, drop_last)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册