Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
fc9ad34e
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
fc9ad34e
编写于
3月 01, 2017
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'feature/inferencer' into feature/recommendation_v2_api
上级
f7a06f17
9ba231d3
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
109 addition
and
18 deletion
+109
-18
demo/mnist/api_train_v2.py
demo/mnist/api_train_v2.py
+13
-0
python/paddle/v2/__init__.py
python/paddle/v2/__init__.py
+5
-1
python/paddle/v2/dataset/mnist.py
python/paddle/v2/dataset/mnist.py
+15
-14
python/paddle/v2/inferencer.py
python/paddle/v2/inferencer.py
+59
-0
python/paddle/v2/reader/decorator.py
python/paddle/v2/reader/decorator.py
+17
-3
未找到文件。
demo/mnist/api_train_v2.py
浏览文件 @
fc9ad34e
...
...
@@ -44,6 +44,19 @@ def main():
batch_size
=
32
),
event_handler
=
event_handler
)
# output is a softmax layer. It returns probabilities.
# Shape should be (100, 10)
probs
=
paddle
.
infer
(
output
=
inference
,
parameters
=
parameters
,
reader
=
paddle
.
reader
.
batched
(
paddle
.
reader
.
limited
(
paddle
.
reader
.
map_readers
(
lambda
item
:
(
item
[
0
],
),
paddle
.
dataset
.
mnist
.
test
()),
limit
=
100
),
batch_size
=
32
))
print
probs
.
shape
if
__name__
==
'__main__'
:
main
()
python/paddle/v2/__init__.py
浏览文件 @
fc9ad34e
...
...
@@ -24,12 +24,13 @@ from . import dataset
from
.
import
reader
import
attr
import
pooling
import
inferencer
import
py_paddle.swig_paddle
as
api
__all__
=
[
'optimizer'
,
'layer'
,
'activation'
,
'parameters'
,
'init'
,
'trainer'
,
'event'
,
'data_type'
,
'attr'
,
'pooling'
,
'data_feeder'
,
'dataset'
,
'reader'
,
'topology'
'topology'
,
'inferencer'
,
'infer'
]
...
...
@@ -39,3 +40,6 @@ def init(**kwargs):
args
.
append
(
'--%s=%s'
%
(
key
,
str
(
kwargs
[
key
])))
api
.
initPaddle
(
*
args
)
infer
=
inferencer
.
infer
python/paddle/v2/dataset/mnist.py
浏览文件 @
fc9ad34e
...
...
@@ -35,24 +35,25 @@ def reader_creator(image_filename, label_filename, buffer_size):
l
=
subprocess
.
Popen
([
zcat_cmd
,
label_filename
],
stdout
=
subprocess
.
PIPE
)
l
.
stdout
.
read
(
8
)
# skip some magic bytes
while
True
:
labels
=
numpy
.
fromfile
(
l
.
stdout
,
'ubyte'
,
count
=
buffer_size
).
astype
(
"int"
)
try
:
# reader could be break.
while
True
:
labels
=
numpy
.
fromfile
(
l
.
stdout
,
'ubyte'
,
count
=
buffer_size
).
astype
(
"int"
)
if
labels
.
size
!=
buffer_size
:
break
# numpy.fromfile returns empty slice after EOF.
if
labels
.
size
!=
buffer_size
:
break
# numpy.fromfile returns empty slice after EOF.
images
=
numpy
.
fromfile
(
m
.
stdout
,
'ubyte'
,
count
=
buffer_size
*
28
*
28
).
reshape
(
(
buffer_size
,
28
*
28
)).
astype
(
'float32'
)
images
=
numpy
.
fromfile
(
m
.
stdout
,
'ubyte'
,
count
=
buffer_size
*
28
*
28
).
reshape
(
(
buffer_size
,
28
*
28
)).
astype
(
'float32'
)
images
=
images
/
255.0
*
2.0
-
1.0
images
=
images
/
255.0
*
2.0
-
1.0
for
i
in
xrange
(
buffer_size
):
yield
images
[
i
,
:],
int
(
labels
[
i
])
m
.
terminate
()
l
.
terminate
()
for
i
in
xrange
(
buffer_size
):
yield
images
[
i
,
:],
int
(
labels
[
i
])
finally
:
m
.
terminate
()
l
.
terminate
()
return
reader
...
...
python/paddle/v2/inferencer.py
0 → 100644
浏览文件 @
fc9ad34e
import
py_paddle.swig_paddle
as
api
import
topology
from
data_feeder
import
DataFeeder
import
itertools
import
numpy
__all__
=
[
'InferenceEngine'
,
'infer'
]
class
InferenceEngine
(
object
):
def
__init__
(
self
,
output
,
parameters
):
topo
=
topology
.
Topology
(
output
)
gm
=
api
.
GradientMachine
.
createFromConfigProto
(
topo
.
proto
(),
api
.
CREATE_MODE_TESTING
,
[
api
.
PARAMETER_VALUE
])
for
param
in
gm
.
getParameters
():
val
=
param
.
getBuf
(
api
.
PARAMETER_VALUE
)
name
=
param
.
getName
()
assert
isinstance
(
val
,
api
.
Vector
)
val
.
copyFromNumpyArray
(
parameters
.
get
(
name
).
flatten
())
self
.
__gradient_machine__
=
gm
self
.
__data_types__
=
topo
.
data_type
()
def
iter_infer
(
self
,
reader
,
reader_dict
=
None
):
if
reader_dict
is
None
:
reader_dict
=
self
.
default_reader_dict
()
feeder
=
DataFeeder
(
self
.
__data_types__
,
reader_dict
)
self
.
__gradient_machine__
.
start
()
for
data_batch
in
reader
():
yield
self
.
__gradient_machine__
.
forwardTest
(
feeder
(
data_batch
))
self
.
__gradient_machine__
.
finish
()
def
iter_infer_field
(
self
,
field
,
**
kwargs
):
for
result
in
self
.
iter_infer
(
**
kwargs
):
yield
[
each_result
[
field
]
for
each_result
in
result
]
def
infer
(
self
,
field
=
'value'
,
**
kwargs
):
retv
=
None
for
result
in
self
.
iter_infer_field
(
field
=
field
,
**
kwargs
):
if
retv
is
None
:
retv
=
[[]]
*
len
(
result
)
for
i
,
item
in
enumerate
(
result
):
retv
[
i
].
append
(
item
)
retv
=
[
numpy
.
concatenate
(
out
)
for
out
in
retv
]
if
len
(
retv
)
==
1
:
return
retv
[
0
]
else
:
return
retv
def
default_reader_dict
(
self
):
reader_dict
=
dict
()
for
i
,
tp
in
enumerate
(
self
.
__data_types__
):
reader_dict
[
tp
[
0
]]
=
i
return
reader_dict
def
infer
(
output
,
parameters
,
reader
,
reader_dict
=
None
,
field
=
'value'
):
inferer
=
InferenceEngine
(
output
=
output
,
parameters
=
parameters
)
return
inferer
.
infer
(
field
=
field
,
reader
=
reader
,
reader_dict
=
reader_dict
)
python/paddle/v2/reader/decorator.py
浏览文件 @
fc9ad34e
...
...
@@ -14,13 +14,13 @@
__all__
=
[
'map_readers'
,
'buffered'
,
'compose'
,
'chain'
,
'shuffle'
,
'ComposeNotAligned'
,
'batched'
'ComposeNotAligned'
,
'batched'
,
'limited'
]
from
Queue
import
Queue
from
threading
import
Thread
import
itertools
import
random
from
Queue
import
Queue
from
threading
import
Thread
def
map_readers
(
func
,
*
readers
):
...
...
@@ -213,3 +213,17 @@ def batched(reader, batch_size):
yield
batch
return
batched_reader
def
limited
(
reader
,
limit
):
"""
Limit the max number of samples that reader could return.
"""
def
limited_reader
():
for
i
,
item
in
enumerate
(
reader
()):
if
i
==
limit
:
break
yield
item
return
limited_reader
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录