Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
3334c279
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3334c279
编写于
2月 27, 2019
作者:
S
sneaxiy
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add sample_generator
test=develop
上级
7b5a9d75
变更
11
显示空白变更内容
内联
并排
Showing
11 changed file
with
264 addition
and
27 deletion
+264
-27
paddle/fluid/API.spec
paddle/fluid/API.spec
+2
-1
paddle/fluid/operators/reader/blocking_queue.h
paddle/fluid/operators/reader/blocking_queue.h
+1
-0
paddle/fluid/operators/reader/buffered_reader.cc
paddle/fluid/operators/reader/buffered_reader.cc
+1
-0
paddle/fluid/operators/reader/buffered_reader.h
paddle/fluid/operators/reader/buffered_reader.h
+1
-0
paddle/fluid/operators/reader/py_reader.cc
paddle/fluid/operators/reader/py_reader.cc
+1
-0
paddle/fluid/operators/reader/py_reader.h
paddle/fluid/operators/reader/py_reader.h
+1
-0
paddle/fluid/pybind/reader_py.cc
paddle/fluid/pybind/reader_py.cc
+3
-0
python/paddle/fluid/data_feeder.py
python/paddle/fluid/data_feeder.py
+61
-21
python/paddle/fluid/reader.py
python/paddle/fluid/reader.py
+56
-4
python/paddle/fluid/tests/unittests/test_decoupled_py_reader.py
.../paddle/fluid/tests/unittests/test_decoupled_py_reader.py
+0
-1
python/paddle/fluid/tests/unittests/test_py_reader_sample_generator.py
.../fluid/tests/unittests/test_py_reader_sample_generator.py
+137
-0
未找到文件。
paddle/fluid/API.spec
浏览文件 @
3334c279
...
@@ -61,8 +61,9 @@ paddle.fluid.io.load_params ArgSpec(args=['executor', 'dirname', 'main_program',
...
@@ -61,8 +61,9 @@ paddle.fluid.io.load_params ArgSpec(args=['executor', 'dirname', 'main_program',
paddle.fluid.io.load_persistables ArgSpec(args=['executor', 'dirname', 'main_program', 'filename'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.io.load_persistables ArgSpec(args=['executor', 'dirname', 'main_program', 'filename'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.io.save_inference_model ArgSpec(args=['dirname', 'feeded_var_names', 'target_vars', 'executor', 'main_program', 'model_filename', 'params_filename', 'export_for_deployment'], varargs=None, keywords=None, defaults=(None, None, None, True))
paddle.fluid.io.save_inference_model ArgSpec(args=['dirname', 'feeded_var_names', 'target_vars', 'executor', 'main_program', 'model_filename', 'params_filename', 'export_for_deployment'], varargs=None, keywords=None, defaults=(None, None, None, True))
paddle.fluid.io.load_inference_model ArgSpec(args=['dirname', 'executor', 'model_filename', 'params_filename', 'pserver_endpoints'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.io.load_inference_model ArgSpec(args=['dirname', 'executor', 'model_filename', 'params_filename', 'pserver_endpoints'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.io.PyReader.__init__ ArgSpec(args=['self', 'feed_list', 'capacity', 'use_double_buffer', 'iterable'], varargs=None, keywords=None, defaults=(True,
Tru
e))
paddle.fluid.io.PyReader.__init__ ArgSpec(args=['self', 'feed_list', 'capacity', 'use_double_buffer', 'iterable'], varargs=None, keywords=None, defaults=(True,
Fals
e))
paddle.fluid.io.PyReader.decorate_paddle_reader ArgSpec(args=['self', 'reader', 'places'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.io.PyReader.decorate_paddle_reader ArgSpec(args=['self', 'reader', 'places'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.io.PyReader.decorate_sample_generator ArgSpec(args=['self', 'sample_generator', 'batch_size', 'drop_last', 'places'], varargs=None, keywords=None, defaults=(True, None))
paddle.fluid.io.PyReader.decorate_tensor_provider ArgSpec(args=['self', 'reader', 'places'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.io.PyReader.decorate_tensor_provider ArgSpec(args=['self', 'reader', 'places'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.io.PyReader.reset ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.io.PyReader.reset ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.io.PyReader.start ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.io.PyReader.start ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
...
...
paddle/fluid/operators/reader/blocking_queue.h
浏览文件 @
3334c279
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
#include <condition_variable> // NOLINT
#include <condition_variable> // NOLINT
#include <deque>
#include <deque>
#include <utility>
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
...
...
paddle/fluid/operators/reader/buffered_reader.cc
浏览文件 @
3334c279
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/operators/reader/buffered_reader.h"
#include "paddle/fluid/operators/reader/buffered_reader.h"
#include <memory>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/data_type.h"
...
...
paddle/fluid/operators/reader/buffered_reader.h
浏览文件 @
3334c279
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#pragma once
#pragma once
#include <list>
#include <list>
#include <memory>
#include <queue>
#include <queue>
#include <vector>
#include <vector>
#include "ThreadPool.h"
#include "ThreadPool.h"
...
...
paddle/fluid/operators/reader/py_reader.cc
浏览文件 @
3334c279
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/operators/reader/py_reader.h"
#include "paddle/fluid/operators/reader/py_reader.h"
#include <memory>
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
...
paddle/fluid/operators/reader/py_reader.h
浏览文件 @
3334c279
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#pragma once
#pragma once
#include <atomic>
#include <atomic>
#include <memory>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h"
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h"
...
...
paddle/fluid/pybind/reader_py.cc
浏览文件 @
3334c279
...
@@ -13,7 +13,10 @@
...
@@ -13,7 +13,10 @@
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/pybind/reader_py.h"
#include "paddle/fluid/pybind/reader_py.h"
#include <memory>
#include <string>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/operators/reader/buffered_reader.h"
#include "paddle/fluid/operators/reader/buffered_reader.h"
...
...
python/paddle/fluid/data_feeder.py
浏览文件 @
3334c279
...
@@ -26,6 +26,24 @@ from .framework import Variable, default_main_program
...
@@ -26,6 +26,24 @@ from .framework import Variable, default_main_program
__all__
=
[
'DataFeeder'
]
__all__
=
[
'DataFeeder'
]
def
convert_dtype
(
dtype
):
if
dtype
==
core
.
VarDesc
.
VarType
.
FP32
:
return
'float32'
elif
dtype
==
core
.
VarDesc
.
VarType
.
INT64
:
return
'int64'
elif
dtype
==
core
.
VarDesc
.
VarType
.
FP64
:
return
'float64'
elif
dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
return
'float16'
elif
dtype
==
core
.
VarDesc
.
VarType
.
INT32
:
return
'int32'
elif
dtype
==
core
.
VarDesc
.
VarType
.
UINT8
:
return
'uint8'
else
:
raise
ValueError
(
"dtype must be any of [int32, float32, int64, "
"float64, uint8]"
)
class
DataToLoDTensorConverter
(
object
):
class
DataToLoDTensorConverter
(
object
):
def
__init__
(
self
,
place
,
lod_level
,
shape
,
dtype
):
def
__init__
(
self
,
place
,
lod_level
,
shape
,
dtype
):
self
.
place
=
place
self
.
place
=
place
...
@@ -38,27 +56,12 @@ class DataToLoDTensorConverter(object):
...
@@ -38,27 +56,12 @@ class DataToLoDTensorConverter(object):
if
negtive_count
>
1
:
if
negtive_count
>
1
:
self
.
shape
=
None
self
.
shape
=
None
break
break
if
dtype
==
core
.
VarDesc
.
VarType
.
FP32
:
self
.
dtype
=
convert_dtype
(
dtype
)
self
.
dtype
=
'float32'
self
.
_reset
()
elif
dtype
==
core
.
VarDesc
.
VarType
.
INT64
:
self
.
dtype
=
'int64'
elif
dtype
==
core
.
VarDesc
.
VarType
.
FP64
:
self
.
dtype
=
'float64'
elif
dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
self
.
dtype
=
'float16'
elif
dtype
==
core
.
VarDesc
.
VarType
.
INT32
:
self
.
dtype
=
'int32'
elif
dtype
==
core
.
VarDesc
.
VarType
.
UINT8
:
self
.
dtype
=
'uint8'
else
:
raise
ValueError
(
"dtype must be any of [int32, float32, int64, "
"float64, uint8]"
)
def
_reset
(
self
):
self
.
data
=
[]
self
.
data
=
[]
self
.
lod
=
[]
self
.
lod
=
[[]
for
_
in
six
.
moves
.
range
(
self
.
lod_level
)]
for
i
in
six
.
moves
.
range
(
lod_level
):
self
.
lod
.
append
([])
def
feed
(
self
,
data
):
def
feed
(
self
,
data
):
self
.
_feed_impl_
(
data
,
self
.
lod
,
self
.
lod_level
)
self
.
_feed_impl_
(
data
,
self
.
lod
,
self
.
lod_level
)
...
@@ -88,15 +91,52 @@ class DataToLoDTensorConverter(object):
...
@@ -88,15 +91,52 @@ class DataToLoDTensorConverter(object):
raise
ValueError
(
raise
ValueError
(
"Reshape error. What is defined in data layer is {}, but receive {}"
"Reshape error. What is defined in data layer is {}, but receive {}"
.
format
(
self
.
shape
,
arr
.
shape
))
.
format
(
self
.
shape
,
arr
.
shape
))
#else:
# self._check_shape(arr.shape)
t
=
core
.
LoDTensor
()
t
=
core
.
LoDTensor
()
t
.
set
(
arr
,
self
.
place
)
t
.
set
(
arr
,
self
.
place
)
if
self
.
lod_level
>
0
:
if
self
.
lod_level
>
0
:
t
.
set_recursive_sequence_lengths
(
self
.
lod
)
t
.
set_recursive_sequence_lengths
(
self
.
lod
)
self
.
_reset
()
return
t
return
t
class
BatchedTensorProvider
(
object
):
def
__init__
(
self
,
feed_list
,
place
,
batch_size
,
generator
,
drop_last
):
self
.
place
=
place
self
.
batch_size
=
batch_size
self
.
generator
=
generator
self
.
converters
=
[]
self
.
drop_last
=
drop_last
for
var
in
feed_list
:
assert
var
.
lod_level
==
0
,
"lod_level must be 0"
self
.
converters
.
append
(
DataToLoDTensorConverter
(
place
=
self
.
place
,
lod_level
=
0
,
shape
=
var
.
shape
,
dtype
=
var
.
dtype
))
def
_done
(
self
):
return
[
c
.
done
()
for
c
in
self
.
converters
]
def
__call__
(
self
):
idx
=
0
for
each_sample
in
self
.
generator
():
for
each_slot
,
each_converter
in
six
.
moves
.
zip
(
each_sample
,
self
.
converters
):
each_converter
.
data
.
append
(
each_slot
)
idx
+=
1
if
idx
==
self
.
batch_size
:
idx
=
0
yield
self
.
_done
()
if
not
self
.
drop_last
and
idx
>
0
:
yield
self
.
_done
()
else
:
[
c
.
_reset
()
for
c
in
self
.
converters
]
class
DataFeeder
(
object
):
class
DataFeeder
(
object
):
"""
"""
DataFeeder converts the data that returned by a reader into a data
DataFeeder converts the data that returned by a reader into a data
...
...
python/paddle/fluid/reader.py
浏览文件 @
3334c279
...
@@ -17,7 +17,7 @@ import six
...
@@ -17,7 +17,7 @@ import six
import
threading
import
threading
from
.framework
import
Program
,
Variable
,
program_guard
,
default_main_program
,
default_startup_program
from
.framework
import
Program
,
Variable
,
program_guard
,
default_main_program
,
default_startup_program
from
.executor
import
global_scope
from
.executor
import
global_scope
from
.data_feeder
import
DataFeeder
from
.data_feeder
import
DataFeeder
,
BatchedTensorProvider
from
.layers.io
import
monkey_patch_reader_methods
,
_copy_reader_var_
,
double_buffer
from
.layers.io
import
monkey_patch_reader_methods
,
_copy_reader_var_
,
double_buffer
from
.unique_name
import
UniqueNameGenerator
from
.unique_name
import
UniqueNameGenerator
...
@@ -46,7 +46,7 @@ class PyReader(object):
...
@@ -46,7 +46,7 @@ class PyReader(object):
feed_list
,
feed_list
,
capacity
,
capacity
,
use_double_buffer
=
True
,
use_double_buffer
=
True
,
iterable
=
Tru
e
):
iterable
=
Fals
e
):
"""
"""
Create a reader object for data feeding in Python.
Create a reader object for data feeding in Python.
Data would be prefetched using Python thread and be pushed
Data would be prefetched using Python thread and be pushed
...
@@ -269,6 +269,54 @@ class PyReader(object):
...
@@ -269,6 +269,54 @@ class PyReader(object):
self
.
_thread
.
daemon
=
True
self
.
_thread
.
daemon
=
True
self
.
_thread
.
start
()
self
.
_thread
.
start
()
def
decorate_sample_generator
(
self
,
sample_generator
,
batch_size
,
drop_last
=
True
,
places
=
None
):
'''
Set the data source of the PyReader object.
The provided :code:`sample_generator` should be a Python generator,
which yields numpy.ndarray typed data of each sample.
:code:`places` must be set when the PyReader object is iterable.
If all inputs have no lods, this method is faster than
:code:`decorate_paddle_reader(paddle.batch(sample_generator, ...))` .
Args:
sample_generator (generator): Python generator that yields
numpy.ndarray-typed sample data.
batch_size (int): batch size. Must be larger than 0.
drop_last (bool): Whether to drop the last batch when sample number
is less than batch_size.
places (None|list(CUDAPlace)|list(CPUPlace)): place list. Must
be provided when PyReader is iterable.
'''
assert
batch_size
>
0
,
"batch_size must be larger than 0"
has_lod
=
False
for
f
in
self
.
_feed_list
:
if
f
.
lod_level
!=
0
:
has_lod
=
True
break
if
has_lod
:
self
.
decorate_paddle_reader
(
paddle
.
batch
(
sample_generator
,
batch_size
=
batch_size
,
drop_last
=
drop_last
),
places
=
places
)
else
:
reader
=
BatchedTensorProvider
(
feed_list
=
self
.
_feed_list
,
place
=
core
.
CPUPlace
(),
batch_size
=
batch_size
,
generator
=
sample_generator
,
drop_last
=
drop_last
)
self
.
decorate_tensor_provider
(
reader
,
places
=
places
)
def
decorate_paddle_reader
(
self
,
reader
,
places
=
None
):
def
decorate_paddle_reader
(
self
,
reader
,
places
=
None
):
'''
'''
Set the data source of the PyReader object.
Set the data source of the PyReader object.
...
@@ -279,8 +327,10 @@ class PyReader(object):
...
@@ -279,8 +327,10 @@ class PyReader(object):
:code:`places` must be set when the PyReader object is iterable.
:code:`places` must be set when the PyReader object is iterable.
Args:
Args:
reader (generator): Python generator that yields numpy-typed
reader (generator): Python generator that yields
batched data.
list(numpy.ndarray)-typed batched data.
places (None|list(CUDAPlace)|list(CPUPlace)): place list. Must
be provided when PyReader is iterable.
'''
'''
assert
self
.
_tensor_reader
is
None
,
\
assert
self
.
_tensor_reader
is
None
,
\
"Cannot reset the data source of PyReader"
"Cannot reset the data source of PyReader"
...
@@ -307,6 +357,8 @@ class PyReader(object):
...
@@ -307,6 +357,8 @@ class PyReader(object):
Args:
Args:
reader (generator): Python generator that yields LoDTensor-typed
reader (generator): Python generator that yields LoDTensor-typed
batched data.
batched data.
places (None|list(CUDAPlace)|list(CPUPlace)): place list. Must
be provided when PyReader is iterable.
'''
'''
assert
self
.
_tensor_reader
is
None
,
\
assert
self
.
_tensor_reader
is
None
,
\
"Cannot reset the data source of PyReader"
"Cannot reset the data source of PyReader"
...
...
python/paddle/fluid/tests/unittests/test_decoupled_py_reader.py
浏览文件 @
3334c279
...
@@ -127,7 +127,6 @@ class TestBase(unittest.TestCase):
...
@@ -127,7 +127,6 @@ class TestBase(unittest.TestCase):
step_list
.
append
(
step
)
step_list
.
append
(
step
)
end_t
=
time
.
time
()
end_t
=
time
.
time
()
ret
=
{
"time"
:
end_t
-
start_t
,
"step"
:
step_list
}
ret
=
{
"time"
:
end_t
-
start_t
,
"step"
:
step_list
}
scope
.
_remove_from_pool
()
return
ret
return
ret
def
prepare_places
(
self
,
with_data_parallel
,
with_cpu
=
True
,
with_gpu
=
True
):
def
prepare_places
(
self
,
with_data_parallel
,
with_cpu
=
True
,
with_gpu
=
True
):
...
...
python/paddle/fluid/tests/unittests/test_py_reader_sample_generator.py
0 → 100644
浏览文件 @
3334c279
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
paddle.fluid
as
fluid
import
math
import
unittest
import
numpy
as
np
import
os
os
.
environ
[
'CPU_NUM'
]
=
'1'
def
random_reader
(
sample_num
):
def
__impl__
():
for
_
in
range
(
sample_num
):
yield
np
.
random
.
random
(
size
=
[
784
]).
astype
(
'float32'
),
np
.
random
.
random_integers
(
low
=
0
,
high
=
9
,
size
=
[
1
]).
astype
(
'int64'
)
return
paddle
.
reader
.
cache
(
__impl__
)
class
TestCaseBase
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
batch_size
=
32
self
.
epoch_num
=
2
self
.
sample_num
=
165
def
generate_all_data
(
self
,
reader
):
ret
=
[]
for
d
in
reader
():
slots
=
[[],
[]]
for
item
in
d
:
slots
[
0
].
append
(
item
[
0
])
slots
[
1
].
append
(
item
[
1
])
slots
=
[
np
.
array
(
slot
)
for
slot
in
slots
]
ret
.
append
(
slots
)
return
ret
def
run_main
(
self
,
reader
,
use_sample_generator
,
iterable
,
drop_last
):
image
=
fluid
.
layers
.
data
(
name
=
'image'
,
dtype
=
'float32'
,
shape
=
[
784
])
label
=
fluid
.
layers
.
data
(
name
=
'label'
,
dtype
=
'int64'
,
shape
=
[
1
])
py_reader
=
fluid
.
io
.
PyReader
(
feed_list
=
[
image
,
label
],
capacity
=
16
,
iterable
=
iterable
,
use_double_buffer
=
False
)
batch_reader
=
paddle
.
batch
(
reader
,
self
.
batch_size
,
drop_last
)
all_datas
=
self
.
generate_all_data
(
batch_reader
)
if
not
use_sample_generator
:
py_reader
.
decorate_paddle_reader
(
batch_reader
,
places
=
fluid
.
cpu_places
())
else
:
py_reader
.
decorate_sample_generator
(
reader
,
self
.
batch_size
,
drop_last
,
places
=
fluid
.
cpu_places
())
if
drop_last
:
batch_num
=
int
(
self
.
sample_num
/
self
.
batch_size
)
else
:
batch_num
=
math
.
ceil
(
float
(
self
.
sample_num
)
/
self
.
batch_size
)
exe
=
fluid
.
Executor
(
fluid
.
CPUPlace
())
exe
.
run
(
fluid
.
default_startup_program
())
for
_
in
range
(
self
.
epoch_num
):
if
py_reader
.
iterable
:
step
=
0
for
data
in
py_reader
():
img
,
lbl
=
exe
.
run
(
feed
=
data
,
fetch_list
=
[
image
,
label
])
self
.
assertArrayEqual
(
img
,
all_datas
[
step
][
0
])
self
.
assertArrayEqual
(
lbl
,
all_datas
[
step
][
1
])
step
+=
1
self
.
assertEqual
(
step
,
len
(
all_datas
))
else
:
step
=
0
try
:
py_reader
.
start
()
while
True
:
img
,
lbl
=
exe
.
run
(
fetch_list
=
[
image
,
label
])
self
.
assertArrayEqual
(
img
,
all_datas
[
step
][
0
])
self
.
assertArrayEqual
(
lbl
,
all_datas
[
step
][
1
])
step
+=
1
except
fluid
.
core
.
EOFException
:
py_reader
.
reset
()
self
.
assertEqual
(
step
,
len
(
all_datas
))
break
def
assertArrayEqual
(
self
,
arr1
,
arr2
):
self
.
assertEqual
(
arr1
.
shape
,
arr2
.
shape
)
self
.
assertTrue
((
arr1
==
arr2
).
all
())
def
test_main
(
self
):
reader
=
random_reader
(
self
.
sample_num
)
for
use_sample_generator
in
[
False
,
True
]:
for
iterable
in
[
False
,
True
]:
for
drop_last
in
[
False
,
True
]:
with
fluid
.
program_guard
(
fluid
.
Program
(),
fluid
.
Program
()):
self
.
run_main
(
reader
,
use_sample_generator
,
iterable
,
drop_last
)
class
TestCase1
(
TestCaseBase
):
def
setUp
(
self
):
self
.
batch_size
=
32
self
.
epoch_num
=
10
self
.
sample_num
=
160
class
TestCase2
(
TestCaseBase
):
def
setUp
(
self
):
self
.
batch_size
=
32
self
.
epoch_num
=
2
self
.
sample_num
=
200
class
TestCase3
(
TestCaseBase
):
def
setUp
(
self
):
self
.
batch_size
=
32
self
.
epoch_num
=
2
self
.
sample_num
=
159
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录