Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
58f896c3
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
58f896c3
编写于
10月 18, 2016
作者:
Y
Yu Yang
提交者:
qingqing01
10月 18, 2016
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Speed up PyDP2, support numpy.float array (#207)
上级
45280a07
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
59 addition
and
12 deletion
+59
-12
cmake/flags.cmake
cmake/flags.cmake
+3
-1
demo/mnist/data/get_mnist_data.sh
demo/mnist/data/get_mnist_data.sh
+1
-1
paddle/gserver/dataproviders/DataProvider.cpp
paddle/gserver/dataproviders/DataProvider.cpp
+6
-2
paddle/gserver/dataproviders/DataProvider.h
paddle/gserver/dataproviders/DataProvider.h
+3
-2
paddle/gserver/dataproviders/PyDataProvider2.cpp
paddle/gserver/dataproviders/PyDataProvider2.cpp
+30
-6
paddle/utils/Queue.h
paddle/utils/Queue.h
+15
-0
python/paddle/trainer_config_helpers/data_sources.py
python/paddle/trainer_config_helpers/data_sources.py
+1
-0
未找到文件。
cmake/flags.cmake
浏览文件 @
58f896c3
...
...
@@ -64,7 +64,9 @@ set(COMMON_FLAGS
-Wdelete-non-virtual-dtor
-Wno-unused-parameter
-Wno-error=literal-suffix
-Wno-error=unused-local-typedefs
)
-Wno-error=unused-local-typedefs
-Wno-error=unused-function
# Warnings in Numpy Header.
)
foreach
(
flag
${
COMMON_FLAGS
}
)
safe_set_cflag
(
CMAKE_C_FLAGS
${
flag
}
)
...
...
demo/mnist/data/get_mnist_data.sh
浏览文件 @
58f896c3
#!/usr/bin/env sh
# This scripts downloads the mnist data and unzips it.
set
-e
DIR
=
"
$(
cd
"
$(
dirname
"
$0
"
)
"
;
pwd
-P
)
"
rm
-rf
"
$DIR
/raw_data"
mkdir
"
$DIR
/raw_data"
...
...
paddle/gserver/dataproviders/DataProvider.cpp
浏览文件 @
58f896c3
...
...
@@ -57,7 +57,8 @@ void BufferBatch::clone(DataBatch* srcBatch, bool useGpu) {
}
}
DoubleBuffer
::
DoubleBuffer
(
DataProvider
*
dataPool
,
bool
useGpu
,
DoubleBuffer
::
DoubleBuffer
(
DataProvider
*
dataPool
,
bool
useGpu
,
int64_t
batchSize
)
{
batchSize_
=
batchSize
;
dataPool_
=
dataPool
;
...
...
@@ -110,6 +111,9 @@ void DoubleBuffer::removeOneBatch(DataBatch* dataBatch) {
}
void
DoubleBuffer
::
insertOneBatch
(
DataBatch
*
batch
)
{
while
(
!
bufferQueue_
->
waitNotEmptyFor
(
2
/* seconds */
))
{
// time out
if
(
stopping_
)
return
;
}
BufferBatch
*
bufBatch
=
bufferQueue_
->
dequeue
();
// clone and copy the data from an Threadlocal Variable
bufBatch
->
clone
(
batch
,
useGpu_
);
...
...
@@ -138,7 +142,7 @@ void DoubleBuffer::asyncLoadBatch() {
actualSize
=
dataPool_
->
getNextBatchInternal
(
batchSize_
,
&
newBatch
);
}
insertOneBatch
(
&
newBatch
);
}
while
(
actualSize
>
0
);
}
while
(
actualSize
>
0
&&
!
stopping_
);
}
}
...
...
paddle/gserver/dataproviders/DataProvider.h
浏览文件 @
58f896c3
...
...
@@ -259,7 +259,9 @@ typedef Queue<BufferBatch*> BufferBatchQueue;
class
DoubleBuffer
{
public:
DoubleBuffer
(
DataProvider
*
dataPool
,
bool
useGpu
,
int64_t
batchSize
=
0
);
DoubleBuffer
(
DataProvider
*
dataPool
,
bool
useGpu
,
int64_t
batchSize
=
0
);
virtual
~
DoubleBuffer
();
void
removeOneBatch
(
DataBatch
*
dataBatch
);
...
...
@@ -349,7 +351,6 @@ public:
*/
virtual
void
reset
()
{
if
(
doubleBuffer_
!=
nullptr
)
{
LOG
(
INFO
)
<<
"the double-buffer is starting ..."
;
doubleBuffer_
->
startAsyncLoad
();
}
}
...
...
paddle/gserver/dataproviders/PyDataProvider2.cpp
浏览文件 @
58f896c3
...
...
@@ -18,9 +18,16 @@ limitations under the License. */
#include <stdlib.h>
#include <unordered_set>
#include <list>
#include <Python.h>
#include <numpy/numpyconfig.h>
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
#include <numpy/ndarrayobject.h>
#include "DataProvider.h"
#include "paddle/utils/PythonUtil.h"
#include "paddle/utils/Locks.h"
#include "paddle/utils/Stat.h"
namespace
paddle
{
...
...
@@ -202,7 +209,10 @@ public:
PyDataProvider2
(
const
DataConfig
&
config
,
const
ModelConfig
&
modelConfig
,
bool
useGpu
)
:
DataProvider
(
config
,
useGpu
),
callingContextCreated_
(
2
)
{
:
DataProvider
(
config
,
useGpu
),
callingContextCreated_
(
2
)
{
if
(
PyArray_API
==
NULL
)
import_array
();
auto
&
args
=
config
.
load_data_args
();
PyObjectPtr
kwargs
=
PyObjectPtr
(
PyDict_New
());
if
(
!
args
.
empty
())
{
...
...
@@ -454,6 +464,7 @@ private:
std
::
condition_variable
pushCV_
;
std
::
condition_variable
pullCV_
;
std
::
mutex
mtx_
;
ThreadBarrier
callingContextCreated_
;
std
::
unique_ptr
<
IPyDataProviderCache
>
cache_
;
...
...
@@ -496,8 +507,8 @@ public:
* Resetting the PyDataProvider. May start reading thread here.
*/
virtual
void
reset
()
{
DataProvider
::
reset
();
resetImpl
(
true
);
DataProvider
::
reset
();
}
/**
...
...
@@ -518,6 +529,7 @@ public:
* Loading a batch of data.
*/
int64_t
getNextBatchInternal
(
int64_t
size_
,
DataBatch
*
batch
)
{
REGISTER_TIMER
(
"PyDP2.getNextBatchInternal"
)
CHECK_GE
(
size_
,
0
);
size_t
size
=
(
size_t
)
size_
;
if
(
loadThread_
)
{
// loading from thread should wait for data pool ready.
...
...
@@ -698,10 +710,22 @@ public:
*/
virtual
void
fill
(
Argument
&
argument
,
PyObject
*
obj
)
{
real
*
dat
=
argument
.
value
->
getData
()
+
height_
*
headerPtr_
->
dim
;
py
::
SequenceHelper
s
(
obj
);
// TODO(yuyang18): Here we can use AVX or SSE to accelerate memory copy.
for
(
size_t
i
=
0
;
i
<
headerPtr_
->
dim
;
++
i
)
{
dat
[
i
]
=
(
real
)
s
.
getDouble
(
i
);
if
(
PyArray_Check
(
obj
))
{
auto
dtype
=
PyArray_DTYPE
((
PyArrayObject
*
)
obj
);
if
(
dtype
->
type
==
'f'
&&
dtype
->
elsize
==
sizeof
(
real
))
{
real
*
data
=
(
real
*
)
PyArray_DATA
((
PyArrayObject
*
)
obj
);
auto
sz
=
PyArray_SIZE
((
PyArrayObject
*
)
obj
);
std
::
copy
(
data
,
data
+
sz
,
dat
);
}
else
{
LOG
(
FATAL
)
<<
"You should yield float"
<<
sizeof
(
real
)
*
8
<<
" array"
;
}
}
else
{
py
::
SequenceHelper
s
(
obj
);
// TODO(yuyang18): Here we can use AVX or SSE to accelerate memory copy.
for
(
size_t
i
=
0
;
i
<
headerPtr_
->
dim
;
++
i
)
{
dat
[
i
]
=
(
real
)
s
.
getDouble
(
i
);
}
}
++
height_
;
}
...
...
paddle/utils/Queue.h
浏览文件 @
58f896c3
...
...
@@ -135,6 +135,21 @@ public:
queueCV_
.
wait
(
lock
,
[
this
]()
{
return
numElements_
==
0
;
});
}
/**
* @brief wait queue is not empty at most for some seconds.
* @param seconds wait time limit.
* @return true if queue is not empty. false if timeout.
*/
bool
waitNotEmptyFor
(
int
seconds
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
queueLock_
);
return
queueCV_
.
wait_for
(
lock
,
std
::
chrono
::
seconds
(
seconds
),
[
this
]
{
return
numElements_
!=
0
;
});
}
private:
std
::
deque
<
T
>
elements_
;
int
numElements_
;
...
...
python/paddle/trainer_config_helpers/data_sources.py
浏览文件 @
58f896c3
...
...
@@ -84,6 +84,7 @@ def define_py_data_source(file_list, cls, module,
data
.
load_data_module
=
load_data_module
data
.
load_data_object
=
load_data_object
data
.
load_data_args
=
load_data_args
data
.
async_load_data
=
True
return
data
data_cls
=
py_data2
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录