提交 9d72cab0 编写于 作者: D dangqingqing

Accelerating image processing for CNN

上级 f8ec510a
......@@ -195,3 +195,7 @@ if(WITH_DOC)
add_subdirectory(doc)
add_subdirectory(doc_cn)
endif()
if(USE_OPENCV)
add_subdirectory(plugin/opencv)
endif()
# use opencv plugin
project(DeJpeg CXX C)
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake")
set(PROJ_ROOT ${CMAKE_SOURCE_DIR})
list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/Modules)
set(DEJPEG_LINKER_LIBS "")
# opencv
find_package(OpenCV REQUIRED COMPONENTS core highgui imgproc)
include_directories(${OpenCV_INCLUDE_DIRS})
list(APPEND DEJPEG_LINKER_LIBS ${OpenCV_LIBS})
message(STATUS "OpenCV found (${OpenCV_CONFIG_PATH})")
add_definitions(-DUSE_OPENCV)
# boost-python
set(Boost_NO_SYSTEM_PATHS ON)
if (Boost_NO_SYSTEM_PATHS)
set(BOOST_ROOT $ENV{BOOST_ROOT})
set(Boost_DIR ${BOOST_ROOT})
set(Boost_INCLUDE_DIR "${BOOST_ROOT}/include")
set(Boost_LIBRARIES "${BOOST_ROOT}/lib/")
endif (Boost_NO_SYSTEM_PATHS)
find_package(Boost 1.46 COMPONENTS python)
include_directories(SYSTEM ${Boost_INCLUDE_DIR})
link_directories(${Boost_INCLUDE_DIR})
message(STATUS "Boost found (${Boost_INCLUDE_DIR})")
message(STATUS "Boost found (${Boost_LIBRARIES})")
list(APPEND DEJPEG_LINKER_LIBS ${Boost_LIBRARIES})
file(GLOB DEJPEG_HEADER "${CMAKE_CURRENT_SOURCE_DIR}" "*.h")
file(GLOB DEJPEG_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}" "*.cpp")
set(CMAKE_CXX_FLAGS "-std=c++11 -O3 -fPIC -Wno-unused-parameter")
add_library(DeJpeg SHARED ${DEJPEG_SOURCES})
target_link_libraries(DeJpeg ${DEJPEG_LINKER_LIBS})
set_target_properties(DeJpeg PROPERTIES PREFIX "")
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "DataTransformer.h"
#include <time.h>
#include <limits>
DataTransformer::DataTransformer(int threadNum,
int capacity,
bool isTest,
bool isColor,
int cropHeight,
int cropWidth,
int imgSize,
bool isEltMean,
bool isChannelMean,
float* meanValues)
: isTest_(isTest),
isColor_(isColor),
cropHeight_(cropHeight),
cropWidth_(cropWidth),
imgSize_(imgSize),
capacity_(capacity),
prefetchFree_(capacity),
prefetchFull_(capacity) {
fetchCount_ = -1;
scale_ = 1.0;
isChannelMean_ = isChannelMean;
isEltMean_ = isEltMean;
loadMean(meanValues);
imgPixels_ = cropHeight * cropWidth * (isColor_ ? 3 : 1);
prefetch_.reserve(capacity);
for (int i = 0; i < capacity; i++) {
auto d = std::make_shared<DataType>(new float[imgPixels_ * 3], 0);
prefetch_.push_back(d);
memset(prefetch_[i]->first, 0, imgPixels_ * sizeof(float));
prefetchFree_.enqueue(prefetch_[i]);
}
numThreads_ = 12;
syncThreadPool_.reset(new SyncThreadPool(numThreads_, false));
}
void DataTransformer::loadMean(float* values) {
if (values) {
int c = isColor_ ? 3 : 1;
int sz = isChannelMean_ ? c : cropHeight_ * cropWidth_ * c;
meanValues_ = new float[sz];
memcpy(meanValues_, values, sz * sizeof(float));
}
}
void DataTransformer::startFetching(const char* src,
const int size,
float* trg) {
vector<char> imbuf(src, src + size);
int cvFlag = (isColor_ ? CV_LOAD_IMAGE_COLOR : CV_LOAD_IMAGE_GRAYSCALE);
cv::Mat im = cv::imdecode(cv::Mat(imbuf), cvFlag);
if (!im.data) {
LOG(ERROR) << "Could not decode image";
LOG(ERROR) << im.channels() << " " << im.rows << " " << im.cols;
}
this->transform(im, trg);
}
int DataTransformer::Rand(int min, int max) {
std::random_device source;
std::mt19937 rng(source());
std::uniform_int_distribution<int> dist(min, max);
return dist(rng);
}
void DataTransformer::transform(Mat& cvImgOri, float* target) {
const int imgChannels = cvImgOri.channels();
const int imgHeight = cvImgOri.rows;
const int imgWidth = cvImgOri.cols;
const bool doMirror = (!isTest_) && Rand(0, 1);
int h_off = 0;
int w_off = 0;
int th = imgHeight;
int tw = imgWidth;
cv::Mat img;
if (imgSize_ > 0) {
if (imgHeight > imgWidth) {
tw = imgSize_;
th = int(double(imgHeight) / imgWidth * tw);
th = th > imgSize_ ? th : imgSize_;
} else {
th = imgSize_;
tw = int(double(imgWidth) / imgHeight * th);
tw = tw > imgSize_ ? tw : imgSize_;
}
cv::resize(cvImgOri, img, cv::Size(tw, th));
} else {
cv::Mat img = cvImgOri;
}
cv::Mat cv_cropped_img = img;
if (cropHeight_ && cropWidth_) {
if (!isTest_) {
h_off = Rand(0, th - cropHeight_);
w_off = Rand(0, tw - cropWidth_);
} else {
h_off = (th - cropHeight_) / 2;
w_off = (tw - cropWidth_) / 2;
}
cv::Rect roi(w_off, h_off, cropWidth_, cropHeight_);
cv_cropped_img = img(roi);
} else {
CHECK_EQ(cropHeight_, imgHeight);
CHECK_EQ(cropWidth_, imgWidth);
}
int height = cropHeight_;
int width = cropWidth_;
int top_index;
for (int h = 0; h < height; ++h) {
const uchar* ptr = cv_cropped_img.ptr<uchar>(h);
int img_index = 0;
for (int w = 0; w < width; ++w) {
for (int c = 0; c < imgChannels; ++c) {
if (doMirror) {
top_index = (c * height + h) * width + width - 1 - w;
} else {
top_index = (c * height + h) * width + w;
}
float pixel = static_cast<float>(ptr[img_index++]);
if (isEltMean_) {
int mean_index = (c * imgHeight + h) * imgWidth + w;
target[top_index] = (pixel - meanValues_[mean_index]) * scale_;
} else {
if (isChannelMean_) {
target[top_index] = (pixel - meanValues_[c]) * scale_;
} else {
target[top_index] = pixel * scale_;
}
}
}
}
} // target: BGR
}
void DataTransformer::start(vector<char*>& data, int* datalen, int* labels) {
auto job = [&](int tid, int numThreads) {
for (int i = tid; i < data.size(); i += numThreads) {
DataTypePtr ret = prefetchFree_.dequeue();
char* buf = data[i];
int size = datalen[i];
ret->second = labels[i];
this->startFetching(buf, size, ret->first);
prefetchFull_.enqueue(ret);
}
};
syncThreadPool_->exec(job);
fetchCount_ = data.size();
}
void DataTransformer::obtain(float* data, int* label) {
fetchCount_--;
if (fetchCount_ < 0) {
LOG(FATAL) << "Empty data";
}
DataTypePtr ret = prefetchFull_.dequeue();
*label = ret->second;
memcpy(data, ret->first, sizeof(float) * imgPixels_);
prefetchFree_.enqueue(ret);
}
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include <fstream>
// #define OPENCV_CAN_BREAK_BINARY_COMPATIBILITY
#include <opencv2/opencv.hpp>
#include <vector>
#include <string>
#include <algorithm>
#include "paddle/utils/Thread.h"
using namespace std;
using namespace cv;
using namespace paddle;
/**
* This is an image processing module with OpenCV, such as
* resizing, scaling, mirroring, substracting the image mean...
*
* This class has a double BlockQueue and they shared the same memory.
* It is used to avoid create memory each time. And it also can
* return the data even if the data are processing in multi-threads.
*/
class DataTransformer {
public:
DataTransformer(int threadNum,
int capacity,
bool isTest,
bool isColor,
int cropHeight,
int cropWidth,
int imgSize,
bool isEltMean,
bool isChannelMean,
float* meanValues);
virtual ~DataTransformer() {
if (meanValues_) {
free(meanValues_);
}
}
/**
* @brief Start multi-threads to transform a list of input data.
* The processed data will be saved in Queue of prefetchFull_.
*
* @param data Data containing the image string to be transformed.
* @param label The label of input image.
*/
void start(vector<char*>& data, int* datalen, int* labels);
/**
* @brief Applies the transformation on one image Mat.
*
* @param img The input img to be transformed.
* @param target target is used to save the transformed data.
*/
void transform(Mat& img, float* target);
/**
* @brief Decode the image string, then calls transform() function.
*
* @param src The input image string.
* @param size The length of string.
* @param trg trg is used to save the transformed data.
*/
void startFetching(const char* src, const int size, float* trg);
/**
* @brief Return the transformed data and its label.
*/
void obtain(float* data, int* label);
private:
int isTest_;
int isColor_;
int cropHeight_;
int cropWidth_;
int imgSize_;
int capacity_;
int fetchCount_;
bool isEltMean_;
bool isChannelMean_;
int numThreads_;
float scale_;
int imgPixels_;
float* meanValues_;
/**
* Initialize the mean values.
*/
void loadMean(float* values);
/**
* @brief Generates a random integer from Uniform({min, min + 1, ..., max}).
* @param min The lower bound (inclusive) value of the random number.
* @param max The upper bound (inclusive) value of the random number.
*
* @return
* A uniformly random integer value from ({min, min + 1, ..., max}).
*/
int Rand(int min, int max);
typedef pair<float*, int> DataType;
typedef std::shared_ptr<DataType> DataTypePtr;
std::vector<DataTypePtr> prefetch_;
std::unique_ptr<SyncThreadPool> syncThreadPool_;
BlockingQueue<DataTypePtr> prefetchFree_;
BlockingQueue<DataTypePtr> prefetchFull_;
}; // class DataTransformer
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <Python.h>
#include <time.h>
#include <vector>
#include <sys/time.h>
#include <unistd.h>
#include <glog/logging.h>
#include <numpy/arrayobject.h>
#include <boost/python.hpp>
#include "DataTransformer.h"
using namespace boost::python;
using namespace std;
/**
* DecodeJpeg is an image processing API for interfacing Python and C++
* code DataTransformer, which used OpenCV and multi-threads to accelerate
* image processing.
* The Boost Python Library is used to wrap C++ interfaces.
*/
class DecodeJpeg {
public:
/**
* The constructor will create and nitialize an object of DataTransformer.
*/
DecodeJpeg(int threadNum,
int capacity,
bool isTest,
bool isColor,
int resize_min_size,
int cropSizeH,
int cropSizeW,
PyObject* meanValues) {
int channel = isColor ? 3 : 1;
bool isEltMean = false;
bool isChannelMean = false;
float* mean = NULL;
if (meanValues || meanValues != Py_None) {
if (!PyArray_Check(meanValues)) {
LOG(FATAL) << "Object is not a numpy array";
}
pyTypeCheck(meanValues);
int size = PyArray_SIZE(meanValues);
isChannelMean = (size == channel) ? true : false;
isEltMean = (size == channel * cropSizeH * cropSizeW) ? true : false;
CHECK(isChannelMean != isEltMean);
mean = (float*)PyArray_DATA(meanValues);
}
tfhandlerPtr_ = std::make_shared<DataTransformer>(threadNum,
capacity,
isTest,
isColor,
cropSizeH,
cropSizeW,
resize_min_size,
isEltMean,
isChannelMean,
mean);
}
~DecodeJpeg() {}
/**
* @brief This function is used to parse the Python object and convert
* the data to C++ format. Then it called the function of
* DataTransformer to start image processing.
* @param pysrc The input image list with string type.
* @param pylabel The input label of image.
* It's type is numpy.array with int32.
*/
void start(boost::python::list& pysrc, PyObject* pydlen, PyObject* pylabel) {
vector<char*> data;
int num = len(pysrc);
for (int t = 0; t < num; ++t) {
char* src = boost::python::extract<char*>(pysrc[t]);
data.push_back(src);
}
int* dlen = (int*)PyArray_DATA(pydlen);
int* dlabels = (int*)PyArray_DATA(pylabel);
tfhandlerPtr_->start(data, dlen, dlabels);
}
/**
* @brief Return one processed data.
* @param pytrg The processed image.
* @param pylabel The label of processed image.
*/
void get(PyObject* pytrg, PyObject* pylab) {
pyWritableCheck(pytrg);
pyWritableCheck(pylab);
pyContinuousCheck(pytrg);
pyContinuousCheck(pylab);
float* data = (float*)PyArray_DATA(pytrg);
int* label = (int*)PyArray_DATA(pylab);
tfhandlerPtr_->obtain(data, label);
}
/**
* @brief An object of DataTransformer, which is used to call
* the image processing funtions.
*/
std::shared_ptr<DataTransformer> tfhandlerPtr_;
private:
/**
* @brief Check whether the type of PyObject is valid or not.
*/
void pyTypeCheck(const PyObject* o) {
int typenum = PyArray_TYPE(o);
// clang-format off
int type =
typenum == NPY_UBYTE ? CV_8U :
typenum == NPY_BYTE ? CV_8S :
typenum == NPY_USHORT ? CV_16U :
typenum == NPY_SHORT ? CV_16S :
typenum == NPY_INT || typenum == NPY_LONG ? CV_32S :
typenum == NPY_FLOAT ? CV_32F :
typenum == NPY_DOUBLE ? CV_64F : -1;
// clang-format on
if (type < 0) {
LOG(FATAL) << "toMat: Data type = " << type << " is not supported";
}
}
/**
* @brief Check whether the PyObject is writable or not.
*/
void pyWritableCheck(PyObject* o) { CHECK(PyArray_ISWRITEABLE(o)); }
/**
* @brief Check whether the PyObject is c-contiguous or not.
*/
void pyContinuousCheck(PyObject* o) { CHECK(PyArray_IS_C_CONTIGUOUS(o)); }
};
/**
* @brief Initialize the Python interpreter and numpy.
*/
static void initPython() {
Py_Initialize();
PyOS_sighandler_t sighandler = PyOS_getsig(SIGINT);
import_array();
PyOS_setsig(SIGINT, sighandler);
}
/**
* Use Boost.Python to expose C++ interface to Python.
*/
BOOST_PYTHON_MODULE(DeJpeg) {
initPython();
class_<DecodeJpeg>("DecodeJpeg",
init<int, int, bool, bool, int, int, int, PyObject*>())
.def("start", &DecodeJpeg::start)
.def("get", &DecodeJpeg::get);
};
import os, psutil
import cv2
from paddle.utils.image_util import *
import multiprocessing
import subprocess, signal, sys
class CvImageTransfomer(ImageTransformer):
"""
CvImageTransfomer used python-opencv to process image.
"""
def __init__(self,
min_size=None,
crop_size=None,
transpose=None,
channel_swap=None,
mean=None,
is_train=True,
is_color=True):
ImageTransformer.__init__(self, transpose, channel_swap, mean, is_color)
self.min_size = min_size
self.crop_size = crop_size
self.is_train = is_train
def cv_resize_fixed_short_side(self, im, min_size):
row, col = im.shape[:2]
scale = min_size / float(min(row, col))
if row < col:
row = min_size
col = int(round(col * scale))
col = col if col > min_size else min_size
else:
col = min_size
row = int(round(row * scale))
row = row if row > min_size else min_size
resized_size = row, col
im = cv2.resize(im, resized_size, interpolation=cv2.INTER_CUBIC)
return im
def crop_img(self, im):
"""
Return cropped image.
The size of the cropped image is inner_size * inner_size.
im: (H x W x K) ndarrays
"""
row, col = im.shape[:2]
start_h, start_w = 0, 0
if self.is_train:
start_h = np.random.randint(0, row - self.crop_size + 1)
start_w = np.random.randint(0, col - self.crop_size + 1)
else:
start_h = (row - self.crop_size) / 2
start_w = (col - self.crop_size) / 2
end_h, end_w = start_h + self.crop_size, start_w + self.crop_size
if self.is_color:
im = im[start_h:end_h, start_w:end_w, :]
else:
im = im[start_h:end_h, start_w:end_w]
if (self.is_train) and (np.random.randint(2) == 0):
if self.is_color:
im = im[:, ::-1, :]
else:
im = im[:, ::-1]
return im
def transform(self, im):
im = self.cv_resize_fixed_short_side(im, self.min_size)
im = self.crop_img(im)
# transpose, swap channel, sub mean
im = im.astype('float32')
ImageTransformer.transformer(self, im)
return im
def load_image_from_string(self, data):
flag = cv2.CV_LOAD_IMAGE_COLOR if self.is_color else cv2.CV_LOAD_IMAGE_GRAYSCALE
im = cv2.imdecode(np.fromstring(data, np.uint8), flag)
return im
def transform_from_string(self, data):
im = self.load_image_from_string(data)
return self.transform(im)
class MultiProcessImageTransfomer():
def __init__(self,
procnum=10,
capacity=10240,
min_size=None,
crop_size=None,
transpose=None,
channel_swap=None,
mean=None,
is_train=True,
is_color=True):
self.procnum = procnum
self.capacity = capacity
self.size = 0
self.count = 0
signal.signal(signal.SIGTERM, self.kill_child_processes)
self.fetch_queue = multiprocessing.Queue(maxsize=capacity)
self.cv_transformer = CvImageTransfomer(min_size, crop_size, transpose,
channel_swap, mean, is_train,
is_color)
def __del__(self):
try:
for p in self.procs:
p.join()
except Exception as e:
print str(e)
def reset(self, size):
self.size = size
self.count = 0
self.procs = []
def run_proc(self, data, label):
dlen = len(label)
self.reset(dlen)
for i in xrange(self.procnum):
start = dlen * i / self.procnum
end = dlen * (i + 1) / self.procnum
proc = multiprocessing.Process(
target=self.batch_transfomer,
args=(data[start:end], label[start:end]))
proc.daemon = True
self.procs.append(proc)
for p in self.procs:
p.start()
def get(self):
"""
Return one processed image.
"""
# block if necessary until an item is available
data, lab = self.fetch_queue.get(block=True)
self.count += 1
if self.count == self.size:
try:
for p in self.procs:
p.join()
except Exception as e:
print str(e)
return data, lab
def batch_transfomer(self, data, label):
"""
param data: input data in format of image string
type data: a list of string
label: the label of image
"""
for i in xrange(len(label)):
res = self.cv_transformer.transform_from_string(data[i])
self.fetch_queue.put((res, int(label[i])))
def kill_child_processes(self, signum, frame):
"""
Kill a process's child processes in python.
"""
parent_id = os.getpid()
ps_command = subprocess.Popen(
"ps -o pid --ppid %d --noheaders" % parent_id,
shell=True,
stdout=subprocess.PIPE)
ps_output = ps_command.stdout.read()
retcode = ps_command.wait()
for pid_str in ps_output.strip().split("\n")[:-1]:
os.kill(int(pid_str), signal.SIGTERM)
sys.exit()
......@@ -186,29 +186,32 @@ class ImageTransformer:
channel_swap=None,
mean=None,
is_color=True):
self.transpose = transpose
self.channel_swap = None
self.mean = None
self.is_color = is_color
self.set_transpose(transpose)
self.set_channel_swap(channel_swap)
self.set_mean(mean)
def set_transpose(self, order):
if self.is_color:
assert 3 == len(order)
if order is not None:
if self.is_color:
assert 3 == len(order)
self.transpose = order
def set_channel_swap(self, order):
if self.is_color:
assert 3 == len(order)
if order is not None:
if self.is_color:
assert 3 == len(order)
self.channel_swap = order
def set_mean(self, mean):
# mean value, may be one value per channel
if mean.ndim == 1:
mean = mean[:, np.newaxis, np.newaxis]
else:
# elementwise mean
if self.is_color:
assert len(mean.shape) == 3
if mean is not None:
# mean value, may be one value per channel
if mean.ndim == 1:
mean = mean[:, np.newaxis, np.newaxis]
else:
# elementwise mean
if self.is_color:
assert len(mean.shape) == 3
self.mean = mean
def transformer(self, data):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册