提交 354abe36 编写于 作者: S sneaxiy

sequential reader stage 1, test=develop

上级 77dd0d97
...@@ -64,7 +64,7 @@ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope d ...@@ -64,7 +64,7 @@ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope d
cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper) cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper)
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass buffer_shared_inplace_op_pass buffer_shared_cross_op_memory_reuse_pass) set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass buffer_shared_inplace_op_pass buffer_shared_cross_op_memory_reuse_pass set_reader_device_count_pass)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS}) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS})
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
......
...@@ -66,6 +66,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -66,6 +66,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
AppendPrintGraphPass("graph_viz_pass", "_fused_graph"); AppendPrintGraphPass("graph_viz_pass", "_fused_graph");
AppendMultiDevPass(); AppendMultiDevPass();
AppendSetReaderDeviceCountPass();
AppendMultiGraphOptPasses(); AppendMultiGraphOptPasses();
AppendPassToSetMkldnnAttr("mkldnn_placement_pass"); AppendPassToSetMkldnnAttr("mkldnn_placement_pass");
...@@ -221,6 +222,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -221,6 +222,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
&strategy_); &strategy_);
} }
void AppendSetReaderDeviceCountPass() {
AppendPass("set_reader_device_count_pass");
}
void AppendPrintGraphPass(const std::string &pass_name, void AppendPrintGraphPass(const std::string &pass_name,
const std::string &debug_file_suffix) { const std::string &debug_file_suffix) {
if (!strategy_.debug_graphviz_path_.empty()) { if (!strategy_.debug_graphviz_path_.empty()) {
...@@ -385,6 +390,8 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, ...@@ -385,6 +390,8 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
"GPU, skipped."; "GPU, skipped.";
continue; continue;
} }
} else if (pass->Type() == "set_reader_device_count_pass") {
pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
} }
VLOG(1) << "Start Apply Pass " << pass->Type(); VLOG(1) << "Start Apply Pass " << pass->Type();
graph = pass->Apply(graph); graph = pass->Apply(graph);
...@@ -421,6 +428,7 @@ USE_PASS(fuse_sgd_op_pass); ...@@ -421,6 +428,7 @@ USE_PASS(fuse_sgd_op_pass);
USE_PASS(fuse_momentum_op_pass); USE_PASS(fuse_momentum_op_pass);
USE_PASS(fuse_all_reduce_op_pass); USE_PASS(fuse_all_reduce_op_pass);
USE_PASS(runtime_context_cache_pass); USE_PASS(runtime_context_cache_pass);
USE_PASS(set_reader_device_count_pass);
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
USE_PASS(mkldnn_placement_pass); USE_PASS(mkldnn_placement_pass);
#endif #endif
......
...@@ -11,6 +11,7 @@ endif() ...@@ -11,6 +11,7 @@ endif()
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle fetch_barrier_op_handle ${ALL_REDUCE_OP_HANDLES} reduce_op_handle broadcast_op_handle fused_broadcast_op_handle) scale_loss_grad_op_handle rpc_op_handle fetch_barrier_op_handle ${ALL_REDUCE_OP_HANDLES} reduce_op_handle broadcast_op_handle fused_broadcast_op_handle)
cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass) cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass)
cc_library(set_reader_device_count_pass SRCS set_reader_device_count_pass.cc DEPS graph graph_helper pass multi_devices_graph_pass)
cc_library(fuse_all_reduce_op_pass SRCS fuse_all_reduce_op_pass.cc DEPS graph graph_helper fused_all_reduce_op_handle) cc_library(fuse_all_reduce_op_pass SRCS fuse_all_reduce_op_pass.cc DEPS graph graph_helper fused_all_reduce_op_handle)
cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS all_reduce_op_handle graph graph_helper pass) cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS all_reduce_op_handle graph graph_helper pass)
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class SetReaderDeviceCountPass : public Pass {
protected:
void ApplyImpl(Graph *graph) const override;
private:
int GetDeviceCount() const;
std::unordered_set<std::string> ReaderOpSet() const;
};
int SetReaderDeviceCountPass::GetDeviceCount() const {
return static_cast<int>(
Get<const std::vector<platform::Place>>(details::kPlaces).size());
}
std::unordered_set<std::string> SetReaderDeviceCountPass::ReaderOpSet() const {
return {"create_py_reader"};
}
void SetReaderDeviceCountPass::ApplyImpl(Graph *graph) const {
auto dev_cnt = GetDeviceCount();
auto reader_ops = ReaderOpSet();
size_t found_op_num = 0;
for (auto &node : graph->Nodes()) {
if (node->IsOp() && node->Op() &&
reader_ops.count(node->Op()->Type()) != 0) {
auto &op_handle = dynamic_cast<details::ComputationOpHandle &>(
node->Wrapper<details::OpHandleBase>());
auto *op_desc = node->Op();
auto &op_base_attrs =
const_cast<framework::AttributeMap &>(op_handle.GetOp()->Attrs());
int dev_idx = static_cast<int>(op_handle.GetScopeIdx());
op_desc->SetAttr("device_index", dev_idx);
op_desc->SetAttr("device_count", dev_cnt);
op_base_attrs["device_index"] = dev_idx;
op_base_attrs["device_count"] = dev_cnt;
++found_op_num;
VLOG(10) << "Found op " << op_desc->Type() << " on device " << dev_idx;
}
}
VLOG(10) << "Found op number " << found_op_num;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(set_reader_device_count_pass,
paddle::framework::ir::SetReaderDeviceCountPass)
.RequirePassAttr(paddle::framework::details::kPlaces);
...@@ -56,6 +56,7 @@ class CudnnRNNCache; ...@@ -56,6 +56,7 @@ class CudnnRNNCache;
namespace reader { namespace reader {
class LoDTensorBlockingQueueHolder; class LoDTensorBlockingQueueHolder;
class OrderedMultiDeviceLoDTensorBlockingQueueHolder;
} // namespace reader } // namespace reader
} // namespace operators } // namespace operators
...@@ -139,6 +140,7 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl< ...@@ -139,6 +140,7 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl<
Tensor, LoDTensor, SelectedRows, std::vector<Scope *>, LoDRankTable, Tensor, LoDTensor, SelectedRows, std::vector<Scope *>, LoDRankTable,
LoDTensorArray, platform::PlaceList, ReaderHolder, std::string, Scope *, LoDTensorArray, platform::PlaceList, ReaderHolder, std::string, Scope *,
operators::reader::LoDTensorBlockingQueueHolder, operators::reader::LoDTensorBlockingQueueHolder,
operators::reader::OrderedMultiDeviceLoDTensorBlockingQueueHolder,
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#ifndef _WIN32 #ifndef _WIN32
ncclUniqueId, platform::Communicator, platform::NCCLCommunicator, ncclUniqueId, platform::Communicator, platform::NCCLCommunicator,
......
...@@ -38,8 +38,21 @@ class CreatePyReaderOp : public framework::OperatorBase { ...@@ -38,8 +38,21 @@ class CreatePyReaderOp : public framework::OperatorBase {
queue_holder_var, queue_holder_var,
"No LoDTensorBlockingQueueHolder variable with name %s found", "No LoDTensorBlockingQueueHolder variable with name %s found",
queue_name); queue_name);
auto* queue_holder = std::shared_ptr<LoDTensorBlockingQueue> queue;
queue_holder_var->template GetMutable<LoDTensorBlockingQueueHolder>(); std::shared_ptr<OrderedMultiDeviceLoDTensorBlockingQueue> ordered_queue;
if (queue_holder_var->IsType<LoDTensorBlockingQueueHolder>()) {
queue = queue_holder_var->Get<LoDTensorBlockingQueueHolder>().GetQueue();
} else if (queue_holder_var
->IsType<OrderedMultiDeviceLoDTensorBlockingQueueHolder>()) {
auto* queue_holder =
queue_holder_var
->GetMutable<OrderedMultiDeviceLoDTensorBlockingQueueHolder>();
auto dev_cnt = Attr<int>("device_count");
auto dev_idx = static_cast<size_t>(Attr<int>("device_index"));
ordered_queue = queue_holder->GetQueue();
ordered_queue->InitOnce(dev_cnt);
queue = ordered_queue->GetQueue(dev_idx);
}
/* Coverting shape_concat and ranks into DDim of each data. /* Coverting shape_concat and ranks into DDim of each data.
shape_concat and ranks are shapes and shape ranks of each data.E.g. shape_concat and ranks are shapes and shape ranks of each data.E.g.
...@@ -71,8 +84,20 @@ class CreatePyReaderOp : public framework::OperatorBase { ...@@ -71,8 +84,20 @@ class CreatePyReaderOp : public framework::OperatorBase {
for (size_t i = 0; i < need_check_feed_int.size(); ++i) { for (size_t i = 0; i < need_check_feed_int.size(); ++i) {
need_check_feed.push_back(static_cast<bool>(need_check_feed_int[i])); need_check_feed.push_back(static_cast<bool>(need_check_feed_int[i]));
} }
out->Reset(std::make_shared<PyReader>(queue_holder->GetQueue(), dims, auto py_reader =
var_types, need_check_feed)); std::make_shared<PyReader>(queue, dims, var_types, need_check_feed);
if (ordered_queue) {
ordered_queue->AddResetMethod([py_reader] {
auto end_readers = py_reader->GetEndPoints();
for (auto* reader : end_readers) {
reader->Shutdown();
}
for (auto* reader : end_readers) {
reader->Start();
}
});
}
out->Reset(py_reader);
} }
}; };
...@@ -82,6 +107,12 @@ class CreatePyReaderOpMaker : public FileReaderMakerBase { ...@@ -82,6 +107,12 @@ class CreatePyReaderOpMaker : public FileReaderMakerBase {
AddInput("blocking_queue", AddInput("blocking_queue",
"Name of the `LoDTensorBlockingQueueHolder` variable"); "Name of the `LoDTensorBlockingQueueHolder` variable");
AddAttr<int>("device_index", "The device index this reader offers data")
.SetDefault(0);
AddAttr<int>("device_count",
"The total number of devices the reader offers data")
.SetDefault(1);
AddComment(R"DOC( AddComment(R"DOC(
Create PyReader to support LoDTensor data feeding in Python side. Create PyReader to support LoDTensor data feeding in Python side.
)DOC"); )DOC");
......
...@@ -27,16 +27,11 @@ namespace paddle { ...@@ -27,16 +27,11 @@ namespace paddle {
namespace operators { namespace operators {
namespace reader { namespace reader {
class LoDTensorBlockingQueueHolder;
class LoDTensorBlockingQueue { class LoDTensorBlockingQueue {
friend class LoDTensorBlockingQueueHolder; public:
private:
explicit LoDTensorBlockingQueue(size_t capacity, bool speed_test_mode = false) explicit LoDTensorBlockingQueue(size_t capacity, bool speed_test_mode = false)
: queue_(capacity, speed_test_mode) {} : queue_(capacity, speed_test_mode) {}
public:
bool Push(const std::vector<framework::LoDTensor>& lod_tensor_vec) { bool Push(const std::vector<framework::LoDTensor>& lod_tensor_vec) {
return queue_.Send(lod_tensor_vec); return queue_.Send(lod_tensor_vec);
} }
...@@ -67,10 +62,145 @@ class LoDTensorBlockingQueue { ...@@ -67,10 +62,145 @@ class LoDTensorBlockingQueue {
inline void Kill() { queue_.Kill(); } inline void Kill() { queue_.Kill(); }
inline bool WaitForInited() { return true; }
private: private:
BlockingQueue<std::vector<framework::LoDTensor>> queue_; BlockingQueue<std::vector<framework::LoDTensor>> queue_;
}; };
class OrderedMultiDeviceLoDTensorBlockingQueue {
public:
OrderedMultiDeviceLoDTensorBlockingQueue(size_t capacity,
bool speed_test_mode = false)
: capacity_(capacity), speed_test_mode_(speed_test_mode) {}
inline bool WaitForInited() {
std::unique_lock<std::mutex> lock(init_mutex_);
cv_.wait(lock, [this] { return queues_ != nullptr || is_closing_; });
is_closing_ = false;
return queues_ != nullptr;
}
inline void InitOnce(size_t dev_cnt) {
PADDLE_ENFORCE_GE(dev_cnt, 1, platform::errors::InvalidArgument(
"Device count to init "
"OrderedMultiDeviceLoDTensorBlockingQueue"
" must be larger than 1"));
VLOG(3) << "Ordered queue init start";
{
std::lock_guard<std::mutex> lock(init_mutex_);
if (queues_) {
PADDLE_ENFORCE_EQ(queues_->size(), dev_cnt,
platform::errors::InvalidArgument(
"Device count to init queue must be equal"));
} else {
queues_.reset(
new std::vector<std::shared_ptr<LoDTensorBlockingQueue>>(dev_cnt));
for (auto& item : *queues_) {
auto cap = (capacity_ + dev_cnt - 1) / dev_cnt;
item.reset(new LoDTensorBlockingQueue(cap, speed_test_mode_));
}
}
}
VLOG(3) << "Ordered queue init finish";
cv_.notify_all();
}
const std::shared_ptr<LoDTensorBlockingQueue>& GetQueue(size_t idx) const {
std::lock_guard<std::mutex> lock(init_mutex_);
PADDLE_ENFORCE_NOT_NULL(queues_,
platform::errors::NotFound(
"Queues must be inited first before getting"));
PADDLE_ENFORCE_LT(
idx, queues_->size(),
platform::errors::OutOfRange("The queue index is out of range"));
return (*queues_)[idx];
}
bool Push(const std::vector<framework::LoDTensor>& lod_tensor_vec) {
return CurQueue()->Push(lod_tensor_vec);
}
bool Push(std::vector<framework::LoDTensor>&& lod_tensor_vec) {
return CurQueue()->Push(std::move(lod_tensor_vec));
}
inline size_t Cap() const { return capacity_; }
inline size_t Size() const {
size_t size = 0;
if (queues_) {
for (auto& item : *queues_) {
size += item->Size();
}
}
return size;
}
inline void ReOpen() {
if (queues_) {
for (auto& item : *queues_) {
item->ReOpen();
}
}
data_index_ = 0;
}
inline void Close() {
{
std::lock_guard<std::mutex> lock(init_mutex_);
if (queues_ == nullptr) {
is_closing_ = true;
}
}
cv_.notify_all();
if (queues_) {
for (auto& item : *queues_) {
item->Close();
}
}
}
inline void Kill() {
if (queues_) {
for (auto& item : *queues_) {
item->Kill();
}
}
}
inline void Reset() {
std::lock_guard<std::mutex> reset_lock(reset_mutex_);
for (auto& method : reset_methods_) {
method();
}
data_index_ = 0;
}
inline void AddResetMethod(const std::function<void()>& reset_method) {
std::lock_guard<std::mutex> reset_lock(reset_mutex_);
reset_methods_.emplace_back(reset_method);
}
private:
const std::shared_ptr<LoDTensorBlockingQueue>& CurQueue() {
return (*queues_)[data_index_.fetch_add(1) % queues_->size()];
}
private:
std::unique_ptr<std::vector<std::shared_ptr<LoDTensorBlockingQueue>>> queues_;
mutable std::atomic<uint64_t> data_index_{0};
const size_t capacity_;
const bool speed_test_mode_;
std::vector<std::function<void()>> reset_methods_;
mutable std::mutex reset_mutex_;
bool is_closing_{false};
mutable std::mutex init_mutex_;
mutable std::condition_variable cv_;
};
class LoDTensorBlockingQueueHolder { class LoDTensorBlockingQueueHolder {
public: public:
void InitOnce(size_t capacity, bool speed_test_mode = false) { void InitOnce(size_t capacity, bool speed_test_mode = false) {
...@@ -88,6 +218,26 @@ class LoDTensorBlockingQueueHolder { ...@@ -88,6 +218,26 @@ class LoDTensorBlockingQueueHolder {
std::shared_ptr<LoDTensorBlockingQueue> queue_; std::shared_ptr<LoDTensorBlockingQueue> queue_;
}; };
class OrderedMultiDeviceLoDTensorBlockingQueueHolder {
public:
void InitOnce(size_t capacity, bool speed_test_mode = false) {
PADDLE_ENFORCE_EQ(queue_, nullptr,
platform::errors::AlreadyExists(
"OrderedMultiDeviceLoDTensorBlockingQueueHolder::"
"InitOnce() can only be called once"));
queue_.reset(new OrderedMultiDeviceLoDTensorBlockingQueue(capacity,
speed_test_mode));
}
inline const std::shared_ptr<OrderedMultiDeviceLoDTensorBlockingQueue>&
GetQueue() const {
return queue_;
}
private:
std::shared_ptr<OrderedMultiDeviceLoDTensorBlockingQueue> queue_;
};
} // namespace reader } // namespace reader
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -48,7 +48,6 @@ limitations under the License. */ ...@@ -48,7 +48,6 @@ limitations under the License. */
#include "paddle/fluid/memory/allocation/allocator_strategy.h" #include "paddle/fluid/memory/allocation/allocator_strategy.h"
#include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/py_func_op.h" #include "paddle/fluid/operators/py_func_op.h"
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h"
#include "paddle/fluid/platform/cpu_helper.h" #include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/dynload/dynamic_loader.h" #include "paddle/fluid/platform/dynload/dynamic_loader.h"
...@@ -91,9 +90,6 @@ limitations under the License. */ ...@@ -91,9 +90,6 @@ limitations under the License. */
#include "pybind11/stl.h" #include "pybind11/stl.h"
DEFINE_bool(reader_queue_speed_test_mode, false,
"If set true, the queue.pop will only get data from queue but not "
"remove the data from queue for speed testing");
DECLARE_bool(use_mkldnn); DECLARE_bool(use_mkldnn);
#ifdef PADDLE_WITH_NGRAPH #ifdef PADDLE_WITH_NGRAPH
DECLARE_bool(use_ngraph); DECLARE_bool(use_ngraph);
...@@ -942,35 +938,6 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -942,35 +938,6 @@ All parameter, weight, gradient are variables in Paddle.
BindReader(&m); BindReader(&m);
using LoDTensorBlockingQueue =
::paddle::operators::reader::LoDTensorBlockingQueue;
using LoDTensorBlockingQueueHolder =
::paddle::operators::reader::LoDTensorBlockingQueueHolder;
py::class_<LoDTensorBlockingQueue, std::shared_ptr<LoDTensorBlockingQueue>>(
m, "LoDTensorBlockingQueue", "")
.def("push",
[](LoDTensorBlockingQueue &self,
const std::vector<framework::LoDTensor> &lod_tensor_vec) {
pybind11::gil_scoped_release release;
return self.Push(lod_tensor_vec);
})
.def("size", &LoDTensorBlockingQueue::Size)
.def("capacity", &LoDTensorBlockingQueue::Cap)
.def("close", &LoDTensorBlockingQueue::Close)
.def("kill", &LoDTensorBlockingQueue::Kill)
.def("is_closed", &LoDTensorBlockingQueue::IsClosed);
m.def("init_lod_tensor_blocking_queue",
[](Variable &var,
size_t capacity) -> std::shared_ptr<LoDTensorBlockingQueue> {
VLOG(1) << "init_lod_tensor_blocking_queue";
auto *holder = var.GetMutable<LoDTensorBlockingQueueHolder>();
holder->InitOnce(capacity, FLAGS_reader_queue_speed_test_mode);
return holder->GetQueue();
},
py::return_value_policy::copy);
py::class_<Scope>(m, "_Scope", R"DOC( py::class_<Scope>(m, "_Scope", R"DOC(
Scope is an association of a name to Variable. All variables belong to Scope. Scope is an association of a name to Variable. All variables belong to Scope.
......
...@@ -20,20 +20,40 @@ ...@@ -20,20 +20,40 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "Python.h" #include "Python.h"
#include "gflags/gflags.h"
#include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/imperative/tracer.h" #include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/operators/reader/buffered_reader.h" #include "paddle/fluid/operators/reader/buffered_reader.h"
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h"
#include "paddle/fluid/operators/reader/py_reader.h" #include "paddle/fluid/operators/reader/py_reader.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "pybind11/stl.h" #include "pybind11/stl.h"
DEFINE_bool(reader_queue_speed_test_mode, false,
"If set true, the queue.pop will only get data from queue but not "
"remove the data from queue for speed testing");
namespace paddle { namespace paddle {
namespace pybind { namespace pybind {
namespace py = pybind11; namespace py = pybind11;
namespace reader = operators::reader;
static const std::shared_ptr<reader::LoDTensorBlockingQueue> &GetQueue(
const std::shared_ptr<reader::LoDTensorBlockingQueue> &queue, size_t idx) {
return queue;
}
static const std::shared_ptr<reader::LoDTensorBlockingQueue> &GetQueue(
const std::shared_ptr<reader::OrderedMultiDeviceLoDTensorBlockingQueue>
&queue,
size_t idx) {
return queue->GetQueue(idx);
}
template <typename QueueType>
class MultiDeviceFeedReader { class MultiDeviceFeedReader {
public: public:
using ResultDictList = using ResultDictList =
...@@ -41,7 +61,7 @@ class MultiDeviceFeedReader { ...@@ -41,7 +61,7 @@ class MultiDeviceFeedReader {
using ResultList = std::vector<std::vector<framework::LoDTensor>>; using ResultList = std::vector<std::vector<framework::LoDTensor>>;
MultiDeviceFeedReader( MultiDeviceFeedReader(
const std::shared_ptr<operators::reader::LoDTensorBlockingQueue> &queue, const std::shared_ptr<QueueType> &queue,
const std::vector<std::string> &names, const std::vector<std::string> &names,
const std::vector<std::vector<int>> &shapes, const std::vector<std::vector<int>> &shapes,
const std::vector<framework::proto::VarType::Type> &dtypes, const std::vector<framework::proto::VarType::Type> &dtypes,
...@@ -54,12 +74,25 @@ class MultiDeviceFeedReader { ...@@ -54,12 +74,25 @@ class MultiDeviceFeedReader {
for (auto &shape : shapes) { for (auto &shape : shapes) {
dims.push_back(framework::make_ddim(shape)); dims.push_back(framework::make_ddim(shape));
} }
std::shared_ptr<framework::ReaderBase> reader(
new operators::reader::PyReader(queue, dims, dtypes, need_check_feed)); auto first_reader = std::make_shared<reader::PyReader>(
GetQueue(queue, 0), dims, dtypes, need_check_feed);
auto create_or_get_reader = [&](size_t idx) {
if (idx == 0 ||
std::is_same<QueueType, reader::LoDTensorBlockingQueue>::value) {
return first_reader;
} else {
return std::make_shared<reader::PyReader>(GetQueue(queue, idx), dims,
dtypes, need_check_feed);
}
};
readers_.reserve(dst_places.size()); readers_.reserve(dst_places.size());
for (auto &p : dst_places) { for (size_t i = 0; i < dst_places.size(); ++i) {
auto &p = dst_places[i];
auto *holder = new framework::ReaderHolder(); auto *holder = new framework::ReaderHolder();
auto reader = create_or_get_reader(i);
if (use_double_buffer) { if (use_double_buffer) {
holder->Reset( holder->Reset(
framework::MakeDecoratedReader<operators::reader::BufferedReader>( framework::MakeDecoratedReader<operators::reader::BufferedReader>(
...@@ -183,7 +216,7 @@ class MultiDeviceFeedReader { ...@@ -183,7 +216,7 @@ class MultiDeviceFeedReader {
PADDLE_ENFORCE_EQ(status, Status::kSuccess); PADDLE_ENFORCE_EQ(status, Status::kSuccess);
} }
std::shared_ptr<operators::reader::LoDTensorBlockingQueue> queue_; std::shared_ptr<QueueType> queue_;
std::vector<std::string> names_; std::vector<std::string> names_;
std::unique_ptr<::ThreadPool> pool_; std::unique_ptr<::ThreadPool> pool_;
...@@ -195,22 +228,18 @@ class MultiDeviceFeedReader { ...@@ -195,22 +228,18 @@ class MultiDeviceFeedReader {
std::vector<std::vector<framework::LoDTensor>> ret_; std::vector<std::vector<framework::LoDTensor>> ret_;
}; };
void BindReader(py::module *module) { template <typename QueueType>
void BindMultiDeviceReader(py::module *module, const char *reader_name) {
auto &m = *module; auto &m = *module;
namespace reader = ::paddle::operators::reader; using ReaderType = MultiDeviceFeedReader<QueueType>;
py::class_<ReaderType>(m, reader_name, "")
py::class_<framework::ReaderHolder>(m, "Reader", "") .def("read_next", &ReaderType::ReadNext,
.def("start", &framework::ReaderHolder::Start)
.def("reset", &framework::ReaderHolder::ResetAll);
py::class_<MultiDeviceFeedReader>(m, "MultiDeviceFeedReader", "")
.def("read_next", &MultiDeviceFeedReader::ReadNext,
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("read_next_list", &MultiDeviceFeedReader::ReadNextList, .def("read_next_list", &ReaderType::ReadNextList,
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("read_next_var_list", .def("read_next_var_list",
[](MultiDeviceFeedReader &self) { [](ReaderType &self) {
auto result_list = self.ReadNextList(); auto result_list = self.ReadNextList();
auto &tensor_list = result_list[0]; auto &tensor_list = result_list[0];
std::vector<std::shared_ptr<imperative::VarBase>> var_list; std::vector<std::shared_ptr<imperative::VarBase>> var_list;
...@@ -234,23 +263,105 @@ void BindReader(py::module *module) { ...@@ -234,23 +263,105 @@ void BindReader(py::module *module) {
return var_list; return var_list;
}, },
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("reset", &MultiDeviceFeedReader::Reset, .def("reset", &ReaderType::Reset,
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
}
void BindReader(py::module *module) {
auto &m = *module;
m.def("init_lod_tensor_blocking_queue",
[](framework::Variable &var, size_t capacity,
bool is_ordered) -> py::object {
VLOG(1) << "init_lod_tensor_blocking_queue";
if (is_ordered) {
auto *holder = var.GetMutable<
reader::OrderedMultiDeviceLoDTensorBlockingQueueHolder>();
holder->InitOnce(capacity, FLAGS_reader_queue_speed_test_mode);
return py::cast(holder->GetQueue());
} else {
auto *holder =
var.GetMutable<reader::LoDTensorBlockingQueueHolder>();
holder->InitOnce(capacity, FLAGS_reader_queue_speed_test_mode);
return py::cast(holder->GetQueue());
}
},
py::return_value_policy::copy);
py::class_<framework::ReaderHolder>(m, "Reader", "")
.def("start", &framework::ReaderHolder::Start)
.def("reset", &framework::ReaderHolder::ResetAll);
py::class_<reader::LoDTensorBlockingQueue,
std::shared_ptr<reader::LoDTensorBlockingQueue>>(
m, "LoDTensorBlockingQueue", "")
.def("push",
[](reader::LoDTensorBlockingQueue &self,
const std::vector<framework::LoDTensor> &lod_tensor_vec) {
return self.Push(lod_tensor_vec);
},
py::call_guard<py::gil_scoped_release>())
.def("size", &reader::LoDTensorBlockingQueue::Size)
.def("capacity", &reader::LoDTensorBlockingQueue::Cap)
.def("close", &reader::LoDTensorBlockingQueue::Close)
.def("kill", &reader::LoDTensorBlockingQueue::Kill)
.def("wait_for_inited", &reader::LoDTensorBlockingQueue::WaitForInited,
py::call_guard<py::gil_scoped_release>());
py::class_<reader::OrderedMultiDeviceLoDTensorBlockingQueue,
std::shared_ptr<reader::OrderedMultiDeviceLoDTensorBlockingQueue>>(
m, "OrderedMultiDeviceLoDTensorBlockingQueue", "")
.def("push",
[](reader::OrderedMultiDeviceLoDTensorBlockingQueue &self,
const std::vector<framework::LoDTensor> &lod_tensor_vec) {
return self.Push(lod_tensor_vec);
},
py::call_guard<py::gil_scoped_release>())
.def("size", &reader::OrderedMultiDeviceLoDTensorBlockingQueue::Size)
.def("capacity", &reader::OrderedMultiDeviceLoDTensorBlockingQueue::Cap)
.def("close", &reader::OrderedMultiDeviceLoDTensorBlockingQueue::Close)
.def("kill", &reader::OrderedMultiDeviceLoDTensorBlockingQueue::Kill)
.def("wait_for_inited",
&reader::OrderedMultiDeviceLoDTensorBlockingQueue::WaitForInited,
py::call_guard<py::gil_scoped_release>())
.def("reset", &reader::OrderedMultiDeviceLoDTensorBlockingQueue::Reset);
BindMultiDeviceReader<reader::LoDTensorBlockingQueue>(
module, "MultiDeviceFeedReader");
BindMultiDeviceReader<reader::OrderedMultiDeviceLoDTensorBlockingQueue>(
module, "OrderedMultiDeviceFeedReader");
m.def("create_py_reader", m.def("create_py_reader",
[](const std::shared_ptr<operators::reader::LoDTensorBlockingQueue> [](const std::shared_ptr<reader::LoDTensorBlockingQueue> &queue,
&queue,
const std::vector<std::string> &names, const std::vector<std::string> &names,
const std::vector<std::vector<int>> &shapes, const std::vector<std::vector<int>> &shapes,
const std::vector<framework::proto::VarType::Type> &dtypes, const std::vector<framework::proto::VarType::Type> &dtypes,
const std::vector<bool> &need_check_feed, const std::vector<bool> &need_check_feed,
const std::vector<platform::Place> &dst_places, const std::vector<platform::Place> &dst_places,
bool use_double_buffer) { bool use_double_buffer) {
return new MultiDeviceFeedReader(queue, names, shapes, dtypes, return new MultiDeviceFeedReader<reader::LoDTensorBlockingQueue>(
need_check_feed, dst_places, queue, names, shapes, dtypes, need_check_feed, dst_places,
use_double_buffer); use_double_buffer);
}, },
py::return_value_policy::take_ownership); py::return_value_policy::take_ownership);
m.def(
"create_py_reader",
[](const std::shared_ptr<reader::OrderedMultiDeviceLoDTensorBlockingQueue>
&queue,
const std::vector<std::string> &names,
const std::vector<std::vector<int>> &shapes,
const std::vector<framework::proto::VarType::Type> &dtypes,
const std::vector<bool> &need_check_feed,
const std::vector<platform::Place> &dst_places,
bool use_double_buffer) {
queue->InitOnce(dst_places.size());
return new MultiDeviceFeedReader<
reader::OrderedMultiDeviceLoDTensorBlockingQueue>(
queue, names, shapes, dtypes, need_check_feed, dst_places,
use_double_buffer);
},
py::return_value_policy::take_ownership);
} }
} // namespace pybind } // namespace pybind
......
...@@ -429,7 +429,7 @@ def _py_reader(capacity, ...@@ -429,7 +429,7 @@ def _py_reader(capacity,
double_buffer_name = "_".join([name, "double_buffer"]) double_buffer_name = "_".join([name, "double_buffer"])
var = global_scope().var(queue_name) var = global_scope().var(queue_name)
feed_queue = core.init_lod_tensor_blocking_queue(var, capacity) feed_queue = core.init_lod_tensor_blocking_queue(var, capacity, False)
startup_blk = default_startup_program().current_block() startup_blk = default_startup_program().current_block()
startup_var = startup_blk.create_var(name=reader_name) startup_var = startup_blk.create_var(name=reader_name)
......
...@@ -87,7 +87,8 @@ class DataLoader(object): ...@@ -87,7 +87,8 @@ class DataLoader(object):
use_double_buffer=True, use_double_buffer=True,
iterable=True, iterable=True,
return_list=False, return_list=False,
use_multiprocess=False): use_multiprocess=False,
keep_order=False):
""" """
Create a DataLoader object for loading data from Python generator. Create a DataLoader object for loading data from Python generator.
Data would be prefetched using Python thread and be pushed Data would be prefetched using Python thread and be pushed
...@@ -133,6 +134,15 @@ class DataLoader(object): ...@@ -133,6 +134,15 @@ class DataLoader(object):
can be used in the dygraph mode. In the static graph mode, can be used in the dygraph mode. In the static graph mode,
whether this parameter is set or not has no effect. whether this parameter is set or not has no effect.
The Default value is False. The Default value is False.
keep_order (bool): whether to assign the data to CPU cores or GPU
cards in order. Supposing that there are 2 batches and we use
2 GPU cards to run the network. If keep_order=True, GPU 0 would
get batch 0 and GPU 1 would get batch 1 exactly. If
keep_order=False, GPU 0 may get batch 0 or may get batch 1, and
GPU 1 may get the rest of the data, which is uncertain. If
keep_order=True, the framework may do some synchronization to
keep the reading order, which may be slower. The default value
is False.
Returns: Returns:
loader (DataLoader): the created DataLoader object. loader (DataLoader): the created DataLoader object.
...@@ -271,12 +281,15 @@ class DataLoader(object): ...@@ -271,12 +281,15 @@ class DataLoader(object):
assert relu.shape == [BATCH_SIZE, 784] assert relu.shape == [BATCH_SIZE, 784]
""" """
if in_dygraph_mode(): if in_dygraph_mode():
# Dygraph only support multiprocess training when using multi GPUs.
# So in each process, we only use 1 GPU card to train the network,
# so `keep_order` would also be True.
return DygraphGeneratorLoader(feed_list, capacity, return DygraphGeneratorLoader(feed_list, capacity,
use_double_buffer, iterable, use_double_buffer, iterable,
return_list, use_multiprocess) return_list, use_multiprocess)
else: else:
return GeneratorLoader(feed_list, capacity, use_double_buffer, return GeneratorLoader(feed_list, capacity, use_double_buffer,
iterable, return_list) iterable, return_list, keep_order)
@staticmethod @staticmethod
def from_dataset(dataset, places, drop_last=True): def from_dataset(dataset, places, drop_last=True):
...@@ -334,6 +347,7 @@ class DygraphGeneratorLoader(DataLoaderBase): ...@@ -334,6 +347,7 @@ class DygraphGeneratorLoader(DataLoaderBase):
self._batch_reader = None self._batch_reader = None
self._places = None self._places = None
self._feed_list = feed_list self._feed_list = feed_list
self._keep_order = True
if not capacity: if not capacity:
raise ValueError("Please give value to capacity.") raise ValueError("Please give value to capacity.")
...@@ -406,7 +420,7 @@ class DygraphGeneratorLoader(DataLoaderBase): ...@@ -406,7 +420,7 @@ class DygraphGeneratorLoader(DataLoaderBase):
self._dtypes = [] self._dtypes = []
self._need_check_feed = [] self._need_check_feed = []
self._blocking_queue = core.init_lod_tensor_blocking_queue( self._blocking_queue = core.init_lod_tensor_blocking_queue(
core.Variable(), self._capacity) core.Variable(), self._capacity, self._keep_order)
self._reader = core.create_py_reader( self._reader = core.create_py_reader(
self.queue, self._var_names, self._shapes, self._dtypes, self.queue, self._var_names, self._shapes, self._dtypes,
self._need_check_feed, self._places, self._use_double_buffer) self._need_check_feed, self._places, self._use_double_buffer)
...@@ -614,7 +628,8 @@ class GeneratorLoader(DataLoaderBase): ...@@ -614,7 +628,8 @@ class GeneratorLoader(DataLoaderBase):
capacity=None, capacity=None,
use_double_buffer=True, use_double_buffer=True,
iterable=True, iterable=True,
return_list=False): return_list=False,
keep_order=False):
self._tensor_reader = None self._tensor_reader = None
self._places = None self._places = None
self._thread = None self._thread = None
...@@ -628,6 +643,7 @@ class GeneratorLoader(DataLoaderBase): ...@@ -628,6 +643,7 @@ class GeneratorLoader(DataLoaderBase):
raise Exception("Feed list must be given under static mode.") raise Exception("Feed list must be given under static mode.")
self._use_double_buffer = use_double_buffer self._use_double_buffer = use_double_buffer
self._capacity = capacity self._capacity = capacity
self._keep_order = keep_order
if not self._iterable: if not self._iterable:
self._init_non_iterable() self._init_non_iterable()
...@@ -647,8 +663,8 @@ class GeneratorLoader(DataLoaderBase): ...@@ -647,8 +663,8 @@ class GeneratorLoader(DataLoaderBase):
self._need_check_feed = [ self._need_check_feed = [
v.desc.need_check_feed() for v in self._feed_list v.desc.need_check_feed() for v in self._feed_list
] ]
self._queue = core.init_lod_tensor_blocking_queue(core.Variable(), self._queue = core.init_lod_tensor_blocking_queue(
self._capacity) core.Variable(), self._capacity, self._keep_order)
self._reader = core.create_py_reader( self._reader = core.create_py_reader(
self.queue, self._var_names, self._shapes, self._dtypes, self.queue, self._var_names, self._shapes, self._dtypes,
self._need_check_feed, self._places, self._use_double_buffer) self._need_check_feed, self._places, self._use_double_buffer)
...@@ -675,16 +691,21 @@ class GeneratorLoader(DataLoaderBase): ...@@ -675,16 +691,21 @@ class GeneratorLoader(DataLoaderBase):
double_buffer_name = data_loader_unique_name_generator('double_buffer') double_buffer_name = data_loader_unique_name_generator('double_buffer')
var = global_scope().var(queue_name) var = global_scope().var(queue_name)
self._queue = core.init_lod_tensor_blocking_queue(var, self._capacity) self._queue = core.init_lod_tensor_blocking_queue(var, self._capacity,
self._keep_order)
if self._keep_order:
block = default_main_program().current_block()
else:
block = default_startup_program().current_block()
startup_blk = default_startup_program().current_block() reader_var = block.create_var(name=reader_name)
startup_var = startup_blk.create_var(name=reader_name)
dtype_int = [int(t) for t in dtypes] dtype_int = [int(t) for t in dtypes]
startup_blk.append_op( block.append_op(
type='create_py_reader', type='create_py_reader',
inputs={'blocking_queue': [queue_name]}, inputs={'blocking_queue': [queue_name]},
outputs={'Out': [startup_var]}, outputs={'Out': [reader_var]},
attrs={ attrs={
'shape_concat': shape_concat, 'shape_concat': shape_concat,
'lod_levels': lod_levels, 'lod_levels': lod_levels,
...@@ -693,16 +714,23 @@ class GeneratorLoader(DataLoaderBase): ...@@ -693,16 +714,23 @@ class GeneratorLoader(DataLoaderBase):
'ranks': ranks 'ranks': ranks
}) })
startup_var.desc.set_dtypes(dtypes) reader_var.desc.set_dtypes(dtypes)
startup_var.persistable = True reader_var.persistable = True
reader_var.stop_gradient = True
main_prog_var = _copy_reader_var_( if self._keep_order:
default_main_program().current_block(), startup_var) main_prog_var = reader_var
reader = main_prog_var
reader.reset = self._queue.reset
else:
main_prog_var = _copy_reader_var_(
default_main_program().current_block(), reader_var)
main_prog_var.stop_gradient = True main_prog_var.stop_gradient = True
main_prog_var.persistable = True main_prog_var.persistable = True
reader = monkey_patch_reader_methods(main_prog_var)
reader = monkey_patch_reader_methods(main_prog_var)
if self._use_double_buffer: if self._use_double_buffer:
double_buffer_reader = double_buffer( double_buffer_reader = double_buffer(
reader, name=double_buffer_name) reader, name=double_buffer_name)
...@@ -765,14 +793,19 @@ class GeneratorLoader(DataLoaderBase): ...@@ -765,14 +793,19 @@ class GeneratorLoader(DataLoaderBase):
" to locate the data causes this issue.\n\t* Please consider using " " to locate the data causes this issue.\n\t* Please consider using "
"'fluid.create_lod_tensor' to convert it to a LoD-Tensor.")) "'fluid.create_lod_tensor' to convert it to a LoD-Tensor."))
return arr
def _start(self): def _start(self):
def __thread_main__(): def __thread_main__():
try: try:
if not self._queue.wait_for_inited():
return
for tensors in self._tensor_reader(): for tensors in self._tensor_reader():
array = core.LoDTensorArray() array = core.LoDTensorArray()
for item in tensors: for item in tensors:
if not isinstance(item, core.LoDTensor): if not isinstance(item, core.LoDTensor):
self._check_input_array(item) item = self._check_input_array(item)
tmp = core.LoDTensor() tmp = core.LoDTensor()
tmp.set(item, core.CPUPlace()) tmp.set(item, core.CPUPlace())
item = tmp item = tmp
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import unittest
import numpy as np
import os
import six
def create_reader(shape, batch_number):
def __impl__():
idx = 0
for _ in six.moves.range(batch_number):
yield np.ones(shape).astype('float32') * idx,
idx += 1
return __impl__
class DataLoaderKeepOrderTestBase(unittest.TestCase):
def initParameters(self):
self.iterable = False
self.break_num = 10000
def setUp(self):
self.epoch_num = 3
self.batch_num = 40
self.shape = [3, 4, 5]
self.initParameters()
def build_network(self, places):
input_data = fluid.data(shape=self.shape, dtype='float32', name="input")
loader = fluid.io.DataLoader.from_generator(
capacity=16,
feed_list=[input_data],
keep_order=True,
iterable=self.iterable)
fc = fluid.layers.fc(input_data, size=10)
loss = fluid.layers.reduce_mean(fc)
loader.set_batch_generator(
create_reader(self.shape, self.batch_num),
places=places if loader.iterable else None)
return input_data, loss, loader
def assertInputData(self, batch_id, input_data, dev_cnt):
if isinstance(input_data, list):
self.assertTrue(len(input_data), dev_cnt)
start_val = dev_cnt * batch_id
for each_input_dict in input_data:
input_tensor = np.array(each_input_dict["input"])
self.assertEqual(self.shape, list(input_tensor.shape))
self.assertTrue((input_tensor == start_val).all())
start_val += 1
else:
self.assertEqual(
list(input_data.shape),
[self.shape[0] * dev_cnt] + self.shape[1:])
start_val = dev_cnt * batch_id
for idx in six.moves.range(dev_cnt):
data_part = input_data[idx * self.shape[0]:(idx + 1) *
self.shape[0], :]
self.assertTrue((data_part == start_val).all())
start_val += 1
def get_places(self):
place_list = [fluid.cpu_places(1), fluid.cpu_places(4)]
if fluid.is_compiled_with_cuda():
place_list.extend([fluid.cuda_places(0), fluid.cuda_places([0, 1])])
return place_list
def test_main(self):
for p in self.get_places():
use_compiled_program_list = [True] if len(p) > 1 else [False, True]
for use_compiled_program in use_compiled_program_list:
self.run_main_with_place(p, use_compiled_program)
def run_main_with_place(self, places, use_compiled_program=True):
with fluid.scope_guard(fluid.Scope()):
with fluid.program_guard(fluid.Program(), fluid.Program()):
input_data, loss, loader = self.build_network(places)
fetch_list = [input_data]
exe = fluid.Executor(places[0])
exe.run(fluid.default_startup_program())
dev_cnt = len(places)
if dev_cnt > 1:
self.assertTrue(use_compiled_program)
main_program = fluid.default_main_program()
if use_compiled_program:
main_program = fluid.CompiledProgram(
main_program).with_data_parallel(
loss_name=loss.name, places=places)
max_batch_num = min(self.break_num,
int(self.batch_num / dev_cnt))
if loader.iterable:
early_break = False
for epoch_id in six.moves.range(self.epoch_num):
early_break = False
batch_id = 0
for data in loader():
if batch_id >= self.break_num:
early_break = True
break
self.assertInputData(batch_id, data, dev_cnt)
fetch_val, = exe.run(program=main_program,
feed=data,
fetch_list=fetch_list)
self.assertInputData(batch_id, fetch_val, dev_cnt)
batch_id += 1
self.assertEqual(batch_id, max_batch_num)
if early_break:
loader._reset()
else:
for epoch_id in six.moves.range(self.epoch_num):
batch_id = 0
loader.start()
try:
while True:
if batch_id >= self.break_num:
loader.reset()
break
fetch_val, = exe.run(program=main_program,
fetch_list=fetch_list)
self.assertInputData(batch_id, fetch_val,
dev_cnt)
batch_id += 1
except fluid.core.EOFException:
loader.reset()
self.assertEqual(batch_id, max_batch_num)
class IterableDataLoaderKeepOrderTest2(DataLoaderKeepOrderTestBase):
def initParameters(self):
self.iterable = True
self.break_num = 10000
class IterableDataLoaderKeepOrderTest3(DataLoaderKeepOrderTestBase):
def initParameters(self):
self.iterable = False
self.break_num = 2
class IterableDataLoaderKeepOrderTest4(DataLoaderKeepOrderTestBase):
def initParameters(self):
self.iterable = True
self.break_num = 2
class IterableDataLoaderKeepOrderTest5(DataLoaderKeepOrderTestBase):
def initParameters(self):
self.iterable = False
self.break_num = 0
class IterableDataLoaderKeepOrderTest6(DataLoaderKeepOrderTestBase):
def initParameters(self):
self.iterable = True
self.break_num = 0
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册