提交 58f896c3 编写于 作者: Y Yu Yang 提交者: qingqing01

Speed up PyDP2, support numpy.float array (#207)

上级 45280a07
......@@ -64,7 +64,9 @@ set(COMMON_FLAGS
-Wdelete-non-virtual-dtor
-Wno-unused-parameter
-Wno-error=literal-suffix
-Wno-error=unused-local-typedefs)
-Wno-error=unused-local-typedefs
-Wno-error=unused-function # Warnings in Numpy Header.
)
foreach(flag ${COMMON_FLAGS})
safe_set_cflag(CMAKE_C_FLAGS ${flag})
......
#!/usr/bin/env sh
# This scripts downloads the mnist data and unzips it.
set -e
DIR="$( cd "$(dirname "$0")" ; pwd -P )"
rm -rf "$DIR/raw_data"
mkdir "$DIR/raw_data"
......
......@@ -57,7 +57,8 @@ void BufferBatch::clone(DataBatch* srcBatch, bool useGpu) {
}
}
DoubleBuffer::DoubleBuffer(DataProvider* dataPool, bool useGpu,
DoubleBuffer::DoubleBuffer(DataProvider *dataPool,
bool useGpu,
int64_t batchSize) {
batchSize_ = batchSize;
dataPool_ = dataPool;
......@@ -110,6 +111,9 @@ void DoubleBuffer::removeOneBatch(DataBatch* dataBatch) {
}
void DoubleBuffer::insertOneBatch(DataBatch* batch) {
while (!bufferQueue_->waitNotEmptyFor(2 /* seconds */)) { // time out
if (stopping_) return;
}
BufferBatch* bufBatch = bufferQueue_->dequeue();
// clone and copy the data from an Threadlocal Variable
bufBatch->clone(batch, useGpu_);
......@@ -138,7 +142,7 @@ void DoubleBuffer::asyncLoadBatch() {
actualSize = dataPool_->getNextBatchInternal(batchSize_, &newBatch);
}
insertOneBatch(&newBatch);
} while (actualSize > 0);
} while (actualSize > 0 && !stopping_);
}
}
......
......@@ -259,7 +259,9 @@ typedef Queue<BufferBatch*> BufferBatchQueue;
class DoubleBuffer {
public:
DoubleBuffer(DataProvider* dataPool, bool useGpu, int64_t batchSize = 0);
DoubleBuffer(DataProvider* dataPool,
bool useGpu,
int64_t batchSize = 0);
virtual ~DoubleBuffer();
void removeOneBatch(DataBatch* dataBatch);
......@@ -349,7 +351,6 @@ public:
*/
virtual void reset() {
if (doubleBuffer_ != nullptr) {
LOG(INFO) << "the double-buffer is starting ...";
doubleBuffer_->startAsyncLoad();
}
}
......
......@@ -18,9 +18,16 @@ limitations under the License. */
#include <stdlib.h>
#include <unordered_set>
#include <list>
#include <Python.h>
#include <numpy/numpyconfig.h>
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
#include <numpy/ndarrayobject.h>
#include "DataProvider.h"
#include "paddle/utils/PythonUtil.h"
#include "paddle/utils/Locks.h"
#include "paddle/utils/Stat.h"
namespace paddle {
......@@ -202,7 +209,10 @@ public:
PyDataProvider2(const DataConfig& config,
const ModelConfig& modelConfig,
bool useGpu)
:DataProvider(config, useGpu), callingContextCreated_(2) {
:DataProvider(config, useGpu),
callingContextCreated_(2) {
if (PyArray_API == NULL)
import_array();
auto& args = config.load_data_args();
PyObjectPtr kwargs = PyObjectPtr(PyDict_New());
if (!args.empty()) {
......@@ -454,6 +464,7 @@ private:
std::condition_variable pushCV_;
std::condition_variable pullCV_;
std::mutex mtx_;
ThreadBarrier callingContextCreated_;
std::unique_ptr<IPyDataProviderCache> cache_;
......@@ -496,8 +507,8 @@ public:
* Resetting the PyDataProvider. May start reading thread here.
*/
virtual void reset() {
DataProvider::reset();
resetImpl(true);
DataProvider::reset();
}
/**
......@@ -518,6 +529,7 @@ public:
* Loading a batch of data.
*/
int64_t getNextBatchInternal(int64_t size_, DataBatch *batch) {
REGISTER_TIMER("PyDP2.getNextBatchInternal")
CHECK_GE(size_, 0);
size_t size = (size_t) size_;
if (loadThread_) { // loading from thread should wait for data pool ready.
......@@ -698,11 +710,23 @@ public:
*/
virtual void fill(Argument &argument, PyObject *obj) {
real* dat = argument.value->getData() + height_ * headerPtr_->dim;
if (PyArray_Check(obj)) {
auto dtype = PyArray_DTYPE((PyArrayObject*)obj);
if (dtype->type == 'f' && dtype->elsize == sizeof(real)) {
real * data = (real*)PyArray_DATA((PyArrayObject*)obj);
auto sz = PyArray_SIZE((PyArrayObject*)obj);
std::copy(data, data + sz, dat);
} else {
LOG(FATAL) << "You should yield float" << sizeof(real) * 8
<< " array";
}
} else {
py::SequenceHelper s(obj);
// TODO(yuyang18): Here we can use AVX or SSE to accelerate memory copy.
for (size_t i=0; i < headerPtr_->dim; ++i) {
dat[i] = (real) s.getDouble(i);
}
}
++height_;
}
......
......@@ -135,6 +135,21 @@ public:
queueCV_.wait(lock, [this]() { return numElements_ == 0; });
}
/**
* @brief wait queue is not empty at most for some seconds.
* @param seconds wait time limit.
* @return true if queue is not empty. false if timeout.
*/
bool waitNotEmptyFor(int seconds) {
std::unique_lock<std::mutex> lock(queueLock_);
return queueCV_.wait_for(
lock,
std::chrono::seconds(seconds),
[this] {
return numElements_ != 0;
});
}
private:
std::deque<T> elements_;
int numElements_;
......
......@@ -84,6 +84,7 @@ def define_py_data_source(file_list, cls, module,
data.load_data_module = load_data_module
data.load_data_object = load_data_object
data.load_data_args = load_data_args
data.async_load_data = True
return data
data_cls = py_data2
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册