Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
590f8f06
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
590f8f06
编写于
6月 06, 2017
作者:
L
Liu Yiqun
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' into fix_setup_opencv-python
上级
23591837
ddb241f4
变更
13
显示空白变更内容
内联
并排
Showing
13 changed file
with
431 addition
and
19 deletion
+431
-19
cmake/cpplint.cmake
cmake/cpplint.cmake
+1
-1
cmake/generic.cmake
cmake/generic.cmake
+3
-3
doc/getstarted/index_cn.rst
doc/getstarted/index_cn.rst
+1
-1
doc/getstarted/index_en.rst
doc/getstarted/index_en.rst
+1
-1
go/cmake/golang.cmake
go/cmake/golang.cmake
+1
-1
paddle/scripts/docker/build.sh
paddle/scripts/docker/build.sh
+1
-1
python/paddle/trainer_config_helpers/layers.py
python/paddle/trainer_config_helpers/layers.py
+9
-1
python/paddle/v2/dataset/flowers.py
python/paddle/v2/dataset/flowers.py
+184
-0
python/paddle/v2/dataset/tests/flowers_test.py
python/paddle/v2/dataset/tests/flowers_test.py
+51
-0
python/paddle/v2/image.py
python/paddle/v2/image.py
+92
-9
python/paddle/v2/layer.py
python/paddle/v2/layer.py
+14
-0
python/paddle/v2/reader/decorator.py
python/paddle/v2/reader/decorator.py
+72
-1
python/paddle/v2/tests/test_layer.py
python/paddle/v2/tests/test_layer.py
+1
-0
未找到文件。
cmake/cpplint.cmake
浏览文件 @
590f8f06
...
...
@@ -59,7 +59,7 @@ macro(add_style_check_target TARGET_NAME)
"--filter=
${
STYLE_FILTER
}
"
"--write-success=
${
CUR_GEN
}
"
${
filename
}
DEPENDS
${
filename
}
WORKING_DIRECTORY
${
CMAKE_CURRENT_
LIST
_DIR
}
)
WORKING_DIRECTORY
${
CMAKE_CURRENT_
SOURCE
_DIR
}
)
endif
()
endforeach
()
endif
()
...
...
cmake/generic.cmake
浏览文件 @
590f8f06
...
...
@@ -182,7 +182,7 @@ function(go_library TARGET_NAME)
COMMAND env GOPATH=
${
GOPATH
}
${
CMAKE_Go_COMPILER
}
build
${
BUILD_MODE
}
-o
"
${
CMAKE_CURRENT_BINARY_DIR
}
/
${
LIB_NAME
}
"
${
go_library_SRCS
}
WORKING_DIRECTORY
${
CMAKE_CURRENT_
LIST
_DIR
}
)
WORKING_DIRECTORY
${
CMAKE_CURRENT_
SOURCE
_DIR
}
)
add_custom_target
(
${
TARGET_NAME
}
_lib ALL DEPENDS
${
TARGET_NAME
}
_timestamp
${
go_library_DEPS
}
)
add_library
(
${
TARGET_NAME
}
STATIC IMPORTED
)
set_property
(
TARGET
${
TARGET_NAME
}
PROPERTY
...
...
@@ -199,7 +199,7 @@ function(go_binary TARGET_NAME)
COMMAND env GOPATH=
${
GOPATH
}
${
CMAKE_Go_COMPILER
}
build
-o
"
${
CMAKE_CURRENT_BINARY_DIR
}
/
${
TARGET_NAME
}
"
${
go_library_SRCS
}
WORKING_DIRECTORY
${
CMAKE_CURRENT_
LIST
_DIR
}
)
WORKING_DIRECTORY
${
CMAKE_CURRENT_
SOURCE
_DIR
}
)
add_custom_target
(
${
TARGET_NAME
}
ALL DEPENDS
${
TARGET_NAME
}
_timestamp
${
go_binary_DEPS
}
)
install
(
PROGRAMS
${
CMAKE_CURRENT_BINARY_DIR
}
/
${
TARGET_NAME
}
DESTINATION bin
)
endfunction
(
go_binary
)
...
...
@@ -213,7 +213,7 @@ function(go_test TARGET_NAME)
COMMAND env GOPATH=
${
GOPATH
}
${
CMAKE_Go_COMPILER
}
test
-c -o
"
${
CMAKE_CURRENT_BINARY_DIR
}
/
${
TARGET_NAME
}
"
${
go_test_SRCS
}
WORKING_DIRECTORY
${
CMAKE_CURRENT_
LIST
_DIR
}
)
WORKING_DIRECTORY
${
CMAKE_CURRENT_
SOURCE
_DIR
}
)
add_custom_target
(
${
TARGET_NAME
}
ALL DEPENDS
${
TARGET_NAME
}
_timestamp
${
go_test_DEPS
}
)
add_test
(
${
TARGET_NAME
}
${
CMAKE_CURRENT_BINARY_DIR
}
/
${
TARGET_NAME
}
)
endfunction
(
go_test
)
...
...
doc/getstarted/index_cn.rst
浏览文件 @
590f8f06
...
...
@@ -7,4 +7,4 @@
build_and_install/index_cn.rst
concepts/use_concepts_cn.rst
- `深度学习入门课程 <http://book.paddlepaddle.org/>`_
- `深度学习入门课程 <http://book.paddlepaddle.org/
index.cn.html
>`_
doc/getstarted/index_en.rst
浏览文件 @
590f8f06
...
...
@@ -6,4 +6,4 @@ GET STARTED
build_and_install/index_en.rst
- `Deep Learning 101 <http://book.paddlepaddle.org/index.
en.
html>`_
- `Deep Learning 101 <http://book.paddlepaddle.org/index.html>`_
go/cmake/golang.cmake
浏览文件 @
590f8f06
...
...
@@ -39,7 +39,7 @@ function(GO_LIBRARY NAME BUILD_TYPE)
COMMAND env GOPATH=
${
GOPATH
}
${
CMAKE_Go_COMPILER
}
build
${
BUILD_MODE
}
-o
"
${
CMAKE_CURRENT_BINARY_DIR
}
/
${
LIB_NAME
}
"
${
CMAKE_GO_FLAGS
}
${
GO_SOURCE
}
WORKING_DIRECTORY
${
CMAKE_CURRENT_
LIST
_DIR
}
)
WORKING_DIRECTORY
${
CMAKE_CURRENT_
SOURCE
_DIR
}
)
add_custom_target
(
${
NAME
}
ALL DEPENDS
${
OUTPUT_DIR
}
/.timestamp
${
ARGN
}
)
add_dependencies
(
${
NAME
}
goGet
)
...
...
paddle/scripts/docker/build.sh
浏览文件 @
590f8f06
...
...
@@ -58,7 +58,7 @@ EOF
make
-j
`
nproc
`
if
[
${
WITH_TESTING
:-
OFF
}
==
"ON"
]
&&
[
${
RUN_TEST
:-
OFF
}
==
"ON"
]
;
then
pip uninstall
-y
py-paddle paddle
||
true
ctest
-
V
ctest
-
-output-on-failure
fi
...
...
python/paddle/trainer_config_helpers/layers.py
浏览文件 @
590f8f06
...
...
@@ -111,6 +111,7 @@ __all__ = [
'block_expand_layer'
,
'maxout_layer'
,
'out_prod_layer'
,
'printer_layer'
,
'print_layer'
,
'priorbox_layer'
,
'cross_channel_norm_layer'
,
...
...
@@ -969,7 +970,7 @@ def fc_layer(input,
@
wrap_name_default
(
"print"
)
def
print_layer
(
input
,
name
=
None
):
def
print
er
_layer
(
input
,
name
=
None
):
"""
Print the output value of input layers. This layer is useful for debugging.
...
...
@@ -991,6 +992,13 @@ def print_layer(input, name=None):
inputs
=
[
l
.
name
for
l
in
input
],
)
# this layer don't return anything, can not be input of other layer.
# Keep print_layer for compatibility with V1 API.
# 'print_layer' does not work for V2 API because it will be changed to
# 'print' for V2 API. But 'print' is a reserved key word in python.
print_layer
=
printer_layer
@
wrap_name_default
(
"priorbox"
)
def
priorbox_layer
(
input
,
...
...
python/paddle/v2/dataset/flowers.py
0 → 100644
浏览文件 @
590f8f06
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module will download dataset from
http://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html
and parse train/test set intopaddle reader creators.
This set contains images of flowers belonging to 102 different categories.
The images were acquired by searching the web and taking pictures. There are a
minimum of 40 images for each category.
The database was used in:
Nilsback, M-E. and Zisserman, A. Automated flower classification over a large
number of classes.Proceedings of the Indian Conference on Computer Vision,
Graphics and Image Processing (2008)
http://www.robots.ox.ac.uk/~vgg/publications/papers/nilsback08.{pdf,ps.gz}.
"""
import
cPickle
import
itertools
from
common
import
download
import
tarfile
import
scipy.io
as
scio
from
paddle.v2.image
import
*
import
os
import
numpy
as
np
import
paddle.v2
as
paddle
from
multiprocessing
import
cpu_count
__all__
=
[
'train'
,
'test'
,
'valid'
]
DATA_URL
=
'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz'
LABEL_URL
=
'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat'
SETID_URL
=
'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/setid.mat'
DATA_MD5
=
'52808999861908f626f3c1f4e79d11fa'
LABEL_MD5
=
'e0620be6f572b9609742df49c70aed4d'
SETID_MD5
=
'a5357ecc9cb78c4bef273ce3793fc85c'
def
default_mapper
(
sample
):
'''
map image bytes data to type needed by model input layer
'''
img
,
label
=
sample
img
=
paddle
.
image
.
load_image_bytes
(
img
)
img
=
paddle
.
image
.
simple_transform
(
img
,
256
,
224
,
True
)
return
img
.
flatten
().
astype
(
'float32'
),
label
def
reader_creator
(
data_file
,
label_file
,
setid_file
,
dataset_name
,
mapper
=
default_mapper
,
buffered_size
=
1024
):
'''
1. read images from tar file and
merge images into batch files in 102flowers.tgz_batch/
2. get a reader to read sample from batch file
:param data_file: downloaded data file
:type data_file: string
:param label_file: downloaded label file
:type label_file: string
:param setid_file: downloaded setid file containing information
about how to split dataset
:type setid_file: string
:param dataset_name: data set name (tstid|trnid|valid)
:type dataset_name: string
:param mapper: a function to map image bytes data to type
needed by model input layer
:type mapper: callable
:param buffered_size: the size of buffer used to process images
:type buffered_size: int
:return: data reader
:rtype: callable
'''
labels
=
scio
.
loadmat
(
label_file
)[
'labels'
][
0
]
indexes
=
scio
.
loadmat
(
setid_file
)[
dataset_name
][
0
]
img2label
=
{}
for
i
in
indexes
:
img
=
"jpg/image_%05d.jpg"
%
i
img2label
[
img
]
=
labels
[
i
-
1
]
file_list
=
batch_images_from_tar
(
data_file
,
dataset_name
,
img2label
)
def
reader
():
for
file
in
open
(
file_list
):
file
=
file
.
strip
()
batch
=
None
with
open
(
file
,
'r'
)
as
f
:
batch
=
cPickle
.
load
(
f
)
data
=
batch
[
'data'
]
labels
=
batch
[
'label'
]
for
sample
,
label
in
itertools
.
izip
(
data
,
batch
[
'label'
]):
yield
sample
,
int
(
label
)
return
paddle
.
reader
.
xmap_readers
(
mapper
,
reader
,
cpu_count
(),
buffered_size
)
def
train
(
mapper
=
default_mapper
,
buffered_size
=
1024
):
'''
Create flowers training set reader.
It returns a reader, each sample in the reader is
image pixels in [0, 1] and label in [1, 102]
translated from original color image by steps:
1. resize to 256*256
2. random crop to 224*224
3. flatten
:param mapper: a function to map sample.
:type mapper: callable
:param buffered_size: the size of buffer used to process images
:type buffered_size: int
:return: train data reader
:rtype: callable
'''
return
reader_creator
(
download
(
DATA_URL
,
'flowers'
,
DATA_MD5
),
download
(
LABEL_URL
,
'flowers'
,
LABEL_MD5
),
download
(
SETID_URL
,
'flowers'
,
SETID_MD5
),
'trnid'
,
mapper
,
buffered_size
)
def
test
(
mapper
=
default_mapper
,
buffered_size
=
1024
):
'''
Create flowers test set reader.
It returns a reader, each sample in the reader is
image pixels in [0, 1] and label in [1, 102]
translated from original color image by steps:
1. resize to 256*256
2. random crop to 224*224
3. flatten
:param mapper: a function to map sample.
:type mapper: callable
:param buffered_size: the size of buffer used to process images
:type buffered_size: int
:return: test data reader
:rtype: callable
'''
return
reader_creator
(
download
(
DATA_URL
,
'flowers'
,
DATA_MD5
),
download
(
LABEL_URL
,
'flowers'
,
LABEL_MD5
),
download
(
SETID_URL
,
'flowers'
,
SETID_MD5
),
'tstid'
,
mapper
,
buffered_size
)
def
valid
(
mapper
=
default_mapper
,
buffered_size
=
1024
):
'''
Create flowers validation set reader.
It returns a reader, each sample in the reader is
image pixels in [0, 1] and label in [1, 102]
translated from original color image by steps:
1. resize to 256*256
2. random crop to 224*224
3. flatten
:param mapper: a function to map sample.
:type mapper: callable
:param buffered_size: the size of buffer used to process images
:type buffered_size: int
:return: test data reader
:rtype: callable
'''
return
reader_creator
(
download
(
DATA_URL
,
'flowers'
,
DATA_MD5
),
download
(
LABEL_URL
,
'flowers'
,
LABEL_MD5
),
download
(
SETID_URL
,
'flowers'
,
SETID_MD5
),
'valid'
,
mapper
,
buffered_size
)
def
fetch
():
download
(
DATA_URL
,
'flowers'
,
DATA_MD5
)
download
(
LABEL_URL
,
'flowers'
,
LABEL_MD5
)
download
(
SETID_URL
,
'flowers'
,
SETID_MD5
)
python/paddle/v2/dataset/tests/flowers_test.py
0 → 100644
浏览文件 @
590f8f06
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle.v2.dataset.flowers
import
unittest
class
TestFlowers
(
unittest
.
TestCase
):
def
check_reader
(
self
,
reader
):
sum
=
0
label
=
0
size
=
224
*
224
*
3
for
l
in
reader
():
self
.
assertEqual
(
l
[
0
].
size
,
size
)
if
l
[
1
]
>
label
:
label
=
l
[
1
]
sum
+=
1
return
sum
,
label
def
test_train
(
self
):
instances
,
max_label_value
=
self
.
check_reader
(
paddle
.
v2
.
dataset
.
flowers
.
train
())
self
.
assertEqual
(
instances
,
1020
)
self
.
assertEqual
(
max_label_value
,
102
)
def
test_test
(
self
):
instances
,
max_label_value
=
self
.
check_reader
(
paddle
.
v2
.
dataset
.
flowers
.
test
())
self
.
assertEqual
(
instances
,
6149
)
self
.
assertEqual
(
max_label_value
,
102
)
def
test_valid
(
self
):
instances
,
max_label_value
=
self
.
check_reader
(
paddle
.
v2
.
dataset
.
flowers
.
valid
())
self
.
assertEqual
(
instances
,
1020
)
self
.
assertEqual
(
max_label_value
,
102
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/v2/image.py
浏览文件 @
590f8f06
import
numpy
as
np
try
:
import
cv2
except
:
print
(
"import cv2 error, please install python wrapper of opencv using:
\n
"
" pip install opencv-python
\n
"
" or
\n
"
" apt-get install python-opencv
\n
"
)
except
ImportError
:
cv2
=
None
import
os
import
tarfile
import
cPickle
__all__
=
[
"load_image"
,
"resize_short"
,
"to_chw"
,
"center_crop"
,
"random_crop"
,
"left_right_flip"
,
"simple_transform"
,
"load_and_transform"
"load_image_bytes"
,
"load_image"
,
"resize_short"
,
"to_chw"
,
"center_crop"
,
"random_crop"
,
"left_right_flip"
,
"simple_transform"
,
"load_and_transform"
,
"batch_images_from_tar"
]
"""
This file contains some common interfaces for image preprocess.
...
...
@@ -31,6 +30,90 @@ the image layout as follows.
"""
def
batch_images_from_tar
(
data_file
,
dataset_name
,
img2label
,
num_per_batch
=
1024
):
"""
Read images from tar file and batch them into batch file.
param data_file: path of image tar file
type data_file: string
param dataset_name: 'train','test' or 'valid'
type dataset_name: string
param img2label: a dic with image file name as key
and image's label as value
type img2label: dic
param num_per_batch: image number per batch file
type num_per_batch: int
return: path of list file containing paths of batch file
rtype: string
"""
batch_dir
=
data_file
+
"_batch"
out_path
=
"%s/%s"
%
(
batch_dir
,
dataset_name
)
meta_file
=
"%s/%s.txt"
%
(
batch_dir
,
dataset_name
)
if
os
.
path
.
exists
(
out_path
):
return
meta_file
else
:
os
.
makedirs
(
out_path
)
tf
=
tarfile
.
open
(
data_file
)
mems
=
tf
.
getmembers
()
data
=
[]
labels
=
[]
file_id
=
0
for
mem
in
mems
:
if
mem
.
name
in
img2label
:
data
.
append
(
tf
.
extractfile
(
mem
).
read
())
labels
.
append
(
img2label
[
mem
.
name
])
if
len
(
data
)
==
num_per_batch
:
output
=
{}
output
[
'label'
]
=
labels
output
[
'data'
]
=
data
cPickle
.
dump
(
output
,
open
(
'%s/batch_%d'
%
(
out_path
,
file_id
),
'w'
),
protocol
=
cPickle
.
HIGHEST_PROTOCOL
)
file_id
+=
1
data
=
[]
labels
=
[]
if
len
(
data
)
>
0
:
output
=
{}
output
[
'label'
]
=
labels
output
[
'data'
]
=
data
cPickle
.
dump
(
output
,
open
(
'%s/batch_%d'
%
(
out_path
,
file_id
),
'w'
),
protocol
=
cPickle
.
HIGHEST_PROTOCOL
)
with
open
(
meta_file
,
'a'
)
as
meta
:
for
file
in
os
.
listdir
(
out_path
):
meta
.
write
(
os
.
path
.
abspath
(
"%s/%s"
%
(
out_path
,
file
))
+
"
\n
"
)
return
meta_file
def
load_image_bytes
(
bytes
,
is_color
=
True
):
"""
Load an color or gray image from bytes array.
Example usage:
.. code-block:: python
with open('cat.jpg') as f:
im = load_image_bytes(f.read())
:param bytes: the input image bytes array.
:type file: str
:param is_color: If set is_color True, it will load and
return a color image. Otherwise, it will
load and return a gray image.
"""
flag
=
1
if
is_color
else
0
file_bytes
=
np
.
asarray
(
bytearray
(
bytes
),
dtype
=
np
.
uint8
)
img
=
cv2
.
imdecode
(
file_bytes
,
flag
)
return
img
def
load_image
(
file
,
is_color
=
True
):
"""
Load an color or gray image from the file path.
...
...
python/paddle/v2/layer.py
浏览文件 @
590f8f06
...
...
@@ -149,6 +149,20 @@ def __get_used_layers__(output_layers, extra_layers=None):
for
layer
in
output_layers
:
dfs_travel
(
layer
.
full_name
)
# print layer needs to be specially handled because no other
# layer depends on it. It is used to print the result of some
# layers when running the model for debug purpose. So we explicitly
# add a print layer to the topolty if its input is in the toplogy.
for
layer
in
cp
.
g_config
.
model_config
.
layers
:
if
layer
.
type
==
'print'
:
used
=
True
for
inp
in
layer
.
inputs
:
if
inp
.
input_layer_name
not
in
layer_names
:
used
=
False
break
if
used
:
layer_names
.
add
(
layer
.
name
)
return
layer_names
...
...
python/paddle/v2/reader/decorator.py
浏览文件 @
590f8f06
...
...
@@ -14,7 +14,7 @@
__all__
=
[
'map_readers'
,
'buffered'
,
'compose'
,
'chain'
,
'shuffle'
,
'ComposeNotAligned'
,
'firstn'
'ComposeNotAligned'
,
'firstn'
,
'xmap_readers'
]
import
itertools
...
...
@@ -224,3 +224,74 @@ def firstn(reader, n):
yield
item
return
firstn_reader
class
XmapEndSignal
():
pass
def
xmap_readers
(
mapper
,
reader
,
process_num
,
buffer_size
):
"""
Use multiprocess to map samples from reader by a mapper defined by user.
And this function contains a buffered decorator.
:param mapper: a function to map sample.
:type mapper: callable
:param reader: the data reader to read from
:type reader: callable
:param process_num: process number to handle original sample
:type process_num: int
:param buffer_size: max buffer size
:type buffer_size: int
:return: the decarated reader
:rtype: callable
"""
end
=
XmapEndSignal
()
in_queue
=
Queue
(
buffer_size
)
out_queue
=
Queue
(
buffer_size
)
# define a worker to read samples from reader to in_queue
def
read_worker
(
reader
,
in_queue
):
for
i
in
reader
():
in_queue
.
put
(
i
)
in_queue
.
put
(
end
)
# start a read worker in a thread
t
=
Thread
(
target
=
read_worker
,
args
=
(
reader
,
in_queue
))
t
.
daemon
=
True
t
.
start
()
# define a worker to handle samples from in_queue by mapper
# and put mapped samples into out_queue
def
handle_worker
(
in_queue
,
out_queue
,
mapper
):
sample
=
in_queue
.
get
()
while
not
isinstance
(
sample
,
XmapEndSignal
):
r
=
mapper
(
sample
)
out_queue
.
put
(
r
)
sample
=
in_queue
.
get
()
in_queue
.
put
(
end
)
out_queue
.
put
(
end
)
# start several handle_workers
workers
=
[]
for
i
in
xrange
(
process_num
):
worker
=
Thread
(
target
=
handle_worker
,
args
=
(
in_queue
,
out_queue
,
mapper
))
worker
.
daemon
=
True
workers
.
append
(
worker
)
for
w
in
workers
:
w
.
start
()
def
xreader
():
sample
=
out_queue
.
get
()
while
not
isinstance
(
sample
,
XmapEndSignal
):
yield
sample
sample
=
out_queue
.
get
()
finish
=
1
while
finish
<
process_num
:
sample
=
out_queue
.
get
()
if
isinstance
(
sample
,
XmapEndSignal
):
finish
+=
1
else
:
yield
sample
return
xreader
python/paddle/v2/tests/test_layer.py
浏览文件 @
590f8f06
...
...
@@ -164,6 +164,7 @@ class OtherLayerTest(unittest.TestCase):
maxid
=
layer
.
max_id
(
input
=
inference
)
sampling_id
=
layer
.
sampling_id
(
input
=
inference
)
eos
=
layer
.
eos
(
input
=
maxid
,
eos_id
=
5
)
layer
.
printer
(
maxid
)
print
layer
.
parse_network
([
maxid
,
sampling_id
,
eos
])
def
test_slicing_joining_layer
(
self
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录