未验证 提交 3f8031e2 编写于 作者: J jiaqi 提交者: GitHub

dataset (#17973)

(1) use channel instead of vector/BlockingQueue in Dataset,to keep same with existing implementation, and make code more readable and flexible (dataset single output channel or multi output channel). one previous memory out of limit problem is cause by not release memory after training.
(2) add Record because MultiSlotType costs too much memory (80B),fix memory out of limit problem.
(3) add Channel, Archive in paddle/fluid/framework
(4) change dataset from shared_ptr to unique_ptr in pybind
(5) move create/destroy readers from trainer to dataset
(6) move shuffle from datafeed to dataset. dataset holds memory, datafeed is only for load data and feed data to network.
(7) fix thread num bug of Dataset when filelist size < thread num
(8) support set_queue_num in InMemoryDataset
上级 5d54ed4a
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#if defined _WIN32 || defined __APPLE__
#else
#define _LINUX
#endif
#include <glog/logging.h>
#include <algorithm>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <valarray>
#include <vector>
#include "paddle/fluid/framework/expect.h"
namespace paddle {
namespace framework {
// not a virtual class
class ArchiveBase {
protected:
ArchiveBase() {}
// Archive is not copyable. But to allow move capture by function objects,
// check it at runtime rather than at compile time.
ArchiveBase(const ArchiveBase&) { LOG(FATAL) << "Not supported"; }
ArchiveBase(ArchiveBase&& other)
: buffer_(other.buffer_),
cursor_(other.cursor_),
finish_(other.finish_),
limit_(other.limit_),
deleter_(std::move(other.deleter_)) {
other.buffer_ = NULL;
other.cursor_ = NULL;
other.finish_ = NULL;
other.limit_ = NULL;
other.deleter_ = nullptr;
}
~ArchiveBase() { FreeBuffer(); }
public:
ArchiveBase& operator=(const ArchiveBase&) {
LOG(FATAL) << "Not supported";
return *this;
}
ArchiveBase& operator=(ArchiveBase&& other) {
if (this != &other) {
FreeBuffer();
buffer_ = other.buffer_;
cursor_ = other.cursor_;
finish_ = other.finish_;
limit_ = other.limit_;
deleter_ = std::move(other.deleter_);
other.buffer_ = NULL;
other.cursor_ = NULL;
other.finish_ = NULL;
other.limit_ = NULL;
other.deleter_ = nullptr;
}
return *this;
}
char* Buffer() { return buffer_; }
void SetReadBuffer(char* buffer, size_t length,
std::function<void(char*)>&& deleter) {
SetBuffer(buffer, length, length, std::move(deleter));
}
void SetWriteBuffer(char* buffer, size_t capacity,
std::function<void(char*)>&& deleter) {
SetBuffer(buffer, 0, capacity, std::move(deleter));
}
void SetBuffer(char* buffer, size_t length, size_t capacity,
std::function<void(char*)>&& deleter) {
CHECK(length <= capacity);
FreeBuffer();
buffer_ = buffer;
cursor_ = buffer_;
finish_ = buffer + length;
limit_ = buffer + capacity;
deleter_ = std::move(deleter);
}
char* Cursor() { return cursor_; }
void SetCursor(char* cursor) {
CHECK(cursor >= buffer_ && cursor <= finish_);
cursor_ = cursor;
}
void AdvanceCursor(size_t offset) {
CHECK(offset <= size_t(finish_ - cursor_));
cursor_ += offset;
}
char* Finish() { return finish_; }
void SetFinish(char* finish) {
CHECK(finish >= cursor_ && finish <= limit_);
finish_ = finish;
}
void AdvanceFinish(size_t offset) {
CHECK(offset <= size_t(limit_ - finish_));
finish_ += offset;
}
char* Limit() { return limit_; }
size_t Position() { return cursor_ - buffer_; }
size_t Length() { return finish_ - buffer_; }
size_t Capacity() { return limit_ - buffer_; }
bool Empty() { return finish_ == buffer_; }
void Reset() {
FreeBuffer();
buffer_ = NULL;
cursor_ = NULL;
finish_ = NULL;
limit_ = NULL;
}
void Clear() {
cursor_ = buffer_;
finish_ = buffer_;
}
char* Release() {
char* buf = buffer_;
buffer_ = NULL;
cursor_ = NULL;
finish_ = NULL;
deleter_ = nullptr;
return buf;
}
void Resize(size_t newsize) {
#ifdef _LINUX
if (unlikely(newsize > Capacity())) {
#else
if (newsize > Capacity()) {
#endif
Reserve(std::max(Capacity() * 2, newsize));
}
finish_ = buffer_ + newsize;
cursor_ = std::min(cursor_, finish_);
}
void Reserve(size_t newcap) {
if (newcap > Capacity()) {
char* newbuf = NULL;
newbuf = new char[newcap];
CHECK(newbuf != nullptr) << "Reserve failed, out of memory";
if (Length() > 0) {
memcpy(newbuf, buffer_, Length());
}
cursor_ = newbuf + (cursor_ - buffer_);
finish_ = newbuf + (finish_ - buffer_);
limit_ = newbuf + newcap;
FreeBuffer();
buffer_ = newbuf;
deleter_ = std::default_delete<char[]>();
}
}
void PrepareRead(size_t size) {
#ifdef _LINUX
if (unlikely(!(size <= size_t(finish_ - cursor_)))) {
#else
if (!(size <= size_t(finish_ - cursor_))) {
#endif
CHECK(size <= size_t(finish_ - cursor_));
}
}
void PrepareWrite(size_t size) {
#ifdef _LINUX
if (unlikely(size > size_t(limit_ - finish_))) {
#else
if (size > size_t(limit_ - finish_)) {
#endif
Reserve(std::max(Capacity() * 2, Length() + size));
}
}
void Read(void* data, size_t size) {
if (size > 0) {
PrepareRead(size);
memcpy(data, cursor_, size);
AdvanceCursor(size);
}
}
void ReadBack(void* data, size_t size) {
if (size > 0) {
CHECK(size <= size_t(finish_ - cursor_));
memcpy(data, finish_ - size, size);
finish_ -= size;
}
}
void Write(const void* data, size_t size) {
if (size > 0) {
PrepareWrite(size);
memcpy(finish_, data, size);
AdvanceFinish(size);
}
}
template <class T>
void GetRaw(T& x) { // NOLINT
PrepareRead(sizeof(T));
memcpy(&x, cursor_, sizeof(T));
AdvanceCursor(sizeof(T));
}
template <class T>
T GetRaw() {
T x;
GetRaw<T>(x);
return x;
}
template <class T>
void PutRaw(const T& x) {
PrepareWrite(sizeof(T));
memcpy(finish_, &x, sizeof(T));
AdvanceFinish(sizeof(T));
}
protected:
char* buffer_ = NULL;
char* cursor_ = NULL;
char* finish_ = NULL;
char* limit_ = NULL;
std::function<void(char*)> deleter_ = nullptr;
void FreeBuffer() {
if (deleter_) {
deleter_(buffer_);
}
deleter_ = nullptr;
}
}; // NOLINT
template <class Type>
class Archive {};
class BinaryArchiveType {};
typedef Archive<BinaryArchiveType> BinaryArchive;
template <>
class Archive<BinaryArchiveType> : public ArchiveBase {
public:
#define ARCHIVE_REPEAT(T) \
BinaryArchive& operator>>(T& x) { \
GetRaw(x); \
return *this; \
} \
BinaryArchive& operator<<(const T& x) { \
PutRaw(x); \
return *this; \
}
ARCHIVE_REPEAT(int16_t)
ARCHIVE_REPEAT(uint16_t)
ARCHIVE_REPEAT(int32_t)
ARCHIVE_REPEAT(uint32_t)
ARCHIVE_REPEAT(int64_t)
ARCHIVE_REPEAT(uint64_t)
ARCHIVE_REPEAT(float)
ARCHIVE_REPEAT(double)
ARCHIVE_REPEAT(signed char)
ARCHIVE_REPEAT(unsigned char)
ARCHIVE_REPEAT(bool)
#undef ARCHIVE_REPEAT
template <class T>
T Get() {
T x;
*this >> x;
return x;
}
};
template <class AR, class T, size_t N>
Archive<AR>& operator<<(Archive<AR>& ar, const T (&p)[N]) {
for (size_t i = 0; i < N; i++) {
ar << p[i];
}
return ar;
}
template <class AR, class T, size_t N>
Archive<AR>& operator>>(Archive<AR>& ar, T (&p)[N]) {
for (size_t i = 0; i < N; i++) {
ar >> p[i];
}
return ar;
}
template <class AR, class T>
Archive<AR>& operator<<(Archive<AR>& ar, const std::vector<T>& p) {
#ifdef _LINUX
ar << (size_t)p.size();
#else
ar << (uint64_t)p.size();
#endif
for (const auto& x : p) {
ar << x;
}
return ar;
}
template <class AR, class T>
Archive<AR>& operator>>(Archive<AR>& ar, std::vector<T>& p) {
#ifdef _LINUX
p.resize(ar.template Get<size_t>());
#else
p.resize(ar.template Get<uint64_t>());
#endif
for (auto& x : p) {
ar >> x;
}
return ar;
}
template <class AR, class T>
Archive<AR>& operator<<(Archive<AR>& ar, const std::valarray<T>& p) {
#ifdef _LINUX
ar << (size_t)p.size();
#else
ar << (uint64_t)p.size();
#endif
for (const auto& x : p) {
ar << x;
}
return ar;
}
template <class AR, class T>
Archive<AR>& operator>>(Archive<AR>& ar, std::valarray<T>& p) {
#ifdef _LINUX
p.resize(ar.template Get<size_t>());
#else
p.resize(ar.template Get<uint64_t>());
#endif
for (auto& x : p) {
ar >> x;
}
return ar;
}
inline BinaryArchive& operator<<(BinaryArchive& ar, const std::string& s) {
#ifdef _LINUX
ar << (size_t)s.length();
#else
ar << (uint64_t)s.length();
#endif
ar.Write(&s[0], s.length());
return ar;
}
inline BinaryArchive& operator>>(BinaryArchive& ar, std::string& s) {
#ifdef _LINUX
size_t len = ar.template Get<size_t>();
#else
size_t len = ar.template Get<uint64_t>();
#endif
ar.PrepareRead(len);
s.assign(ar.Cursor(), len);
ar.AdvanceCursor(len);
return ar;
}
template <class AR, class T1, class T2>
Archive<AR>& operator<<(Archive<AR>& ar, const std::pair<T1, T2>& x) {
return ar << x.first << x.second;
}
template <class AR, class T1, class T2>
Archive<AR>& operator>>(Archive<AR>& ar, std::pair<T1, T2>& x) { // NOLINT
return ar >> x.first >> x.second;
}
#ifdef _LINUX
template <class AR, class... T>
Archive<AR>& SerializeTuple(Archive<AR>& ar, // NOLINT
const std::tuple<T...>& x, // NOLINT
std::integral_constant<size_t, 0> n) { // NOLINT
return ar;
}
#else
template <class AR, class... T>
Archive<AR>& SerializeTuple(Archive<AR>& ar, // NOLINT
const std::tuple<T...>& x, // NOLINT
std::integral_constant<uint64_t, 0> n) { // NOLINT
return ar;
}
#endif
#ifdef _LINUX
template <class AR, class... T, size_t N>
Archive<AR>& serialize_tuple(Archive<AR>& ar, // NOLINT
const std::tuple<T...>& x, // NOLINT
std::integral_constant<size_t, N> n) { // NOLINT
return SerializeTuple(ar, x, std::integral_constant<size_t, N - 1>())
<< std::get<N - 1>(x);
}
#else
template <class AR, class... T, uint64_t N>
Archive<AR>& serialize_tuple(Archive<AR>& ar, // NOLINT
const std::tuple<T...>& x, // NOLINT
std::integral_constant<uint64_t, N> n) { // NOLINT
return SerializeTuple(ar, x, std::integral_constant<uint64_t, N - 1>())
<< std::get<N - 1>(x);
}
#endif
#ifdef _LINUX
template <class AR, class... T>
Archive<AR>& operator<<(Archive<AR>& ar, const std::tuple<T...>& x) {
const size_t size = std::tuple_size<std::tuple<T...>>::value;
return SerializeTuple(ar, x, std::integral_constant<size_t, size>());
}
#else
template <class AR, class... T>
Archive<AR>& operator<<(Archive<AR>& ar, const std::tuple<T...>& x) {
const uint64_t size = std::tuple_size<std::tuple<T...>>::value;
return SerializeTuple(ar, x, std::integral_constant<uint64_t, size>());
}
#endif
#ifdef _LINUX
template <class AR, class... T>
Archive<AR>& DeserializeTuple(Archive<AR>& ar, std::tuple<T...>& x, // NOLINT
std::integral_constant<size_t, 0> n) {
return ar;
}
#else
template <class AR, class... T>
Archive<AR>& DeserializeTuple(Archive<AR>& ar, std::tuple<T...>& x, // NOLINT
std::integral_constant<uint64_t, 0> n) {
return ar;
}
#endif
#ifdef _LINUX
template <class AR, class... T, size_t N>
Archive<AR>& DeserializeTuple(Archive<AR>& ar, std::tuple<T...>& x, // NOLINT
std::integral_constant<size_t, N> n) {
return DeserializeTuple(ar, x, std::integral_constant<size_t, N - 1>()) >>
std::get<N - 1>(x);
}
#else
template <class AR, class... T, uint64_t N>
Archive<AR>& DeserializeTuple(Archive<AR>& ar, std::tuple<T...>& x, // NOLINT
std::integral_constant<uint64_t, N> n) {
return DeserializeTuple(ar, x, std::integral_constant<uint64_t, N - 1>()) >>
std::get<N - 1>(x);
}
#endif
#ifdef _LINUX
template <class AR, class... T>
Archive<AR>& operator>>(Archive<AR>& ar, std::tuple<T...>& x) {
const size_t size = std::tuple_size<std::tuple<T...>>::value;
return DeserializeTuple(ar, x, std::integral_constant<size_t, size>());
}
#else
template <class AR, class... T>
Archive<AR>& operator>>(Archive<AR>& ar, std::tuple<T...>& x) {
const uint64_t size = std::tuple_size<std::tuple<T...>>::value;
return DeserializeTuple(ar, x, std::integral_constant<uint64_t, size>());
}
#endif
#ifdef _LINUX
#define ARCHIVE_REPEAT(MAP_TYPE, RESERVE_STATEMENT) \
template <class AR, class KEY, class VALUE, class... ARGS> \
Archive<AR>& operator<<(Archive<AR>& ar, \
const MAP_TYPE<KEY, VALUE, ARGS...>& p) { \
ar << (size_t)p.size(); \
for (auto it = p.begin(); it != p.end(); ++it) { \
ar << *it; \
} \
return ar; \
} \
template <class AR, class KEY, class VALUE, class... ARGS> \
Archive<AR>& operator>>(Archive<AR>& ar, MAP_TYPE<KEY, VALUE, ARGS...>& p) { \
size_t size = ar.template get<size_t>(); \
p.clear(); \
RESERVE_STATEMENT; \
for (size_t i = 0; i < size; i++) { \
p.insert(ar.template get<std::pair<KEY, VALUE>>()); \
} \
return ar; \
}
#else
#define ARCHIVE_REPEAT(MAP_TYPE, RESERVE_STATEMENT) \
template <class AR, class KEY, class VALUE, class... ARGS> \
Archive<AR>& operator<<(Archive<AR>& ar, \
const MAP_TYPE<KEY, VALUE, ARGS...>& p) { \
ar << (uint64_t)p.size(); \
for (auto it = p.begin(); it != p.end(); ++it) { \
ar << *it; \
} \
return ar; \
} \
template <class AR, class KEY, class VALUE, class... ARGS> \
Archive<AR>& operator>>(Archive<AR>& ar, MAP_TYPE<KEY, VALUE, ARGS...>& p) { \
size_t size = ar.template get<uint64_t>(); \
p.clear(); \
RESERVE_STATEMENT; \
for (size_t i = 0; i < size; i++) { \
p.insert(ar.template get<std::pair<KEY, VALUE>>()); \
} \
return ar; \
}
#endif
ARCHIVE_REPEAT(std::map, )
ARCHIVE_REPEAT(std::multimap, )
ARCHIVE_REPEAT(std::unordered_map, p.reserve(size))
ARCHIVE_REPEAT(std::unordered_multimap, p.reserve(size))
#undef ARCHIVE_REPEAT
#ifdef _LINUX
#define ARCHIVE_REPEAT(SET_TYPE, RESERVE_STATEMENT) \
template <class AR, class KEY, class... ARGS> \
Archive<AR>& operator<<(Archive<AR>& ar, const SET_TYPE<KEY, ARGS...>& p) { \
ar << (size_t)p.size(); \
for (auto it = p.begin(); it != p.end(); ++it) { \
ar << *it; \
} \
return ar; \
} \
template <class AR, class KEY, class... ARGS> \
Archive<AR>& operator>>(Archive<AR>& ar, SET_TYPE<KEY, ARGS...>& p) { \
size_t size = ar.template get<size_t>(); \
p.clear(); \
RESERVE_STATEMENT; \
for (size_t i = 0; i < size; i++) { \
p.insert(ar.template get<KEY>()); \
} \
return ar; \
}
#else
#define ARCHIVE_REPEAT(SET_TYPE, RESERVE_STATEMENT) \
template <class AR, class KEY, class... ARGS> \
Archive<AR>& operator<<(Archive<AR>& ar, const SET_TYPE<KEY, ARGS...>& p) { \
ar << (uint64_t)p.size(); \
for (auto it = p.begin(); it != p.end(); ++it) { \
ar << *it; \
} \
return ar; \
} \
template <class AR, class KEY, class... ARGS> \
Archive<AR>& operator>>(Archive<AR>& ar, SET_TYPE<KEY, ARGS...>& p) { \
size_t size = ar.template get<uint64_t>(); \
p.clear(); \
RESERVE_STATEMENT; \
for (size_t i = 0; i < size; i++) { \
p.insert(ar.template get<KEY>()); \
} \
return ar; \
}
#endif
ARCHIVE_REPEAT(std::set, )
ARCHIVE_REPEAT(std::multiset, )
ARCHIVE_REPEAT(std::unordered_set, p.reserve(size))
ARCHIVE_REPEAT(std::unordered_multiset, p.reserve(size))
#undef ARCHIVE_REPEAT
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#if defined _WIN32 || defined __APPLE__
#else
#define _LINUX
#endif
#include <glog/logging.h>
#include <algorithm>
#include <condition_variable> // NOLINT
#include <deque>
#include <limits>
#include <memory>
#include <mutex> // NOLINT
#include <utility>
#include <vector>
#include "paddle/fluid/framework/expect.h"
namespace paddle {
namespace framework {
template <class T>
class ChannelObject {
public:
ChannelObject() {}
// capacity can be zero
explicit ChannelObject(size_t capacity) {
capacity_ = std::min(MaxCapacity(), capacity);
}
void Clear() {
std::unique_lock<std::mutex> lock(mutex_);
data_.clear();
data_.shrink_to_fit();
}
size_t Capacity() {
return capacity_; // atomic
}
void SetCapacity(size_t x) { // capacity can be zero
std::lock_guard<std::mutex> lock(mutex_);
capacity_ = std::min(MaxCapacity(), x);
Notify();
}
size_t BlockSize() {
return block_size_; // atomic
}
void SetBlockSize(size_t x) {
CHECK(x >= 1) << "block size must be >= 1";
std::lock_guard<std::mutex> lock(mutex_);
block_size_ = x;
}
template <class U>
void InheritFrom(const std::shared_ptr<ChannelObject<U>>& other) {
std::lock_guard<std::mutex> lock(mutex_);
capacity_ = other->Capacity();
block_size_ = other->BlockSize();
}
bool Closed() {
return closed_; // atomic
}
// open channel, then data can be write() to channel
void Open() {
std::lock_guard<std::mutex> lock(mutex_);
closed_ = false;
Notify();
}
// close channel, then no more data can be write() to channel
void Close() {
std::lock_guard<std::mutex> lock(mutex_);
closed_ = true;
Notify();
}
size_t Size() {
std::lock_guard<std::mutex> lock(mutex_);
return data_.size();
}
bool Empty() {
std::lock_guard<std::mutex> lock(mutex_);
return EmptyUnlocked();
}
// blocking operation
bool Get(T& val) { return Read(1, &val) != 0; } // NOLINT
// blocking operation
// returns 0 if the channel is closed and empty
size_t Read(size_t n, T* p) {
if (n == 0) {
return 0;
}
std::unique_lock<std::mutex> lock(mutex_);
size_t finished = Read(n, p, lock);
Notify();
return finished;
}
// blocking operation
bool Put(T&& val) { return WriteMove(1, &val) != 0; }
// blocking operation
bool Put(const T& val) { return Write(1, &val) != 0; }
// blocking operation
// returns value less than n if the channel is closed
size_t Write(size_t n, const T* p) {
if (n == 0) {
return 0;
}
std::unique_lock<std::mutex> lock(mutex_);
size_t finished = Write(n, p, lock);
Notify();
return finished;
}
// WriteMove() will clear original contents of input array
size_t WriteMove(size_t n, T* p) {
if (n == 0) {
return 0;
}
std::unique_lock<std::mutex> lock(mutex_);
size_t finished = WriteMove(n, p, lock);
Notify();
return finished;
}
// read data of block size from channel to vector
size_t Read(std::vector<T>& p) { // NOLINT
p.resize(block_size_);
size_t finished = Read(p.size(), &p[0]);
p.resize(finished);
return finished;
}
size_t ReadAll(std::vector<T>& p) { // NOLINT
p.clear();
size_t finished = 0;
size_t n = 0;
do {
// _block_size may change anytime
n = block_size_;
p.resize(finished + n);
n = Read(n, &p[finished]);
finished += n;
} while (n != 0);
p.resize(finished);
return finished;
}
// write data from vector to channel
size_t Write(const std::vector<T>& p) { return Write(p.size(), &p[0]); }
// write data from vector to channel
size_t Write(std::vector<T>&& p) { return WriteMove(p.size(), &p[0]); }
private:
size_t capacity_ = MaxCapacity();
size_t block_size_ = 1024;
bool closed_ = false;
std::mutex mutex_;
// use deque to store data
std::deque<T> data_;
size_t reading_count_ = 0;
int empty_waiters_ = 0;
int full_waiters_ = 0;
std::condition_variable empty_cond_;
std::condition_variable full_cond_;
static constexpr size_t MaxCapacity() {
return std::numeric_limits<size_t>::max() / 2;
}
void Notify() {
if (empty_waiters_ != 0 && (!EmptyUnlocked() || closed_)) {
empty_cond_.notify_one();
}
if (full_waiters_ != 0 && (!FullUnlocked() || closed_)) {
full_cond_.notify_one();
}
}
bool EmptyUnlocked() { return data_.empty(); }
bool FullUnlocked() { return data_.size() >= capacity_ + reading_count_; }
bool WaitForRead(std::unique_lock<std::mutex>& lock) { // NOLINT
#ifdef _LINUX
while (unlikely(EmptyUnlocked() && !closed_)) {
#else
while (EmptyUnlocked() && !closed_) {
#endif
if (full_waiters_ != 0) {
full_cond_.notify_one();
}
empty_waiters_++;
empty_cond_.wait(lock);
empty_waiters_--;
}
return !EmptyUnlocked();
}
bool WaitForWrite(std::unique_lock<std::mutex>& lock) { // NOLINT
#ifdef _LINUX
while (unlikely(FullUnlocked() && !closed_)) {
#else
while (FullUnlocked() && !closed_) {
#endif
if (empty_waiters_ != 0) {
empty_cond_.notify_one();
}
full_waiters_++;
full_cond_.wait(lock);
full_waiters_--;
}
return !closed_;
}
size_t Read(size_t n, T* p, std::unique_lock<std::mutex>& lock) { // NOLINT
size_t finished = 0;
CHECK(n <= MaxCapacity() - reading_count_);
reading_count_ += n;
while (finished < n && WaitForRead(lock)) {
size_t m = std::min(n - finished, data_.size());
for (size_t i = 0; i < m; i++) {
p[finished++] = std::move(data_.front());
data_.pop_front();
}
reading_count_ -= m;
}
reading_count_ -= n - finished;
return finished;
}
size_t Write(size_t n,
const T* p, // NOLINT
std::unique_lock<std::mutex>& lock) { // NOLINT
size_t finished = 0;
while (finished < n && WaitForWrite(lock)) {
size_t m =
std::min(n - finished, capacity_ + reading_count_ - data_.size());
for (size_t i = 0; i < m; i++) {
data_.push_back(p[finished++]);
}
}
return finished;
}
size_t WriteMove(size_t n,
T* p, // NOLINT
std::unique_lock<std::mutex>& lock) { // NOLINT
size_t finished = 0;
while (finished < n && WaitForWrite(lock)) {
size_t m =
std::min(n - finished, capacity_ + reading_count_ - data_.size());
for (size_t i = 0; i < m; i++) {
data_.push_back(std::move(p[finished++]));
}
}
return finished;
}
}; // NOLINT
template <class T>
using Channel = std::shared_ptr<ChannelObject<T>>;
template <class T>
Channel<T> MakeChannel(size_t capacity = std::numeric_limits<size_t>::max()) {
return std::make_shared<ChannelObject<T>>(capacity);
}
template <class T, class U>
Channel<T> MakeChannel(const Channel<U>& other) {
CHECK(other != nullptr) << "channel can not be NULL";
Channel<T> chan = std::make_shared<ChannelObject<T>>();
chan->InheritFrom(other);
return chan;
}
// NOTE: ChannelReader is a wrapper for quick read channel with a buffer. It
// will read a block data from channel, but user can get data one by one. So it
// is important to notice that user must call operator>> until false, or call
// get_buffer_remain until false to make sure the buffered data all readed.
template <class T>
class ChannelReader {
public:
explicit ChannelReader(ChannelObject<T>* channel = nullptr) {
Reset(channel);
}
~ChannelReader() { CHECK(cursor_ == 0) << "Forgot to read buffer data"; }
ChannelObject<T>* channel() { return channel_; }
void Reset(ChannelObject<T>* channel) {
CHECK(channel != nullptr) << "Channel can not be nullptr";
channel_ = channel;
cursor_ = 0;
failed_ = !channel;
}
// whether there were read failed
operator bool() { return !failed_; }
ChannelReader<T>& operator>>(T& val) {
if (failed_) {
return *this;
}
if (cursor_ >= buffer_.size()) {
cursor_ = 0;
if (channel_->read(buffer_) == 0) {
failed_ = true;
return *this;
}
}
val = std::move(buffer_[cursor_++]);
return *this;
}
bool GetBufferRemain(T& val) { // NOLINT
if (cursor_ >= buffer_.size()) {
cursor_ = 0;
return false;
}
val = std::move(buffer_[cursor_++]);
return true;
}
private:
ChannelObject<T>* channel_ = nullptr;
std::vector<T> buffer_;
size_t cursor_ = 0;
bool failed_ = true;
}; // NOLINT
template <class T>
class ChannelWriter {
public:
explicit ChannelWriter(ChannelObject<T>* channel = nullptr) {
Reset(channel);
}
~ChannelWriter() { CHECK(buffer_.empty()) << "Forgot to flush"; }
ChannelObject<T>* channel() { return channel_; }
void Reset(ChannelObject<T>* channel) {
CHECK(buffer_.empty()) << "Forgot to flush";
CHECK(channel != nullptr) << "Channel can not be nullptr";
channel_ = channel;
buffer_.clear();
failed_ = !channel;
}
// whether there were write failed
operator bool() { return !failed_; }
ChannelWriter<T>& operator<<(T&& val) {
if (failed_) {
return *this;
}
buffer_.push_back(std::move(val));
if (buffer_.size() >= channel_->BlockSize()) {
Flush();
}
return *this;
}
ChannelWriter<T>& operator<<(const T& val) {
if (failed_) {
return *this;
}
buffer_.push_back(val);
if (buffer_.size() >= channel_->BlockSize()) {
Flush();
}
return *this;
}
void Flush() {
if (failed_ || buffer_.empty()) {
buffer_.clear();
return;
}
failed_ |=
channel_->WriteMove(buffer_.size(), &buffer_[0]) != buffer_.size();
buffer_.clear();
}
private:
ChannelObject<T>* channel_ = nullptr;
std::vector<T> buffer_;
bool failed_ = true;
}; // NOLINT
// only used for range-for loop
// for (auto& x : chan) {...}
template <class T>
struct ChannelIterator {
std::shared_ptr<ChannelReader<T>> reader_;
T data_;
void operator++() {
CHECK(reader_ != nullptr) << "reader can not be NULL";
if (!(*reader_ >> data_)) {
reader_ = nullptr;
}
}
T& operator*() { return data_; }
friend bool operator==(const ChannelIterator<T>& a,
const ChannelIterator<T>& b) {
return a.reader_ == b.reader_;
}
friend bool operator!=(const ChannelIterator<T>& a,
const ChannelIterator<T>& b) {
return a.reader_ != b.reader_;
}
}; // NOLINT
template <class T>
ChannelIterator<T> begin(ChannelObject<T>* chan) {
ChannelIterator<T> it{std::make_shared<ChannelReader<T>>(chan), T()};
++it;
return it;
}
template <class T>
ChannelIterator<T> end(ChannelObject<T>* chan) {
return {nullptr, T()};
}
} // namespace framework
} // namespace paddle
此差异已折叠。
......@@ -14,6 +14,11 @@ limitations under the License. */
#pragma once
#if defined _WIN32 || defined __APPLE__
#else
#define _LINUX
#endif
#include <fstream>
#include <future> // NOLINT
#include <memory>
......@@ -24,7 +29,9 @@ limitations under the License. */
#include <utility>
#include <vector>
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/lod_tensor.h"
......@@ -88,17 +95,15 @@ class DataFeed {
virtual void AssignFeedVar(const Scope& scope);
// This function will do nothing at default
virtual void SetMemoryData(void* memory_data) {}
virtual void SetInputChannel(void* channel) {}
// This function will do nothing at default
virtual void SetOutputChannel(void* channel) {}
// This function will do nothing at default
virtual void SetMemoryDataMutex(std::mutex* mutex) {}
virtual void SetConsumeChannel(void* channel) {}
// This function will do nothing at default
virtual void SetThreadId(int thread_id) {}
// This function will do nothing at default
virtual void SetThreadNum(int thread_num) {}
// This function will do nothing at default
virtual void SetTrainerNum(int trainer_num) {}
// This function will do nothing at default
virtual void SetFleetSendBatchSize(int64_t size) {}
virtual void SetFileListMutex(std::mutex* mutex) {
mutex_for_pick_file_ = mutex;
}
......@@ -106,21 +111,6 @@ class DataFeed {
virtual void LoadIntoMemory() {
PADDLE_THROW("This function(LoadIntoMemory) is not implemented.");
}
virtual void LocalShuffle() {
PADDLE_THROW("This function(LocalShuffle) is not implemented.");
}
virtual void GlobalShuffle() {
PADDLE_THROW("This function(GlobalShuffle) is not implemented.");
}
// This function will do nothing at default
virtual void FillMemoryDataToChannel() {}
// This function will do nothing at default
virtual void FillChannelToMemoryData() {}
// This function will do nothing at default
virtual void PutInsToChannel(const std::string& ins_str) {}
virtual int64_t GetChannelDataSize() { return 0; }
// This function will do nothing at default
virtual void ReleaseChannelData() {}
protected:
// The following three functions are used to check if it is executed in this
......@@ -212,54 +202,32 @@ class PrivateQueueDataFeed : public DataFeed {
};
template <typename T>
class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
class InMemoryDataFeed : public DataFeed {
public:
InMemoryDataFeed();
virtual ~InMemoryDataFeed() {}
virtual void Init(const DataFeedDesc& data_feed_desc) = 0;
virtual bool Start();
virtual int Next();
virtual void SetMemoryData(void* memory_data);
virtual void SetMemoryDataMutex(std::mutex* mutex);
virtual void SetInputChannel(void* channel);
virtual void SetOutputChannel(void* channel);
virtual void SetConsumeChannel(void* channel);
virtual void SetThreadId(int thread_id);
virtual void SetThreadNum(int thread_num);
virtual void SetTrainerNum(int trainer_num);
virtual void SetFleetSendBatchSize(int64_t size);
virtual void PutInsToChannel(const std::string& ins_str);
virtual void FillMemoryDataToChannel();
virtual void FillChannelToMemoryData();
virtual void LoadIntoMemory();
virtual void LocalShuffle();
virtual void GlobalShuffle();
virtual int64_t GetChannelDataSize();
virtual void ReleaseChannelData();
protected:
virtual void AddInstanceToInsVec(T* vec_ins, const T& instance,
int index) = 0;
virtual bool ParseOneInstance(T* instance) = 0;
virtual bool ParseOneInstanceFromPipe(T* instance) = 0;
virtual void PutToFeedVec(const T& ins_vec) = 0;
virtual void SerializeIns(const std::vector<T*>& ins, std::string* str) = 0;
virtual void DeserializeIns(std::vector<T>* ins, const std::string& str) = 0;
virtual std::pair<int64_t, int64_t> GetMemoryDataInterval();
virtual void PutToFeedVec(const std::vector<T>& ins_vec) = 0;
int thread_id_;
int thread_num_;
int trainer_num_;
uint32_t rand_seed;
std::vector<T>* memory_data_;
std::mutex* mutex_for_update_memory_data_;
// when read ins, we put ins from one channel to the other,
// and when finish reading, we set cur_channel = 1 - cur_channel,
// so if cur_channel=0, all data are in shuffled_ins_, else shuffled_ins_out_
int cur_channel_;
std::shared_ptr<paddle::framework::BlockingQueue<T>> shuffled_ins_;
std::shared_ptr<paddle::framework::BlockingQueue<T>> shuffled_ins_out_;
int64_t fleet_send_batch_size_;
// sleep after send is to slow down sending data, but it's trick,
// should be removed later.
int64_t fleet_send_sleep_seconds_;
std::ifstream file_;
std::shared_ptr<FILE> fp_;
paddle::framework::ChannelObject<T>* input_channel_;
paddle::framework::ChannelObject<T>* output_channel_;
paddle::framework::ChannelObject<T>* consume_channel_;
};
// This class define the data type of instance(ins_vec) in MultiSlotDataFeed
......@@ -381,6 +349,126 @@ class MultiSlotType {
std::vector<size_t> offset_;
};
template <class AR>
paddle::framework::Archive<AR>& operator<<(paddle::framework::Archive<AR>& ar,
const MultiSlotType& ins) {
ar << ins.GetType();
#ifdef _LINUX
ar << ins.GetOffset();
#else
const auto& offset = ins.GetOffset();
ar << (uint64_t)offset.size();
for (const size_t& x : offset) {
ar << (const uint64_t)x;
}
#endif
ar << ins.GetFloatData();
ar << ins.GetUint64Data();
return ar;
}
template <class AR>
paddle::framework::Archive<AR>& operator>>(paddle::framework::Archive<AR>& ar,
MultiSlotType& ins) {
ar >> ins.MutableType();
#ifdef _LINUX
ar >> ins.MutableOffset();
#else
auto& offset = ins.MutableOffset();
offset.resize(ar.template Get<uint64_t>());
for (size_t& x : offset) {
uint64_t t;
ar >> t;
x = (size_t)t;
}
#endif
ar >> ins.MutableFloatData();
ar >> ins.MutableUint64Data();
return ar;
}
union FeatureKey {
uint64_t uint64_feasign_;
float float_feasign_;
};
struct FeatureItem {
FeatureItem() {}
FeatureItem(FeatureKey sign, uint16_t slot) {
this->sign() = sign;
this->slot() = slot;
}
FeatureKey& sign() { return *(reinterpret_cast<FeatureKey*>(sign_buffer())); }
const FeatureKey& sign() const {
const FeatureKey* ret = reinterpret_cast<FeatureKey*>(sign_buffer());
return *ret;
}
uint16_t& slot() { return slot_; }
const uint16_t& slot() const { return slot_; }
private:
char* sign_buffer() const { return const_cast<char*>(sign_); }
char sign_[sizeof(FeatureKey)];
uint16_t slot_;
};
// sizeof Record is much less than std::vector<MultiSlotType>
struct Record {
std::vector<FeatureItem> uint64_feasigns_;
std::vector<FeatureItem> float_feasigns_;
std::string ins_id_;
};
template <class AR>
paddle::framework::Archive<AR>& operator<<(paddle::framework::Archive<AR>& ar,
const FeatureKey& fk) {
ar << fk.uint64_feasign_;
ar << fk.float_feasign_;
return ar;
}
template <class AR>
paddle::framework::Archive<AR>& operator>>(paddle::framework::Archive<AR>& ar,
FeatureKey& fk) {
ar >> fk.uint64_feasign_;
ar >> fk.float_feasign_;
return ar;
}
template <class AR>
paddle::framework::Archive<AR>& operator<<(paddle::framework::Archive<AR>& ar,
const FeatureItem& fi) {
ar << fi.sign();
ar << fi.slot();
return ar;
}
template <class AR>
paddle::framework::Archive<AR>& operator>>(paddle::framework::Archive<AR>& ar,
FeatureItem& fi) {
ar >> fi.sign();
ar >> fi.slot();
return ar;
}
template <class AR>
paddle::framework::Archive<AR>& operator<<(paddle::framework::Archive<AR>& ar,
const Record& r) {
ar << r.uint64_feasigns_;
ar << r.float_feasigns_;
ar << r.ins_id_;
return ar;
}
template <class AR>
paddle::framework::Archive<AR>& operator>>(paddle::framework::Archive<AR>& ar,
Record& r) {
ar >> r.uint64_feasigns_;
ar >> r.float_feasigns_;
ar >> r.ins_id_;
return ar;
}
// This DataFeed is used to feed multi-slot type data.
// The format of multi-slot type data:
// [n feasign_0 feasign_1 ... feasign_n]*
......@@ -391,7 +479,6 @@ class MultiSlotDataFeed
virtual ~MultiSlotDataFeed() {}
virtual void Init(const DataFeedDesc& data_feed_desc);
virtual bool CheckFile(const char* filename);
// virtual void ReadThread();
protected:
virtual void ReadThread();
......@@ -403,24 +490,16 @@ class MultiSlotDataFeed
virtual void PutToFeedVec(const std::vector<MultiSlotType>& ins_vec);
};
class MultiSlotInMemoryDataFeed
: public InMemoryDataFeed<std::vector<MultiSlotType>> {
class MultiSlotInMemoryDataFeed : public InMemoryDataFeed<Record> {
public:
MultiSlotInMemoryDataFeed() {}
virtual ~MultiSlotInMemoryDataFeed() {}
virtual void Init(const DataFeedDesc& data_feed_desc);
protected:
virtual void AddInstanceToInsVec(std::vector<MultiSlotType>* vec_ins,
const std::vector<MultiSlotType>& instance,
int index);
virtual bool ParseOneInstance(std::vector<MultiSlotType>* instance);
virtual bool ParseOneInstanceFromPipe(std::vector<MultiSlotType>* instance);
virtual void PutToFeedVec(const std::vector<MultiSlotType>& ins_vec);
virtual void SerializeIns(const std::vector<std::vector<MultiSlotType>*>& ins,
std::string* str);
virtual void DeserializeIns(std::vector<std::vector<MultiSlotType>>* ins,
const std::string& str);
virtual bool ParseOneInstance(Record* instance);
virtual bool ParseOneInstanceFromPipe(Record* instance);
virtual void PutToFeedVec(const std::vector<Record>& ins_vec);
};
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
......
......@@ -32,9 +32,14 @@ namespace framework {
// constructor
template <typename T>
DatasetImpl<T>::DatasetImpl() {
VLOG(3) << "DatasetImpl<T>::DatasetImpl() constructor";
thread_num_ = 1;
trainer_num_ = 1;
channel_num_ = 1;
file_idx_ = 0;
cur_channel_ = 0;
fleet_send_batch_size_ = 80000;
fleet_send_sleep_seconds_ = 2;
}
// set filelist, file_idx_ will reset to zero.
......@@ -58,10 +63,6 @@ void DatasetImpl<T>::SetThreadNum(int thread_num) {
template <typename T>
void DatasetImpl<T>::SetTrainerNum(int trainer_num) {
trainer_num_ = trainer_num;
// should inform reader of trainer_num directly
for (auto reader : readers_) {
reader->SetTrainerNum(trainer_num);
}
}
// if you run distributed, and want to do global shuffle,
......@@ -70,9 +71,6 @@ void DatasetImpl<T>::SetTrainerNum(int trainer_num) {
template <typename T>
void DatasetImpl<T>::SetFleetSendBatchSize(int64_t size) {
fleet_send_batch_size_ = size;
for (auto reader : readers_) {
reader->SetFleetSendBatchSize(size);
}
}
template <typename T>
......@@ -92,12 +90,38 @@ void DatasetImpl<T>::SetDataFeedDesc(const std::string& data_feed_desc_str) {
&data_feed_desc_);
}
// readers_.size() may not be equal to thread_num_,
// it changes when filelist_.size() < thread_num_
template <typename T>
std::vector<std::shared_ptr<paddle::framework::DataFeed>>&
DatasetImpl<T>::GetReaders() {
return readers_;
void DatasetImpl<T>::SetChannelNum(int channel_num) {
channel_num_ = channel_num;
}
template <typename T>
std::vector<paddle::framework::DataFeed*> DatasetImpl<T>::GetReaders() {
std::vector<paddle::framework::DataFeed*> ret;
ret.reserve(readers_.size());
for (auto i : readers_) {
ret.push_back(i.get());
}
return ret;
}
template <typename T>
void DatasetImpl<T>::CreateChannel() {
if (input_channel_ == nullptr) {
input_channel_ = paddle::framework::MakeChannel<T>();
}
if (multi_output_channel_.size() == 0) {
multi_output_channel_.reserve(channel_num_);
for (int i = 0; i < channel_num_; ++i) {
multi_output_channel_.push_back(paddle::framework::MakeChannel<T>());
}
}
if (multi_consume_channel_.size() == 0) {
multi_consume_channel_.reserve(channel_num_);
for (int i = 0; i < channel_num_; ++i) {
multi_consume_channel_.push_back(paddle::framework::MakeChannel<T>());
}
}
}
// if sent message between workers, should first call this function
......@@ -119,9 +143,6 @@ void DatasetImpl<T>::LoadIntoMemory() {
VLOG(3) << "DatasetImpl<T>::LoadIntoMemory() begin";
platform::Timer timeline;
timeline.Start();
if (readers_.size() == 0) {
CreateReaders();
}
std::vector<std::thread> load_threads;
for (int64_t i = 0; i < thread_num_; ++i) {
load_threads.push_back(std::thread(
......@@ -130,20 +151,63 @@ void DatasetImpl<T>::LoadIntoMemory() {
for (std::thread& t : load_threads) {
t.join();
}
input_channel_->Close();
int64_t in_chan_size = input_channel_->Size();
input_channel_->SetBlockSize(in_chan_size / thread_num_ + 1);
timeline.Pause();
VLOG(3) << "DatasetImpl<T>::LoadIntoMemory() end"
<< ", memory data size=" << memory_data_.size()
<< ", memory data size=" << input_channel_->Size()
<< ", cost time=" << timeline.ElapsedSec() << " seconds";
}
template <typename T>
void DatasetImpl<T>::PreLoadIntoMemory() {
VLOG(3) << "DatasetImpl<T>::PreLoadIntoMemory() begin";
preload_threads_.clear();
for (int64_t i = 0; i < thread_num_; ++i) {
preload_threads_.push_back(std::thread(
&paddle::framework::DataFeed::LoadIntoMemory, readers_[i].get()));
}
VLOG(3) << "DatasetImpl<T>::PreLoadIntoMemory() end";
}
template <typename T>
void DatasetImpl<T>::WaitPreLoadDone() {
VLOG(3) << "DatasetImpl<T>::WaitPreLoadDone() begin";
for (std::thread& t : preload_threads_) {
t.join();
}
input_channel_->Close();
int64_t in_chan_size = input_channel_->Size();
input_channel_->SetBlockSize(in_chan_size / thread_num_ + 1);
VLOG(3) << "DatasetImpl<T>::WaitPreLoadDone() end";
}
// release memory data
template <typename T>
void DatasetImpl<T>::ReleaseMemory() {
VLOG(3) << "DatasetImpl<T>::ReleaseMemory() begin";
std::vector<T>().swap(memory_data_);
for (int i = 0; i < readers_.size(); ++i) {
readers_[i]->ReleaseChannelData();
if (input_channel_) {
input_channel_->Clear();
input_channel_ = nullptr;
}
for (size_t i = 0; i < multi_output_channel_.size(); ++i) {
if (!multi_output_channel_[i]) {
continue;
}
multi_output_channel_[i]->Clear();
multi_output_channel_[i] = nullptr;
}
std::vector<paddle::framework::Channel<T>>().swap(multi_output_channel_);
for (size_t i = 0; i < multi_consume_channel_.size(); ++i) {
if (!multi_consume_channel_[i]) {
continue;
}
multi_consume_channel_[i]->Clear();
multi_consume_channel_[i] = nullptr;
}
std::vector<paddle::framework::Channel<T>>().swap(multi_consume_channel_);
std::vector<std::shared_ptr<paddle::framework::DataFeed>>().swap(readers_);
VLOG(3) << "DatasetImpl<T>::ReleaseMemory() end";
}
......@@ -153,21 +217,22 @@ void DatasetImpl<T>::LocalShuffle() {
VLOG(3) << "DatasetImpl<T>::LocalShuffle() begin";
platform::Timer timeline;
timeline.Start();
if (readers_.size() == 0) {
CreateReaders();
}
// if it is not InMemory, memory_data_ is empty
std::random_shuffle(memory_data_.begin(), memory_data_.end());
std::vector<std::thread> local_shuffle_threads;
for (int64_t i = 0; i < thread_num_; ++i) {
local_shuffle_threads.push_back(std::thread(
&paddle::framework::DataFeed::LocalShuffle, readers_[i].get()));
}
for (std::thread& t : local_shuffle_threads) {
t.join();
if (!input_channel_ || input_channel_->Size() == 0) {
VLOG(3) << "DatasetImpl<T>::LocalShuffle() end, no data to shuffle";
return;
}
std::vector<T>().swap(memory_data_);
auto fleet_ptr = FleetWrapper::GetInstance();
input_channel_->Close();
std::vector<T> data;
input_channel_->ReadAll(data);
std::shuffle(data.begin(), data.end(), fleet_ptr->LocalRandomEngine());
input_channel_->Open();
input_channel_->Write(std::move(data));
data.clear();
data.shrink_to_fit();
input_channel_->Close();
timeline.Pause();
VLOG(3) << "DatasetImpl<T>::LocalShuffle() end, cost time="
<< timeline.ElapsedSec() << " seconds";
......@@ -178,23 +243,75 @@ void DatasetImpl<T>::GlobalShuffle() {
VLOG(3) << "DatasetImpl<T>::GlobalShuffle() begin";
platform::Timer timeline;
timeline.Start();
if (readers_.size() == 0) {
CreateReaders();
}
auto fleet_ptr = FleetWrapper::GetInstance();
// local shuffle all data before global shuffle
std::shuffle(memory_data_.begin(), memory_data_.end(),
fleet_ptr->LocalRandomEngine());
if (!input_channel_ || input_channel_->Size() == 0) {
VLOG(3) << "DatasetImpl<T>::GlobalShuffle() end, no data to shuffle";
return;
}
// local shuffle
input_channel_->Close();
std::vector<T> data;
input_channel_->ReadAll(data);
std::shuffle(data.begin(), data.end(), fleet_ptr->LocalRandomEngine());
input_channel_->Open();
input_channel_->Write(std::move(data));
data.clear();
data.shrink_to_fit();
input_channel_->Close();
input_channel_->SetBlockSize(fleet_send_batch_size_);
VLOG(3) << "DatasetImpl<T>::GlobalShuffle() input_channel_ size "
<< input_channel_->Size();
auto global_shuffle_func = [this]() {
auto fleet_ptr = FleetWrapper::GetInstance();
std::vector<T> data;
while (this->input_channel_->Read(data)) {
std::vector<paddle::framework::BinaryArchive> ars(this->trainer_num_);
for (auto& t : data) {
auto client_id = fleet_ptr->LocalRandomEngine()() % this->trainer_num_;
ars[client_id] << t;
}
std::vector<std::future<int32_t>> total_status;
std::vector<int> send_index(this->trainer_num_);
for (int i = 0; i < this->trainer_num_; ++i) {
send_index[i] = i;
}
std::shuffle(send_index.begin(), send_index.end(),
fleet_ptr->LocalRandomEngine());
for (auto index = 0u; index < this->trainer_num_; ++index) {
int i = send_index[index];
if (ars[i].Length() == 0) {
continue;
}
std::string msg(ars[i].Buffer(), ars[i].Length());
auto ret = fleet_ptr->SendClientToClientMsg(0, i, msg);
total_status.push_back(std::move(ret));
}
for (auto& t : total_status) {
t.wait();
}
ars.clear();
ars.shrink_to_fit();
data.clear();
data.shrink_to_fit();
sleep(this->fleet_send_sleep_seconds_);
}
};
VLOG(3) << "start global shuffle threads";
std::vector<std::thread> global_shuffle_threads;
for (int i = 0; i < thread_num_; ++i) {
global_shuffle_threads.push_back(std::thread(
&paddle::framework::DataFeed::GlobalShuffle, readers_[i].get()));
global_shuffle_threads.push_back(std::thread(global_shuffle_func));
}
for (std::thread& t : global_shuffle_threads) {
t.join();
}
std::vector<T>().swap(memory_data_);
global_shuffle_threads.clear();
global_shuffle_threads.shrink_to_fit();
input_channel_->Clear();
timeline.Pause();
VLOG(3) << "DatasetImpl<T>::GlobalShuffle() end, cost time="
<< timeline.ElapsedSec() << " seconds";
......@@ -203,78 +320,67 @@ void DatasetImpl<T>::GlobalShuffle() {
template <typename T>
void DatasetImpl<T>::CreateReaders() {
VLOG(3) << "Calling CreateReaders()";
CHECK(thread_num_ > 0) << "thread_num should > 0";
int file_cnt = filelist_.size();
int memory_data_size = memory_data_.size();
if (memory_data_size != 0 && thread_num_ > memory_data_size) {
VLOG(3) << "Dataset thread num = " << thread_num_
<< ", memory data size = " << memory_data_size
<< ". Changing Dataset thread num = " << memory_data_size;
thread_num_ = memory_data_size;
} else if (file_cnt != 0 && thread_num_ > file_cnt) {
VLOG(3) << "Dataset thread num = " << thread_num_
<< ", file num = " << file_cnt
<< ". Changing Dataset thread num = " << file_cnt;
thread_num_ = file_cnt;
}
VLOG(3) << "thread_num in Readers: " << thread_num_;
VLOG(3) << "thread num in Dataset: " << thread_num_;
VLOG(3) << "Filelist size in Dataset: " << filelist_.size();
VLOG(3) << "channel num in Dataset: " << channel_num_;
CHECK(thread_num_ > 0) << "thread num should > 0";
CHECK(thread_num_ <= filelist_.size())
<< "thread num should <= filelist size";
CHECK(channel_num_ > 0) << "channel num should > 0";
CHECK(channel_num_ <= thread_num_) << "channel num should <= thread num";
VLOG(3) << "readers size: " << readers_.size();
VLOG(3) << "Filelist size in readers: " << filelist_.size();
if (readers_.size() != 0) {
VLOG(3) << "readers_.size() = " << readers_.size()
<< ", will not create again";
return;
}
VLOG(3) << "data feed class name: " << data_feed_desc_.name();
int channel_idx = 0;
for (int i = 0; i < thread_num_; ++i) {
readers_.push_back(DataFeedFactory::CreateDataFeed(data_feed_desc_.name()));
readers_.back()->Init(data_feed_desc_);
readers_.back()->SetMemoryData(&memory_data_);
readers_.back()->SetMemoryDataMutex(&mutex_for_update_memory_data_);
readers_.back()->SetThreadId(i);
readers_.back()->SetThreadNum(thread_num_);
readers_.back()->SetTrainerNum(trainer_num_);
readers_.back()->SetFileListMutex(&mutex_for_pick_file_);
readers_.back()->SetFileListIndex(&file_idx_);
readers_.back()->SetFileList(filelist_);
readers_[i]->Init(data_feed_desc_);
readers_[i]->SetThreadId(i);
readers_[i]->SetThreadNum(thread_num_);
readers_[i]->SetFileListMutex(&mutex_for_pick_file_);
readers_[i]->SetFileListIndex(&file_idx_);
readers_[i]->SetFileList(filelist_);
if (input_channel_ != nullptr) {
readers_[i]->SetInputChannel(input_channel_.get());
}
if (cur_channel_ == 0 && channel_idx < multi_output_channel_.size()) {
readers_[i]->SetOutputChannel(multi_output_channel_[channel_idx].get());
readers_[i]->SetConsumeChannel(multi_consume_channel_[channel_idx].get());
} else if (channel_idx < multi_output_channel_.size()) {
readers_[i]->SetOutputChannel(multi_consume_channel_[channel_idx].get());
readers_[i]->SetConsumeChannel(multi_output_channel_[channel_idx].get());
}
++channel_idx;
if (channel_idx >= channel_num_) {
channel_idx = 0;
}
}
VLOG(3) << "readers size: " << readers_.size();
}
template <typename T>
void DatasetImpl<T>::DestroyReaders() {
VLOG(3) << "Calling DestroyReaders()";
// clear memory_data_ before fill it
// because if LoadIntoMemory but no Shuffle,
// memory_data_ has empty data which has been std::move to channel
if (memory_data_.size() != 0) {
std::vector<T>().swap(memory_data_);
}
std::vector<std::thread> fill_threads;
for (int i = 0; i < thread_num_; ++i) {
fill_threads.push_back(
std::thread(&paddle::framework::DataFeed::FillChannelToMemoryData,
readers_[i].get()));
}
for (std::thread& t : fill_threads) {
t.join();
}
std::vector<std::shared_ptr<paddle::framework::DataFeed>>().swap(readers_);
VLOG(3) << "readers size: " << readers_.size();
// if memory_data_ is empty, which means it's not InMemory mode,
// so the next epoch should read all data again
if (memory_data_.size() == 0) {
file_idx_ = 0;
}
file_idx_ = 0;
cur_channel_ = 1 - cur_channel_;
}
template <typename T>
int64_t DatasetImpl<T>::GetMemoryDataSize() {
return memory_data_.size();
return input_channel_->Size();
}
template <typename T>
int64_t DatasetImpl<T>::GetShuffleDataSize() {
int64_t sum = 0;
for (int i = 0; i < readers_.size(); ++i) {
sum += readers_[i]->GetChannelDataSize();
for (size_t i = 0; i < multi_output_channel_.size(); ++i) {
sum += multi_output_channel_[i]->Size() + multi_consume_channel_[i]->Size();
}
return sum;
}
......@@ -285,16 +391,34 @@ int DatasetImpl<T>::ReceiveFromClient(int msg_type, int client_id,
#ifdef _LINUX
VLOG(3) << "ReceiveFromClient msg_type=" << msg_type
<< ", client_id=" << client_id << ", msg length=" << msg.length();
if (msg.length() == 0) {
return 0;
}
paddle::framework::BinaryArchive ar;
ar.SetReadBuffer(const_cast<char*>(msg.c_str()), msg.length(), nullptr);
if (ar.Cursor() == ar.Finish()) {
return 0;
}
std::vector<T> data;
while (ar.Cursor() < ar.Finish()) {
data.push_back(ar.Get<T>());
}
CHECK(ar.Cursor() == ar.Finish());
auto fleet_ptr = FleetWrapper::GetInstance();
int64_t index = fleet_ptr->LocalRandomEngine()() % thread_num_;
int64_t index = fleet_ptr->LocalRandomEngine()() % channel_num_;
VLOG(3) << "ramdom index=" << index;
readers_[index]->PutInsToChannel(msg);
multi_output_channel_[index]->Write(std::move(data));
data.clear();
data.shrink_to_fit();
#endif
return 0;
}
// explicit instantiation
template class DatasetImpl<std::vector<MultiSlotType>>;
template class DatasetImpl<Record>;
} // end namespace framework
} // end namespace paddle
......@@ -55,6 +55,8 @@ class Dataset {
// set data fedd desc, which contains:
// data feed name, batch size, slots
virtual void SetDataFeedDesc(const std::string& data_feed_desc_str) = 0;
// set channel num
virtual void SetChannelNum(int channel_num) = 0;
// get file list
virtual const std::vector<std::string>& GetFileList() = 0;
// get thread num
......@@ -67,14 +69,21 @@ class Dataset {
virtual std::pair<std::string, std::string> GetHdfsConfig() = 0;
// get data fedd desc
virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() = 0;
// get channel num
virtual int GetChannelNum() = 0;
// get readers, the reader num depend both on thread num
// and filelist size
virtual std::vector<std::shared_ptr<paddle::framework::DataFeed>>&
GetReaders() = 0;
virtual std::vector<paddle::framework::DataFeed*> GetReaders() = 0;
// create input channel and output channel
virtual void CreateChannel() = 0;
// register message handler between workers
virtual void RegisterClientToClientMsgHandler() = 0;
// load all data into memory
virtual void LoadIntoMemory() = 0;
// load all data into memory in async mode
virtual void PreLoadIntoMemory() = 0;
// wait async load done
virtual void WaitPreLoadDone() = 0;
// release all memory data
virtual void ReleaseMemory() = 0;
// local shuffle data
......@@ -110,6 +119,7 @@ class DatasetImpl : public Dataset {
virtual void SetHdfsConfig(const std::string& fs_name,
const std::string& fs_ugi);
virtual void SetDataFeedDesc(const std::string& data_feed_desc_str);
virtual void SetChannelNum(int channel_num);
virtual const std::vector<std::string>& GetFileList() { return filelist_; }
virtual int GetThreadNum() { return thread_num_; }
......@@ -121,11 +131,13 @@ class DatasetImpl : public Dataset {
virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() {
return data_feed_desc_;
}
virtual std::vector<std::shared_ptr<paddle::framework::DataFeed>>&
GetReaders();
virtual int GetChannelNum() { return channel_num_; }
virtual std::vector<paddle::framework::DataFeed*> GetReaders();
virtual void CreateChannel();
virtual void RegisterClientToClientMsgHandler();
virtual void LoadIntoMemory();
virtual void PreLoadIntoMemory();
virtual void WaitPreLoadDone();
virtual void ReleaseMemory();
virtual void LocalShuffle();
virtual void GlobalShuffle();
......@@ -138,8 +150,14 @@ class DatasetImpl : public Dataset {
virtual int ReceiveFromClient(int msg_type, int client_id,
const std::string& msg);
std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers_;
std::vector<T> memory_data_;
std::mutex mutex_for_update_memory_data_;
paddle::framework::Channel<T> input_channel_;
int channel_num_;
std::vector<paddle::framework::Channel<T>> multi_output_channel_;
std::vector<paddle::framework::Channel<T>> multi_consume_channel_;
// when read ins, we put ins from one channel to the other,
// and when finish reading, we set cur_channel = 1 - cur_channel,
// so if cur_channel=0, all data are in output_channel, else consume_channel
int cur_channel_;
int thread_num_;
paddle::framework::DataFeedDesc data_feed_desc_;
int trainer_num_;
......@@ -148,12 +166,13 @@ class DatasetImpl : public Dataset {
std::mutex mutex_for_pick_file_;
std::string fs_name_;
std::string fs_ugi_;
unsigned int rand_seed;
int64_t fleet_send_batch_size_;
int64_t fleet_send_sleep_seconds_;
std::vector<std::thread> preload_threads_;
};
// use std::vector<MultiSlotType> as data type
class MultiSlotDataset : public DatasetImpl<std::vector<MultiSlotType>> {
class MultiSlotDataset : public DatasetImpl<Record> {
public:
MultiSlotDataset() {}
virtual ~MultiSlotDataset() {}
......
......@@ -21,14 +21,14 @@ limitations under the License. */
namespace paddle {
namespace framework {
typedef std::shared_ptr<Dataset> (*CreateDatasetFunction)();
typedef std::unique_ptr<Dataset> (*CreateDatasetFunction)();
typedef std::unordered_map<std::string, CreateDatasetFunction> datasetMap;
datasetMap g_dataset_map;
#define REGISTER_DATASET_CLASS(dataset_class) \
namespace { \
std::shared_ptr<Dataset> Creator_##dataset_class() { \
return std::shared_ptr<Dataset>(new dataset_class); \
std::unique_ptr<Dataset> Creator_##dataset_class() { \
return std::unique_ptr<Dataset>(new dataset_class); \
} \
class __Registerer_##dataset_class { \
public: \
......@@ -50,7 +50,7 @@ std::string DatasetFactory::DatasetTypeList() {
return dataset_types;
}
std::shared_ptr<Dataset> DatasetFactory::CreateDataset(
std::unique_ptr<Dataset> DatasetFactory::CreateDataset(
std::string dataset_class) {
if (g_dataset_map.count(dataset_class) < 1) {
LOG(WARNING) << "Your Dataset " << dataset_class
......
......@@ -23,7 +23,7 @@ namespace framework {
class DatasetFactory {
public:
static std::string DatasetTypeList();
static std::shared_ptr<Dataset> CreateDataset(std::string dataset_class);
static std::unique_ptr<Dataset> CreateDataset(std::string dataset_class);
};
} // namespace framework
} // namespace paddle
......@@ -19,7 +19,7 @@ namespace framework {
void DeviceWorker::SetRootScope(Scope* root_scope) { root_scope_ = root_scope; }
void DeviceWorker::SetDataFeed(const std::shared_ptr<DataFeed>& data_feed) {
void DeviceWorker::SetDataFeed(DataFeed* data_feed) {
device_reader_ = data_feed;
}
......
......@@ -113,7 +113,7 @@ class DeviceWorker {
// will make this zero copy in the future
virtual void BindingDataFeedMemory() = 0;
virtual void SetRootScope(Scope* root_scope);
virtual void SetDataFeed(const std::shared_ptr<DataFeed>& data_feed);
virtual void SetDataFeed(DataFeed* data_feed);
virtual void SetPlace(const paddle::platform::Place& place) {
place_ = place;
}
......@@ -121,7 +121,7 @@ class DeviceWorker {
protected:
Scope* root_scope_;
paddle::platform::Place place_;
std::shared_ptr<DataFeed> device_reader_;
DataFeed* device_reader_;
int64_t batch_num_;
FetchConfig fetch_config_;
bool use_cvm_;
......
......@@ -27,8 +27,7 @@ void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc,
thread_num_ = trainer_desc.thread_num();
SetDataset(dataset);
dataset->CreateReaders();
const std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers =
const std::vector<paddle::framework::DataFeed*> readers =
dataset->GetReaders();
thread_num_ = readers.size();
......@@ -72,7 +71,6 @@ void DistMultiTrainer::Finalize() {
th.join();
}
pull_dense_worker_->Stop();
dataset_ptr_->DestroyReaders();
root_scope_->DropKids();
}
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#if defined _WIN32 || defined __APPLE__
#else
#define _LINUX
#endif
#ifdef _LINUX
#ifndef likely
#define likely(x) __builtin_expect((x), 1)
#endif
#endif
#ifdef _LINUX
#ifndef unlikely
#define unlikely(x) __builtin_expect((x), 0)
#endif
#endif
......@@ -26,9 +26,7 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc,
thread_num_ = trainer_desc.thread_num();
SetDataset(dataset);
// get filelist from trainer_desc here
dataset->CreateReaders();
VLOG(3) << "readers created";
const std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers =
const std::vector<paddle::framework::DataFeed*> readers =
dataset->GetReaders();
VLOG(3) << "readers num: " << readers.size();
// change thread num to readers num
......@@ -75,7 +73,6 @@ void MultiTrainer::Finalize() {
for (auto& th : threads_) {
th.join();
}
dataset_ptr_->DestroyReaders();
root_scope_->DropKids();
}
......
......@@ -28,9 +28,7 @@ void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc,
SetDataset(dataset);
// get filelist from trainer_desc here
dataset->CreateReaders();
VLOG(3) << "readers created";
const std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers =
const std::vector<paddle::framework::DataFeed*> readers =
dataset->GetReaders();
VLOG(3) << "readers num: " << readers.size();
......@@ -259,7 +257,6 @@ void PipelineTrainer::Finalize() {
pipeline_scopes_[0]->FindVar(var)->Get<LoDTensor>();
TensorCopySync(thread_tensor, platform::CPUPlace(), root_tensor);
}
dataset_ptr_->DestroyReaders();
root_scope_->DropKids();
}
......
......@@ -74,7 +74,7 @@ class MultiTrainer : public TrainerBase {
protected:
int thread_num_;
std::vector<std::thread> threads_;
std::vector<std::shared_ptr<DataFeed>> readers_;
std::vector<DataFeed*> readers_;
std::vector<std::shared_ptr<DeviceWorker>> workers_;
};
......@@ -136,7 +136,7 @@ class PipelineTrainer : public TrainerBase {
std::vector<std::unique_ptr<SyncFunctor>> sync_functors_;
std::shared_ptr<platform::NCCLContextMap> nccl_ctx_map_;
std::vector<std::shared_ptr<DataFeed>> readers_;
std::vector<DataFeed*> readers_;
void InitFirstScopeQueue(ScopeQueue* scope_queue, int pipeline_id,
const ProgramDesc& main_program);
......
......@@ -42,33 +42,64 @@ namespace paddle {
namespace pybind {
void BindDataset(py::module* m) {
py::class_<framework::Dataset, std::shared_ptr<framework::Dataset>>(*m,
py::class_<framework::Dataset, std::unique_ptr<framework::Dataset>>(*m,
"Dataset")
.def(py::init([](const std::string& name = "MultiSlotDataset") {
return framework::DatasetFactory::CreateDataset(name);
}))
.def("set_filelist", &framework::Dataset::SetFileList)
.def("set_thread_num", &framework::Dataset::SetThreadNum)
.def("set_trainer_num", &framework::Dataset::SetTrainerNum)
.def("set_filelist", &framework::Dataset::SetFileList,
py::call_guard<py::gil_scoped_release>())
.def("set_thread_num", &framework::Dataset::SetThreadNum,
py::call_guard<py::gil_scoped_release>())
.def("set_trainer_num", &framework::Dataset::SetTrainerNum,
py::call_guard<py::gil_scoped_release>())
.def("set_fleet_send_batch_size",
&framework::Dataset::SetFleetSendBatchSize)
.def("set_hdfs_config", &framework::Dataset::SetHdfsConfig)
.def("set_data_feed_desc", &framework::Dataset::SetDataFeedDesc)
.def("get_filelist", &framework::Dataset::GetFileList)
.def("get_thread_num", &framework::Dataset::GetThreadNum)
.def("get_trainer_num", &framework::Dataset::GetTrainerNum)
&framework::Dataset::SetFleetSendBatchSize,
py::call_guard<py::gil_scoped_release>())
.def("set_hdfs_config", &framework::Dataset::SetHdfsConfig,
py::call_guard<py::gil_scoped_release>())
.def("set_data_feed_desc", &framework::Dataset::SetDataFeedDesc,
py::call_guard<py::gil_scoped_release>())
.def("get_filelist", &framework::Dataset::GetFileList,
py::call_guard<py::gil_scoped_release>())
.def("get_thread_num", &framework::Dataset::GetThreadNum,
py::call_guard<py::gil_scoped_release>())
.def("get_trainer_num", &framework::Dataset::GetTrainerNum,
py::call_guard<py::gil_scoped_release>())
.def("get_fleet_send_batch_size",
&framework::Dataset::GetFleetSendBatchSize)
.def("get_hdfs_config", &framework::Dataset::GetHdfsConfig)
.def("get_data_feed_desc", &framework::Dataset::GetDataFeedDesc)
&framework::Dataset::GetFleetSendBatchSize,
py::call_guard<py::gil_scoped_release>())
.def("get_hdfs_config", &framework::Dataset::GetHdfsConfig,
py::call_guard<py::gil_scoped_release>())
.def("get_data_feed_desc", &framework::Dataset::GetDataFeedDesc,
py::call_guard<py::gil_scoped_release>())
.def("register_client2client_msg_handler",
&framework::Dataset::RegisterClientToClientMsgHandler)
.def("load_into_memory", &framework::Dataset::LoadIntoMemory)
.def("release_memory", &framework::Dataset::ReleaseMemory)
.def("local_shuffle", &framework::Dataset::LocalShuffle)
.def("global_shuffle", &framework::Dataset::GlobalShuffle)
.def("get_memory_data_size", &framework::Dataset::GetMemoryDataSize)
.def("get_shuffle_data_size", &framework::Dataset::GetShuffleDataSize);
&framework::Dataset::RegisterClientToClientMsgHandler,
py::call_guard<py::gil_scoped_release>())
.def("create_channel", &framework::Dataset::CreateChannel,
py::call_guard<py::gil_scoped_release>())
.def("create_readers", &framework::Dataset::CreateReaders,
py::call_guard<py::gil_scoped_release>())
.def("destroy_readers", &framework::Dataset::DestroyReaders,
py::call_guard<py::gil_scoped_release>())
.def("load_into_memory", &framework::Dataset::LoadIntoMemory,
py::call_guard<py::gil_scoped_release>())
.def("preload_into_memory", &framework::Dataset::PreLoadIntoMemory,
py::call_guard<py::gil_scoped_release>())
.def("wait_preload_done", &framework::Dataset::WaitPreLoadDone,
py::call_guard<py::gil_scoped_release>())
.def("release_memory", &framework::Dataset::ReleaseMemory,
py::call_guard<py::gil_scoped_release>())
.def("local_shuffle", &framework::Dataset::LocalShuffle,
py::call_guard<py::gil_scoped_release>())
.def("global_shuffle", &framework::Dataset::GlobalShuffle,
py::call_guard<py::gil_scoped_release>())
.def("get_memory_data_size", &framework::Dataset::GetMemoryDataSize,
py::call_guard<py::gil_scoped_release>())
.def("get_shuffle_data_size", &framework::Dataset::GetShuffleDataSize,
py::call_guard<py::gil_scoped_release>())
.def("set_queue_num", &framework::Dataset::SetChannelNum,
py::call_guard<py::gil_scoped_release>());
}
} // end namespace pybind
......
......@@ -71,6 +71,7 @@ class DatasetBase(object):
self.proto_desc.pipe_command = "cat"
self.dataset = core.Dataset("MultiSlotDataset")
self.thread_num = 0
self.filelist = []
def set_pipe_command(self, pipe_command):
"""
......@@ -139,6 +140,7 @@ class DatasetBase(object):
filelist(list): file list
"""
self.dataset.set_filelist(filelist)
self.filelist = filelist
def set_use_var(self, var_list):
"""
......@@ -193,7 +195,14 @@ class DatasetBase(object):
Set data_feed_desc before load or shuffle,
user no need to call this function.
"""
if self.thread_num > len(self.filelist):
self.thread_num = len(self.filelist)
self.dataset.set_thread_num(self.thread_num)
self.dataset.set_data_feed_desc(self.desc())
self.dataset.create_readers()
def _finish_to_run(self):
self.dataset.destroy_readers()
def desc(self):
"""
......@@ -226,6 +235,57 @@ class InMemoryDataset(DatasetBase):
""" Init. """
super(InMemoryDataset, self).__init__()
self.proto_desc.name = "MultiSlotInMemoryDataFeed"
self.fleet_send_batch_size = 80000
self.queue_num = None
def _prepare_to_run(self):
"""
Set data_feed_desc before load or shuffle,
user no need to call this function.
"""
if self.thread_num > len(self.filelist):
self.thread_num = len(self.filelist)
self.dataset.set_thread_num(self.thread_num)
if self.queue_num is None:
self.queue_num = self.thread_num
self.dataset.set_queue_num(self.queue_num)
self.dataset.set_data_feed_desc(self.desc())
self.dataset.create_channel()
self.dataset.create_readers()
def set_queue_num(self, queue_num):
"""
Set Dataset output queue num, training threads get data from queues
Args:
set_queue_num(int): dataset output queue num
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
dataset.set_queue_num(12)
"""
self.queue_num = queue_num
def set_fleet_send_batch_size(self, fleet_send_batch_size):
"""
Set fleet send batch size, default is 80000
Args:
fleet_send_batch_size(int): fleet send batch size
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
dataset.set_fleet_send_batch_size(800)
"""
self.fleet_send_batch_size = fleet_send_batch_size
def load_into_memory(self):
"""
......@@ -243,6 +303,39 @@ class InMemoryDataset(DatasetBase):
self._prepare_to_run()
self.dataset.load_into_memory()
def preload_into_memory(self):
"""
Load data into memory in async mode
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
filelist = ["a.txt", "b.txt"]
dataset.set_filelist(filelist)
dataset.preload_into_memory()
dataset.wait_preload_done()
"""
self._prepare_to_run()
self.dataset.preload_into_memory()
def wait_preload_done(self):
"""
Wait preload_into_memory done
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
filelist = ["a.txt", "b.txt"]
dataset.set_filelist(filelist)
dataset.preload_into_memory()
dataset.wait_preload_done()
"""
self.dataset.wait_preload_done()
def local_shuffle(self):
"""
Local shuffle
......@@ -282,13 +375,12 @@ class InMemoryDataset(DatasetBase):
"""
trainer_num = 1
fleet_send_batch_size = 80000
if fleet is not None:
fleet._role_maker._barrier_worker()
trainer_num = fleet.worker_num()
self.dataset.register_client2client_msg_handler()
self.dataset.set_trainer_num(trainer_num)
self.dataset.set_fleet_send_batch_size(fleet_send_batch_size)
self.dataset.set_fleet_send_batch_size(self.fleet_send_batch_size)
if fleet is not None:
fleet._role_maker._barrier_worker()
self.dataset.global_shuffle()
......
......@@ -889,6 +889,7 @@ class Executor(object):
if dataset == None:
raise RuntimeError("dataset is needed and should be initialized")
dataset._prepare_to_run()
scope, trainer = self._prepare_trainer(
program=program,
dataset=dataset,
......@@ -900,11 +901,11 @@ class Executor(object):
print_period=print_period)
trainer._set_infer(True)
trainer._gen_trainer_desc()
dataset._prepare_to_run()
self._dump_debug_info(program=program, trainer=trainer)
self._default_executor.run_from_dataset(program.desc, scope,
dataset.dataset,
trainer._desc())
dataset._finish_to_run()
return None
def train_from_dataset(self,
......@@ -969,6 +970,7 @@ class Executor(object):
if dataset == None:
raise RuntimeError("dataset is need and should be initialized")
dataset._prepare_to_run()
scope, trainer = self._prepare_trainer(
program=program,
dataset=dataset,
......@@ -979,9 +981,9 @@ class Executor(object):
fetch_info=fetch_info,
print_period=print_period)
trainer._gen_trainer_desc()
dataset._prepare_to_run()
self._dump_debug_info(program=program, trainer=trainer)
self._default_executor.run_from_dataset(program.desc, scope,
dataset.dataset,
trainer._desc())
dataset._finish_to_run()
return None
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册