提交 15668482 编写于 作者: X xzl

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into improve_pruning

...@@ -59,7 +59,7 @@ macro(add_style_check_target TARGET_NAME) ...@@ -59,7 +59,7 @@ macro(add_style_check_target TARGET_NAME)
"--filter=${STYLE_FILTER}" "--filter=${STYLE_FILTER}"
"--write-success=${CUR_GEN}" ${filename} "--write-success=${CUR_GEN}" ${filename}
DEPENDS ${filename} DEPENDS ${filename}
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}) WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
endif() endif()
endforeach() endforeach()
endif() endif()
......
...@@ -11,23 +11,16 @@ find_path(CUDNN_INCLUDE_DIR cudnn.h ...@@ -11,23 +11,16 @@ find_path(CUDNN_INCLUDE_DIR cudnn.h
get_filename_component(__libpath_hist ${CUDA_CUDART_LIBRARY} PATH) get_filename_component(__libpath_hist ${CUDA_CUDART_LIBRARY} PATH)
if(NOT ${CMAKE_HOST_SYSTEM_PROCESSOR}) set(TARGET_ARCH "x86_64")
execute_process( if(NOT ${CMAKE_SYSTEM_PROCESSOR})
COMMAND uname -m COMMAND tr -d '\n' set(TARGET_ARCH ${CMAKE_SYSTEM_PROCESSOR})
OUTPUT_VARIABLE HOST_ARCH endif()
RESULT_VARIABLE UNAME_RESULT)
if(${UNAME_RESULT})
set(HOST_ARCH "x86_64")
endif(${UNAME_RESULT})
else(NOT ${CMAKE_HOST_SYSTEM_PROCESSOR})
set(HOST_ARCH ${CMAKE_HOST_SYSTEM_PROCESSOR})
endif(NOT ${CMAKE_HOST_SYSTEM_PROCESSOR})
list(APPEND CUDNN_CHECK_LIBRARY_DIRS list(APPEND CUDNN_CHECK_LIBRARY_DIRS
${CUDNN_ROOT} ${CUDNN_ROOT}
${CUDNN_ROOT}/lib64 ${CUDNN_ROOT}/lib64
${CUDNN_ROOT}/lib ${CUDNN_ROOT}/lib
${CUDNN_ROOT}/lib/${HOST_ARCH}-linux-gnu ${CUDNN_ROOT}/lib/${TARGET_ARCH}-linux-gnu
$ENV{CUDNN_ROOT} $ENV{CUDNN_ROOT}
$ENV{CUDNN_ROOT}/lib64 $ENV{CUDNN_ROOT}/lib64
$ENV{CUDNN_ROOT}/lib $ENV{CUDNN_ROOT}/lib
......
...@@ -24,20 +24,25 @@ IF(NOT ${CBLAS_FOUND}) ...@@ -24,20 +24,25 @@ IF(NOT ${CBLAS_FOUND})
SET(CBLAS_LIBRARIES "${CBLAS_INSTALL_DIR}/lib/${LIBRARY_PREFIX}openblas${STATIC_LIBRARY_SUFFIX}" SET(CBLAS_LIBRARIES "${CBLAS_INSTALL_DIR}/lib/${LIBRARY_PREFIX}openblas${STATIC_LIBRARY_SUFFIX}"
CACHE FILEPATH "openblas library." FORCE) CACHE FILEPATH "openblas library." FORCE)
SET(COMMON_ARGS CC=${CMAKE_C_COMPILER} NO_SHARED=1 NO_LAPACK=1) SET(COMMON_ARGS CC=${CMAKE_C_COMPILER} NO_SHARED=1 NO_LAPACK=1 libs)
IF(CMAKE_CROSSCOMPILING)
IF(ANDROID) IF(ANDROID)
# arm_soft_fp_abi branch of OpenBLAS to support softfp # arm_soft_fp_abi branch of OpenBLAS to support softfp
# https://github.com/xianyi/OpenBLAS/tree/arm_soft_fp_abi # https://github.com/xianyi/OpenBLAS/tree/arm_soft_fp_abi
SET(OPENBLAS_COMMIT "b5c96fcfcdc82945502a2303116a64d89985daf5") SET(OPENBLAS_COMMIT "b5c96fcfcdc82945502a2303116a64d89985daf5")
SET(OPTIONAL_ARGS HOSTCC=${HOST_C_COMPILER} TARGET=ARMV7 ARM_SOFTFP_ABI=1 USE_THREAD=0 libs) SET(OPTIONAL_ARGS HOSTCC=${HOST_C_COMPILER} TARGET=ARMV7 ARM_SOFTFP_ABI=1 USE_THREAD=0)
ELSEIF(RPI) ELSEIF(RPI)
# use hardfp # use hardfp
SET(OPENBLAS_COMMIT "v0.2.19") SET(OPENBLAS_COMMIT "v0.2.19")
SET(OPTIONAL_ARGS HOSTCC=${HOST_C_COMPILER} TARGET=ARMV7 USE_THREAD=0 libs) SET(OPTIONAL_ARGS HOSTCC=${HOST_C_COMPILER} TARGET=ARMV7 USE_THREAD=0)
ENDIF()
ELSE() ELSE()
SET(OPENBLAS_COMMIT "v0.2.19") SET(OPENBLAS_COMMIT "v0.2.19")
SET(OPTIONAL_ARGS DYNAMIC_ARCH=1 libs NUM_THREADS=64) SET(OPTIONAL_ARGS "")
IF(CMAKE_SYSTEM_PROCESSOR MATCHES "^x86(_64)?$")
SET(OPTIONAL_ARGS DYNAMIC_ARCH=1 NUM_THREADS=64)
ENDIF()
ENDIF() ENDIF()
ExternalProject_Add( ExternalProject_Add(
......
...@@ -182,7 +182,7 @@ function(go_library TARGET_NAME) ...@@ -182,7 +182,7 @@ function(go_library TARGET_NAME)
COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build ${BUILD_MODE} COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build ${BUILD_MODE}
-o "${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME}" -o "${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME}"
${go_library_SRCS} ${go_library_SRCS}
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}) WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
add_custom_target(${TARGET_NAME}_lib ALL DEPENDS ${TARGET_NAME}_timestamp ${go_library_DEPS}) add_custom_target(${TARGET_NAME}_lib ALL DEPENDS ${TARGET_NAME}_timestamp ${go_library_DEPS})
add_library(${TARGET_NAME} STATIC IMPORTED) add_library(${TARGET_NAME} STATIC IMPORTED)
set_property(TARGET ${TARGET_NAME} PROPERTY set_property(TARGET ${TARGET_NAME} PROPERTY
...@@ -199,7 +199,7 @@ function(go_binary TARGET_NAME) ...@@ -199,7 +199,7 @@ function(go_binary TARGET_NAME)
COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build
-o "${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}" -o "${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}"
${go_library_SRCS} ${go_library_SRCS}
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}) WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
add_custom_target(${TARGET_NAME} ALL DEPENDS ${TARGET_NAME}_timestamp ${go_binary_DEPS}) add_custom_target(${TARGET_NAME} ALL DEPENDS ${TARGET_NAME}_timestamp ${go_binary_DEPS})
install(PROGRAMS ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME} DESTINATION bin) install(PROGRAMS ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME} DESTINATION bin)
endfunction(go_binary) endfunction(go_binary)
...@@ -213,7 +213,7 @@ function(go_test TARGET_NAME) ...@@ -213,7 +213,7 @@ function(go_test TARGET_NAME)
COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} test COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} test
-c -o "${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}" -c -o "${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}"
${go_test_SRCS} ${go_test_SRCS}
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}) WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
add_custom_target(${TARGET_NAME} ALL DEPENDS ${TARGET_NAME}_timestamp ${go_test_DEPS}) add_custom_target(${TARGET_NAME} ALL DEPENDS ${TARGET_NAME}_timestamp ${go_test_DEPS})
add_test(${TARGET_NAME} ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}) add_test(${TARGET_NAME} ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME})
endfunction(go_test) endfunction(go_test)
......
...@@ -7,4 +7,4 @@ ...@@ -7,4 +7,4 @@
build_and_install/index_cn.rst build_and_install/index_cn.rst
concepts/use_concepts_cn.rst concepts/use_concepts_cn.rst
- `深度学习入门课程 <http://book.paddlepaddle.org/>`_ - `深度学习入门课程 <http://book.paddlepaddle.org/index.cn.html>`_
...@@ -6,4 +6,4 @@ GET STARTED ...@@ -6,4 +6,4 @@ GET STARTED
build_and_install/index_en.rst build_and_install/index_en.rst
- `Deep Learning 101 <http://book.paddlepaddle.org/index.en.html>`_ - `Deep Learning 101 <http://book.paddlepaddle.org/index.html>`_
...@@ -39,7 +39,7 @@ function(GO_LIBRARY NAME BUILD_TYPE) ...@@ -39,7 +39,7 @@ function(GO_LIBRARY NAME BUILD_TYPE)
COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build ${BUILD_MODE} COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build ${BUILD_MODE}
-o "${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME}" -o "${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME}"
${CMAKE_GO_FLAGS} ${GO_SOURCE} ${CMAKE_GO_FLAGS} ${GO_SOURCE}
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}) WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
add_custom_target(${NAME} ALL DEPENDS ${OUTPUT_DIR}/.timestamp ${ARGN}) add_custom_target(${NAME} ALL DEPENDS ${OUTPUT_DIR}/.timestamp ${ARGN})
add_dependencies(${NAME} goGet) add_dependencies(${NAME} goGet)
......
...@@ -58,7 +58,7 @@ EOF ...@@ -58,7 +58,7 @@ EOF
make -j `nproc` make -j `nproc`
if [ ${WITH_TESTING:-OFF} == "ON" ] && [ ${RUN_TEST:-OFF} == "ON" ] ; then if [ ${WITH_TESTING:-OFF} == "ON" ] && [ ${RUN_TEST:-OFF} == "ON" ] ; then
pip uninstall -y py-paddle paddle || true pip uninstall -y py-paddle paddle || true
ctest -V ctest --output-on-failure
fi fi
......
...@@ -111,6 +111,7 @@ __all__ = [ ...@@ -111,6 +111,7 @@ __all__ = [
'block_expand_layer', 'block_expand_layer',
'maxout_layer', 'maxout_layer',
'out_prod_layer', 'out_prod_layer',
'printer_layer',
'print_layer', 'print_layer',
'priorbox_layer', 'priorbox_layer',
'cross_channel_norm_layer', 'cross_channel_norm_layer',
...@@ -969,7 +970,7 @@ def fc_layer(input, ...@@ -969,7 +970,7 @@ def fc_layer(input,
@wrap_name_default("print") @wrap_name_default("print")
def print_layer(input, name=None): def printer_layer(input, name=None):
""" """
Print the output value of input layers. This layer is useful for debugging. Print the output value of input layers. This layer is useful for debugging.
...@@ -991,6 +992,13 @@ def print_layer(input, name=None): ...@@ -991,6 +992,13 @@ def print_layer(input, name=None):
inputs=[l.name for l in input], ) inputs=[l.name for l in input], )
# this layer don't return anything, can not be input of other layer. # this layer don't return anything, can not be input of other layer.
# Keep print_layer for compatibility with V1 API.
# 'print_layer' does not work for V2 API because it will be changed to
# 'print' for V2 API. But 'print' is a reserved key word in python.
print_layer = printer_layer
@wrap_name_default("priorbox") @wrap_name_default("priorbox")
def priorbox_layer(input, def priorbox_layer(input,
......
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module will download dataset from
http://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html
and parse train/test set intopaddle reader creators.
This set contains images of flowers belonging to 102 different categories.
The images were acquired by searching the web and taking pictures. There are a
minimum of 40 images for each category.
The database was used in:
Nilsback, M-E. and Zisserman, A. Automated flower classification over a large
number of classes.Proceedings of the Indian Conference on Computer Vision,
Graphics and Image Processing (2008)
http://www.robots.ox.ac.uk/~vgg/publications/papers/nilsback08.{pdf,ps.gz}.
"""
import cPickle
import itertools
from common import download
import tarfile
import scipy.io as scio
from paddle.v2.image import *
import os
import numpy as np
import paddle.v2 as paddle
from multiprocessing import cpu_count
__all__ = ['train', 'test', 'valid']
DATA_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz'
LABEL_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat'
SETID_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/setid.mat'
DATA_MD5 = '52808999861908f626f3c1f4e79d11fa'
LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d'
SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c'
def default_mapper(sample):
'''
map image bytes data to type needed by model input layer
'''
img, label = sample
img = paddle.image.load_image_bytes(img)
img = paddle.image.simple_transform(img, 256, 224, True)
return img.flatten().astype('float32'), label
def reader_creator(data_file,
label_file,
setid_file,
dataset_name,
mapper=default_mapper,
buffered_size=1024):
'''
1. read images from tar file and
merge images into batch files in 102flowers.tgz_batch/
2. get a reader to read sample from batch file
:param data_file: downloaded data file
:type data_file: string
:param label_file: downloaded label file
:type label_file: string
:param setid_file: downloaded setid file containing information
about how to split dataset
:type setid_file: string
:param dataset_name: data set name (tstid|trnid|valid)
:type dataset_name: string
:param mapper: a function to map image bytes data to type
needed by model input layer
:type mapper: callable
:param buffered_size: the size of buffer used to process images
:type buffered_size: int
:return: data reader
:rtype: callable
'''
labels = scio.loadmat(label_file)['labels'][0]
indexes = scio.loadmat(setid_file)[dataset_name][0]
img2label = {}
for i in indexes:
img = "jpg/image_%05d.jpg" % i
img2label[img] = labels[i - 1]
file_list = batch_images_from_tar(data_file, dataset_name, img2label)
def reader():
for file in open(file_list):
file = file.strip()
batch = None
with open(file, 'r') as f:
batch = cPickle.load(f)
data = batch['data']
labels = batch['label']
for sample, label in itertools.izip(data, batch['label']):
yield sample, int(label)
return paddle.reader.xmap_readers(mapper, reader,
cpu_count(), buffered_size)
def train(mapper=default_mapper, buffered_size=1024):
'''
Create flowers training set reader.
It returns a reader, each sample in the reader is
image pixels in [0, 1] and label in [1, 102]
translated from original color image by steps:
1. resize to 256*256
2. random crop to 224*224
3. flatten
:param mapper: a function to map sample.
:type mapper: callable
:param buffered_size: the size of buffer used to process images
:type buffered_size: int
:return: train data reader
:rtype: callable
'''
return reader_creator(
download(DATA_URL, 'flowers', DATA_MD5),
download(LABEL_URL, 'flowers', LABEL_MD5),
download(SETID_URL, 'flowers', SETID_MD5), 'trnid', mapper,
buffered_size)
def test(mapper=default_mapper, buffered_size=1024):
'''
Create flowers test set reader.
It returns a reader, each sample in the reader is
image pixels in [0, 1] and label in [1, 102]
translated from original color image by steps:
1. resize to 256*256
2. random crop to 224*224
3. flatten
:param mapper: a function to map sample.
:type mapper: callable
:param buffered_size: the size of buffer used to process images
:type buffered_size: int
:return: test data reader
:rtype: callable
'''
return reader_creator(
download(DATA_URL, 'flowers', DATA_MD5),
download(LABEL_URL, 'flowers', LABEL_MD5),
download(SETID_URL, 'flowers', SETID_MD5), 'tstid', mapper,
buffered_size)
def valid(mapper=default_mapper, buffered_size=1024):
'''
Create flowers validation set reader.
It returns a reader, each sample in the reader is
image pixels in [0, 1] and label in [1, 102]
translated from original color image by steps:
1. resize to 256*256
2. random crop to 224*224
3. flatten
:param mapper: a function to map sample.
:type mapper: callable
:param buffered_size: the size of buffer used to process images
:type buffered_size: int
:return: test data reader
:rtype: callable
'''
return reader_creator(
download(DATA_URL, 'flowers', DATA_MD5),
download(LABEL_URL, 'flowers', LABEL_MD5),
download(SETID_URL, 'flowers', SETID_MD5), 'valid', mapper,
buffered_size)
def fetch():
download(DATA_URL, 'flowers', DATA_MD5)
download(LABEL_URL, 'flowers', LABEL_MD5)
download(SETID_URL, 'flowers', SETID_MD5)
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.v2.dataset.flowers
import unittest
class TestFlowers(unittest.TestCase):
def check_reader(self, reader):
sum = 0
label = 0
size = 224 * 224 * 3
for l in reader():
self.assertEqual(l[0].size, size)
if l[1] > label:
label = l[1]
sum += 1
return sum, label
def test_train(self):
instances, max_label_value = self.check_reader(
paddle.v2.dataset.flowers.train())
self.assertEqual(instances, 1020)
self.assertEqual(max_label_value, 102)
def test_test(self):
instances, max_label_value = self.check_reader(
paddle.v2.dataset.flowers.test())
self.assertEqual(instances, 6149)
self.assertEqual(max_label_value, 102)
def test_valid(self):
instances, max_label_value = self.check_reader(
paddle.v2.dataset.flowers.valid())
self.assertEqual(instances, 1020)
self.assertEqual(max_label_value, 102)
if __name__ == '__main__':
unittest.main()
import numpy as np import numpy as np
try: try:
import cv2 import cv2
except: except ImportError:
print( cv2 = None
"import cv2 error, please install opencv-python: pip install opencv-python" import os
) import tarfile
import cPickle
__all__ = [ __all__ = [
"load_image", "resize_short", "to_chw", "center_crop", "random_crop", "load_image_bytes", "load_image", "resize_short", "to_chw", "center_crop",
"left_right_flip", "simple_transform", "load_and_transform" "random_crop", "left_right_flip", "simple_transform", "load_and_transform",
"batch_images_from_tar"
] ]
""" """
This file contains some common interfaces for image preprocess. This file contains some common interfaces for image preprocess.
...@@ -28,6 +30,90 @@ the image layout as follows. ...@@ -28,6 +30,90 @@ the image layout as follows.
""" """
def batch_images_from_tar(data_file,
dataset_name,
img2label,
num_per_batch=1024):
"""
Read images from tar file and batch them into batch file.
param data_file: path of image tar file
type data_file: string
param dataset_name: 'train','test' or 'valid'
type dataset_name: string
param img2label: a dic with image file name as key
and image's label as value
type img2label: dic
param num_per_batch: image number per batch file
type num_per_batch: int
return: path of list file containing paths of batch file
rtype: string
"""
batch_dir = data_file + "_batch"
out_path = "%s/%s" % (batch_dir, dataset_name)
meta_file = "%s/%s.txt" % (batch_dir, dataset_name)
if os.path.exists(out_path):
return meta_file
else:
os.makedirs(out_path)
tf = tarfile.open(data_file)
mems = tf.getmembers()
data = []
labels = []
file_id = 0
for mem in mems:
if mem.name in img2label:
data.append(tf.extractfile(mem).read())
labels.append(img2label[mem.name])
if len(data) == num_per_batch:
output = {}
output['label'] = labels
output['data'] = data
cPickle.dump(
output,
open('%s/batch_%d' % (out_path, file_id), 'w'),
protocol=cPickle.HIGHEST_PROTOCOL)
file_id += 1
data = []
labels = []
if len(data) > 0:
output = {}
output['label'] = labels
output['data'] = data
cPickle.dump(
output,
open('%s/batch_%d' % (out_path, file_id), 'w'),
protocol=cPickle.HIGHEST_PROTOCOL)
with open(meta_file, 'a') as meta:
for file in os.listdir(out_path):
meta.write(os.path.abspath("%s/%s" % (out_path, file)) + "\n")
return meta_file
def load_image_bytes(bytes, is_color=True):
"""
Load an color or gray image from bytes array.
Example usage:
.. code-block:: python
with open('cat.jpg') as f:
im = load_image_bytes(f.read())
:param bytes: the input image bytes array.
:type file: str
:param is_color: If set is_color True, it will load and
return a color image. Otherwise, it will
load and return a gray image.
"""
flag = 1 if is_color else 0
file_bytes = np.asarray(bytearray(bytes), dtype=np.uint8)
img = cv2.imdecode(file_bytes, flag)
return img
def load_image(file, is_color=True): def load_image(file, is_color=True):
""" """
Load an color or gray image from the file path. Load an color or gray image from the file path.
......
...@@ -149,6 +149,20 @@ def __get_used_layers__(output_layers, extra_layers=None): ...@@ -149,6 +149,20 @@ def __get_used_layers__(output_layers, extra_layers=None):
for layer in output_layers: for layer in output_layers:
dfs_travel(layer.full_name) dfs_travel(layer.full_name)
# print layer needs to be specially handled because no other
# layer depends on it. It is used to print the result of some
# layers when running the model for debug purpose. So we explicitly
# add a print layer to the topolty if its input is in the toplogy.
for layer in cp.g_config.model_config.layers:
if layer.type == 'print':
used = True
for inp in layer.inputs:
if inp.input_layer_name not in layer_names:
used = False
break
if used:
layer_names.add(layer.name)
return layer_names return layer_names
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
__all__ = [ __all__ = [
'map_readers', 'buffered', 'compose', 'chain', 'shuffle', 'map_readers', 'buffered', 'compose', 'chain', 'shuffle',
'ComposeNotAligned', 'firstn' 'ComposeNotAligned', 'firstn', 'xmap_readers'
] ]
import itertools import itertools
...@@ -224,3 +224,74 @@ def firstn(reader, n): ...@@ -224,3 +224,74 @@ def firstn(reader, n):
yield item yield item
return firstn_reader return firstn_reader
class XmapEndSignal():
pass
def xmap_readers(mapper, reader, process_num, buffer_size):
"""
Use multiprocess to map samples from reader by a mapper defined by user.
And this function contains a buffered decorator.
:param mapper: a function to map sample.
:type mapper: callable
:param reader: the data reader to read from
:type reader: callable
:param process_num: process number to handle original sample
:type process_num: int
:param buffer_size: max buffer size
:type buffer_size: int
:return: the decarated reader
:rtype: callable
"""
end = XmapEndSignal()
in_queue = Queue(buffer_size)
out_queue = Queue(buffer_size)
# define a worker to read samples from reader to in_queue
def read_worker(reader, in_queue):
for i in reader():
in_queue.put(i)
in_queue.put(end)
# start a read worker in a thread
t = Thread(target=read_worker, args=(reader, in_queue))
t.daemon = True
t.start()
# define a worker to handle samples from in_queue by mapper
# and put mapped samples into out_queue
def handle_worker(in_queue, out_queue, mapper):
sample = in_queue.get()
while not isinstance(sample, XmapEndSignal):
r = mapper(sample)
out_queue.put(r)
sample = in_queue.get()
in_queue.put(end)
out_queue.put(end)
# start several handle_workers
workers = []
for i in xrange(process_num):
worker = Thread(
target=handle_worker, args=(in_queue, out_queue, mapper))
worker.daemon = True
workers.append(worker)
for w in workers:
w.start()
def xreader():
sample = out_queue.get()
while not isinstance(sample, XmapEndSignal):
yield sample
sample = out_queue.get()
finish = 1
while finish < process_num:
sample = out_queue.get()
if isinstance(sample, XmapEndSignal):
finish += 1
else:
yield sample
return xreader
...@@ -164,6 +164,7 @@ class OtherLayerTest(unittest.TestCase): ...@@ -164,6 +164,7 @@ class OtherLayerTest(unittest.TestCase):
maxid = layer.max_id(input=inference) maxid = layer.max_id(input=inference)
sampling_id = layer.sampling_id(input=inference) sampling_id = layer.sampling_id(input=inference)
eos = layer.eos(input=maxid, eos_id=5) eos = layer.eos(input=maxid, eos_id=5)
layer.printer(maxid)
print layer.parse_network([maxid, sampling_id, eos]) print layer.parse_network([maxid, sampling_id, eos])
def test_slicing_joining_layer(self): def test_slicing_joining_layer(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册