Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
aa2bcf51
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
aa2bcf51
编写于
3月 04, 2017
作者:
Y
Yu Yang
提交者:
GitHub
3月 04, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1537 from helinwang/batch
move paddle.reader.batch to paddle.batch
上级
8bef3f4d
3432b4cd
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
42 addition
and
27 deletion
+42
-27
demo/image_classification/api_v2_train.py
demo/image_classification/api_v2_train.py
+2
-2
demo/mnist/api_train_v2.py
demo/mnist/api_train_v2.py
+2
-2
python/paddle/v2/__init__.py
python/paddle/v2/__init__.py
+2
-0
python/paddle/v2/minibatch.py
python/paddle/v2/minibatch.py
+35
-0
python/paddle/v2/reader/decorator.py
python/paddle/v2/reader/decorator.py
+1
-23
未找到文件。
demo/image_classification/api_v2_train.py
浏览文件 @
aa2bcf51
...
...
@@ -66,7 +66,7 @@ def main():
sys
.
stdout
.
flush
()
if
isinstance
(
event
,
paddle
.
event
.
EndPass
):
result
=
trainer
.
test
(
reader
=
paddle
.
reader
.
batched
(
reader
=
paddle
.
batch
(
paddle
.
dataset
.
cifar
.
test10
(),
batch_size
=
128
),
reader_dict
=
{
'image'
:
0
,
'label'
:
1
})
...
...
@@ -77,7 +77,7 @@ def main():
parameters
=
parameters
,
update_equation
=
momentum_optimizer
)
trainer
.
train
(
reader
=
paddle
.
reader
.
batched
(
reader
=
paddle
.
batch
(
paddle
.
reader
.
shuffle
(
paddle
.
dataset
.
cifar
.
train10
(),
buf_size
=
50000
),
batch_size
=
128
),
...
...
demo/mnist/api_train_v2.py
浏览文件 @
aa2bcf51
...
...
@@ -98,7 +98,7 @@ def main():
result
.
metrics
[
'classification_error_evaluator'
]))
trainer
.
train
(
reader
=
paddle
.
reader
.
batched
(
reader
=
paddle
.
batch
(
paddle
.
reader
.
shuffle
(
paddle
.
dataset
.
mnist
.
train
(),
buf_size
=
8192
),
batch_size
=
128
),
...
...
@@ -115,7 +115,7 @@ def main():
probs
=
paddle
.
infer
(
output
=
predict
,
parameters
=
parameters
,
reader
=
paddle
.
reader
.
batched
(
reader
=
paddle
.
batch
(
paddle
.
reader
.
firstn
(
paddle
.
reader
.
map_readers
(
lambda
item
:
(
item
[
0
],
),
paddle
.
dataset
.
mnist
.
test
()),
...
...
python/paddle/v2/__init__.py
浏览文件 @
aa2bcf51
...
...
@@ -28,6 +28,7 @@ import pooling
import
inference
import
networks
import
py_paddle.swig_paddle
as
api
import
minibatch
__all__
=
[
'optimizer'
,
'layer'
,
'activation'
,
'parameters'
,
'init'
,
'trainer'
,
...
...
@@ -45,3 +46,4 @@ def init(**kwargs):
infer
=
inference
.
infer
batch
=
minibatch
.
batch
python/paddle/v2/minibatch.py
0 → 100644
浏览文件 @
aa2bcf51
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def
batch
(
reader
,
batch_size
):
"""
Create a batch reader.
:param reader: the data reader to read from.
:param batch_size: batch_size
:return: the batch reader.
"""
def
batch_reader
():
r
=
reader
()
batch
=
[]
for
instance
in
r
:
batch
.
append
(
instance
)
if
len
(
batch
)
==
batch_size
:
yield
batch
batch
=
[]
if
batch
:
yield
batch
return
batch_reader
python/paddle/v2/reader/decorator.py
浏览文件 @
aa2bcf51
...
...
@@ -14,7 +14,7 @@
__all__
=
[
'map_readers'
,
'buffered'
,
'compose'
,
'chain'
,
'shuffle'
,
'ComposeNotAligned'
,
'
batched'
,
'
firstn'
'ComposeNotAligned'
,
'firstn'
]
import
itertools
...
...
@@ -193,28 +193,6 @@ def buffered(reader, size):
return
data_reader
def
batched
(
reader
,
batch_size
):
"""
Create a batched reader.
:param reader: the data reader to read from.
:param batch_size: batch_size
:return: the batched reader.
"""
def
batched_reader
():
r
=
reader
()
batch
=
[]
for
instance
in
r
:
batch
.
append
(
instance
)
if
len
(
batch
)
==
batch_size
:
yield
batch
batch
=
[]
if
batch
:
yield
batch
return
batched_reader
def
firstn
(
reader
,
n
):
"""
Limit the max number of samples that reader could return.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录