From 5929ac80d7525b7c52aca95cc1deea1741d92c76 Mon Sep 17 00:00:00 2001 From: BohaoWu <37072443+BohaoWu@users.noreply.github.com> Date: Mon, 28 Sep 2020 21:00:15 +0800 Subject: [PATCH] Init Zelda --- core/preprocess/hwvideoframe/CMakeLists.txt | 32 +++ core/preprocess/hwvideoframe/README.md | 14 + .../hwvideoframe/cuda/CMakeLists.txt | 22 ++ core/preprocess/hwvideoframe/cuda/resize.cu | 251 ++++++++++++++++++ .../hwvideoframe/include/center_crop.h | 30 +++ core/preprocess/hwvideoframe/include/div.h | 29 ++ .../hwvideoframe/include/image_io.h | 33 +++ .../hwvideoframe/include/normalize.h | 32 +++ .../hwvideoframe/include/op_context.h | 76 ++++++ core/preprocess/hwvideoframe/include/resize.h | 54 ++++ .../hwvideoframe/include/resize_by_factor.h | 48 ++++ .../hwvideoframe/include/rgb_swap.h | 33 +++ core/preprocess/hwvideoframe/include/sub.h | 31 +++ core/preprocess/hwvideoframe/include/utils.h | 27 ++ .../pybind/pybind_gpu_preprocess.cpp | 58 ++++ .../hwvideoframe/src/center_crop.cpp | 34 +++ core/preprocess/hwvideoframe/src/div.cpp | 26 ++ core/preprocess/hwvideoframe/src/image_io.cpp | 42 +++ .../preprocess/hwvideoframe/src/normalize.cpp | 37 +++ core/preprocess/hwvideoframe/src/resize.cpp | 43 +++ .../hwvideoframe/src/resize_by_factor.cpp | 79 ++++++ core/preprocess/hwvideoframe/src/rgb_swap.cpp | 22 ++ core/preprocess/hwvideoframe/src/sub.cpp | 37 +++ core/preprocess/hwvideoframe/src/utlis.cpp | 23 ++ .../reader/test_audio_reader.py | 20 ++ .../reader/test_frame_reader.py | 42 +++ .../reader/test_functional.py | 169 ++++++++++++ python/setup.py.app.in | 35 ++- 28 files changed, 1372 insertions(+), 7 deletions(-) create mode 100644 core/preprocess/hwvideoframe/CMakeLists.txt create mode 100644 core/preprocess/hwvideoframe/README.md create mode 100644 core/preprocess/hwvideoframe/cuda/CMakeLists.txt create mode 100644 core/preprocess/hwvideoframe/cuda/resize.cu create mode 100644 core/preprocess/hwvideoframe/include/center_crop.h create mode 100644 core/preprocess/hwvideoframe/include/div.h create mode 100644 core/preprocess/hwvideoframe/include/image_io.h create mode 100644 core/preprocess/hwvideoframe/include/normalize.h create mode 100644 core/preprocess/hwvideoframe/include/op_context.h create mode 100644 core/preprocess/hwvideoframe/include/resize.h create mode 100644 core/preprocess/hwvideoframe/include/resize_by_factor.h create mode 100644 core/preprocess/hwvideoframe/include/rgb_swap.h create mode 100644 core/preprocess/hwvideoframe/include/sub.h create mode 100644 core/preprocess/hwvideoframe/include/utils.h create mode 100644 core/preprocess/hwvideoframe/pybind/pybind_gpu_preprocess.cpp create mode 100644 core/preprocess/hwvideoframe/src/center_crop.cpp create mode 100644 core/preprocess/hwvideoframe/src/div.cpp create mode 100644 core/preprocess/hwvideoframe/src/image_io.cpp create mode 100644 core/preprocess/hwvideoframe/src/normalize.cpp create mode 100644 core/preprocess/hwvideoframe/src/resize.cpp create mode 100644 core/preprocess/hwvideoframe/src/resize_by_factor.cpp create mode 100644 core/preprocess/hwvideoframe/src/rgb_swap.cpp create mode 100644 core/preprocess/hwvideoframe/src/sub.cpp create mode 100644 core/preprocess/hwvideoframe/src/utlis.cpp create mode 100644 python/paddle_serving_app/reader/test_audio_reader.py create mode 100644 python/paddle_serving_app/reader/test_frame_reader.py create mode 100644 python/paddle_serving_app/reader/test_functional.py diff --git a/core/preprocess/hwvideoframe/CMakeLists.txt b/core/preprocess/hwvideoframe/CMakeLists.txt new file mode 100644 index 00000000..768fdbda --- /dev/null +++ b/core/preprocess/hwvideoframe/CMakeLists.txt @@ -0,0 +1,32 @@ +cmake_minimum_required(VERSION 3.2) +project(hw-frame-extract) +# SET(CUDA_VERSION 10.1) + +#gcc version +#GCC('gcc482') +#CUDA("10.1") +set(global_cflags_str "-g -pipe -W -Wall -fPIC") +set(CMAKE_C_FLAGS ${global_cflags_str}) + +#C++ flags. +set(global_cxxflags_str "-g -pipe -W -Wall -fPIC -std=c++11") +set(CMAKE_CXX_FLAGS ${global_cxxflags_str}) + +add_subdirectory(cuda) +add_subdirectory(pybind11) +set (EXTRA_LIBS ${EXTRA_LIBS} gpu) + +message(${CMAKE_CURRENT_SOURCE_DIR}) +include_directories("${CMAKE_CURRENT_SOURCE_DIR}/include") +include_directories("${CMAKE_CURRENT_SOURCE_DIR}/pybind11/include") +include_directories("/opt/compiler/cuda-10.1/include") +include_directories("/home/work/wubohao/baidu/third-party/python/include/python2.7") + +file(GLOB SOURCE_FILES src/*.cpp pybind11/*.cpp ) + +link_directories("-L/opt/compiler/cuda-10.1/lib64 -lcudart -lnppidei_static -lnppial_static -lnpps_static -lnppc_static -lculibos") +# link_directories("/home/work/wubohao/baidu/third-party/python/lib") + +#.so +add_library(gpupreprocess SHARED ${SOURCE_FILES}) +target_link_libraries (gpupreprocess ${EXTRA_LIBS}) diff --git a/core/preprocess/hwvideoframe/README.md b/core/preprocess/hwvideoframe/README.md new file mode 100644 index 00000000..ab3b27d3 --- /dev/null +++ b/core/preprocess/hwvideoframe/README.md @@ -0,0 +1,14 @@ +# hwvideoframe +简要说明 + +## 快速开始 +如何构建、安装、运行 + +## 测试 +如何执行自动化测试 + +## 如何贡献 +贡献patch流程、质量要求 + +## 讨论 +百度Hi讨论群:XXXX diff --git a/core/preprocess/hwvideoframe/cuda/CMakeLists.txt b/core/preprocess/hwvideoframe/cuda/CMakeLists.txt new file mode 100644 index 00000000..4efef2ad --- /dev/null +++ b/core/preprocess/hwvideoframe/cuda/CMakeLists.txt @@ -0,0 +1,22 @@ +cmake_minimum_required(VERSION 3.2) +project(gpu) +FIND_PACKAGE(CUDA ${CUDA_VERSION} REQUIRED) +SET(CUDA_TARGET_INCLUDE ${CUDA_TOOLKIT_ROOT_DIR}-${CUDA_VERSION}/targets/${CMAKE_HOST_SYSTEM_PROCESSOR}-${LOWER_SYSTEM_NAME}/include) + +file(GLOB_RECURSE CURRENT_HEADERS *.h *.hpp *.cuh) +file(GLOB CURRENT_SOURCES *.cpp *.cu) +file(GLOB CUDA_LIBS /opt/compiler/cuda-10.1/lib64/*.so) + +source_group("Include" FILES ${CURRENT_HEADERS}) +source_group("Source" FILES ${CURRENT_SOURCES}) + +set(CMAKE_CUDA_FLAGS "-ccbin /opt/compiler/gcc-4.8.2/bin -Xcompiler -fPIC --std=c++11") +set(CUDA_NVCC_FLAGS "-L/opt/compiler/gcc-4.8.2/bin -Xcompiler -fPIC --std=c++11") + +include_directories("/opt/compiler/cuda-10.1/include") + +#cuda_add_library(gpu SHARED ${CURRENT_HEADERS} ${CURRENT_SOURCES}) +cuda_add_library(gpu SHARED ${CURRENT_HEADERS} ${CURRENT_SOURCES}) +target_link_libraries(gpu ${CUDA_LIBS}) + +# import libgpupreprocess as pp \ No newline at end of file diff --git a/core/preprocess/hwvideoframe/cuda/resize.cu b/core/preprocess/hwvideoframe/cuda/resize.cu new file mode 100644 index 00000000..8475611b --- /dev/null +++ b/core/preprocess/hwvideoframe/cuda/resize.cu @@ -0,0 +1,251 @@ +#include "cuda_runtime.h" + +#define clip(x, a, b) x >= a ? (x < b ? x : b-1) : a; + +const int INTER_RESIZE_COEF_BITS = 11; +const int INTER_RESIZE_COEF_SCALE = 1 << INTER_RESIZE_COEF_BITS; + +__global__ void resizeCudaKernel( const float* input, + float* output, + const int inputWidth, + const int inputHeight, + const int outputWidth, + const int outputHeight, + const int inputChannels) +{ + //2D Index of current thread + const int dx = blockIdx.x * blockDim.x + threadIdx.x; + const int dy = blockIdx.y * blockDim.y + threadIdx.y; + + if((dx < outputWidth) && (dy < outputHeight)) + { + if(inputChannels == 1) { // grayscale image + // TODO: support grayscale + } else if(inputChannels == 3) { // RGB image + + double scale_x = (double) inputWidth / outputWidth; + double scale_y = (double) inputHeight / outputHeight; + + int xmax = outputWidth; + + float fx = (float)((dx + 0.5) * scale_x - 0.5); + int sx = floorf(fx); + fx = fx - sx; + + int isx1 = sx; + if (isx1 < 0) { + fx = 0.0; + isx1 = 0; + } + if (isx1 >= (inputWidth - 1)) { + xmax = ::min( xmax, dx); + fx = 0; + isx1 = inputWidth - 1; + } + + float2 cbufx; + cbufx.x = (1.f - fx); + cbufx.y = fx; + + float fy = (float)((dy + 0.5) * scale_y - 0.5); + int sy = floorf(fy); + fy = fy - sy; + + int isy1 = clip(sy + 0, 0, inputHeight); + int isy2 = clip(sy + 1, 0, inputHeight); + + float2 cbufy; + cbufy.x = (1.f - fy); + cbufy.y = fy; + + int isx2 = isx1 + 1; + + float3 d0; + + float3 s11 = make_float3(input[(isy1 * inputWidth + isx1) * inputChannels + 0] , input[(isy1 * inputWidth + isx1) * inputChannels + 1] , input[(isy1 * inputWidth + isx1) * inputChannels + 2]); + float3 s12 = make_float3(input[(isy1 * inputWidth + isx2) * inputChannels + 0] , input[(isy1 * inputWidth + isx2) * inputChannels + 1] , input[(isy1 * inputWidth + isx2) * inputChannels + 2]); + float3 s21 = make_float3(input[(isy2 * inputWidth + isx1) * inputChannels + 0] , input[(isy2 * inputWidth + isx1) * inputChannels + 1] , input[(isy2 * inputWidth + isx1) * inputChannels + 2]); + float3 s22 = make_float3(input[(isy2 * inputWidth + isx2) * inputChannels + 0] , input[(isy2 * inputWidth + isx2) * inputChannels + 1] , input[(isy2 * inputWidth + isx2) * inputChannels + 2]); + + float h_rst00, h_rst01; + // B + if( dx > xmax - 1) + { + h_rst00 = s11.x; + h_rst01 = s21.x; + } + else + { + h_rst00 = s11.x * cbufx.x + s12.x * cbufx.y; + h_rst01 = s21.x * cbufx.x + s22.x * cbufx.y; + } + d0.x = h_rst00 * cbufy.x + h_rst01 * cbufy.y; + + // G + if( dx > xmax - 1) + { + h_rst00 = s11.y; + h_rst01 = s21.y; + } + else + { + h_rst00 = s11.y * cbufx.x + s12.y * cbufx.y; + h_rst01 = s21.y * cbufx.x + s22.y * cbufx.y; + } + d0.y = h_rst00 * cbufy.x + h_rst01 * cbufy.y; + // R + if( dx > xmax - 1) + { + h_rst00 = s11.z; + h_rst01 = s21.z; + } + else + { + h_rst00 = s11.z * cbufx.x + s12.z * cbufx.y; + h_rst01 = s21.z * cbufx.x + s22.z * cbufx.y; + } + d0.z = h_rst00 * cbufy.x + h_rst01 * cbufy.y; + + output[(dy*outputWidth + dx) * 3 + 0 ] = (d0.x); // R + output[(dy*outputWidth + dx) * 3 + 1 ] = (d0.y); // G + output[(dy*outputWidth + dx) * 3 + 2 ] = (d0.z); // B + } else { + // TODO: support alpha channel + } + } +} + +__global__ void resizeCudaKernel_fixpt( const float* input, + float* output, + const int inputWidth, + const int inputHeight, + const int outputWidth, + const int outputHeight, + const int inputChannels) +{ + //2D Index of current thread + const int dx = blockIdx.x * blockDim.x + threadIdx.x; + const int dy = blockIdx.y * blockDim.y + threadIdx.y; + + if((dx < outputWidth) && (dy < outputHeight)) + { + if(inputChannels == 1) { // grayscale image + // TODO: support grayscale + } else if(inputChannels == 3) { // RGB image + + double scale_x = (double) inputWidth / outputWidth; + double scale_y = (double) inputHeight / outputHeight; + + int xmax = outputWidth; + + float fx = (float)((dx + 0.5) * scale_x - 0.5); + int sx = floorf(fx); + fx = fx - sx; + + int isx1 = sx; + if (isx1 < 0) { + fx = 0.0; + isx1 = 0; + } + if (isx1 >= (inputWidth - 1)) { + xmax = ::min( xmax, dx); + fx = 0; + isx1 = inputWidth - 1; + } + + short2 cbufx; + cbufx.x = lrintf((1.f - fx) * INTER_RESIZE_COEF_SCALE); + cbufx.y = lrintf(fx * INTER_RESIZE_COEF_SCALE); + + float fy = (float)((dy + 0.5) * scale_y - 0.5); + int sy = floorf(fy); + fy = fy - sy; + + int isy1 = clip(sy + 0, 0, inputHeight); + int isy2 = clip(sy + 1, 0, inputHeight); + + short2 cbufy; + cbufy.x = lrintf((1.f - fy) * INTER_RESIZE_COEF_SCALE); + cbufy.y = lrintf(fy * INTER_RESIZE_COEF_SCALE); + + int isx2 = isx1 + 1; + + uchar3 d0; + + int3 s11 = make_int3(input[(isy1 * inputWidth + isx1) * inputChannels + 0] , input[(isy1 * inputWidth + isx1) * inputChannels + 1] , input[(isy1 * inputWidth + isx1) * inputChannels + 2]); + int3 s12 = make_int3(input[(isy1 * inputWidth + isx2) * inputChannels + 0] , input[(isy1 * inputWidth + isx2) * inputChannels + 1] , input[(isy1 * inputWidth + isx2) * inputChannels + 2]); + int3 s21 = make_int3(input[(isy2 * inputWidth + isx1) * inputChannels + 0] , input[(isy2 * inputWidth + isx1) * inputChannels + 1] , input[(isy2 * inputWidth + isx1) * inputChannels + 2]); + int3 s22 = make_int3(input[(isy2 * inputWidth + isx2) * inputChannels + 0] , input[(isy2 * inputWidth + isx2) * inputChannels + 1] , input[(isy2 * inputWidth + isx2) * inputChannels + 2]); + + int h_rst00, h_rst01; + // B + if( dx > xmax - 1) + { + h_rst00 = s11.x * INTER_RESIZE_COEF_SCALE; + h_rst01 = s21.x * INTER_RESIZE_COEF_SCALE; + } + else + { + h_rst00 = s11.x * cbufx.x + s12.x * cbufx.y; + h_rst01 = s21.x * cbufx.x + s22.x * cbufx.y; + } + d0.x = (unsigned char)(( ((cbufy.x * (h_rst00 >> 4)) >> 16) + ((cbufy.y * (h_rst01 >> 4)) >> 16) + 2) >> 2); + + // G + if( dx > xmax - 1) + { + h_rst00 = s11.y * INTER_RESIZE_COEF_SCALE; + h_rst01 = s21.y * INTER_RESIZE_COEF_SCALE; + } + else + { + h_rst00 = s11.y * cbufx.x + s12.y * cbufx.y; + h_rst01 = s21.y * cbufx.x + s22.y * cbufx.y; + } + d0.y = (unsigned char)(( ((cbufy.x * (h_rst00 >> 4)) >> 16) + ((cbufy.y * (h_rst01 >> 4)) >> 16) + 2) >> 2); + // R + if( dx > xmax - 1) + { + h_rst00 = s11.z * INTER_RESIZE_COEF_SCALE; + h_rst01 = s21.z * INTER_RESIZE_COEF_SCALE; + } + else + { + h_rst00 = s11.z * cbufx.x + s12.z * cbufx.y; + h_rst01 = s21.z * cbufx.x + s22.z * cbufx.y; + } + d0.z = (unsigned char)(( ((cbufy.x * (h_rst00 >> 4)) >> 16) + ((cbufy.y * (h_rst01 >> 4)) >> 16) + 2) >> 2); + + output[(dy*outputWidth + dx) * 3 + 0 ] = (d0.x); // R + output[(dy*outputWidth + dx) * 3 + 1 ] = (d0.y); // G + output[(dy*outputWidth + dx) * 3 + 2 ] = (d0.z); // B + } else { + // TODO: support alpha channel + } + } +} + +extern "C" cudaError_t resize_linear( const float* input, + float* output, + const int inputWidth, + const int inputHeight, + const int outputWidth, + const int outputHeight, + const int inputChannels, + const bool use_fixed_point) { + //Specify a reasonable block size + const dim3 block(16,16); + + //Calculate grid size to cover the whole image + const dim3 grid((outputWidth + block.x - 1) / block.x, (outputHeight + block.y - 1) / block.y); + + //Launch the size conversion kernel + if (use_fixed_point) { + resizeCudaKernel_fixpt<<>>(input, output, inputWidth, inputHeight, outputWidth, outputHeight, inputChannels); + } else { + resizeCudaKernel<<>>(input, output, inputWidth, inputHeight, outputWidth, outputHeight, inputChannels); + } + + //Synchronize to check for any kernel launch errors + return cudaDeviceSynchronize(); +} \ No newline at end of file diff --git a/core/preprocess/hwvideoframe/include/center_crop.h b/core/preprocess/hwvideoframe/include/center_crop.h new file mode 100644 index 00000000..8a9d75ab --- /dev/null +++ b/core/preprocess/hwvideoframe/include/center_crop.h @@ -0,0 +1,30 @@ +/******************************************* + * + * Copyright (c) 2020 Baidu.com, Inc. All Rights Reserved + * + ******************************************/ +/** + * @file center_crop.h + * @author yinyijun@baidu.com + * @date 2020-06-15 + **/ + +#ifndef BAIDU_CVPRE_CENTER_CROP_H +#define BAIDU_CVPRE_CENTER_CROP_H + +#include +#include +#include "op_context.h" + +// Crops the given Image at the center. +// the size must not bigger than any inputs' height and width +class CenterCrop { +public: + CenterCrop(int size) : _size(size) {}; + std::shared_ptr operator()(std::shared_ptr input); + +private: + int _size; +}; + +#endif // BAIDU_CVPRE_CENTER_CROP_H \ No newline at end of file diff --git a/core/preprocess/hwvideoframe/include/div.h b/core/preprocess/hwvideoframe/include/div.h new file mode 100644 index 00000000..ad5fe2de --- /dev/null +++ b/core/preprocess/hwvideoframe/include/div.h @@ -0,0 +1,29 @@ +/******************************************* + * + * Copyright (c) 2020 Baidu.com, Inc. All Rights Reserved + * + ******************************************/ +/** + * @file div.h + * @author yinyijun@baidu.com + * @date 2020-06-11 + **/ + +#ifndef BAIDU_CVPRE_DIV_H +#define BAIDU_CVPRE_DIV_H + +#include +#include +#include "op_context.h" + +// divide by some float number for all pixel +class Div { +public: + Div(float value); + std::shared_ptr operator()(std::shared_ptr input); + +private: + Npp32f _divisor; +}; + +#endif // BAIDU_CVPRE_DIV_H \ No newline at end of file diff --git a/core/preprocess/hwvideoframe/include/image_io.h b/core/preprocess/hwvideoframe/include/image_io.h new file mode 100644 index 00000000..3f4b3e07 --- /dev/null +++ b/core/preprocess/hwvideoframe/include/image_io.h @@ -0,0 +1,33 @@ +/******************************************* + * + * Copyright (c) 2020 Baidu.com, Inc. All Rights Reserved + * + ******************************************/ +/** + * @file image_io.h + * @author yinyijun@baidu.com + * @date 2020-06-08 + **/ + +#ifndef BAIDU_CVPRE_IMAGE_IO_H +#define BAIDU_CVPRE_IMAGE_IO_H + +#include +#include +#include + +#include "op_context.h" + +// Input operator that copy numpy data to gpu buffer +class Image2Gpubuffer { +public: + std::shared_ptr operator()(pybind11::array_t array); +}; + +// Output operator that copy gpu buffer data to numpy +class Gpubuffer2Image { +public: + pybind11::array_t operator()(std::shared_ptr input); +}; + +#endif // BAIDU_CVPRE_IMAGE_IO_H \ No newline at end of file diff --git a/core/preprocess/hwvideoframe/include/normalize.h b/core/preprocess/hwvideoframe/include/normalize.h new file mode 100644 index 00000000..54e810a8 --- /dev/null +++ b/core/preprocess/hwvideoframe/include/normalize.h @@ -0,0 +1,32 @@ +/******************************************* + * + * Copyright (c) 2020 Baidu.com, Inc. All Rights Reserved + * + ******************************************/ +/** + * @file normalize.h + * @author yinyijun@baidu.com + * @date 2020-06-04 + **/ + +#ifndef BAIDU_CVPRE_NORMALIZE_H +#define BAIDU_CVPRE_NORMALIZE_H + +#include +#include +#include +#include "op_context.h" + +// utilize normalize operator on gpu +class Normalize { +public: + Normalize(const std::vector &mean, const std::vector &std, bool channel_first = false); + std::shared_ptr operator()(std::shared_ptr input); + +private: + Npp32f _mean[CHANNEL_SIZE]; + Npp32f _std[CHANNEL_SIZE]; + bool _channel_first; // indicate whether the channel is dimension 0, unsupported +}; + +#endif // BAIDU_CVPRE_NORMALIZE_H \ No newline at end of file diff --git a/core/preprocess/hwvideoframe/include/op_context.h b/core/preprocess/hwvideoframe/include/op_context.h new file mode 100644 index 00000000..7f278e44 --- /dev/null +++ b/core/preprocess/hwvideoframe/include/op_context.h @@ -0,0 +1,76 @@ +/******************************************* + * + * Copyright (c) 2020 Baidu.com, Inc. All Rights Reserved + * + ******************************************/ +/** + * @file op_context.h + * @author yinyijun@baidu.com + * @date 2020-06-03 + **/ + +#ifndef BAIDU_CVPRE_OPCONTEXT_H +#define BAIDU_CVPRE_OPCONTEXT_H + +#include + +const size_t CHANNEL_SIZE = 3; + +// The context as input/ouput of all operators +// contains pointer to raw data on gpu, frame size +class OpContext { +public: + OpContext() { + _step = 0; + _size = 0; + _p_frame = nullptr; + } + // constructor to apply gpu memory of image raw data + OpContext(int height, int width) { + _step = sizeof(Npp32f) * width * CHANNEL_SIZE; + _length = height * width * CHANNEL_SIZE; + _size = _step * height; + _nppi_size.height = height; + _nppi_size.width = width; + cudaMalloc((void **)(&_p_frame), _size); + } + virtual ~OpContext() { + free_memory(); + } + +public: + Npp32f* p_frame() const { + return _p_frame; + } + int step() const { + return _step; + } + int length() const { + return _length; + } + int size() const { + return _size; + } + NppiSize& nppi_size() { + return _nppi_size; + } + void free_memory() { + if (_p_frame != nullptr) { + cudaFree(_p_frame); + _p_frame = nullptr; + } + _nppi_size.height = 0; + _nppi_size.width = 0; + _step = 0; + _size = 0; + } + +private: + Npp32f *_p_frame; // pointer to raw data on gpu + int _step; // number of bytes in a row + int _length; // length of _p_frame, _size = _length * sizeof(Npp32f) + int _size; // number of bytes of the image + NppiSize _nppi_size; // contains height and width +}; + +#endif // BAIDU_CVPRE_OPCONTEXT_H \ No newline at end of file diff --git a/core/preprocess/hwvideoframe/include/resize.h b/core/preprocess/hwvideoframe/include/resize.h new file mode 100644 index 00000000..64bea4f2 --- /dev/null +++ b/core/preprocess/hwvideoframe/include/resize.h @@ -0,0 +1,54 @@ +/******************************************* + * + * Copyright (c) 2020 Baidu.com, Inc. All Rights Reserved + * + ******************************************/ +/** + * @file resize.h + * @author yinyijun@baidu.com + * @date 2020-06-15 + **/ + +#ifndef BAIDU_CVPRE_RESIZE_H +#define BAIDU_CVPRE_RESIZE_H + +#include +#include + +#include +#include "op_context.h" + +extern "C" cudaError_t resize_linear(const float* input, + float* output, + const int inputWidth, + const int inputHeight, + const int outputWidth, + const int outputHeight, + const int inputChannels, + const bool use_fixed_point); + +// Resize the input numpy array Image to the given size. +// only support linear interpolation +// only support RGB channels +class Resize { +public: + // size is an int, smaller edge of the image will be matched to this number. + Resize(int size, int max_size = 214748364, bool use_fixed_point = false, int interpolation = 0) + : _size(size), _max_size(max_size), _use_fixed_point(use_fixed_point), _interpolation(interpolation) {}; + // size is a sequence like (w, h), output size will be matched to this + Resize(std::vector size, int max_size = 214748364, bool use_fixed_point = false, int interpolation = 0) + : _size(-1), _max_size(max_size), _use_fixed_point(use_fixed_point), _interpolation(interpolation) { + _target_size[0] = size[0]; + _target_size[1] = size[1]; + } + std::shared_ptr operator()(std::shared_ptr input); + +private: + int _size; // target of smaller edge + int _target_size[2]; // target size sequence (w, h) + int _max_size; + bool _use_fixed_point; + int _interpolation; // unused +}; + +#endif // BAIDU_CVPRE_RESIZE_H \ No newline at end of file diff --git a/core/preprocess/hwvideoframe/include/resize_by_factor.h b/core/preprocess/hwvideoframe/include/resize_by_factor.h new file mode 100644 index 00000000..36f1625e --- /dev/null +++ b/core/preprocess/hwvideoframe/include/resize_by_factor.h @@ -0,0 +1,48 @@ +/******************************************* + * + * Copyright (c) 2020 Baidu.com, Inc. All Rights Reserved + * + ******************************************/ +/** + * @file resize_by_factor.h + * @author wubohao@baidu.com + * @date 2020-07-21 + **/ + +#ifndef BAIDU_CVPRE_RESIZE_BY_FACTOR_H +#define BAIDU_CVPRE_RESIZE_BY_FACTOR_H + +#include +#include + +#include +#include "op_context.h" + +extern "C" cudaError_t resize_linear(const float *input, + float *output, + const int inputWidth, + const int inputHeight, + const int outputWidth, + const int outputHeight, + const int inputChannels, + const bool use_fixed_point); + +// Resize the input numpy array Image to a size multiple of factor which is usually required by a network +// only support linear interpolation +// only support RGB channels +class ResizeByFactor +{ +public: + // Resize factor. make width and height multiple factor of the value of factor. Default is 32 + ResizeByFactor(int factor = 32, int max_side_len = 2400, bool use_fixed_point = false, int interpolation = 0) + : _factor(factor), _max_side_len(max_side_len), _use_fixed_point(use_fixed_point), _interpolation(interpolation){}; + std::shared_ptr operator()(std::shared_ptr input); + +private: + int _factor; // target of smaller edge + int _max_side_len; + bool _use_fixed_point; + int _interpolation; // unused +}; + +#endif // BAIDU_CVPRE_RESIZE_BY_FACTOR_H \ No newline at end of file diff --git a/core/preprocess/hwvideoframe/include/rgb_swap.h b/core/preprocess/hwvideoframe/include/rgb_swap.h new file mode 100644 index 00000000..037c0163 --- /dev/null +++ b/core/preprocess/hwvideoframe/include/rgb_swap.h @@ -0,0 +1,33 @@ +/******************************************* + * + * Copyright (c) 2020 Baidu.com, Inc. All Rights Reserved + * + ******************************************/ +/** + * @file rgb_swap.h + * @author yinyijun@baidu.com + * @date 2020-06-01 + **/ + +#ifndef BAIDU_CVPRE_RGB_SWAP_H +#define BAIDU_CVPRE_RGB_SWAP_H + +#include +#include +#include "op_context.h" + +// swap channel 0 and channel 2 for every pixel +// both RGB2BGR and BGR2RGB use this operator +class SwapChannel { +public: + SwapChannel(){}; + std::shared_ptr operator()(std::shared_ptr input); + +private: + static const int _ORDER[CHANNEL_SIZE]; // describing how channel values are permutated +}; + +class RGB2BGR : public SwapChannel {}; +class BGR2RGB : public SwapChannel {}; + +#endif // BAIDU_CVPRE_RGB_SWAP_H diff --git a/core/preprocess/hwvideoframe/include/sub.h b/core/preprocess/hwvideoframe/include/sub.h new file mode 100644 index 00000000..54202f4f --- /dev/null +++ b/core/preprocess/hwvideoframe/include/sub.h @@ -0,0 +1,31 @@ +/******************************************* + * + * Copyright (c) 2020 Baidu.com, Inc. All Rights Reserved + * + ******************************************/ +/** + * @file sub.h + * @author yinyijun@baidu.com + * @date 2020-07-06 + **/ + +#ifndef BAIDU_CVPRE_SUB_H +#define BAIDU_CVPRE_SUB_H + +#include +#include +#include +#include "op_context.h" + +// subtract by some float numbers +class Sub { +public: + Sub(float subtractor); + Sub(const std::vector &subtractors); + std::shared_ptr operator()(std::shared_ptr input); + +private: + Npp32f _subtractors[CHANNEL_SIZE]; +}; + +#endif // BAIDU_CVPRE_SUB_H \ No newline at end of file diff --git a/core/preprocess/hwvideoframe/include/utils.h b/core/preprocess/hwvideoframe/include/utils.h new file mode 100644 index 00000000..85157cd4 --- /dev/null +++ b/core/preprocess/hwvideoframe/include/utils.h @@ -0,0 +1,27 @@ +/******************************************* + * + * Copyright (c) 2020 Baidu.com, Inc. All Rights Reserved + * + ******************************************/ +/** + * @file rgb_swap.h + * @author yinyijun@baidu.com + * @date 2020-06-10 + **/ + +#ifndef BAIDU_CVPRE_UTILS_H +#define BAIDU_CVPRE_UTILS_H + +#include + +#include + +// verify return value of npp function +// throw an exception if failed +void verify_npp_ret(const std::string& function_name, NppStatus ret); + +// verify return value of cuda runtime function +// throw an exception if failed +void verify_cuda_ret(const std::string& function_name, cudaError_t ret); + +#endif // BAIDU_CVPRE_UTILS_H \ No newline at end of file diff --git a/core/preprocess/hwvideoframe/pybind/pybind_gpu_preprocess.cpp b/core/preprocess/hwvideoframe/pybind/pybind_gpu_preprocess.cpp new file mode 100644 index 00000000..3d0989f3 --- /dev/null +++ b/core/preprocess/hwvideoframe/pybind/pybind_gpu_preprocess.cpp @@ -0,0 +1,58 @@ +#include +#include + +#include "div.h" +#include "sub.h" +#include "image_io.h" +#include "rgb_swap.h" +#include "normalize.h" +#include "center_crop.h" +#include "resize.h" +#include "resize_by_factor.h" + +PYBIND11_MODULE(libgpupreprocess, m) { + pybind11::class_>(m, "OpContext"); + pybind11::class_(m, "Image2Gpubuffer") + .def(pybind11::init<>()) + .def("__call__", &Image2Gpubuffer::operator()); + pybind11::class_(m, "Gpubuffer2Image") + .def(pybind11::init<>()) + .def("__call__", &Gpubuffer2Image::operator()); + pybind11::class_(m, "RGB2BGR") + .def(pybind11::init<>()) + .def("__call__", &RGB2BGR::operator()); + pybind11::class_(m, "BGR2RGB") + .def(pybind11::init<>()) + .def("__call__", &BGR2RGB::operator()); + pybind11::class_
(m, "Div") + .def(pybind11::init()) + .def("__call__", &Div::operator()); + pybind11::class_(m, "Sub") + .def(pybind11::init()) + .def(pybind11::init&>()) + .def("__call__", &Sub::operator()); + pybind11::class_(m, "Normalize") + .def(pybind11::init&, const std::vector&, bool>(), + pybind11::arg("mean"), pybind11::arg("std"), pybind11::arg("channel_first") = false) + .def("__call__", &Normalize::operator()); + pybind11::class_(m, "CenterCrop") + .def(pybind11::init()) + .def("__call__", &CenterCrop::operator()); + pybind11::class_(m, "Resize") + .def(pybind11::init(), + pybind11::arg("size"), + pybind11::arg("max_size") = 214748364, + pybind11::arg("use_fixed_point") = false) + .def(pybind11::init&, int, bool>(), + pybind11::arg("target_size"), + pybind11::arg("max_size") = 214748364, + pybind11::arg("use_fixed_point") = false) + .def("__call__", &Resize::operator()); + pybind11::class_(m, "ResizeByFactor") + .def(pybind11::init(), + pybind11::arg("factor") = 32, + pybind11::arg("max_side_len") = 2400, + pybind11::arg("use_fixed_point") = false) + .def("__call__", &ResizeByFactor::operator()); +} + diff --git a/core/preprocess/hwvideoframe/src/center_crop.cpp b/core/preprocess/hwvideoframe/src/center_crop.cpp new file mode 100644 index 00000000..cb882500 --- /dev/null +++ b/core/preprocess/hwvideoframe/src/center_crop.cpp @@ -0,0 +1,34 @@ +/******************************************* + * + * Copyright (c) 2020 Baidu.com, Inc. All Rights Reserved + * + ******************************************/ +/** + * @file center_crop.cpp + * @author yinyijun@baidu.com + * @date 2020-06-15 + **/ +#include "center_crop.h" + +#include + +#include +#include "utils.h" + +std::shared_ptr CenterCrop::operator()(std::shared_ptr input) { + int new_width = std::min(_size, input->nppi_size().width); + int new_height = std::min(_size, input->nppi_size().height); + auto output = std::make_shared(new_height, new_width); + int x_start = (input->nppi_size().width - new_width) / 2; + int y_start = (input->nppi_size().height - new_height) / 2; + Npp32f* p_src = input->p_frame() + + y_start * input->nppi_size().width * CHANNEL_SIZE + + x_start * CHANNEL_SIZE; + NppStatus ret = nppiCopy_32f_C3R(p_src, + input->step(), + output->p_frame(), + output->step(), + output->nppi_size()); + verify_npp_ret("nppiCopy_32f_C3R", ret); + return output; +} diff --git a/core/preprocess/hwvideoframe/src/div.cpp b/core/preprocess/hwvideoframe/src/div.cpp new file mode 100644 index 00000000..7d8b5fa2 --- /dev/null +++ b/core/preprocess/hwvideoframe/src/div.cpp @@ -0,0 +1,26 @@ +/******************************************* + * + * Copyright (c) 2020 Baidu.com, Inc. All Rights Reserved + * + ******************************************/ +/** + * @file div.cpp + * @author yinyijun@baidu.com + * @date 2020-06-11 + **/ +#include "div.h" + +#include + +#include +#include "utils.h" + +Div::Div(float value) { + _divisor = value; +} + +std::shared_ptr Div::operator()(std::shared_ptr input) { + NppStatus ret = nppsDivC_32f_I(_divisor, input->p_frame(), input->length()); + verify_npp_ret("nppsDivC_32f_I", ret); + return input; +} diff --git a/core/preprocess/hwvideoframe/src/image_io.cpp b/core/preprocess/hwvideoframe/src/image_io.cpp new file mode 100644 index 00000000..196b21c1 --- /dev/null +++ b/core/preprocess/hwvideoframe/src/image_io.cpp @@ -0,0 +1,42 @@ +/******************************************* + * + * Copyright (c) 2020 Baidu.com, Inc. All Rights Reserved + * + ******************************************/ +/** + * @file image_io.cpp + * @author yinyijun@baidu.com + * @date 2020-06-08 + **/ +#include "image_io.h" + +#include + +#include +#include +#include "utils.h" + +std::shared_ptr Image2Gpubuffer::operator()(pybind11::array_t input) { + pybind11::buffer_info buf = input.request(); + if (buf.format != pybind11::format_descriptor::format()) { + throw std::runtime_error("Incompatible format: expected a float numpy!"); + } + if (buf.ndim != 3) { + throw std::runtime_error("Number of dimensions must be three"); + } + if (buf.shape[2] != CHANNEL_SIZE) { + throw std::runtime_error("Number of channels must be three"); + } + auto result = std::make_shared(buf.shape[0], buf.shape[1]); + auto ret = cudaMemcpy(result->p_frame(), static_cast(buf.ptr), result->size(), cudaMemcpyHostToDevice); + verify_cuda_ret("cudaMemcpy", ret); + return result; +} + +pybind11::array_t Gpubuffer2Image::operator()(std::shared_ptr input) { + auto result = pybind11::array_t({ input->nppi_size().height, input->nppi_size().width, (int)CHANNEL_SIZE }); + pybind11::buffer_info buf = result.request(); + auto ret = cudaMemcpy(static_cast(buf.ptr), input->p_frame(), input->size(), cudaMemcpyDeviceToHost); + verify_cuda_ret("cudaMemcpy", ret); + return result; +} diff --git a/core/preprocess/hwvideoframe/src/normalize.cpp b/core/preprocess/hwvideoframe/src/normalize.cpp new file mode 100644 index 00000000..374716e5 --- /dev/null +++ b/core/preprocess/hwvideoframe/src/normalize.cpp @@ -0,0 +1,37 @@ +/******************************************* + * + * Copyright (c) 2020 Baidu.com, Inc. All Rights Reserved + * + ******************************************/ +/** + * @file normalize.cpp + * @author yinyijun@baidu.com + * @date 2020-06-04 + **/ +#include "normalize.h" +#include + +#include +#include "utils.h" + +Normalize::Normalize(const std::vector &mean, const std::vector &std, bool channel_first) { + if (mean.size() != CHANNEL_SIZE) { + throw std::runtime_error("size of mean must be three"); + } + if (std.size() != CHANNEL_SIZE) { + throw std::runtime_error("size of std must be three"); + } + for (size_t i = 0; i < CHANNEL_SIZE; i++) { + _mean[i] = mean[i]; + _std[i] = std[i]; + } + _channel_first = channel_first; +} + +std::shared_ptr Normalize::operator()(std::shared_ptr input) { + NppStatus ret = nppiSubC_32f_C3IR(_mean, input->p_frame(), input->step(), input->nppi_size()); + verify_npp_ret("nppiSubC_32f_C3IR", ret); + ret = nppiDivC_32f_C3IR(_std, input->p_frame(), input->step(), input->nppi_size()); + verify_npp_ret("nppiDivC_32f_C3IR", ret); + return input; +} diff --git a/core/preprocess/hwvideoframe/src/resize.cpp b/core/preprocess/hwvideoframe/src/resize.cpp new file mode 100644 index 00000000..8f823ceb --- /dev/null +++ b/core/preprocess/hwvideoframe/src/resize.cpp @@ -0,0 +1,43 @@ +/******************************************* + * + * Copyright (c) 2020 Baidu.com, Inc. All Rights Reserved + * + ******************************************/ +/** + * @file resize.cpp + * @author yinyijun@baidu.com + * @date 2020-06-15 + **/ +#include "resize.h" + +#include +#include + +#include "utils.h" + +std::shared_ptr Resize::operator()(std::shared_ptr input) { + int resized_width = 0, resized_height = 0; + if (_size == -1) { + resized_width = std::min(_target_size[0], _max_size); + resized_height = std::min(_target_size[1], _max_size); + } else { + int im_max_size = std::max(input->nppi_size().height, input->nppi_size().width); + float percent = (float)_size / std::min(input->nppi_size().height, input->nppi_size().width); + if (round(percent * im_max_size) > _max_size) { + percent = float(_max_size) / float(im_max_size); + } + resized_width = int(round(input->nppi_size().width * percent)); + resized_height = int(round(input->nppi_size().height * percent)); + } + auto output = std::make_shared(resized_height, resized_width); + auto ret = resize_linear(input->p_frame(), + output->p_frame(), + input->nppi_size().width, + input->nppi_size().height, + output->nppi_size().width, + output->nppi_size().height, + CHANNEL_SIZE, + _use_fixed_point); + verify_cuda_ret("resize_linear", ret); + return output; +} diff --git a/core/preprocess/hwvideoframe/src/resize_by_factor.cpp b/core/preprocess/hwvideoframe/src/resize_by_factor.cpp new file mode 100644 index 00000000..76e91a73 --- /dev/null +++ b/core/preprocess/hwvideoframe/src/resize_by_factor.cpp @@ -0,0 +1,79 @@ +/******************************************* + * + * Copyright (c) 2020 Baidu.com, Inc. All Rights Reserved + * + ******************************************/ +/** + * @file resize_by_factor.cpp + * @author wubohao@baidu.com + * @date 2020-07-21 + **/ +#include "resize_by_factor.h" + +#include +#include + +#include "resize.h" +#include "utils.h" + +std::shared_ptr ResizeByFactor::operator()(std::shared_ptr input) +{ + int resized_width = input->nppi_size().width, resized_height = input->nppi_size().height; + float radio = 0; + if (std::max(resized_width, resized_height) > _max_side_len) + { + if (resized_width > resized_height) + { + radio = float(_max_side_len / resized_width); + } + else + { + radio = float(_max_side_len / resized_height); + } + } + else + { + radio = 1; + } + resized_width = int(resized_width * radio); + resized_height = int(resized_height * radio); + if (resized_height % _factor == 0) + { + resized_height = resized_height; + } + else if (floor(resized_height / _factor) <= 1) + { + resized_height = _factor; + } + else + { + resized_height = (floor(resized_height / 32) - 1) * 32; + } + if (resized_width % _factor == 0) + { + resized_width = resized_width; + } + else if (floor(resized_width / _factor) <= 1) + { + resized_width = _factor; + } + else + { + resized_width = (floor(resized_width / 32) - 1) * _factor; + } + if (int(resized_width) <= 0 || int(resized_height) <= 0) + { + return NULL; + } + auto output = std::make_shared(resized_height, resized_width); + auto ret = resize_linear(input->p_frame(), + output->p_frame(), + input->nppi_size().width, + input->nppi_size().height, + output->nppi_size().width, + output->nppi_size().height, + CHANNEL_SIZE, + _use_fixed_point); + verify_cuda_ret("resize_linear", ret); + return output; +} diff --git a/core/preprocess/hwvideoframe/src/rgb_swap.cpp b/core/preprocess/hwvideoframe/src/rgb_swap.cpp new file mode 100644 index 00000000..559372f0 --- /dev/null +++ b/core/preprocess/hwvideoframe/src/rgb_swap.cpp @@ -0,0 +1,22 @@ +/******************************************* + * + * Copyright (c) 2020 Baidu.com, Inc. All Rights Reserved + * + ******************************************/ +/** + * @file ExtractFrameJpeg.h + * @author yinyijun@baidu.com + * @date 2020-06-01 + **/ +#include "rgb_swap.h" + +#include +#include "utils.h" + +const int SwapChannel::_ORDER[CHANNEL_SIZE] = {2, 1, 0}; + +std::shared_ptr SwapChannel::operator()(std::shared_ptr input) { + NppStatus ret = nppiSwapChannels_32f_C3IR(input->p_frame(), input->step(), input->nppi_size(), _ORDER); + verify_npp_ret("nppiSwapChannels_32f_C3IR", ret); + return input; +} diff --git a/core/preprocess/hwvideoframe/src/sub.cpp b/core/preprocess/hwvideoframe/src/sub.cpp new file mode 100644 index 00000000..aaa477e4 --- /dev/null +++ b/core/preprocess/hwvideoframe/src/sub.cpp @@ -0,0 +1,37 @@ +/******************************************* + * + * Copyright (c) 2020 Baidu.com, Inc. All Rights Reserved + * + ******************************************/ +/** + * @file sub.cpp + * @author yinyijun@baidu.com + * @date 2020-06-11 + **/ +#include "sub.h" + +#include + +#include +#include "utils.h" + +Sub::Sub(float subtractor) { + for (size_t i = 0; i < CHANNEL_SIZE; i++) { + _subtractors[i] = subtractor; + } +} + +Sub::Sub(const std::vector &subtractors) { + if (subtractors.size() != CHANNEL_SIZE) { + throw std::runtime_error("size of subtractors must be three"); + } + for (size_t i = 0; i < CHANNEL_SIZE; i++) { + _subtractors[i] = subtractors[i]; + } +} + +std::shared_ptr Sub::operator()(std::shared_ptr input) { + NppStatus ret = nppiSubC_32f_C3IR(_subtractors, input->p_frame(), input->step(), input->nppi_size()); + verify_npp_ret("nppiSubC_32f_C3IR", ret); + return input; +} diff --git a/core/preprocess/hwvideoframe/src/utlis.cpp b/core/preprocess/hwvideoframe/src/utlis.cpp new file mode 100644 index 00000000..52dab218 --- /dev/null +++ b/core/preprocess/hwvideoframe/src/utlis.cpp @@ -0,0 +1,23 @@ + +#include +#include + +#include + +#include "utils.h" + +void verify_npp_ret(const std::string& function_name, NppStatus ret) { + if (ret != NPP_SUCCESS) { + std::ostringstream ss; + ss << function_name << ", ret: " << ret; + throw std::runtime_error(ss.str()); + } +} + +void verify_cuda_ret(const std::string& function_name, cudaError_t ret) { + if (ret != cudaSuccess) { + std::ostringstream ss; + ss << function_name << ", ret: " << ret; + throw std::runtime_error(ss.str()); + } +} \ No newline at end of file diff --git a/python/paddle_serving_app/reader/test_audio_reader.py b/python/paddle_serving_app/reader/test_audio_reader.py new file mode 100644 index 00000000..07b50794 --- /dev/null +++ b/python/paddle_serving_app/reader/test_audio_reader.py @@ -0,0 +1,20 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from audio_reader import AudioFeatureOp + +if __name__ == '__main__': + local_video_path = 'case.mp4' + test_op = AudioFeatureOp() + test_op.extract_audio_from_video(local_video_path) \ No newline at end of file diff --git a/python/paddle_serving_app/reader/test_frame_reader.py b/python/paddle_serving_app/reader/test_frame_reader.py new file mode 100644 index 00000000..0364d3e0 --- /dev/null +++ b/python/paddle_serving_app/reader/test_frame_reader.py @@ -0,0 +1,42 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +hwextract lib use case +""" +import paddle_serving_app.reader.hwextract +import sys +if __name__ == '__main__': + handler = hwextract.HwExtractFrameJpeg(0) + # 0, gpu card index + # if you want BGRA Raw Data, plz use HwExtractBGRARaw + handler.init_handler() + # init once can decode many videos + video_file_name = sys.argv[1] + # for now just support h264 codec + frame_list = [] + try: + frame_list = handler.extract_frame(video_file_name, 1) + # specifiy file name and fps you want to extract, 0 for all frame + except Exception as e_frame: + print("Failed to cutframe, exception[%s]" % (e_frame)) + sys.exit(1) + for item in frame_list: + print "i am a item in frame_list" + # do something, for instance + # jpeg_array = np.array(item, copy=False) + # img = cv2.imdecode(item, cv2.IMREAD_COLOR) + # etc..... + item.free_memory() + # have to release memory \ No newline at end of file diff --git a/python/paddle_serving_app/reader/test_functional.py b/python/paddle_serving_app/reader/test_functional.py new file mode 100644 index 00000000..25b6eb47 --- /dev/null +++ b/python/paddle_serving_app/reader/test_functional.py @@ -0,0 +1,169 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import sys +import numpy as np +from paddle_serving_app.reader import Sequential, Resize, File2Image +import libgpupreprocess as pp + + +class TestOperators(unittest.TestCase): + """ + test all operators, e.g. Div, Normalize + """ + + def test_div(self): + """ + test Div + """ + height = 4 + width = 5 + channels = 3 + value = 255.0 + img = np.arange(height * width * channels).reshape( + [height, width, channels]) + seq = Sequential( + [pp.Image2Gpubuffer(), pp.Div(value), pp.Gpubuffer2Image()]) + result = seq(img).reshape(-1) + for i in range(0, result.size): + self.assertAlmostEqual(i / value, result[i], 5) + + def test_sub(self): + """ + test Sub + """ + height = 4 + width = 5 + channels = 3 + img = np.arange(height * width * channels).reshape( + [height, width, channels]) + + # input size is an int + value = 10.0 + seq = Sequential( + [pp.Image2Gpubuffer(), pp.Sub(value), pp.Gpubuffer2Image()]) + result = seq(img).reshape(-1) + for i in range(0, result.size): + self.assertEqual(i - value, result[i]) + + # input size is a sequence + values = (9, 4, 2) + seq = Sequential( + [pp.Image2Gpubuffer(), pp.Sub(values), pp.Gpubuffer2Image()]) + result = seq(img) + for i in range(0, result.shape[0]): + for j in range(0, result.shape[1]): + for k in range(0, result.shape[2]): + self.assertEqual(result[i][j][k], img[i][j][k] - values[k]) + + def test_normalize(self): + """ + test Normalize + """ + height = 4 + width = 5 + channels = 3 + img = np.random.rand(height, width, channels) + mean = [5.0, 5.0, 5.0] + std = [2.0, 2.0, 2.0] + seq = Sequential([ + pp.Image2Gpubuffer(), pp.Normalize(mean, std), pp.Gpubuffer2Image() + ]) + result = seq(img) + for i in range(0, height): + for j in range(0, width): + for k in range(0, channels): + self.assertAlmostEqual((img[i][j][k] - mean[k]) / std[k], + result[i][j][k], 5) + + def test_center_crop(self): + """ + test CenterCrop + """ + height = 9 + width = 7 + channels = 3 + img = np.arange(height * width * channels).reshape( + [height, width, channels]) + new_size = 5 + seq = Sequential([ + pp.Image2Gpubuffer(), pp.CenterCrop(new_size), pp.Gpubuffer2Image() + ]) + result = seq(img) + self.assertEqual(result.shape[0], new_size) + self.assertEqual(result.shape[1], new_size) + self.assertEqual(result.shape[2], channels) + + def test_resize(self): + """ + test Resize + """ + height = 9 + width = 5 + channels = 3 + img = np.arange(height * width).reshape([height, width, 1]) * np.ones( + (1, channels)) + + # input size is an int + for new_size in [3, 10]: + seq_gpu = Sequential([ + pp.Image2Gpubuffer(), pp.Resize(new_size), pp.Gpubuffer2Image() + ]) + seq_paddle = Sequential([Resize(new_size)]) + result_gpu = seq_gpu(img) + result_paddle = seq_paddle(img) + self.assertEqual(result_gpu.shape, result_paddle.shape) + for i in range(0, result_gpu.shape[0]): + for j in range(0, result_gpu.shape[1]): + for k in range(0, result_gpu.shape[2]): + self.assertAlmostEqual(result_gpu[i][j][k], + result_paddle[i][j][k], 5) + + # input size is a sequence + for new_height, new_width in [(7, 3), (15, 10)]: + seq_gpu = Sequential([ + pp.Image2Gpubuffer(), pp.Resize((new_width, new_height)), + pp.Gpubuffer2Image() + ]) + seq_paddle = Sequential([Resize((new_width, new_height))]) + result_gpu = seq_gpu(img) + result_paddle = seq_paddle(img) + self.assertEqual(result_gpu.shape, result_paddle.shape) + for i in range(0, result_gpu.shape[0]): + for j in range(0, result_gpu.shape[1]): + for k in range(0, result_gpu.shape[2]): + self.assertAlmostEqual(result_gpu[i][j][k], + result_paddle[i][j][k], 5) + + def test_resize_fixed_point(self): + """ + test Resize by using fixed-point + """ + new_height = 256 + new_width = 256 * 4 / 3 + seq = Sequential([ + File2Image(), pp.Image2Gpubuffer(), pp.Resize( + (new_width, new_height), use_fixed_point=True), + pp.Gpubuffer2Image() + ]) + img = seq("./capture_16.bmp") + img = np.resize(img, (new_height, new_width * 3)) + img_vis = np.loadtxt("./cap_resize_16.raw") + img_resize_diff = img_vis - img + self.assertEqual(np.all(img_resize_diff == 0), True) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/python/setup.py.app.in b/python/setup.py.app.in index 1a06b0d3..43cc4342 100644 --- a/python/setup.py.app.in +++ b/python/setup.py.app.in @@ -16,6 +16,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import platform import os from setuptools import setup, Distribution, Extension @@ -23,17 +24,33 @@ from setuptools import find_packages from setuptools import setup from paddle_serving_app.version import serving_app_version from pkg_resources import DistributionNotFound, get_distribution -import util -max_version, mid_version, min_version = util.python_version() +def python_version(): + return [int(v) for v in platform.python_version().split(".")] + +def find_package(pkgname): + try: + get_distribution(pkgname) + return True + except DistributionNotFound: + return False + +max_version, mid_version, min_version = python_version() if '${PACK}' == 'ON': copy_lib() +os.system('cp ../core/preprocess/nvdec-extractframe/libhwextract.so ./') +os.system('mv ./libhwextract.so ./paddle_serving_app/reader/hwextract.so') +os.system('cp ../core/preprocess/hwvideoframe/libgpupreprocess.so ./paddle_serving_app/reader') +os.system('mkdir ./paddle_serving_app/reader/lib') +os.system('cp ../core/preprocess/nvdec-extractframe/cuda/libhwgpu.so ./paddle_serving_app/reader/lib') +os.system('cp ../core/preprocess/hwvideoframe/cuda/libgpu.so ./paddle_serving_app/reader/lib') +os.system('export LD_LIBRARY_PATH="./paddle_serving_app/reader/lib"') REQUIRED_PACKAGES = [ 'six >= 1.10.0', 'sentencepiece', 'opencv-python<=4.2.0.32', 'pillow', - 'shapely<=1.6.1', 'pyclipper' + 'shapely', 'pyclipper' ] packages=['paddle_serving_app', @@ -41,9 +58,11 @@ packages=['paddle_serving_app', 'paddle_serving_app.reader', 'paddle_serving_app.utils', 'paddle_serving_app.models', - 'paddle_serving_app.reader.pddet'] + 'paddle_serving_app.reader.pddet', + 'paddle_serving_app.reader.lib'] -package_data={} +package_data={'paddle_serving_app': ['reader/*.so']} + # 'paddle_serving_app.reader': ['lib/*.so']} package_dir={'paddle_serving_app': '${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_app', 'paddle_serving_app.proto': @@ -55,7 +74,9 @@ package_dir={'paddle_serving_app': 'paddle_serving_app.models': '${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_app/models', 'paddle_serving_app.reader.pddet': - '${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_app/reader/pddet',} + '${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_app/reader/pddet', + 'paddle_serving_app.reader.lib': + '${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_app/reader/lib',} setup( name='paddle-serving-app', @@ -89,4 +110,4 @@ setup( 'Topic :: Software Development :: Libraries :: Python Modules', ], license='Apache 2.0', - keywords=('paddle-serving serving-client deployment industrial easy-to-use')) + keywords=('paddle-serving serving-client deployment industrial easy-to-use')) \ No newline at end of file -- GitLab