提交 ae06debf 编写于 作者: D dangqingqing

Remove the C++ code and refine Python code.

上级 fe073d1f
......@@ -43,7 +43,6 @@ option(WITH_SWIG_PY "Compile PaddlePaddle with py PaddlePaddle prediction api" $
option(ON_TRAVIS "Running test on travis-ci or not." OFF)
option(ON_COVERALLS "Generating code coverage data on coveralls or not." OFF)
option(COVERALLS_UPLOAD "Uploading the generated coveralls json." ON)
option(USE_OPENCV "Compile PaddlePaddle with opencv" OFF)
if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE "RelWithDebInfo" CACHE STRING
......@@ -196,7 +195,3 @@ if(WITH_DOC)
add_subdirectory(doc)
add_subdirectory(doc_cn)
endif()
if(USE_OPENCV)
add_subdirectory(plugin/opencv)
endif()
# use opencv plugin
project(DeJpeg CXX C)
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake")
set(PROJ_ROOT ${CMAKE_SOURCE_DIR})
list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/Modules)
set(DEJPEG_LINKER_LIBS "")
# opencv
find_package(OpenCV REQUIRED COMPONENTS core highgui imgproc)
include_directories(${OpenCV_INCLUDE_DIRS})
list(APPEND DEJPEG_LINKER_LIBS ${OpenCV_LIBS})
message(STATUS "OpenCV found (${OpenCV_CONFIG_PATH})")
add_definitions(-DUSE_OPENCV)
# boost-python
set(Boost_NO_SYSTEM_PATHS ON)
if (Boost_NO_SYSTEM_PATHS)
set(BOOST_ROOT $ENV{BOOST_ROOT})
set(Boost_DIR ${BOOST_ROOT})
set(Boost_INCLUDE_DIR "${BOOST_ROOT}/include")
set(Boost_LIBRARIES "${BOOST_ROOT}/lib/")
endif (Boost_NO_SYSTEM_PATHS)
find_package(Boost 1.46 COMPONENTS python)
include_directories(SYSTEM ${Boost_INCLUDE_DIR})
link_directories(${Boost_INCLUDE_DIR})
message(STATUS "Boost found (${Boost_INCLUDE_DIR})")
message(STATUS "Boost found (${Boost_LIBRARIES})")
list(APPEND DEJPEG_LINKER_LIBS ${Boost_LIBRARIES})
file(GLOB DEJPEG_HEADER "${CMAKE_CURRENT_SOURCE_DIR}" "*.h")
file(GLOB DEJPEG_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}" "*.cpp")
set(BUILD_PRIVATE_FLAGS
-Wno-all
-Wno-error
-Wno-non-virtual-dtor
-Wno-delete-non-virtual-dtor)
add_library(DeJpeg SHARED ${DEJPEG_SOURCES})
target_compile_options(DeJpeg BEFORE PRIVATE ${BUILD_PRIVATE_FLAGS})
target_link_libraries(DeJpeg ${DEJPEG_LINKER_LIBS})
set_target_properties(DeJpeg PROPERTIES PREFIX "")
add_style_check_target(DeJpeg ${DEJPEG_SOURCES})
add_style_check_target(DeJpeg ${DEJPEG_HEADER})
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "DataTransformer.h"
#include <time.h>
#include <limits>
DataTransformer::DataTransformer(int threadNum,
int capacity,
bool isTest,
bool isColor,
int cropHeight,
int cropWidth,
int imgSize,
bool isEltMean,
bool isChannelMean,
float* meanValues)
: isTest_(isTest),
isColor_(isColor),
cropHeight_(cropHeight),
cropWidth_(cropWidth),
imgSize_(imgSize),
capacity_(capacity),
prefetchFree_(capacity),
prefetchFull_(capacity) {
fetchCount_ = -1;
scale_ = 1.0;
isChannelMean_ = isChannelMean;
isEltMean_ = isEltMean;
loadMean(meanValues);
imgPixels_ = cropHeight * cropWidth * (isColor_ ? 3 : 1);
prefetch_.reserve(capacity);
for (int i = 0; i < capacity; i++) {
auto d = std::make_shared<DataType>(new float[imgPixels_ * 3], 0);
prefetch_.push_back(d);
memset(prefetch_[i]->first, 0, imgPixels_ * sizeof(float));
prefetchFree_.enqueue(prefetch_[i]);
}
numThreads_ = threadNum;
syncThreadPool_.reset(new paddle::SyncThreadPool(numThreads_, false));
}
void DataTransformer::loadMean(float* values) {
if (values) {
int c = isColor_ ? 3 : 1;
int sz = isChannelMean_ ? c : cropHeight_ * cropWidth_ * c;
meanValues_ = new float[sz];
memcpy(meanValues_, values, sz * sizeof(float));
}
}
void DataTransformer::startFetching(const char* src,
const int size,
float* trg) {
std::vector<char> imbuf(src, src + size);
int cvFlag = (isColor_ ? CV_LOAD_IMAGE_COLOR : CV_LOAD_IMAGE_GRAYSCALE);
cv::Mat im = cv::imdecode(cv::Mat(imbuf), cvFlag);
if (!im.data) {
LOG(ERROR) << "Could not decode image";
LOG(ERROR) << im.channels() << " " << im.rows << " " << im.cols;
}
this->transform(im, trg);
}
int DataTransformer::Rand(int min, int max) {
std::random_device source;
std::mt19937 rng(source());
std::uniform_int_distribution<int> dist(min, max);
return dist(rng);
}
void DataTransformer::transform(cv::Mat& cvImgOri, float* target) {
const int imgChannels = cvImgOri.channels();
const int imgHeight = cvImgOri.rows;
const int imgWidth = cvImgOri.cols;
const bool doMirror = (!isTest_) && Rand(0, 1);
int h_off = 0;
int w_off = 0;
int th = imgHeight;
int tw = imgWidth;
cv::Mat img;
if (imgSize_ > 0) {
if (imgHeight > imgWidth) {
tw = imgSize_;
th = int(double(imgHeight) / imgWidth * tw);
th = th > imgSize_ ? th : imgSize_;
} else {
th = imgSize_;
tw = int(double(imgWidth) / imgHeight * th);
tw = tw > imgSize_ ? tw : imgSize_;
}
cv::resize(cvImgOri, img, cv::Size(tw, th));
} else {
cv::Mat img = cvImgOri;
}
cv::Mat cv_cropped_img = img;
if (cropHeight_ && cropWidth_) {
if (!isTest_) {
h_off = Rand(0, th - cropHeight_);
w_off = Rand(0, tw - cropWidth_);
} else {
h_off = (th - cropHeight_) / 2;
w_off = (tw - cropWidth_) / 2;
}
cv::Rect roi(w_off, h_off, cropWidth_, cropHeight_);
cv_cropped_img = img(roi);
} else {
CHECK_EQ(cropHeight_, imgHeight);
CHECK_EQ(cropWidth_, imgWidth);
}
int height = cropHeight_;
int width = cropWidth_;
int top_index;
for (int h = 0; h < height; ++h) {
const uchar* ptr = cv_cropped_img.ptr<uchar>(h);
int img_index = 0;
for (int w = 0; w < width; ++w) {
for (int c = 0; c < imgChannels; ++c) {
if (doMirror) {
top_index = (c * height + h) * width + width - 1 - w;
} else {
top_index = (c * height + h) * width + w;
}
float pixel = static_cast<float>(ptr[img_index++]);
if (isEltMean_) {
int mean_index = (c * imgHeight + h) * imgWidth + w;
target[top_index] = (pixel - meanValues_[mean_index]) * scale_;
} else {
if (isChannelMean_) {
target[top_index] = (pixel - meanValues_[c]) * scale_;
} else {
target[top_index] = pixel * scale_;
}
}
}
}
} // target: BGR
}
void DataTransformer::start(std::vector<char*>& data,
int* datalen,
int* labels) {
auto job = [&](int tid, int numThreads) {
for (size_t i = tid; i < data.size(); i += numThreads) {
DataTypePtr ret = prefetchFree_.dequeue();
char* buf = data[i];
int size = datalen[i];
ret->second = labels[i];
this->startFetching(buf, size, ret->first);
prefetchFull_.enqueue(ret);
}
};
syncThreadPool_->exec(job);
fetchCount_ = data.size();
}
void DataTransformer::obtain(float* data, int* label) {
fetchCount_--;
if (fetchCount_ < 0) {
LOG(FATAL) << "Empty data";
}
DataTypePtr ret = prefetchFull_.dequeue();
*label = ret->second;
memcpy(data, ret->first, sizeof(float) * imgPixels_);
prefetchFree_.enqueue(ret);
}
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef DATATRANSFORMER_H_
#define DATATRANSFORMER_H_
#include <iostream>
#include <fstream>
#include <opencv2/opencv.hpp>
#include <vector>
#include <string>
#include <algorithm>
#include "paddle/utils/Thread.h"
/**
* This is an image processing module with OpenCV, such as
* resizing, scaling, mirroring, substracting the image mean...
*
* This class has a double BlockQueue and they shared the same memory.
* It is used to avoid create memory each time. And it also can
* return the data even if the data are processing in multi-threads.
*/
class DataTransformer {
public:
DataTransformer(int threadNum,
int capacity,
bool isTest,
bool isColor,
int cropHeight,
int cropWidth,
int imgSize,
bool isEltMean,
bool isChannelMean,
float* meanValues);
virtual ~DataTransformer() {
if (meanValues_) {
free(meanValues_);
}
}
/**
* @brief Start multi-threads to transform a list of input data.
* The processed data will be saved in Queue of prefetchFull_.
*
* @param data Data containing the image string to be transformed.
* @param label The label of input image.
*/
void start(std::vector<char*>& data, int* datalen, int* labels);
/**
* @brief Applies the transformation on one image Mat.
*
* @param img The input img to be transformed.
* @param target target is used to save the transformed data.
*/
void transform(cv::Mat& img, float* target);
/**
* @brief Decode the image string, then calls transform() function.
*
* @param src The input image string.
* @param size The length of string.
* @param trg trg is used to save the transformed data.
*/
void startFetching(const char* src, const int size, float* trg);
/**
* @brief Return the transformed data and its label.
*/
void obtain(float* data, int* label);
private:
int isTest_;
int isColor_;
int cropHeight_;
int cropWidth_;
int imgSize_;
int capacity_;
int fetchCount_;
bool isEltMean_;
bool isChannelMean_;
int numThreads_;
float scale_;
int imgPixels_;
float* meanValues_;
/**
* Initialize the mean values.
*/
void loadMean(float* values);
/**
* @brief Generates a random integer from Uniform({min, min + 1, ..., max}).
* @param min The lower bound (inclusive) value of the random number.
* @param max The upper bound (inclusive) value of the random number.
*
* @return
* A uniformly random integer value from ({min, min + 1, ..., max}).
*/
int Rand(int min, int max);
typedef std::pair<float*, int> DataType;
typedef std::shared_ptr<DataType> DataTypePtr;
std::vector<DataTypePtr> prefetch_;
std::unique_ptr<paddle::SyncThreadPool> syncThreadPool_;
paddle::BlockingQueue<DataTypePtr> prefetchFree_;
paddle::BlockingQueue<DataTypePtr> prefetchFull_;
}; // class DataTransformer
#endif // DATATRANSFORMER_H_
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <Python.h>
#include <time.h>
#include <vector>
#include <sys/time.h>
#include <unistd.h>
#include <glog/logging.h>
#include <numpy/arrayobject.h>
#include <boost/python.hpp>
#include "DataTransformer.h"
/**
* DecodeJpeg is an image processing API for interfacing Python and C++
* code DataTransformer, which used OpenCV and multi-threads to accelerate
* image processing.
* The Boost Python Library is used to wrap C++ interfaces.
*/
class DecodeJpeg {
public:
/**
* The constructor will create and initialize an object of DataTransformer.
*/
DecodeJpeg(int threadNum,
int capacity,
bool isTest,
bool isColor,
int resize_min_size,
int cropSizeH,
int cropSizeW,
PyObject* meanValues) {
int channel = isColor ? 3 : 1;
bool isEltMean = false;
bool isChannelMean = false;
float* mean = NULL;
if (meanValues || meanValues != Py_None) {
if (!PyArray_Check(meanValues)) {
LOG(FATAL) << "Object is not a numpy array";
}
pyTypeCheck(meanValues);
int size = PyArray_SIZE(reinterpret_cast<PyArrayObject*>(meanValues));
isChannelMean = (size == channel) ? true : false;
isEltMean = (size == channel * cropSizeH * cropSizeW) ? true : false;
CHECK(isChannelMean != isEltMean);
mean = (float*)PyArray_DATA(reinterpret_cast<PyArrayObject*>(meanValues));
}
tfhandlerPtr_ = std::make_shared<DataTransformer>(threadNum,
capacity,
isTest,
isColor,
cropSizeH,
cropSizeW,
resize_min_size,
isEltMean,
isChannelMean,
mean);
}
~DecodeJpeg() {}
/**
* @brief This function is used to parse the Python object and convert
* the data to C++ format. Then it called the function of
* DataTransformer to start image processing.
* @param pysrc The input image list with string type.
* @param pylabel The input label of image.
* It's type is numpy.array with int32.
*/
void start(boost::python::list& pysrc, PyObject* pydlen, PyObject* pylabel) {
std::vector<char*> data;
int num = len(pysrc);
for (int t = 0; t < num; ++t) {
char* src = boost::python::extract<char*>(pysrc[t]);
data.push_back(src);
}
int* dlen = (int*)PyArray_DATA(reinterpret_cast<PyArrayObject*>(pydlen));
int* dlabels =
(int*)PyArray_DATA(reinterpret_cast<PyArrayObject*>(pylabel));
tfhandlerPtr_->start(data, dlen, dlabels);
}
/**
* @brief Return one processed data.
* @param pytrg The processed image.
* @param pylabel The label of processed image.
*/
void get(PyObject* pytrg, PyObject* pylab) {
pyWritableCheck(pytrg);
pyWritableCheck(pylab);
pyContinuousCheck(pytrg);
pyContinuousCheck(pylab);
float* data = (float*)PyArray_DATA(reinterpret_cast<PyArrayObject*>(pytrg));
int* label = (int*)PyArray_DATA(reinterpret_cast<PyArrayObject*>(pylab));
tfhandlerPtr_->obtain(data, label);
}
/**
* @brief An object of DataTransformer, which is used to call
* the image processing funtions.
*/
std::shared_ptr<DataTransformer> tfhandlerPtr_;
private:
/**
* @brief Check whether the type of PyObject is valid or not.
*/
void pyTypeCheck(PyObject* o) {
int typenum = PyArray_TYPE(reinterpret_cast<PyArrayObject*>(o));
// clang-format off
int type =
typenum == NPY_UBYTE ? CV_8U :
typenum == NPY_BYTE ? CV_8S :
typenum == NPY_USHORT ? CV_16U :
typenum == NPY_SHORT ? CV_16S :
typenum == NPY_INT || typenum == NPY_LONG ? CV_32S :
typenum == NPY_FLOAT ? CV_32F :
typenum == NPY_DOUBLE ? CV_64F : -1;
// clang-format on
if (type < 0) {
LOG(FATAL) << "toMat: Data type = " << type << " is not supported";
}
}
/**
* @brief Check whether the PyObject is writable or not.
*/
void pyWritableCheck(PyObject* o) {
CHECK(PyArray_ISWRITEABLE(reinterpret_cast<PyArrayObject*>(o)));
}
/**
* @brief Check whether the PyObject is c-contiguous or not.
*/
void pyContinuousCheck(PyObject* o) {
CHECK(PyArray_IS_C_CONTIGUOUS(reinterpret_cast<PyArrayObject*>(o)));
}
}; // DecodeJpeg
/**
* @brief Initialize the Python interpreter and numpy.
*/
static void initPython() {
Py_Initialize();
PyOS_sighandler_t sighandler = PyOS_getsig(SIGINT);
import_array();
PyOS_setsig(SIGINT, sighandler);
}
/**
* Use Boost.Python to expose C++ interface to Python.
*/
BOOST_PYTHON_MODULE(DeJpeg) {
initPython();
boost::python::class_<DecodeJpeg>(
"DecodeJpeg",
boost::python::init<int, int, bool, bool, int, int, int, PyObject*>())
.def("start", &DecodeJpeg::start)
.def("get", &DecodeJpeg::get);
};
import os, psutil
import cv2
from paddle.utils.image_util import *
import os, sys
import numpy as np
from PIL import Image
from cStringIO import StringIO
import multiprocessing
import subprocess, signal, sys
from functools import partial
from paddle.utils.image_util import *
from paddle.trainer.config_parser import logger
try:
import cv2
except ImportError:
logger.warning("OpenCV2 is not installed, using PIL to prcoess")
cv2 = None
class CvImageTransfomer(ImageTransformer):
class CvTransfomer(ImageTransformer):
"""
CvImageTransfomer used python-opencv to process image.
CvTransfomer used python-opencv to process image.
"""
def __init__(self,
def __init__(
self,
min_size=None,
crop_size=None,
transpose=None,
transpose=(2, 0, 1), # transpose to C * H * W
channel_swap=None,
mean=None,
is_train=True,
......@@ -23,22 +34,17 @@ class CvImageTransfomer(ImageTransformer):
self.crop_size = crop_size
self.is_train = is_train
def cv_resize_fixed_short_side(self, im, min_size):
def resize(self, im, min_size):
row, col = im.shape[:2]
scale = min_size / float(min(row, col))
if row < col:
row = min_size
col = int(round(col * scale))
col = col if col > min_size else min_size
new_row, new_col = min_size, min_size
if row > col:
new_row = min_size * row / col
else:
col = min_size
row = int(round(row * scale))
row = row if row > min_size else min_size
resized_size = row, col
im = cv2.resize(im, resized_size, interpolation=cv2.INTER_CUBIC)
new_col = min_size * col / row
im = cv2.resize(im, (new_row, new_col), interpolation=cv2.INTER_CUBIC)
return im
def crop_img(self, im):
def crop_and_flip(self, im):
"""
Return cropped image.
The size of the cropped image is inner_size * inner_size.
......@@ -65,8 +71,8 @@ class CvImageTransfomer(ImageTransformer):
return im
def transform(self, im):
im = self.cv_resize_fixed_short_side(im, self.min_size)
im = self.crop_img(im)
im = self.resize(im, self.min_size)
im = self.crop_and_flip(im)
# transpose, swap channel, sub mean
im = im.astype('float32')
ImageTransformer.transformer(self, im)
......@@ -81,90 +87,187 @@ class CvImageTransfomer(ImageTransformer):
im = self.load_image_from_string(data)
return self.transform(im)
def load_image_from_file(self, file):
flag = cv2.CV_LOAD_IMAGE_COLOR if self.is_color else cv2.CV_LOAD_IMAGE_GRAYSCALE
im = cv2.imread(file, flag)
return im
class MultiProcessImageTransfomer():
def __init__(self,
procnum=10,
capacity=10240,
def transform_from_file(self, file):
im = self.load_image_from_file(file)
return self.transform(im)
class PILTransfomer(ImageTransformer):
"""
PILTransfomer used PIL to process image.
"""
def __init__(
self,
min_size=None,
crop_size=None,
transpose=None,
transpose=(2, 0, 1), # transpose to C * H * W
channel_swap=None,
mean=None,
is_train=True,
is_color=True):
self.procnum = procnum
self.capacity = capacity
self.size = 0
self.count = 0
signal.signal(signal.SIGTERM, self.kill_child_processes)
self.fetch_queue = multiprocessing.Queue(maxsize=capacity)
self.cv_transformer = CvImageTransfomer(min_size, crop_size, transpose,
channel_swap, mean, is_train,
is_color)
ImageTransformer.__init__(self, transpose, channel_swap, mean, is_color)
self.min_size = min_size
self.crop_size = crop_size
self.is_train = is_train
def __del__(self):
try:
for p in self.procs:
p.join()
except Exception as e:
print str(e)
def reset(self, size):
self.size = size
self.count = 0
self.procs = []
def run_proc(self, data, label):
dlen = len(label)
self.reset(dlen)
for i in xrange(self.procnum):
start = dlen * i / self.procnum
end = dlen * (i + 1) / self.procnum
proc = multiprocessing.Process(
target=self.batch_transfomer,
args=(data[start:end], label[start:end]))
proc.daemon = True
self.procs.append(proc)
for p in self.procs:
p.start()
def get(self):
"""
Return one processed image.
"""
# block if necessary until an item is available
data, lab = self.fetch_queue.get(block=True)
self.count += 1
if self.count == self.size:
try:
for p in self.procs:
p.join()
except Exception as e:
print str(e)
return data, lab
def resize(self, im, min_size):
row, col = im.size[:2]
new_row, new_col = min_size, min_size
if row > col:
new_row = min_size * row / col
else:
new_col = min_size * col / row
im = im.resize((new_row, new_col), Image.ANTIALIAS)
return im
def batch_transfomer(self, data, label):
def crop_and_flip(self, im):
"""
param data: input data in format of image string
type data: a list of string
label: the label of image
Return cropped image.
The size of the cropped image is inner_size * inner_size.
"""
for i in xrange(len(label)):
res = self.cv_transformer.transform_from_string(data[i])
self.fetch_queue.put((res, int(label[i])))
row, col = im.size[:2]
start_h, start_w = 0, 0
if self.is_train:
start_h = np.random.randint(0, row - self.crop_size + 1)
start_w = np.random.randint(0, col - self.crop_size + 1)
else:
start_h = (row - self.crop_size) / 2
start_w = (col - self.crop_size) / 2
end_h, end_w = start_h + self.crop_size, start_w + self.crop_size
im = im.crop((start_h, start_w, end_h, end_w))
if (self.is_train) and (np.random.randint(2) == 0):
im = im.transpose(Image.FLIP_LEFT_RIGHT)
return im
def transform(self, im):
im = self.resize(im, self.min_size)
im = self.crop_and_flip(im)
im = np.array(im, dtype=np.float32) # convert to numpy.array
# transpose, swap channel, sub mean
ImageTransformer.transformer(self, im)
return im
def load_image_from_string(self, data):
im = Image.open(StringIO(data))
return im
def kill_child_processes(self, signum, frame):
def transform_from_string(self, data):
im = self.load_image_from_string(data)
return self.transform(im)
def load_image_from_file(self, file):
im = Image.open(file)
return im
def transform_from_file(self, file):
im = self.load_image_from_file(file)
return self.transform(im)
def warpper(cls, (dat, label)):
return cls.job(dat, label)
class MultiProcessImageTransformer(object):
def __init__(self,
procnum=10,
resize_size=None,
crop_size=None,
transpose=(2, 0, 1),
channel_swap=None,
mean=None,
is_train=True,
is_color=True,
is_img_string=True):
"""
Kill a process's child processes in python.
Processing image with multi-process. If it is used in PyDataProvider,
the simple usage for CNN is as follows:
.. code-block:: python
def hool(settings, is_train, **kwargs):
settings.is_train = is_train
settings.mean_value = np.array([103.939,116.779,123.68], dtype=np.float32)
settings.input_types = [
dense_vector(3 * 224 * 224),
integer_value(1)]
settings.transformer = MultiProcessImageTransformer(
procnum=10,
resize_size=256,
crop_size=224,
transpose=(2, 0, 1),
mean=settings.mean_values,
is_train=settings.is_train)
@provider(init_hook=hook, pool_size=20480)
def process(settings, file_list):
with open(file_list, 'r') as fdata:
for line in fdata:
data_dic = np.load(line.strip()) # load the data batch pickled by Pickle.
data = data_dic['data']
labels = data_dic['label']
labels = np.array(labels, dtype=np.float32)
for im, lab in settings.dp.run(data, labels):
yield [im.astype('float32'), int(lab)]
:param procnum: processor number.
:type procnum: int
:param resize_size: the shorter edge size of image after resizing.
:type resize_size: int
:param crop_size: the croping size.
:type crop_size: int
:param transpose: the transpose order, Paddle only allow C * H * W order.
:type transpose: tuple or list
:param channel_swap: the channel swap order, RGB or BRG.
:type channel_swap: tuple or list
:param mean: the mean values of image, per-channel mean or element-wise mean.
:type mean: array, The dimension is 1 for per-channel mean.
The dimension is 3 for element-wise mean.
:param is_train: training peroid or testing peroid.
:type is_train: bool.
:param is_color: the image is color or gray.
:type is_color: bool.
:param is_img_string: The input can be the file name of image or image string.
:type is_img_string: bool.
"""
parent_id = os.getpid()
ps_command = subprocess.Popen(
"ps -o pid --ppid %d --noheaders" % parent_id,
shell=True,
stdout=subprocess.PIPE)
ps_output = ps_command.stdout.read()
retcode = ps_command.wait()
for pid_str in ps_output.strip().split("\n")[:-1]:
os.kill(int(pid_str), signal.SIGTERM)
sys.exit()
self.pool = multiprocessing.Pool(procnum)
self.is_img_string = is_img_string
if cv2 is not None:
self.transformer = CvTransfomer(resize_size, crop_size, transpose,
channel_swap, mean, is_train,
is_color)
else:
self.transformer = PILTransfomer(resize_size, crop_size, transpose,
channel_swap, mean, is_train,
is_color)
def run(self, data, label):
try:
fun = partial(warpper, self)
return self.pool.imap_unordered(fun, zip(data, label), chunksize=5)
except KeyboardInterrupt:
self.pool.terminate()
except Exception, e:
self.pool.terminate()
def job(self, data, label):
if self.is_img_string:
return self.transformer.transform_from_string(data), label
else:
return self.transformer.transform_from_file(data), label
def __getstate__(self):
self_dict = self.__dict__.copy()
del self_dict['pool']
return self_dict
def __setstate__(self, state):
self.__dict__.update(state)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册