提交 7160cb0f 编写于 作者: S sneaxiy

decoupled reader

test=develop
上级 fc87ef74
...@@ -54,6 +54,7 @@ class ReaderBase { ...@@ -54,6 +54,7 @@ class ReaderBase {
private: private:
friend class DecoratedReader; friend class DecoratedReader;
friend class MultiDecoratedReader;
// These methods can be only invoked inside DecoratedReader to record the // These methods can be only invoked inside DecoratedReader to record the
// decorating chain. // decorating chain.
void InsertDecoratedReader( void InsertDecoratedReader(
...@@ -62,15 +63,20 @@ class ReaderBase { ...@@ -62,15 +63,20 @@ class ReaderBase {
std::vector<std::weak_ptr<ReaderBase>> decorated_readers_; std::vector<std::weak_ptr<ReaderBase>> decorated_readers_;
}; };
class DecoratedReader : public ReaderBase, class DecoratedReaderBase : public ReaderBase {
public:
virtual void RegisterDecorateChain() = 0;
};
class DecoratedReader : public DecoratedReaderBase,
public std::enable_shared_from_this<DecoratedReader> { public std::enable_shared_from_this<DecoratedReader> {
public: public:
explicit DecoratedReader(const std::shared_ptr<ReaderBase>& reader) explicit DecoratedReader(const std::shared_ptr<ReaderBase>& reader)
: ReaderBase(), reader_(reader) { : DecoratedReaderBase(), reader_(reader) {
PADDLE_ENFORCE_NOT_NULL(reader_); PADDLE_ENFORCE_NOT_NULL(reader_);
} }
void RegisterDecorateChain() { void RegisterDecorateChain() final {
reader_->InsertDecoratedReader(shared_from_this()); reader_->InsertDecoratedReader(shared_from_this());
} }
...@@ -84,6 +90,41 @@ class DecoratedReader : public ReaderBase, ...@@ -84,6 +90,41 @@ class DecoratedReader : public ReaderBase,
std::shared_ptr<ReaderBase> reader_; std::shared_ptr<ReaderBase> reader_;
}; };
class MultiDecoratedReader
: public DecoratedReaderBase,
public std::enable_shared_from_this<MultiDecoratedReader> {
public:
explicit MultiDecoratedReader(
const std::vector<std::shared_ptr<ReaderBase>>& readers)
: readers_(readers) {
PADDLE_ENFORCE(!readers_.empty());
for (auto& r : readers_) {
PADDLE_ENFORCE_NOT_NULL(r);
}
}
void RegisterDecorateChain() final {
for (auto& r : readers_) {
r->InsertDecoratedReader(shared_from_this());
}
}
protected:
void ShutdownImpl() override {
for (auto& r : readers_) {
r->Shutdown();
}
}
void StartImpl() override {
for (auto& r : readers_) {
r->Start();
}
}
std::vector<std::shared_ptr<ReaderBase>> readers_;
};
// FileReader is just a conceptual class. // FileReader is just a conceptual class.
class FileReader : public ReaderBase {}; class FileReader : public ReaderBase {};
...@@ -132,8 +173,10 @@ class ReaderHolder { ...@@ -132,8 +173,10 @@ class ReaderHolder {
}; };
template <typename T, typename... ARGS> template <typename T, typename... ARGS>
inline std::shared_ptr<DecoratedReader> MakeDecoratedReader(ARGS&&... args) { inline std::shared_ptr<DecoratedReaderBase> MakeDecoratedReader(
std::shared_ptr<DecoratedReader> reader(new T(std::forward<ARGS>(args)...)); ARGS&&... args) {
std::shared_ptr<DecoratedReaderBase> reader(
new T(std::forward<ARGS>(args)...));
reader->RegisterDecorateChain(); reader->RegisterDecorateChain();
return reader; return reader;
} }
......
...@@ -17,7 +17,10 @@ function(reader_library TARGET_NAME) ...@@ -17,7 +17,10 @@ function(reader_library TARGET_NAME)
PARENT_SCOPE) PARENT_SCOPE)
endfunction() endfunction()
cc_library(py_reader SRCS py_reader.cc DEPS reader)
cc_library(compose_reader SRCS compose_reader.cc DEPS reader)
cc_library(buffered_reader SRCS buffered_reader.cc DEPS reader simple_threadpool) cc_library(buffered_reader SRCS buffered_reader.cc DEPS reader simple_threadpool)
reader_library(open_files_op SRCS open_files_op.cc DEPS buffered_reader) reader_library(open_files_op SRCS open_files_op.cc DEPS buffered_reader)
reader_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc) reader_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc)
reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc) reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc)
...@@ -26,7 +29,7 @@ reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_o ...@@ -26,7 +29,7 @@ reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_o
reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc DEPS buffered_reader) reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc DEPS buffered_reader)
reader_library(create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc) reader_library(create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc)
reader_library(create_custom_reader_op SRCS create_custom_reader_op.cc) reader_library(create_custom_reader_op SRCS create_custom_reader_op.cc)
reader_library(create_py_reader_op SRCS create_py_reader_op.cc) reader_library(create_py_reader_op SRCS create_py_reader_op.cc DEPS py_reader)
if (NOT WIN32 AND NOT ON_INFER) if (NOT WIN32 AND NOT ON_INFER)
cc_library(ctr_reader SRCS ctr_reader.cc DEPS gzstream reader zlib) cc_library(ctr_reader SRCS ctr_reader.cc DEPS gzstream reader zlib)
...@@ -38,7 +41,7 @@ cc_test(reader_blocking_queue_test SRCS reader_blocking_queue_test.cc) ...@@ -38,7 +41,7 @@ cc_test(reader_blocking_queue_test SRCS reader_blocking_queue_test.cc)
# Export local libraries to parent # Export local libraries to parent
# set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE) # set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE)
op_library(read_op) op_library(read_op DEPS py_reader compose_reader buffered_reader)
foreach(src ${LOCAL_READER_LIBS}) foreach(src ${LOCAL_READER_LIBS})
set(OP_LIBRARY ${src} ${OP_LIBRARY} CACHE INTERNAL "op libs") set(OP_LIBRARY ${src} ${OP_LIBRARY} CACHE INTERNAL "op libs")
......
...@@ -34,7 +34,7 @@ class BlockingQueue { ...@@ -34,7 +34,7 @@ class BlockingQueue {
explicit BlockingQueue(size_t capacity, bool speed_test_mode = false) explicit BlockingQueue(size_t capacity, bool speed_test_mode = false)
: capacity_(capacity), speed_test_mode_(speed_test_mode), closed_(false) { : capacity_(capacity), speed_test_mode_(speed_test_mode), closed_(false) {
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
capacity_, 0, capacity_, static_cast<size_t>(0),
"The capacity of a reader::BlockingQueue must be greater than 0."); "The capacity of a reader::BlockingQueue must be greater than 0.");
} }
...@@ -114,6 +114,11 @@ class BlockingQueue { ...@@ -114,6 +114,11 @@ class BlockingQueue {
return queue_.size(); return queue_.size();
} }
void Clear() {
std::lock_guard<std::mutex> lock(mutex_);
queue_.clear();
}
private: private:
size_t capacity_; size_t capacity_;
bool speed_test_mode_; bool speed_test_mode_;
......
...@@ -28,8 +28,10 @@ BufferedReader::~BufferedReader() { ...@@ -28,8 +28,10 @@ BufferedReader::~BufferedReader() {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place_)) { if (platform::is_gpu_place(place_)) {
platform::SetDeviceId(boost::get<platform::CUDAPlace>(place_).device); platform::SetDeviceId(boost::get<platform::CUDAPlace>(place_).device);
PADDLE_ENFORCE(cudaStreamDestroy(stream)); PADDLE_ENFORCE(cudaStreamDestroy(stream_));
for (auto &event : events) PADDLE_ENFORCE(cudaEventDestroy(event)); for (auto &event : events_) {
PADDLE_ENFORCE(cudaEventDestroy(event));
}
} }
#endif #endif
} }
...@@ -44,14 +46,15 @@ BufferedReader::BufferedReader( ...@@ -44,14 +46,15 @@ BufferedReader::BufferedReader(
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place_)) { if (platform::is_gpu_place(place_)) {
platform::SetDeviceId(boost::get<platform::CUDAPlace>(place_).device); platform::SetDeviceId(boost::get<platform::CUDAPlace>(place_).device);
compute_stream = compute_stream_ =
((platform::CUDADeviceContext *)(platform::DeviceContextPool::Instance() ((platform::CUDADeviceContext *)(platform::DeviceContextPool::Instance()
.Get(place_))) .Get(place_)))
->stream(); ->stream();
events.resize(buffer_size); events_.resize(buffer_size);
for (auto &event : events) for (auto &event : events_) {
PADDLE_ENFORCE(cudaEventCreateWithFlags(&event, cudaEventDisableTiming)); PADDLE_ENFORCE(cudaEventCreateWithFlags(&event, cudaEventDisableTiming));
PADDLE_ENFORCE(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); }
PADDLE_ENFORCE(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking));
} }
#endif #endif
cpu_buffer_.resize(buffer_size); cpu_buffer_.resize(buffer_size);
...@@ -70,7 +73,7 @@ void BufferedReader::ReadAsync(size_t i) { ...@@ -70,7 +73,7 @@ void BufferedReader::ReadAsync(size_t i) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place_)) { if (platform::is_gpu_place(place_)) {
platform::SetDeviceId(boost::get<platform::CUDAPlace>(place_).device); platform::SetDeviceId(boost::get<platform::CUDAPlace>(place_).device);
PADDLE_ENFORCE(cudaEventRecord(events[i], compute_stream)); PADDLE_ENFORCE(cudaEventRecord(events_[i], compute_stream_));
} }
#endif #endif
position_.emplace(thread_pool_.enqueue([this, i]() -> size_t { position_.emplace(thread_pool_.enqueue([this, i]() -> size_t {
...@@ -86,7 +89,7 @@ void BufferedReader::ReadAsync(size_t i) { ...@@ -86,7 +89,7 @@ void BufferedReader::ReadAsync(size_t i) {
// TensorCopySync would block other stream // TensorCopySync would block other stream
if (platform::is_gpu_place(place_)) { if (platform::is_gpu_place(place_)) {
platform::SetDeviceId(boost::get<platform::CUDAPlace>(place_).device); platform::SetDeviceId(boost::get<platform::CUDAPlace>(place_).device);
PADDLE_ENFORCE(cudaStreamWaitEvent(stream, events[i], 0)); PADDLE_ENFORCE(cudaStreamWaitEvent(stream_, events_[i], 0));
TensorVec &gpu = gpu_buffer_[i]; TensorVec &gpu = gpu_buffer_[i];
gpu.resize(cpu.size()); gpu.resize(cpu.size());
for (size_t i = 0; i < cpu.size(); ++i) { for (size_t i = 0; i < cpu.size(); ++i) {
...@@ -97,23 +100,24 @@ void BufferedReader::ReadAsync(size_t i) { ...@@ -97,23 +100,24 @@ void BufferedReader::ReadAsync(size_t i) {
auto gpu_ptr = gpu[i].mutable_data(place_, cpu[i].type()); auto gpu_ptr = gpu[i].mutable_data(place_, cpu[i].type());
auto size = auto size =
cpu[i].numel() * paddle::framework::SizeOfType(cpu[i].type()); cpu[i].numel() * paddle::framework::SizeOfType(cpu[i].type());
if (platform::is_cuda_pinned_place(cpu_place)) if (platform::is_cuda_pinned_place(cpu_place)) {
memory::Copy(boost::get<platform::CUDAPlace>(place_), gpu_ptr, memory::Copy(boost::get<platform::CUDAPlace>(place_), gpu_ptr,
boost::get<platform::CUDAPinnedPlace>(cpu_place), boost::get<platform::CUDAPinnedPlace>(cpu_place),
cpu_ptr, size, stream); cpu_ptr, size, stream_);
else if ((platform::is_gpu_place(cpu_place))) } else if ((platform::is_gpu_place(cpu_place))) {
memory::Copy(boost::get<platform::CUDAPlace>(place_), gpu_ptr, memory::Copy(boost::get<platform::CUDAPlace>(place_), gpu_ptr,
boost::get<platform::CUDAPlace>(cpu_place), cpu_ptr, boost::get<platform::CUDAPlace>(cpu_place), cpu_ptr,
size, stream); size, stream_);
else } else {
// if cpu place is not pinned, async copy is slower than sync copy, // if cpu place is not pinned, async copy is slower than sync copy,
// so we use sync copy instead. // so we use sync copy instead.
memory::Copy(boost::get<platform::CUDAPlace>(place_), gpu_ptr, memory::Copy(boost::get<platform::CUDAPlace>(place_), gpu_ptr,
boost::get<platform::CPUPlace>(cpu_place), cpu_ptr, size, boost::get<platform::CPUPlace>(cpu_place), cpu_ptr, size,
0); 0);
}
gpu[i].set_lod(cpu[i].lod()); gpu[i].set_lod(cpu[i].lod());
} }
PADDLE_ENFORCE(cudaStreamSynchronize(stream)); PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
} }
#endif #endif
return i; return i;
......
...@@ -63,9 +63,9 @@ class BufferedReader : public framework::DecoratedReader { ...@@ -63,9 +63,9 @@ class BufferedReader : public framework::DecoratedReader {
std::vector<TensorVec> gpu_buffer_; std::vector<TensorVec> gpu_buffer_;
size_t prev_pos_{-1UL}; size_t prev_pos_{-1UL};
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
cudaStream_t stream; cudaStream_t stream_;
cudaStream_t compute_stream; cudaStream_t compute_stream_;
std::vector<cudaEvent_t> events; std::vector<cudaEvent_t> events_;
#endif #endif
}; };
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reader/compose_reader.h"
namespace paddle {
namespace operators {
namespace reader {
ComposeReader::ComposeReader(
const std::vector<std::shared_ptr<framework::ReaderBase>> &readers)
: framework::MultiDecoratedReader(readers) {}
void ComposeReader::ReadNext(std::vector<framework::LoDTensor> *out) {
out->clear();
std::vector<framework::LoDTensor> each_ret;
for (auto &r : readers_) {
r->ReadNext(&each_ret);
out->reserve(out->size() + each_ret.size());
for (auto &data : each_ret) {
out->emplace_back(std::move(data));
}
}
}
} // namespace reader
} // namespace operators
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/framework/reader.h"
namespace paddle {
namespace operators {
namespace reader {
class ComposeReader : public framework::MultiDecoratedReader {
public:
explicit ComposeReader(
const std::vector<std::shared_ptr<framework::ReaderBase>> &readers);
void ReadNext(std::vector<framework::LoDTensor> *out) override;
};
} // namespace reader
} // namespace operators
} // namespace paddle
...@@ -12,37 +12,13 @@ ...@@ -12,37 +12,13 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h" #include "paddle/fluid/operators/reader/py_reader.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h" #include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace reader { namespace reader {
class PyReader : public framework::FileReader {
public:
explicit PyReader(const std::shared_ptr<LoDTensorBlockingQueue>& queue)
: framework::FileReader() {
PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null");
queue_ = queue;
}
void ReadNext(std::vector<framework::LoDTensor>* out) override {
bool success;
*out = queue_->Pop(&success);
if (!success) out->clear();
}
~PyReader() { queue_->Close(); }
void Shutdown() override { queue_->Close(); }
void Start() override { queue_->ReOpen(); }
private:
std::shared_ptr<LoDTensorBlockingQueue> queue_;
};
class CreatePyReaderOp : public framework::OperatorBase { class CreatePyReaderOp : public framework::OperatorBase {
public: public:
using framework::OperatorBase::OperatorBase; using framework::OperatorBase::OperatorBase;
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reader/py_reader.h"
namespace paddle {
namespace operators {
namespace reader {
PyReader::PyReader(const std::shared_ptr<LoDTensorBlockingQueue>& queue)
: framework::FileReader() {
PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null");
queue_ = queue;
}
void PyReader::ReadNext(std::vector<framework::LoDTensor>* out) {
bool success;
*out = queue_->Pop(&success);
if (!success) out->clear();
}
PyReader::~PyReader() { queue_->Close(); }
void PyReader::Shutdown() { queue_->Close(); }
void PyReader::Start() { queue_->ReOpen(); }
MultiQueuePyReader::MultiQueuePyReader(
const std::vector<std::shared_ptr<LoDTensorBlockingQueue>>& queues)
: queues_(queues) {
PADDLE_ENFORCE(!queues_.empty());
for (auto& q : queues_) {
PADDLE_ENFORCE_NOT_NULL(q);
}
}
void MultiQueuePyReader::ReadNext(std::vector<framework::LoDTensor>* out) {
auto idx = read_out_idx_.fetch_add(1) % queues_.size();
for (size_t i = 0; i < queues_.size(); ++i) {
*out = queues_[idx]->Pop();
if (!out->empty()) return;
idx = (idx + 1) % queues_.size();
}
}
MultiQueuePyReader::~MultiQueuePyReader() {
for (auto& q : queues_) {
q->Close();
}
}
void MultiQueuePyReader::Shutdown() {
for (auto& q : queues_) {
q->Close();
}
read_out_idx_.store(0, std::memory_order::memory_order_seq_cst);
}
void MultiQueuePyReader::Start() {
for (auto& q : queues_) {
q->ReOpen();
}
}
} // namespace reader
} // namespace operators
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <atomic>
#include <vector>
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h"
namespace paddle {
namespace operators {
namespace reader {
class PyReader : public framework::FileReader {
public:
explicit PyReader(const std::shared_ptr<LoDTensorBlockingQueue>& queue);
void ReadNext(std::vector<framework::LoDTensor>* out) override;
~PyReader();
void Shutdown() override;
void Start() override;
private:
std::shared_ptr<LoDTensorBlockingQueue> queue_;
};
class MultiQueuePyReader : public framework::FileReader {
public:
explicit MultiQueuePyReader(
const std::vector<std::shared_ptr<LoDTensorBlockingQueue>>& queues);
void ReadNext(std::vector<framework::LoDTensor>* out) override;
~MultiQueuePyReader();
void Shutdown() override;
void Start() override;
private:
std::vector<std::shared_ptr<LoDTensorBlockingQueue>> queues_;
std::atomic<size_t> read_out_idx_{0};
};
} // namespace reader
} // namespace operators
} // namespace paddle
...@@ -5,7 +5,7 @@ set(PYBIND_DEPS pybind python proto_desc memory executor async_executor prune ...@@ -5,7 +5,7 @@ set(PYBIND_DEPS pybind python proto_desc memory executor async_executor prune
if(WITH_PYTHON) if(WITH_PYTHON)
list(APPEND PYBIND_DEPS py_func_op) list(APPEND PYBIND_DEPS py_func_op)
endif() endif()
set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc async_executor_py.cc imperative.cc ir.cc inference_api.cc) set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc reader_py.cc async_executor_py.cc imperative.cc ir.cc inference_api.cc)
if(WITH_PYTHON) if(WITH_PYTHON)
if(WITH_AMD_GPU) if(WITH_AMD_GPU)
......
...@@ -54,6 +54,7 @@ limitations under the License. */ ...@@ -54,6 +54,7 @@ limitations under the License. */
#include "paddle/fluid/pybind/ir.h" #include "paddle/fluid/pybind/ir.h"
#include "paddle/fluid/pybind/protobuf.h" #include "paddle/fluid/pybind/protobuf.h"
#include "paddle/fluid/pybind/pybind.h" // NOLINT #include "paddle/fluid/pybind/pybind.h" // NOLINT
#include "paddle/fluid/pybind/reader_py.h"
#include "paddle/fluid/pybind/recordio.h" #include "paddle/fluid/pybind/recordio.h"
#include "paddle/fluid/pybind/tensor_py.h" #include "paddle/fluid/pybind/tensor_py.h"
...@@ -106,6 +107,16 @@ bool IsCompiledWithDIST() { ...@@ -106,6 +107,16 @@ bool IsCompiledWithDIST() {
#endif #endif
} }
template <typename PlaceType1, typename PlaceType2>
static inline bool IsSamePlace(const PlaceType1 &p1, const PlaceType2 &p2) {
return paddle::platform::Place(p1) == paddle::platform::Place(p2);
}
template <typename PlaceType>
static inline int PlaceIndex(const PlaceType &p) {
return static_cast<int>(paddle::platform::Place(p).which());
}
PYBIND11_MODULE(core, m) { PYBIND11_MODULE(core, m) {
// Not used, just make sure cpu_info.cc is linked. // Not used, just make sure cpu_info.cc is linked.
paddle::platform::CpuTotalPhysicalMemory(); paddle::platform::CpuTotalPhysicalMemory();
...@@ -452,6 +463,7 @@ PYBIND11_MODULE(core, m) { ...@@ -452,6 +463,7 @@ PYBIND11_MODULE(core, m) {
All parameter, weight, gradient are variables in Paddle. All parameter, weight, gradient are variables in Paddle.
)DOC") )DOC")
.def(py::init<>())
.def("is_int", [](const Variable &var) { return var.IsType<int>(); }) .def("is_int", [](const Variable &var) { return var.IsType<int>(); })
.def("set_int", .def("set_int",
[](Variable &var, int val) -> void { *var.GetMutable<int>() = val; }) [](Variable &var, int val) -> void { *var.GetMutable<int>() = val; })
...@@ -493,9 +505,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -493,9 +505,7 @@ All parameter, weight, gradient are variables in Paddle.
}, },
py::return_value_policy::reference); py::return_value_policy::reference);
py::class_<framework::ReaderHolder>(m, "Reader", "") BindReader(&m);
.def("start", &framework::ReaderHolder::Start)
.def("reset", &framework::ReaderHolder::ResetAll);
using LoDTensorBlockingQueue = using LoDTensorBlockingQueue =
::paddle::operators::reader::LoDTensorBlockingQueue; ::paddle::operators::reader::LoDTensorBlockingQueue;
...@@ -657,29 +667,65 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -657,29 +667,65 @@ All parameter, weight, gradient are variables in Paddle.
PADDLE_THROW("Cannot use CUDAPlace in CPU only version"); PADDLE_THROW("Cannot use CUDAPlace in CPU only version");
#endif #endif
}) })
.def("_type", &PlaceIndex<platform::CUDAPlace>)
.def("_equals", &IsSamePlace<platform::CUDAPlace, platform::Place>)
.def("_equals", &IsSamePlace<platform::CUDAPlace, platform::CUDAPlace>)
.def("_equals", &IsSamePlace<platform::CUDAPlace, platform::CPUPlace>)
.def("_equals",
&IsSamePlace<platform::CUDAPlace, platform::CUDAPinnedPlace>)
.def("gpu_device_id",
[](platform::CUDAPlace &self) { return self.device; })
.def("__str__", string::to_string<const platform::CUDAPlace &>); .def("__str__", string::to_string<const platform::CUDAPlace &>);
py::class_<paddle::platform::CPUPlace>(m, "CPUPlace") py::class_<paddle::platform::CPUPlace>(m, "CPUPlace")
.def(py::init<>()) .def(py::init<>())
.def("_type", &PlaceIndex<platform::CPUPlace>)
.def("_equals", &IsSamePlace<platform::CPUPlace, platform::Place>)
.def("_equals", &IsSamePlace<platform::CPUPlace, platform::CUDAPlace>)
.def("_equals", &IsSamePlace<platform::CPUPlace, platform::CPUPlace>)
.def("_equals",
&IsSamePlace<platform::CPUPlace, platform::CUDAPinnedPlace>)
.def("__str__", string::to_string<const platform::CPUPlace &>); .def("__str__", string::to_string<const platform::CPUPlace &>);
py::class_<paddle::platform::CUDAPinnedPlace>(m, "CUDAPinnedPlace") py::class_<paddle::platform::CUDAPinnedPlace>(m, "CUDAPinnedPlace")
.def("__init__", .def("__init__",
[](platform::CUDAPinnedPlace &) { [](platform::CUDAPinnedPlace &self) {
#ifndef PADDLE_WITH_CUDA #ifndef PADDLE_WITH_CUDA
PADDLE_THROW("Cannot use CUDAPinnedPlace in CPU only version"); PADDLE_THROW("Cannot use CUDAPinnedPlace in CPU only version");
#endif #endif
new (&self) platform::CUDAPinnedPlace();
}) })
.def("_type", &PlaceIndex<platform::CUDAPinnedPlace>)
.def("_equals", &IsSamePlace<platform::CUDAPinnedPlace, platform::Place>)
.def("_equals",
&IsSamePlace<platform::CUDAPinnedPlace, platform::CUDAPlace>)
.def("_equals",
&IsSamePlace<platform::CUDAPinnedPlace, platform::CPUPlace>)
.def("_equals",
&IsSamePlace<platform::CUDAPinnedPlace, platform::CUDAPinnedPlace>)
.def("__str__", string::to_string<const platform::CUDAPinnedPlace &>); .def("__str__", string::to_string<const platform::CUDAPinnedPlace &>);
py::class_<platform::Place>(m, "Place") py::class_<platform::Place>(m, "Place")
.def(py::init<>()) .def(py::init<>())
.def("_type", &PlaceIndex<platform::Place>)
.def("_equals", &IsSamePlace<platform::Place, platform::Place>)
.def("_equals", &IsSamePlace<platform::Place, platform::CUDAPlace>)
.def("_equals", &IsSamePlace<platform::Place, platform::CPUPlace>)
.def("_equals", &IsSamePlace<platform::Place, platform::CUDAPinnedPlace>)
.def("is_gpu_place", .def("is_gpu_place",
[](platform::Place &self) { return platform::is_gpu_place(self); }) [](platform::Place &self) { return platform::is_gpu_place(self); })
.def("is_cpu_place",
[](platform::Place &self) { return platform::is_cpu_place(self); })
.def("is_cuda_pinned_place",
[](platform::Place &self) {
return platform::is_cuda_pinned_place(self);
})
.def("gpu_device_id", .def("gpu_device_id",
[](platform::Place &self) { [](platform::Place &self) {
return boost::get<platform::CUDAPlace>(self).device; return boost::get<platform::CUDAPlace>(self).device;
}) })
.def("set_place", [](platform::Place &self,
const platform::Place &other) { self = other; })
.def("set_place", .def("set_place",
[](platform::Place &self, const platform::CPUPlace &cpu_place) { [](platform::Place &self, const platform::CPUPlace &cpu_place) {
self = cpu_place; self = cpu_place;
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/pybind/reader_py.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/operators/reader/buffered_reader.h"
#include "paddle/fluid/operators/reader/compose_reader.h"
#include "paddle/fluid/operators/reader/py_reader.h"
#include "paddle/fluid/platform/place.h"
#include "pybind11/stl.h"
namespace paddle {
namespace pybind {
class FeedReader {
using ResultDictList =
std::vector<std::unordered_map<std::string, framework::LoDTensor>>;
public:
FeedReader(std::unique_ptr<framework::ReaderHolder> reader,
const std::vector<std::string> &names, size_t num_places,
bool drop_last = true)
: reader_(std::move(reader)),
names_(names),
num_places_(num_places),
drop_last_(drop_last) {}
ResultDictList ReadNext() {
std::vector<framework::LoDTensor> tensors;
reader_->ReadNext(&tensors);
if (tensors.empty()) return ResultDictList();
PADDLE_ENFORCE(tensors.size() % names_.size() == 0,
"Tensor size: %d, names size: %d", tensors.size(),
names_.size());
size_t read_place_num = tensors.size() / names_.size();
if (drop_last_ && read_place_num != num_places_) {
return ResultDictList();
}
ResultDictList ret(read_place_num);
for (size_t i = 0; i < tensors.size(); ++i) {
ret[i / names_.size()].emplace(names_[i % names_.size()],
std::move(tensors[i]));
}
return ret;
}
void Start() { reader_->Start(); }
void Reset() { reader_->ResetAll(); }
private:
std::unique_ptr<framework::ReaderHolder> reader_;
std::vector<std::string> names_;
size_t num_places_;
bool drop_last_;
};
static std::unique_ptr<framework::ReaderHolder> CreatePyReader(
const std::vector<
std::shared_ptr<operators::reader::LoDTensorBlockingQueue>> &queues,
const std::vector<platform::Place> &dst_places) {
std::shared_ptr<framework::ReaderBase> reader;
if (queues.size() == 1) {
reader.reset(new operators::reader::PyReader(queues[0]));
} else {
reader.reset(new operators::reader::MultiQueuePyReader(queues));
}
std::vector<std::shared_ptr<framework::ReaderBase>> buffered_reader;
buffered_reader.reserve(dst_places.size());
for (auto &p : dst_places) {
buffered_reader.emplace_back(
framework::MakeDecoratedReader<operators::reader::BufferedReader>(
reader, p, 2));
}
reader = framework::MakeDecoratedReader<operators::reader::ComposeReader>(
buffered_reader);
auto *holder = new framework::ReaderHolder();
holder->Reset(reader);
return std::unique_ptr<framework::ReaderHolder>(holder);
}
namespace py = pybind11;
void BindReader(py::module *module) {
auto &m = *module;
namespace reader = ::paddle::operators::reader;
py::class_<framework::ReaderHolder>(m, "Reader", "")
.def("start", &framework::ReaderHolder::Start)
.def("reset", &framework::ReaderHolder::ResetAll);
py::class_<FeedReader>(m, "FeedReader", "")
.def("read_next", &FeedReader::ReadNext,
py::call_guard<py::gil_scoped_release>())
.def("start", &FeedReader::Start,
py::call_guard<py::gil_scoped_release>())
.def("reset", &FeedReader::Reset,
py::call_guard<py::gil_scoped_release>());
m.def("create_py_reader",
[](const std::vector<
std::shared_ptr<operators::reader::LoDTensorBlockingQueue>>
queues,
const std::vector<std::string> &names,
const std::vector<platform::Place> &dst_places, bool drop_last) {
return new FeedReader(CreatePyReader(queues, dst_places), names,
dst_places.size(), drop_last);
},
py::return_value_policy::take_ownership);
}
} // namespace pybind
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "pybind11/pybind11.h"
namespace paddle {
namespace pybind {
void BindReader(pybind11::module *module);
} // namespace pybind
} // namespace paddle
...@@ -17,6 +17,7 @@ import os ...@@ -17,6 +17,7 @@ import os
import six import six
import sys import sys
from .. import compat as cpt from .. import compat as cpt
from .framework import cuda_places, cpu_places
from . import core from . import core
...@@ -78,7 +79,8 @@ class CompiledProgram(object): ...@@ -78,7 +79,8 @@ class CompiledProgram(object):
loss_name=None, loss_name=None,
build_strategy=None, build_strategy=None,
exec_strategy=None, exec_strategy=None,
share_vars_from=None): share_vars_from=None,
places=None):
"""Configs the program to run in data parallel way. """Configs the program to run in data parallel way.
Args: Args:
...@@ -97,6 +99,12 @@ class CompiledProgram(object): ...@@ -97,6 +99,12 @@ class CompiledProgram(object):
will share variables from `share_vars_from`. `share_vars_from` will share variables from `share_vars_from`. `share_vars_from`
must be run by the executor before this CompiledProgram so that must be run by the executor before this CompiledProgram so that
vars are ready. vars are ready.
places(list(CUDAPlace)|list(CPUPlace)|None): If provide, only compile
program in the given places. Otherwise, the places used when compiled
is determined by the Executor, and the places used are controlled
by environment variables: FLAGS_selected_gpus or CUDA_VISIBLE_DEVICES
if using GPU; or CPU_NUM if using CPU.
Returns: Returns:
self self
""" """
...@@ -110,6 +118,12 @@ class CompiledProgram(object): ...@@ -110,6 +118,12 @@ class CompiledProgram(object):
self._exec_strategy = ExecutionStrategy() self._exec_strategy = ExecutionStrategy()
if self._build_strategy is None: if self._build_strategy is None:
self._build_strategy = BuildStrategy() self._build_strategy = BuildStrategy()
if places is not None:
if not isinstance(places, (list, tuple)):
places = [places]
self._places = [_place_obj(p) for p in places]
else:
self._places = None
return self return self
def with_inference_optimize(self, config): def with_inference_optimize(self, config):
...@@ -148,19 +162,16 @@ class CompiledProgram(object): ...@@ -148,19 +162,16 @@ class CompiledProgram(object):
self._local_scopes = [] self._local_scopes = []
self._exec_strategy.use_cuda = isinstance(self._place, core.CUDAPlace) self._exec_strategy.use_cuda = isinstance(self._place, core.CUDAPlace)
if self._exec_strategy.use_cuda: has_set_place = (self._places is not None)
gpus_env = os.getenv("FLAGS_selected_gpus") if has_set_place:
if gpus_env: desire_place = _place_obj(self._place)
gpus = [int(s) for s in gpus_env.split(",")] for p in self._places:
else: assert p._type() == desire_place._type(), \
gpus = [ "Place type not match. You may set the wrong type of places"
i for i in six.moves.range(core.get_cuda_device_count())
]
self._places = [core.CUDAPlace(i) for i in gpus]
else: else:
cpu_num = int( places = cuda_places(
os.environ.get('CPU_NUM', multiprocessing.cpu_count())) ) if self._exec_strategy.use_cuda else cpu_places()
self._places = [core.CPUPlace() for _ in six.moves.range(cpu_num)] self._places = [_place_obj(p) for p in places]
assert self._places, "no place for execution" assert self._places, "no place for execution"
if self._exec_strategy.num_threads == 0: if self._exec_strategy.num_threads == 0:
...@@ -169,9 +180,7 @@ class CompiledProgram(object): ...@@ -169,9 +180,7 @@ class CompiledProgram(object):
# performance. Worth tunning for other models in the future. # performance. Worth tunning for other models in the future.
self._exec_strategy.num_threads = len(self._places) * 4 self._exec_strategy.num_threads = len(self._places) * 4
else: else:
cpu_num = int( self._exec_strategy.num_threads = len(self._places) * 2
os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
self._exec_strategy.num_threads = cpu_num * 2
trainers_endpoints = self._program._trainers_endpoints trainers_endpoints = self._program._trainers_endpoints
...@@ -217,7 +226,7 @@ class CompiledProgram(object): ...@@ -217,7 +226,7 @@ class CompiledProgram(object):
if self._compiled: if self._compiled:
if scope and self._scope != scope: if scope and self._scope != scope:
raise ValueError("Cannot compile with different scope") raise ValueError("Cannot compile with different scope")
if place and self._place != place: if place and not self._place._equals(place):
raise ValueError("Cannot compile with different place") raise ValueError("Cannot compile with different place")
return self return self
self._compiled = True self._compiled = True
......
...@@ -554,6 +554,10 @@ class Executor(object): ...@@ -554,6 +554,10 @@ class Executor(object):
if feed is None: if feed is None:
feed = {} feed = {}
elif isinstance(feed, (list, tuple)):
assert len(feed) == 1, "Not compiled with data parallel"
feed = feed[0]
if not isinstance(feed, dict): if not isinstance(feed, dict):
raise TypeError( raise TypeError(
"feed requires dict as its Parameter. But you passed in %s" % "feed requires dict as its Parameter. But you passed in %s" %
......
...@@ -26,6 +26,7 @@ import six ...@@ -26,6 +26,7 @@ import six
import numpy as np import numpy as np
import subprocess import subprocess
import multiprocessing
from .. import compat as cpt from .. import compat as cpt
from .proto import framework_pb2 from .proto import framework_pb2
...@@ -63,6 +64,9 @@ __all__ = [ ...@@ -63,6 +64,9 @@ __all__ = [
'default_main_program', 'default_main_program',
'program_guard', 'program_guard',
'name_scope', 'name_scope',
'cuda_places',
'cpu_places',
'cuda_pinned_places',
] ]
EMPTY_VAR_NAME = core.kEmptyVarName() EMPTY_VAR_NAME = core.kEmptyVarName()
...@@ -87,6 +91,38 @@ def _current_expected_place(): ...@@ -87,6 +91,38 @@ def _current_expected_place():
return _imperative_current_expected_place_ return _imperative_current_expected_place_
def _cpu_num():
return int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
def cuda_places(device_ids=None):
assert core.is_compiled_with_cuda(), \
"Not compiled with CUDA"
if device_ids is None:
gpus_env = os.getenv("FLAGS_selected_gpus")
if gpus_env:
device_ids = [int(s) for s in gpus_env.split(",")]
else:
device_ids = six.moves.range(core.get_cuda_device_count())
elif not isinstance(device_ids, (list, tuple)):
device_ids = [device_ids]
return [core.CUDAPlace(dev_id) for dev_id in device_ids]
def cpu_places(device_count=None):
if device_count is None:
device_count = _cpu_num()
return [core.CPUPlace()] * device_count
def cuda_pinned_places(device_count=None):
assert core.is_compiled_with_cuda(), \
"Not compiled with CUDA"
if device_count is None:
device_count = _cpu_num()
return [core.cuda_pinned_places()] * device_count
class NameScope(object): class NameScope(object):
def __init__(self, name="", parent=None): def __init__(self, name="", parent=None):
self._children = dict() self._children = dict()
......
...@@ -26,12 +26,14 @@ from paddle.fluid import layers ...@@ -26,12 +26,14 @@ from paddle.fluid import layers
from paddle.fluid.executor import Executor from paddle.fluid.executor import Executor
from paddle.fluid.evaluator import Evaluator from paddle.fluid.evaluator import Evaluator
from paddle.fluid.framework import Program, Parameter, default_main_program, default_startup_program, Variable, program_guard from paddle.fluid.framework import Program, Parameter, default_main_program, default_startup_program, Variable, program_guard
from . import reader
from .reader import *
from . import core from . import core
__all__ = [ __all__ = [
'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params', 'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params',
'load_persistables', 'save_inference_model', 'load_inference_model' 'load_persistables', 'save_inference_model', 'load_inference_model'
] ] + reader.__all__
def is_parameter(var): def is_parameter(var):
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import core
import six
import threading
from .framework import Program, Variable, program_guard
from .data_feeder import DataFeeder
__all__ = ['PyReader']
def _convert_places(places):
if not isinstance(places, (list, tuple)):
places = [places]
ret = []
for p in places:
if not isinstance(p, core.Place):
tmp = core.Place()
tmp.set_place(p)
p = tmp
ret.append(p)
return ret
class PyReader(object):
def __init__(self, feed_list, places, capacity, multi_queue=True):
self._tensor_reader = None
self._thread = None
# TODO(zjl): to support drop_last = False
self._drop_last = True
self._feed_list = feed_list
self._var_names = [v.name for v in feed_list]
self._queues = []
self._places = _convert_places(places)
self._queue_capacity = capacity
queue_num = len(self._places) if multi_queue else 1
for _ in six.moves.range(queue_num):
self._queues.append(
core.init_lod_tensor_blocking_queue(core.Variable(),
self._queue_capacity))
self._reader = core.create_py_reader(self._queues, self._var_names,
self._places, self._drop_last)
self._exited = True
def __call__(self):
assert self._tensor_reader is not None, \
"Data source of PyReader has not set yet"
class Iterator(object):
def __init__(self, reader):
self._reader = reader
def __iter__(self):
return self
def next(self):
ret = self._reader._reader.read_next()
if len(ret):
return ret
else:
self._reader._restart_reader()
self._reader._reader.reset()
raise StopIteration
return Iterator(self)
def _restart_reader(self):
if not self._exited:
for q in self._queues:
q.close()
self._thread.join()
def __thread_main__():
queue_num = len(self._queues)
idx = 0
for tensors in self._tensor_reader():
array = core.LoDTensorArray()
for item in tensors:
if not isinstance(item, core.LoDTensor):
tmp = core.LoDTensor()
tmp.set(item, core.CPUPlace())
item = tmp
array.append(item)
if not self._queues[idx].push(array):
break
idx = (idx + 1) % queue_num
for q in self._queues:
q.close()
self._exited = True
self._thread = threading.Thread(target=__thread_main__)
self._thread.daemon = True
self._exited = False
self._thread.start()
def set_numpy_reader(self, reader):
assert self._tensor_reader is None, \
"Cannot reset the data source of PyReader"
with program_guard(Program(), Program()):
feeder = DataFeeder(
feed_list=self._feed_list, place=core.CPUPlace())
paddle_reader = feeder.decorate_reader(reader, multi_devices=False)
def __tensor_reader_impl__():
for slots in paddle_reader():
yield [slots[var.name] for var in self._feed_list]
self.set_tensor_reader(__tensor_reader_impl__)
def set_tensor_reader(self, reader):
assert self._tensor_reader is None, \
"Cannot reset the data source of PyReader"
self._tensor_reader = reader
self._restart_reader()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid as fluid
import numpy as np
import time
import six
import unittest
EPOCH_NUM = 60
BATCH_SIZE = 32
CLASS_NUM = 10
def random_reader():
for i in range(BATCH_SIZE * 40):
image = np.random.random([784])
label = np.random.random_integers(low=0, high=CLASS_NUM - 1)
yield image, label
def simple_fc_net(places, use_legacy_py_reader):
startup_prog = fluid.Program()
main_prog = fluid.Program()
startup_prog.random_seed = 1
main_prog.random_seed = 1
reader = paddle.batch(random_reader, batch_size=BATCH_SIZE)
with fluid.unique_name.guard():
with fluid.program_guard(main_prog, startup_prog):
if not use_legacy_py_reader:
image = fluid.layers.data(
name='image', shape=[784], dtype='float32')
label = fluid.layers.data(
name='label', shape=[1], dtype='int64')
py_reader = fluid.io.PyReader(
feed_list=[image, label],
places=places,
capacity=4,
multi_queue=False)
py_reader.set_numpy_reader(reader)
else:
py_reader = fluid.layers.py_reader(
capacity=4,
shapes=[(-1, 784), (-1, 1)],
dtypes=['float32', 'int64'])
image, label = fluid.layers.read_file(py_reader)
py_reader.decorate_paddle_reader(reader)
hidden = image
for hidden_size in [10, 20, 30]:
hidden = fluid.layers.fc(
hidden,
size=hidden_size,
act='tanh',
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=1.0)))
predict_label = fluid.layers.fc(hidden,
size=CLASS_NUM,
act='softmax')
loss = fluid.layers.mean(
fluid.layers.cross_entropy(
input=predict_label, label=label))
optimizer = fluid.optimizer.Adam()
optimizer.minimize(loss)
return startup_prog, main_prog, py_reader, loss
class TestBase(unittest.TestCase):
def run_main(self, use_legacy_py_reader, with_data_parallel, places):
with fluid.scope_guard(fluid.Scope()):
startup_prog, main_prog, py_reader, loss = simple_fc_net(
places, use_legacy_py_reader)
exe = fluid.Executor(place=places[0])
exe.run(startup_prog)
prog = fluid.CompiledProgram(main_prog)
if with_data_parallel:
prog = prog.with_data_parallel(
loss_name=loss.name, places=places)
step = 0
start_t = time.time()
if use_legacy_py_reader:
for _ in six.moves.range(EPOCH_NUM):
py_reader.start()
while True:
try:
L, = exe.run(program=prog, fetch_list=[loss])
step += 1
except fluid.core.EOFException:
py_reader.reset()
break
else:
for _ in six.moves.range(EPOCH_NUM):
for d in py_reader():
'''
assert len(d) == len(places)
for i, item in enumerate(d):
image = item['image']
label = item['label']
assert image.shape() == [BATCH_SIZE, 784]
assert label.shape() == [BATCH_SIZE, 1]
assert image._place()._equals(places[i])
assert label._place()._equals(places[i])
'''
L, = exe.run(program=prog, feed=d, fetch_list=[loss])
step += 1
end_t = time.time()
return {"time": end_t - start_t, "step": step}
def prepare_places(self, with_data_parallel):
places = [[fluid.CPUPlace()], ]
if with_data_parallel:
places.append([fluid.CPUPlace()] * 2)
if fluid.core.is_compiled_with_cuda():
tmp = fluid.cuda_places()
assert len(tmp) > 0, "no gpu detected"
if with_data_parallel:
places.append(tmp)
places.append([tmp[0]])
return places
def test_main(self):
for with_data_parallel in [True, False]:
for p in self.prepare_places(with_data_parallel):
t = []
for use_legacy_py_reader in [False, True]:
ret = self.run_main(
use_legacy_py_reader=use_legacy_py_reader,
with_data_parallel=with_data_parallel,
places=p)
ret['legacy'] = use_legacy_py_reader
ret['data_parallel'] = with_data_parallel
ret['places'] = p
t.append(ret)
print(t)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册