未验证 提交 d85a7733 编写于 作者: T TeslaZhao 提交者: GitHub

Merge pull request #30 from PaddlePaddle/develop

Sync latest code from branch develop
...@@ -124,7 +124,7 @@ Recommended to install paddle >= 2.0.0 ...@@ -124,7 +124,7 @@ Recommended to install paddle >= 2.0.0
pip install paddlepaddle==2.0.0 pip install paddlepaddle==2.0.0
# GPU Cuda10.2 please run # GPU Cuda10.2 please run
pip install paddlepaddle-gpu==2.0.0 pip install paddlepaddle-gpu==2.0.0
``` ```
**Note**: If your Cuda version is not 10.2, please do not execute the above commands directly, you need to refer to [Paddle official documentation-multi-version whl package list **Note**: If your Cuda version is not 10.2, please do not execute the above commands directly, you need to refer to [Paddle official documentation-multi-version whl package list
...@@ -135,8 +135,12 @@ The url corresponding to `cuda9.0_cudnn7-mkl`, copy it and run ...@@ -135,8 +135,12 @@ The url corresponding to `cuda9.0_cudnn7-mkl`, copy it and run
``` ```
pip install https://paddle-wheel.bj.bcebos.com/2.0.0-gpu-cuda9-cudnn7-mkl/paddlepaddle_gpu-2.0.0.post90-cp27-cp27mu-linux_x86_64.whl pip install https://paddle-wheel.bj.bcebos.com/2.0.0-gpu-cuda9-cudnn7-mkl/paddlepaddle_gpu-2.0.0.post90-cp27-cp27mu-linux_x86_64.whl
``` ```
the default `paddlepaddle-gpu==2.0.0` is Cuda 10.2 with no TensorRT. If you want to install PaddlePaddle with TensorRT. please also check the documentation-multi-version whl package list and find key word `cuda10.2-cudnn8.0-trt7.1.3`. More info please check [Paddle Serving uses TensorRT](./doc/TENSOR_RT.md)
If it is other environment and Python version, please find the corresponding link in the table and install it with pip. If it is other environment and Python version, please find the corresponding link in the table and install it with pip.
For **Windows Users**, please read the document [Paddle Serving for Windows Users](./doc/WINDOWS_TUTORIAL.md) For **Windows Users**, please read the document [Paddle Serving for Windows Users](./doc/WINDOWS_TUTORIAL.md)
<h2 align="center">Quick Start Example</h2> <h2 align="center">Quick Start Example</h2>
...@@ -219,6 +223,7 @@ the response is ...@@ -219,6 +223,7 @@ the response is
- [Develop Pipeline Serving](doc/PIPELINE_SERVING.md) - [Develop Pipeline Serving](doc/PIPELINE_SERVING.md)
- [Deploy Web Service with uWSGI](doc/UWSGI_DEPLOY.md) - [Deploy Web Service with uWSGI](doc/UWSGI_DEPLOY.md)
- [Hot loading for model file](doc/HOT_LOADING_IN_SERVING.md) - [Hot loading for model file](doc/HOT_LOADING_IN_SERVING.md)
- [Paddle Serving uses TensorRT](doc/TENSOR_RT.md)
### About Efficiency ### About Efficiency
- [How to profile Paddle Serving latency?](python/examples/util) - [How to profile Paddle Serving latency?](python/examples/util)
......
...@@ -112,7 +112,7 @@ pip install paddle-serving-server-gpu==0.5.0.post11 # GPU with CUDA10.1 + Tensor ...@@ -112,7 +112,7 @@ pip install paddle-serving-server-gpu==0.5.0.post11 # GPU with CUDA10.1 + Tensor
您可能需要使用国内镜像源(例如清华源, 在pip命令中添加`-i https://pypi.tuna.tsinghua.edu.cn/simple`)来加速下载。 您可能需要使用国内镜像源(例如清华源, 在pip命令中添加`-i https://pypi.tuna.tsinghua.edu.cn/simple`)来加速下载。
如果需要使用develop分支编译的安装包,请从[最新安装包列表](./doc/LATEST_PACKAGES.md)中获取下载地址进行下载,使用`pip install`命令进行安装。如果您想自行编译,请参照[Paddle Serving编译文档](./doc/COMPILE_CN.md) 如果需要使用develop分支编译的安装包,请从[最新安装包列表](./doc/LATEST_PACKAGES.md)中获取下载地址进行下载,使用`pip install`命令进行安装。如果您想自行编译,请参照[Paddle Serving编译文档](./doc/COMPILE_CN.md)
paddle-serving-server和paddle-serving-server-gpu安装包支持Centos 6/7, Ubuntu 16/18和Windows 10。 paddle-serving-server和paddle-serving-server-gpu安装包支持Centos 6/7, Ubuntu 16/18和Windows 10。
...@@ -134,6 +134,8 @@ pip install paddlepaddle-gpu==2.0.0 ...@@ -134,6 +134,8 @@ pip install paddlepaddle-gpu==2.0.0
``` ```
pip install https://paddle-wheel.bj.bcebos.com/2.0.0-gpu-cuda9-cudnn7-mkl/paddlepaddle_gpu-2.0.0.post90-cp27-cp27mu-linux_x86_64.whl pip install https://paddle-wheel.bj.bcebos.com/2.0.0-gpu-cuda9-cudnn7-mkl/paddlepaddle_gpu-2.0.0.post90-cp27-cp27mu-linux_x86_64.whl
``` ```
由于默认的`paddlepaddle-gpu==2.0.0`是Cuda 10.2,并没有联编TensorRT,因此如果需要和在`paddlepaddle-gpu`上使用TensorRT,需要在上述多版本whl包列表当中,找到`cuda10.2-cudnn8.0-trt7.1.3`,下载对应的Python版本。更多信息请参考[如何使用TensorRT?](doc/TENSOR_RT_CN.md)
如果是其他环境和Python版本,请在表格中找到对应的链接并用pip安装。 如果是其他环境和Python版本,请在表格中找到对应的链接并用pip安装。
对于**Windows 10 用户**,请参考文档[Windows平台使用Paddle Serving指导](./doc/WINDOWS_TUTORIAL_CN.md) 对于**Windows 10 用户**,请参考文档[Windows平台使用Paddle Serving指导](./doc/WINDOWS_TUTORIAL_CN.md)
...@@ -220,6 +222,7 @@ curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"x": [0.0137, -0.1 ...@@ -220,6 +222,7 @@ curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"x": [0.0137, -0.1
- [如何开发Pipeline?](doc/PIPELINE_SERVING_CN.md) - [如何开发Pipeline?](doc/PIPELINE_SERVING_CN.md)
- [如何使用uWSGI部署Web Service](doc/UWSGI_DEPLOY_CN.md) - [如何使用uWSGI部署Web Service](doc/UWSGI_DEPLOY_CN.md)
- [如何实现模型文件热加载](doc/HOT_LOADING_IN_SERVING_CN.md) - [如何实现模型文件热加载](doc/HOT_LOADING_IN_SERVING_CN.md)
- [如何使用TensorRT?](doc/TENSOR_RT_CN.md)
### 关于Paddle Serving性能 ### 关于Paddle Serving性能
- [如何测试Paddle Serving性能?](python/examples/util/) - [如何测试Paddle Serving性能?](python/examples/util/)
......
...@@ -25,6 +25,7 @@ endif() ...@@ -25,6 +25,7 @@ endif()
if (APP) if (APP)
add_subdirectory(configure) add_subdirectory(configure)
add_subdirectory(preprocess)
endif() endif()
......
if(NOT WITH_GPU)
message("GPU preprocess will not be compiled.")
return()
endif()
message(STATUS "CUDA detected: " ${CUDA_VERSION})
if (${CUDA_VERSION} LESS 10.0)
message("CUDA version should be 10.0.")
return()
elseif (${CUDA_VERSION} LESS 10.1) # CUDA 10.0
add_subdirectory(hwvideoframe)
endif()
cmake_minimum_required(VERSION 3.2)
project(gpupreprocess)
include(cuda)
include(configure)
#C flags
set(CMAKE_C_FLAGS " -g -pipe -W -Wall -fPIC")
#C++ flags.
set(CMAKE_CXX_FLAGS " -g -pipe -W -Wall -fPIC -std=c++11")
add_subdirectory(cuda)
set(PYTHON_SO {PYTHON_LIBRARY})
set(EXTRA_LIBS ${EXTRA_LIBS} gpu)
file(GLOB SOURCE_FILES pybind/*.cpp src/*.cpp)
include_directories("./include")
include_directories(${CUDA_INCLUDE_DIRS})
include_directories(/usr${PYTHON_INCLUDE_DIR})
link_directories("-L${CUDA_TOOLKIT_ROOT_DIR}/lib64 -lcudart -lnppidei_static -lnppial_static -lnpps_static -lnppc_static -lculibos")
link_directories(${PYTHON_SO})
#.so
add_library(gpupreprocess SHARED ${SOURCE_FILES})
target_link_libraries(gpupreprocess ${EXTRA_LIBS})
target_link_libraries(gpupreprocess ${CUDA_LIBRARIES})
# hwvideoframe
Hwvideoframe is a CV preprocessing library based on cuda. The project uses GPU for image preprocessing operations. It speeds up the processing speed while increasing the utilization rate of the GPU.
## Preprocess API
Hwvideoframe provides a variety of data preprocessing methods for photo preprocess:
- class Image2Gpubuffer
- `__call__(img)`
- img(np.array):Image data.
- class Gpubuffer2Image
- `__call__(img)`
- img(np.array):Image data.
- class Div
- `__init__(value)`
- value(float):Constant value to be divided.
- `__call__(img)`
- img(np.array):Image data.
- class Sub
- `__init__(subtractor)`
- subtractor(list/float):Three 32-bit floating point channel image subtract constant. When the input is a list type, length of list must be three.
- `__call__(img)`
- img(np.array):Image data in (C,H,W) channels.
- class Normalize
- `__init__(mean,std)`
- mean(list):Mean. Length of list must be three.
- std(list):Variance. Length of list must be three.
- `__call__(img)`
- img(np.array):Image data in (C,H,W) channels.
- class CenterCrop
- `__init__(size)`
- size(int):Crops the given Image at the center while the size must not bigger than any inputs' height and width.
- `__call__(img)`
- img(np.array):Image data in (C,H,W) channels.
- class Resize
- `__init__(size, max_size=2147483647, interpolation=None)`
- size(list/int):The expected image size, when the input is a list type, it needs to contain the expected length and width. When the input is int type, the short side will be set to the length of size, and the long side will be scaled proportionally.
- `__call__(img)`
- img(numpy array):Image data in (C,H,W) channels.
## Quick start
[After compiling from code](https://github.com/PaddlePaddle/Serving/blob/develop/doc/COMPILE.md),this project will be stored in reader。
## How to Test
Test file:Serving/python/paddle_serving_app/reader/test_preprocess.py
# hwvideoframe
hwvideoframe是一个基于cuda的图片预处理库。项目利用GPU进行图片预处理操作,在加快处理速度的同时,提高GPU的使用率。
## 项目结构
hwvideoframe目前提供以下图片预处理功能:
- class Image2Gpubuffer
- `__call__(img)`
- img(np.array):输入图像。
- class Gpubuffer2Image
- `__call__(img)`
- img(np.array):输入图像。
- class Div
- `__init__(value)`
- value(float):根据固定值切割图像。
- `__call__(img)`
- img(np.array):输入图像。
- class Sub
- `__init__(subtractor)`
- subtractor(list/float):list的长度必须为3。
- `__call__(img)`
- img(np.array):(C,H,W)排列的图像数据。
- class Normalize
- `__init__(mean,std)`
- mean(list):均值。 list长度必须为3。
- std(list):方差。 list长度必须为3。
- `__call__(img)`
- img(np.array):(C,H,W)排列的图像数据。
- class CenterCrop
- `__init__(size)`
- size(int):预期的裁剪后的大小,list类型时需要包含预期的长和宽,int类型时会返回边长为size的正方形图片。size不能大于原始图片大小。
- `__call__(img)`
- img(np.array):(C,H,W)排列的图像数据
- class Resize
- `__init__(size, max_size=2147483647, interpolation=None)`
- size(list/int):预期的图像大小,短边会设置为size的长度,长边按比例缩放.
- `__call__(img)`
- img(numpy array):(C,H,W)排列的图像数据
## 快速开始
按照Paddle Serving文档编译,编译结果在reader中。
## 测试
测试文件:Serving/python/paddle_serving_app/reader/test_preprocess.py
cmake_minimum_required(VERSION 3.2)
project(gpu)
FIND_PACKAGE(CUDA ${CUDA_VERSION} REQUIRED)
SET(CUDA_TARGET_INCLUDE ${CUDA_TOOLKIT_ROOT_DIR}-${CUDA_VERSION}/targets/${CMAKE_HOST_SYSTEM_PROCESSOR}-${LOWER_SYSTEM_NAME}/include)
file(GLOB_RECURSE CURRENT_HEADERS *.h *.hpp *.cuh)
file(GLOB CURRENT_SOURCES *.cpp *.cu)
source_group("Include" FILES ${CURRENT_HEADERS})
source_group("Source" FILES ${CURRENT_SOURCES})
set(CMAKE_CUDA_FLAGS "-ccbin /opt/compiler/gcc-4.8.2/bin -Xcompiler -fPIC --std=c++11")
set(CUDA_NVCC_FLAGS "-L/opt/compiler/gcc-4.8.2/bin -Xcompiler -fPIC --std=c++11")
include_directories(${CUDA_INCLUDE_DIRS})
cuda_add_library(gpu SHARED ${CURRENT_HEADERS} ${CURRENT_SOURCES})
target_link_libraries(gpu ${CUDA_LIBS})
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "./cuda_runtime.h"
#define clip(x, a, b) x >= a ? (x < b ? x : b - 1) : a;
const int INTER_RESIZE_COEF_BITS = 11;
const int INTER_RESIZE_COEF_SCALE = 1 << INTER_RESIZE_COEF_BITS;
__global__ void resizeCudaKernel(const float* input,
float* output,
const int inputWidth,
const int inputHeight,
const int outputWidth,
const int outputHeight,
const int inputChannels) {
// 2D Index of current thread
const int dx = blockIdx.x * blockDim.x + threadIdx.x;
const int dy = blockIdx.y * blockDim.y + threadIdx.y;
if ((dx < outputWidth) && (dy < outputHeight)) {
if (inputChannels == 1) { // grayscale image
// TODO(Zelda): support grayscale
} else if (inputChannels == 3) { // RGB image
double scale_x = static_cast<double>(inputWidth / outputWidth);
double scale_y = static_cast<double>(inputHeight / outputHeight);
int xmax = outputWidth;
float fx = static_cast<float>((dx + 0.5) * scale_x - 0.5);
int sx = floorf(fx);
fx = fx - sx;
int isx1 = sx;
if (isx1 < 0) {
fx = 0.0;
isx1 = 0;
}
if (isx1 >= (inputWidth - 1)) {
xmax = ::min(xmax, dx);
fx = 0;
isx1 = inputWidth - 1;
}
float2 cbufx;
cbufx.x = (1.f - fx);
cbufx.y = fx;
float fy = static_cast<float>((dy + 0.5) * scale_y - 0.5);
int sy = floorf(fy);
fy = fy - sy;
int isy1 = clip(sy + 0, 0, inputHeight);
int isy2 = clip(sy + 1, 0, inputHeight);
float2 cbufy;
cbufy.x = (1.f - fy);
cbufy.y = fy;
int isx2 = isx1 + 1;
float3 d0;
float3 s11 =
make_float3(input[(isy1 * inputWidth + isx1) * inputChannels + 0],
input[(isy1 * inputWidth + isx1) * inputChannels + 1],
input[(isy1 * inputWidth + isx1) * inputChannels + 2]);
float3 s12 =
make_float3(input[(isy1 * inputWidth + isx2) * inputChannels + 0],
input[(isy1 * inputWidth + isx2) * inputChannels + 1],
input[(isy1 * inputWidth + isx2) * inputChannels + 2]);
float3 s21 =
make_float3(input[(isy2 * inputWidth + isx1) * inputChannels + 0],
input[(isy2 * inputWidth + isx1) * inputChannels + 1],
input[(isy2 * inputWidth + isx1) * inputChannels + 2]);
float3 s22 =
make_float3(input[(isy2 * inputWidth + isx2) * inputChannels + 0],
input[(isy2 * inputWidth + isx2) * inputChannels + 1],
input[(isy2 * inputWidth + isx2) * inputChannels + 2]);
float h_rst00, h_rst01;
// B
if (dx > xmax - 1) {
h_rst00 = s11.x;
h_rst01 = s21.x;
} else {
h_rst00 = s11.x * cbufx.x + s12.x * cbufx.y;
h_rst01 = s21.x * cbufx.x + s22.x * cbufx.y;
}
d0.x = h_rst00 * cbufy.x + h_rst01 * cbufy.y;
// G
if (dx > xmax - 1) {
h_rst00 = s11.y;
h_rst01 = s21.y;
} else {
h_rst00 = s11.y * cbufx.x + s12.y * cbufx.y;
h_rst01 = s21.y * cbufx.x + s22.y * cbufx.y;
}
d0.y = h_rst00 * cbufy.x + h_rst01 * cbufy.y;
// R
if (dx > xmax - 1) {
h_rst00 = s11.z;
h_rst01 = s21.z;
} else {
h_rst00 = s11.z * cbufx.x + s12.z * cbufx.y;
h_rst01 = s21.z * cbufx.x + s22.z * cbufx.y;
}
d0.z = h_rst00 * cbufy.x + h_rst01 * cbufy.y;
output[(dy * outputWidth + dx) * 3 + 0] = (d0.x); // R
output[(dy * outputWidth + dx) * 3 + 1] = (d0.y); // G
output[(dy * outputWidth + dx) * 3 + 2] = (d0.z); // B
} else {
// TODO(Zelda): support alpha channel
}
}
}
__global__ void resizeCudaKernel_fixpt(const float* input,
float* output,
const int inputWidth,
const int inputHeight,
const int outputWidth,
const int outputHeight,
const int inputChannels) {
// 2D Index of current thread
const int dx = blockIdx.x * blockDim.x + threadIdx.x;
const int dy = blockIdx.y * blockDim.y + threadIdx.y;
if ((dx < outputWidth) && (dy < outputHeight)) {
if (inputChannels == 1) { // grayscale image
// TODO(Zelda): support grayscale
} else if (inputChannels == 3) { // RGB image
double scale_x = static_cast<double>(inputWidth / outputWidth);
double scale_y = static_cast<double>(inputHeight / outputHeight);
int xmax = outputWidth;
float fx = static_cast<float>((dx + 0.5) * scale_x - 0.5);
int sx = floorf(fx);
fx = fx - sx;
int isx1 = sx;
if (isx1 < 0) {
fx = 0.0;
isx1 = 0;
}
if (isx1 >= (inputWidth - 1)) {
xmax = ::min(xmax, dx);
fx = 0;
isx1 = inputWidth - 1;
}
short2 cbufx;
cbufx.x = lrintf((1.f - fx) * INTER_RESIZE_COEF_SCALE);
cbufx.y = lrintf(fx * INTER_RESIZE_COEF_SCALE);
float fy = static_cast<float>((dy + 0.5) * scale_y - 0.5);
int sy = floorf(fy);
fy = fy - sy;
int isy1 = clip(sy + 0, 0, inputHeight);
int isy2 = clip(sy + 1, 0, inputHeight);
short2 cbufy;
cbufy.x = lrintf((1.f - fy) * INTER_RESIZE_COEF_SCALE);
cbufy.y = lrintf(fy * INTER_RESIZE_COEF_SCALE);
int isx2 = isx1 + 1;
uchar3 d0;
int3 s11 =
make_int3(input[(isy1 * inputWidth + isx1) * inputChannels + 0],
input[(isy1 * inputWidth + isx1) * inputChannels + 1],
input[(isy1 * inputWidth + isx1) * inputChannels + 2]);
int3 s12 =
make_int3(input[(isy1 * inputWidth + isx2) * inputChannels + 0],
input[(isy1 * inputWidth + isx2) * inputChannels + 1],
input[(isy1 * inputWidth + isx2) * inputChannels + 2]);
int3 s21 =
make_int3(input[(isy2 * inputWidth + isx1) * inputChannels + 0],
input[(isy2 * inputWidth + isx1) * inputChannels + 1],
input[(isy2 * inputWidth + isx1) * inputChannels + 2]);
int3 s22 =
make_int3(input[(isy2 * inputWidth + isx2) * inputChannels + 0],
input[(isy2 * inputWidth + isx2) * inputChannels + 1],
input[(isy2 * inputWidth + isx2) * inputChannels + 2]);
int h_rst00, h_rst01;
// B
if (dx > xmax - 1) {
h_rst00 = s11.x * INTER_RESIZE_COEF_SCALE;
h_rst01 = s21.x * INTER_RESIZE_COEF_SCALE;
} else {
h_rst00 = s11.x * cbufx.x + s12.x * cbufx.y;
h_rst01 = s21.x * cbufx.x + s22.x * cbufx.y;
}
d0.x = (unsigned char)((((cbufy.x * (h_rst00 >> 4)) >> 16) +
((cbufy.y * (h_rst01 >> 4)) >> 16) + 2) >>
2);
// G
if (dx > xmax - 1) {
h_rst00 = s11.y * INTER_RESIZE_COEF_SCALE;
h_rst01 = s21.y * INTER_RESIZE_COEF_SCALE;
} else {
h_rst00 = s11.y * cbufx.x + s12.y * cbufx.y;
h_rst01 = s21.y * cbufx.x + s22.y * cbufx.y;
}
d0.y = (unsigned char)((((cbufy.x * (h_rst00 >> 4)) >> 16) +
((cbufy.y * (h_rst01 >> 4)) >> 16) + 2) >>
2);
// R
if (dx > xmax - 1) {
h_rst00 = s11.z * INTER_RESIZE_COEF_SCALE;
h_rst01 = s21.z * INTER_RESIZE_COEF_SCALE;
} else {
h_rst00 = s11.z * cbufx.x + s12.z * cbufx.y;
h_rst01 = s21.z * cbufx.x + s22.z * cbufx.y;
}
d0.z = (unsigned char)((((cbufy.x * (h_rst00 >> 4)) >> 16) +
((cbufy.y * (h_rst01 >> 4)) >> 16) + 2) >>
2);
output[(dy * outputWidth + dx) * 3 + 0] = (d0.x); // R
output[(dy * outputWidth + dx) * 3 + 1] = (d0.y); // G
output[(dy * outputWidth + dx) * 3 + 2] = (d0.z); // B
} else {
// TODO(Zelda): support alpha channel
}
}
}
extern "C" cudaError_t resize_linear(const float* input,
float* output,
const int inputWidth,
const int inputHeight,
const int outputWidth,
const int outputHeight,
const int inputChannels,
const bool use_fixed_point) {
// Specify a reasonable block size
const dim3 block(16, 16);
// Calculate grid size to cover the whole image
const dim3 grid((outputWidth + block.x - 1) / block.x,
(outputHeight + block.y - 1) / block.y);
// Launch the size conversion kernel
if (use_fixed_point) {
resizeCudaKernel_fixpt<<<grid, block>>>(input,
output,
inputWidth,
inputHeight,
outputWidth,
outputHeight,
inputChannels);
} else {
resizeCudaKernel<<<grid, block>>>(input,
output,
inputWidth,
inputHeight,
outputWidth,
outputHeight,
inputChannels);
}
// Synchronize to check for any kernel launch errors
return cudaDeviceSynchronize();
}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_CENTER_CROP_H_
#define CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_CENTER_CROP_H_
#include <npp.h>
#include <memory>
#include "./op_context.h"
// Crops the given Image at the center.
// the size must not bigger than any inputs' height and width
class CenterCrop {
public:
explicit CenterCrop(int size) : _size(size) {}
std::shared_ptr<OpContext> operator()(std::shared_ptr<OpContext> input);
private:
int _size;
};
#endif // CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_CENTER_CROP_H_
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_DIV_H_
#define CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_DIV_H_
#include <npp.h>
#include <memory>
#include "core/preprocess/hwvideoframe/include/op_context.h"
// divide by some float number for all pixel
class Div {
public:
explicit Div(float value);
std::shared_ptr<OpContext> operator()(std::shared_ptr<OpContext> input);
private:
Npp32f _divisor;
};
#endif // CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_DIV_H_
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_IMAGE_IO_H_
#define CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_IMAGE_IO_H_
#include <npp.h>
#include <pybind11/numpy.h>
#include <memory>
#include "core/preprocess/hwvideoframe/include/op_context.h"
// Input operator that copy numpy data to gpu buffer
class Image2Gpubuffer {
public:
std::shared_ptr<OpContext> operator()(pybind11::array_t<float> array);
};
// Output operator that copy gpu buffer data to numpy
class Gpubuffer2Image {
public:
pybind11::array_t<float> operator()(std::shared_ptr<OpContext> input);
};
#endif // CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_IMAGE_IO_H_
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_NORMALIZE_H_
#define CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_NORMALIZE_H_
#include <npp.h>
#include <memory>
#include <vector>
#include "core/preprocess/hwvideoframe/include/op_context.h"
// utilize normalize operator on gpu
class Normalize {
public:
Normalize(const std::vector<float> &mean,
const std::vector<float> &std,
bool channel_first = false);
std::shared_ptr<OpContext> operator()(std::shared_ptr<OpContext> input);
private:
Npp32f _mean[CHANNEL_SIZE];
Npp32f _std[CHANNEL_SIZE];
bool _channel_first; // indicate whether the channel is dimension 0,
// unsupported
};
#endif // CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_NORMALIZE_H_
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_OP_CONTEXT_H_
#define CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_OP_CONTEXT_H_
#include <npp.h>
const size_t CHANNEL_SIZE = 3;
// The context as input/ouput of all operators
// contains pointer to raw data on gpu, frame size
class OpContext {
public:
OpContext() {
_step = 0;
_size = 0;
_p_frame = nullptr;
}
// constructor to apply gpu memory of image raw data
OpContext(int height, int width) {
_step = sizeof(Npp32f) * width * CHANNEL_SIZE;
_length = height * width * CHANNEL_SIZE;
_size = _step * height;
_nppi_size.height = height;
_nppi_size.width = width;
cudaMalloc(reinterpret_cast<void**>(&_p_frame), _size);
}
virtual ~OpContext() { free_memory(); }
public:
Npp32f* p_frame() const { return _p_frame; }
int step() const { return _step; }
int length() const { return _length; }
int size() const { return _size; }
NppiSize& nppi_size() { return _nppi_size; }
void free_memory() {
if (_p_frame != nullptr) {
cudaFree(_p_frame);
_p_frame = nullptr;
}
_nppi_size.height = 0;
_nppi_size.width = 0;
_step = 0;
_size = 0;
}
private:
Npp32f* _p_frame; // pointer to raw data on gpu
int _step; // number of bytes in a row
int _length; // length of _p_frame, _size = _length * sizeof(Npp32f)
int _size; // number of bytes of the image
NppiSize _nppi_size; // contains height and width
};
#endif // CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_OP_CONTEXT_H_
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_RESIZE_H_
#define CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_RESIZE_H_
#include <npp.h>
#include <memory>
#include <vector>
#include "core/preprocess/hwvideoframe/include/op_context.h"
extern "C" cudaError_t resize_linear(const float* input,
float* output,
const int inputWidth,
const int inputHeight,
const int outputWidth,
const int outputHeight,
const int inputChannels,
const bool use_fixed_point);
// Resize the input numpy array Image to the given size.
// only support linear interpolation
// only support RGB channels
class Resize {
public:
// size is an int, smaller edge of the image will be matched to this number.
Resize(int size,
int max_size = 214748364,
bool use_fixed_point = false,
int interpolation = 0)
: _size(size),
_max_size(max_size),
_use_fixed_point(use_fixed_point),
_interpolation(interpolation) {}
// size is a sequence like (w, h), output size will be matched to this
Resize(std::vector<int> size,
int max_size = 214748364,
bool use_fixed_point = false,
int interpolation = 0)
: _size(-1),
_max_size(max_size),
_use_fixed_point(use_fixed_point),
_interpolation(interpolation) {
_target_size[0] = size[0];
_target_size[1] = size[1];
}
std::shared_ptr<OpContext> operator()(std::shared_ptr<OpContext> input);
private:
int _size; // target of smaller edge
int _target_size[2]; // target size sequence (w, h)
int _max_size;
bool _use_fixed_point;
int _interpolation; // unused
};
#endif // CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_RESIZE_H_
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_RESIZE_BY_FACTOR_H_
#define CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_RESIZE_BY_FACTOR_H_
#include <npp.h>
#include <memory>
#include <vector>
#include "core/preprocess/hwvideoframe/include/op_context.h"
extern "C" cudaError_t resize_linear(const float *input,
float *output,
const int inputWidth,
const int inputHeight,
const int outputWidth,
const int outputHeight,
const int inputChannels,
const bool use_fixed_point);
// Resize the input numpy array Image to a size multiple of factor which is
// usually required by a network
// only support linear interpolation
// only support RGB channels
class ResizeByFactor {
public:
// Resize factor. make width and height multiple factor of the value of
// factor. Default is 32
ResizeByFactor(int factor = 32,
int max_side_len = 2400,
bool use_fixed_point = false,
int interpolation = 0)
: _factor(factor),
_max_side_len(max_side_len),
_use_fixed_point(use_fixed_point),
_interpolation(interpolation) {}
std::shared_ptr<OpContext> operator()(std::shared_ptr<OpContext> input);
private:
int _factor; // target of smaller edge
int _max_side_len;
bool _use_fixed_point;
int _interpolation; // unused
};
#endif // CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_RESIZE_BY_FACTOR_H_
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_RGB_SWAP_H_
#define CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_RGB_SWAP_H_
#include <npp.h>
#include <memory>
#include "core/preprocess/hwvideoframe/include/op_context.h"
// swap channel 0 and channel 2 for every pixel
// both RGB2BGR and BGR2RGB use this operator
class SwapChannel {
public:
SwapChannel() {}
std::shared_ptr<OpContext> operator()(std::shared_ptr<OpContext> input);
private:
static const int
_ORDER[CHANNEL_SIZE]; // describing how channel values are permutated
};
class RGB2BGR : public SwapChannel {};
class BGR2RGB : public SwapChannel {};
#endif // CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_RGB_SWAP_H_
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_SUB_H_
#define CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_SUB_H_
#include <npp.h>
#include <memory>
#include <vector>
#include "core/preprocess/hwvideoframe/include/op_context.h"
// subtract by some float numbers
class Sub {
public:
explicit Sub(float subtractor);
explicit Sub(const std::vector<float> &subtractors);
std::shared_ptr<OpContext> operator()(std::shared_ptr<OpContext> input);
private:
Npp32f _subtractors[CHANNEL_SIZE];
};
#endif // CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_SUB_H_
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_UTILS_H_
#define CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_UTILS_H_
#include <npp.h>
#include <string>
// verify return value of npp function
// throw an exception if failed
void verify_npp_ret(const std::string& function_name, NppStatus ret);
// verify return value of cuda runtime function
// throw an exception if failed
void verify_cuda_ret(const std::string& function_name, cudaError_t ret);
#endif // CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_UTILS_H_
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "core/preprocess/hwvideoframe/include/center_crop.h"
#include "core/preprocess/hwvideoframe/include/div.h"
#include "core/preprocess/hwvideoframe/include/image_io.h"
#include "core/preprocess/hwvideoframe/include/normalize.h"
#include "core/preprocess/hwvideoframe/include/resize.h"
#include "core/preprocess/hwvideoframe/include/resize_by_factor.h"
#include "core/preprocess/hwvideoframe/include/rgb_swap.h"
#include "core/preprocess/hwvideoframe/include/sub.h"
PYBIND11_MODULE(libgpupreprocess, m) {
pybind11::class_<OpContext, std::shared_ptr<OpContext>>(m, "OpContext");
pybind11::class_<Image2Gpubuffer>(m, "Image2Gpubuffer")
.def(pybind11::init<>())
.def("__call__", &Image2Gpubuffer::operator());
pybind11::class_<Gpubuffer2Image>(m, "Gpubuffer2Image")
.def(pybind11::init<>())
.def("__call__", &Gpubuffer2Image::operator());
pybind11::class_<RGB2BGR>(m, "RGB2BGR")
.def(pybind11::init<>())
.def("__call__", &RGB2BGR::operator());
pybind11::class_<BGR2RGB>(m, "BGR2RGB")
.def(pybind11::init<>())
.def("__call__", &BGR2RGB::operator());
pybind11::class_<Div>(m, "Div")
.def(pybind11::init<float>())
.def("__call__", &Div::operator());
pybind11::class_<Sub>(m, "Sub")
.def(pybind11::init<float>())
.def(pybind11::init<const std::vector<float>&>())
.def("__call__", &Sub::operator());
pybind11::class_<Normalize>(m, "Normalize")
.def(pybind11::init<const std::vector<float>&,
const std::vector<float>&,
bool>(),
pybind11::arg("mean"),
pybind11::arg("std"),
pybind11::arg("channel_first") = false)
.def("__call__", &Normalize::operator());
pybind11::class_<CenterCrop>(m, "CenterCrop")
.def(pybind11::init<int>())
.def("__call__", &CenterCrop::operator());
pybind11::class_<Resize>(m, "Resize")
.def(pybind11::init<int, int, bool>(),
pybind11::arg("size"),
pybind11::arg("max_size") = 214748364,
pybind11::arg("use_fixed_point") = false)
.def(pybind11::init<const std::vector<int>&, int, bool>(),
pybind11::arg("target_size"),
pybind11::arg("max_size") = 214748364,
pybind11::arg("use_fixed_point") = false)
.def("__call__", &Resize::operator());
pybind11::class_<ResizeByFactor>(m, "ResizeByFactor")
.def(pybind11::init<int, int, bool>(),
pybind11::arg("factor") = 32,
pybind11::arg("max_side_len") = 2400,
pybind11::arg("use_fixed_point") = false)
.def("__call__", &ResizeByFactor::operator());
}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <npp.h>
#include <algorithm>
#include "core/preprocess/hwvideoframe/include/center_crop.h"
#include "core/preprocess/hwvideoframe/include/utils.h"
std::shared_ptr<OpContext> CenterCrop::operator()(
std::shared_ptr<OpContext> input) {
int new_width = std::min(_size, input->nppi_size().width);
int new_height = std::min(_size, input->nppi_size().height);
auto output = std::make_shared<OpContext>(new_height, new_width);
int x_start = (input->nppi_size().width - new_width) / 2;
int y_start = (input->nppi_size().height - new_height) / 2;
Npp32f* p_src = input->p_frame() +
y_start * input->nppi_size().width * CHANNEL_SIZE +
x_start * CHANNEL_SIZE;
NppStatus ret = nppiCopy_32f_C3R(p_src,
input->step(),
output->p_frame(),
output->step(),
output->nppi_size());
verify_npp_ret("nppiCopy_32f_C3R", ret);
return output;
}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <npp.h>
#include <stdexcept>
#include "core/preprocess/hwvideoframe/include/div.h"
#include "core/preprocess/hwvideoframe/include/utils.h"
Div::Div(float value) { _divisor = value; }
std::shared_ptr<OpContext> Div::operator()(std::shared_ptr<OpContext> input) {
NppStatus ret = nppsDivC_32f_I(_divisor, input->p_frame(), input->length());
verify_npp_ret("nppsDivC_32f_I", ret);
return input;
}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cuda_runtime.h>
#include <pybind11/numpy.h>
#include <stdexcept>
#include "core/preprocess/hwvideoframe/include/image_io.h"
#include "core/preprocess/hwvideoframe/include/utils.h"
std::shared_ptr<OpContext> Image2Gpubuffer::operator()(
pybind11::array_t<float> input) {
pybind11::buffer_info buf = input.request();
if (buf.format != pybind11::format_descriptor<float>::format()) {
throw std::runtime_error("Incompatible format: expected a float numpy!");
}
if (buf.ndim != 3) {
throw std::runtime_error("Number of dimensions must be three");
}
if (buf.shape[2] != CHANNEL_SIZE) {
throw std::runtime_error("Number of channels must be three");
}
auto result = std::make_shared<OpContext>(buf.shape[0], buf.shape[1]);
auto ret = cudaMemcpy(result->p_frame(),
static_cast<float*>(buf.ptr),
result->size(),
cudaMemcpyHostToDevice);
verify_cuda_ret("cudaMemcpy", ret);
return result;
}
pybind11::array_t<float> Gpubuffer2Image::operator()(
std::shared_ptr<OpContext> input) {
auto result = pybind11::array_t<float>({input->nppi_size().height,
input->nppi_size().width,
static_cast<int>(CHANNEL_SIZE)});
pybind11::buffer_info buf = result.request();
auto ret = cudaMemcpy(static_cast<float*>(buf.ptr),
input->p_frame(),
input->size(),
cudaMemcpyDeviceToHost);
verify_cuda_ret("cudaMemcpy", ret);
return result;
}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <npp.h>
#include <stdexcept>
#include "core/preprocess/hwvideoframe/include/normalize.h"
#include "core/preprocess/hwvideoframe/include/utils.h"
Normalize::Normalize(const std::vector<float> &mean,
const std::vector<float> &std,
bool channel_first) {
if (mean.size() != CHANNEL_SIZE) {
throw std::runtime_error("size of mean must be three");
}
if (std.size() != CHANNEL_SIZE) {
throw std::runtime_error("size of std must be three");
}
for (size_t i = 0; i < CHANNEL_SIZE; i++) {
_mean[i] = mean[i];
_std[i] = std[i];
}
_channel_first = channel_first;
}
std::shared_ptr<OpContext> Normalize::operator()(
std::shared_ptr<OpContext> input) {
NppStatus ret = nppiSubC_32f_C3IR(
_mean, input->p_frame(), input->step(), input->nppi_size());
verify_npp_ret("nppiSubC_32f_C3IR", ret);
ret = nppiDivC_32f_C3IR(
_std, input->p_frame(), input->step(), input->nppi_size());
verify_npp_ret("nppiDivC_32f_C3IR", ret);
return input;
}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "core/preprocess/hwvideoframe/include/resize.h"
#include <math.h>
#include <sstream>
#include "core/preprocess/hwvideoframe/include/utils.h"
std::shared_ptr<OpContext> Resize::operator()(
std::shared_ptr<OpContext> input) {
int resized_width = 0, resized_height = 0;
if (_size == -1) {
resized_width = std::min(_target_size[0], _max_size);
resized_height = std::min(_target_size[1], _max_size);
} else {
int im_max_size =
std::max(input->nppi_size().height, input->nppi_size().width);
float percent =
static_cast<float>(_size) /
std::min(input->nppi_size().height, input->nppi_size().width);
if (round(percent * im_max_size) > _max_size) {
percent = static_cast<float>(_max_size) / static_cast<float>(im_max_size);
}
resized_width = static_cast<int>(round(input->nppi_size().width * percent));
resized_height =
static_cast<int>(round(input->nppi_size().height * percent));
}
auto output = std::make_shared<OpContext>(resized_height, resized_width);
auto ret = resize_linear(input->p_frame(),
output->p_frame(),
input->nppi_size().width,
input->nppi_size().height,
output->nppi_size().width,
output->nppi_size().height,
CHANNEL_SIZE,
_use_fixed_point);
verify_cuda_ret("resize_linear", ret);
return output;
}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <math.h>
#include <sstream>
#include "core/preprocess/hwvideoframe/include/resize.h"
#include "core/preprocess/hwvideoframe/include/resize_by_factor.h"
#include "core/preprocess/hwvideoframe/include/utils.h"
std::shared_ptr<OpContext> ResizeByFactor::operator()(
std::shared_ptr<OpContext> input) {
int resized_width = input->nppi_size().width,
resized_height = input->nppi_size().height;
float radio = 0;
if (std::max(resized_width, resized_height) > _max_side_len) {
if (resized_width > resized_height) {
radio = static_cast<float>(_max_side_len / resized_width);
} else {
radio = static_cast<float>(_max_side_len / resized_height);
}
} else {
radio = 1;
}
resized_width = static_cast<int>(resized_width * radio);
resized_height = static_cast<int>(resized_height * radio);
if (resized_height % _factor == 0) {
resized_height = resized_height;
} else if (floor(resized_height / _factor) <= 1) {
resized_height = _factor;
} else {
resized_height = (floor(resized_height / 32) - 1) * 32;
}
if (resized_width % _factor == 0) {
resized_width = resized_width;
} else if (floor(resized_width / _factor) <= 1) {
resized_width = _factor;
} else {
resized_width = (floor(resized_width / 32) - 1) * _factor;
}
if (static_cast<int>(resized_width) <= 0 ||
static_cast<int>(resized_height) <= 0) {
return NULL;
}
auto output = std::make_shared<OpContext>(resized_height, resized_width);
auto ret = resize_linear(input->p_frame(),
output->p_frame(),
input->nppi_size().width,
input->nppi_size().height,
output->nppi_size().width,
output->nppi_size().height,
CHANNEL_SIZE,
_use_fixed_point);
verify_cuda_ret("resize_linear", ret);
return output;
}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <npp.h>
#include "core/preprocess/hwvideoframe/include/rgb_swap.h"
#include "core/preprocess/hwvideoframe/include/utils.h"
const int SwapChannel::_ORDER[CHANNEL_SIZE] = {2, 1, 0};
std::shared_ptr<OpContext> SwapChannel::operator()(
std::shared_ptr<OpContext> input) {
NppStatus ret = nppiSwapChannels_32f_C3IR(
input->p_frame(), input->step(), input->nppi_size(), _ORDER);
verify_npp_ret("nppiSwapChannels_32f_C3IR", ret);
return input;
}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <npp.h>
#include <stdexcept>
#include "core/preprocess/hwvideoframe/include/sub.h"
#include "core/preprocess/hwvideoframe/include/utils.h"
Sub::Sub(float subtractor) {
for (size_t i = 0; i < CHANNEL_SIZE; i++) {
_subtractors[i] = subtractor;
}
}
Sub::Sub(const std::vector<float> &subtractors) {
if (subtractors.size() != CHANNEL_SIZE) {
throw std::runtime_error("size of subtractors must be three");
}
for (size_t i = 0; i < CHANNEL_SIZE; i++) {
_subtractors[i] = subtractors[i];
}
}
std::shared_ptr<OpContext> Sub::operator()(std::shared_ptr<OpContext> input) {
NppStatus ret = nppiSubC_32f_C3IR(
_subtractors, input->p_frame(), input->step(), input->nppi_size());
verify_npp_ret("nppiSubC_32f_C3IR", ret);
return input;
}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <npp.h>
#include <sstream>
#include <stdexcept>
#include "core/preprocess/hwvideoframe/include/utils.h"
void verify_npp_ret(const std::string& function_name, NppStatus ret) {
if (ret != NPP_SUCCESS) {
std::ostringstream ss;
ss << function_name << ", ret: " << ret;
throw std::runtime_error(ss.str());
}
}
void verify_cuda_ret(const std::string& function_name, cudaError_t ret) {
if (ret != cudaSuccess) {
std::ostringstream ss;
ss << function_name << ", ret: " << ret;
throw std::runtime_error(ss.str());
}
}
## Paddle Serving uses TensorRT
(English|[简体中文]((./TENSOR_RT_CN.md)))
### Background
Deploying models trained on mainstream frameworks through the tensorRT tool launched by Nvidia can greatly increase the speed of model inference, which is often at least 1 times faster than the original framework, and it also takes up more device memory. less. Therefore, it is very useful for all users who need to deploy models to master the method of deploying deep learning models with tensorRT. Paddle Serving provides comprehensive TensorRT ecological support.
### surroundings
Serving Cuda10.1 Cuda10.2 and Cuda11 versions support TensorRT.
#### Install Paddle
In [Development using Docker environment](./RUN_IN_DOCKER.md) and [Docker image list](./DOCKER_IMAGES.md), we give the development image of TensorRT. After using the mirror to start, you need to install the Paddle whl package that supports TensorRT, refer to the documentation on the home page
```
# GPU Cuda10.2 environment please execute
pip install paddlepaddle-gpu==2.0.0
```
**Note**: If your Cuda version is not 10.2, please do not execute the above commands directly, you need to refer to [Paddle official documentation-multi-version whl package list
](https://www.paddlepaddle.org.cn/documentation/docs/en/install/Tables_en.html#multi-version-whl-package-list-release)
Select the URL link of the corresponding GPU environment and install it. For example, for Python2.7 users of Cuda 10.1, please select `cp27-cp27mu` and
`cuda10.1-cudnn7.6-trt6.0.1.5` corresponding url, copy it and execute
```
pip install https://paddle-wheel.bj.bcebos.com/with-trt/2.0.0-gpu-cuda10.1-cudnn7-mkl/paddlepaddle_gpu-2.0.0.post101-cp27-cp27mu-linux_x86_64.whl
```
Since the default `paddlepaddle-gpu==2.0.0` is Cuda 10.2 and TensorRT is not built, if you need to use TensorRT on `paddlepaddle-gpu`, you need to find `cuda10 in the above multi-version whl package list .2-cudnn8.0-trt7.1.3`, download the corresponding Python version.
#### Install Paddle Serving
```
# Cuda10.2
pip install paddle-server-server==${VERSION}.post102
# Cuda 10.1
pip install paddle-server-server==${VERSION}.post101
# Cuda 11
pip install paddle-server-server==${VERSION}.post11
```
### Use TensorRT
#### RPC mode
In [Serving model example](../python/examples), we have given models that can be accelerated using TensorRT, such as [Faster_RCNN model](../python/examples/detection/faster_rcnn_r50_fpn_1x_coco) under detection
We just need
```
wget --no-check-certificate https://paddle-serving.bj.bcebos.com/pddet_demo/2.0/faster_rcnn_r50_fpn_1x_coco.tar
tar xf faster_rcnn_r50_fpn_1x_coco.tar
python -m paddle_serving_server_gpu.serve --model serving_server --port 9494 --gpu_ids 0 --use_trt
```
The TensorRT version of the faster_rcnn model server is started
#### Local Predictor mode
In [local_predictor](../python/paddle_serving_app/local_predict.py#L52), users can explicitly specify `use_trt=True` and pass it to `load_model_config`.
Other methods are no different from other Local Predictor methods, and you need to pay attention to the compatibility of the model with TensorRT.
#### Pipeline Mode
In [Pipeline mode](./PIPELINE_SERVING.md), our [imagenet example](../python/examples/pipeline/imagenet/config.yml#L23) gives the way to set TensorRT.
## Paddle Serving 使用 TensorRT
([English](./TENSOR_RT.md)|简体中文)
### 背景
通过Nvidia推出的tensorRT工具来部署主流框架上训练的模型能够极大的提高模型推断的速度,往往相比与原本的框架能够有至少1倍以上的速度提升,同时占用的设备内存也会更加的少。因此对是所有需要部署模型的用户来说,掌握用tensorRT来部署深度学习模型的方法是非常有用的。Paddle Serving提供了全面的TensorRT生态支持。
### 环境
Serving 的Cuda10.1 Cuda10.2和Cuda11版本支持TensorRT。
#### 安装Paddle
[使用Docker环境开发](./RUN_IN_DOCKER_CN.md)[Docker镜像列表](./DOCKER_IMAGES_CN.md)当中,我们给出了TensorRT的开发镜像。使用镜像启动之后,需要安装支持TensorRT的Paddle whl包,参考首页的文档
```
# GPU Cuda10.2环境请执行
pip install paddlepaddle-gpu==2.0.0
```
**注意**: 如果您的Cuda版本不是10.2,请勿直接执行上述命令,需要参考[Paddle官方文档-多版本whl包列表
](https://www.paddlepaddle.org.cn/documentation/docs/zh/install/Tables.html#whl-release)
选择相应的GPU环境的url链接并进行安装,例如Cuda 10.1的Python2.7用户,请选择表格当中的`cp27-cp27mu`
`cuda10.1-cudnn7.6-trt6.0.1.5`对应的url,复制下来并执行
```
pip install https://paddle-wheel.bj.bcebos.com/with-trt/2.0.0-gpu-cuda10.1-cudnn7-mkl/paddlepaddle_gpu-2.0.0.post101-cp27-cp27mu-linux_x86_64.whl
```
由于默认的`paddlepaddle-gpu==2.0.0`是Cuda 10.2,并没有联编TensorRT,因此如果需要和在`paddlepaddle-gpu`上使用TensorRT,需要在上述多版本whl包列表当中,找到`cuda10.2-cudnn8.0-trt7.1.3`,下载对应的Python版本。
#### 安装Paddle Serving
```
# Cuda10.2
pip install paddle-server-server==${VERSION}.post102
# Cuda 10.1
pip install paddle-server-server==${VERSION}.post101
# Cuda 11
pip install paddle-server-server==${VERSION}.post11
```
### 使用TensorRT
#### RPC模式
[Serving模型示例](../python/examples)当中,我们有给出可以使用TensorRT加速的模型,例如detection下的[Faster_RCNN模型](../python/examples/detection/faster_rcnn_r50_fpn_1x_coco)
我们只需
```
wget --no-check-certificate https://paddle-serving.bj.bcebos.com/pddet_demo/2.0/faster_rcnn_r50_fpn_1x_coco.tar
tar xf faster_rcnn_r50_fpn_1x_coco.tar
python -m paddle_serving_server_gpu.serve --model serving_server --port 9494 --gpu_ids 0 --use_trt
```
TensorRT版本的faster_rcnn模型服务端就启动了
#### Local Predictor模式
[local_predictor](../python/paddle_serving_app/local_predict.py#L52)当中,用户可以显式制定`use_trt=True`传入到`load_model_config`当中。
其他方式和其他Local Predictor使用方法没有区别,需要注意模型对TensorRT的兼容性。
#### Pipeline模式
[Pipeline模式](./PIPELINE_SERVING_CN.md)当中,我们的[imagenet例子](../python/examples/pipeline/imagenet/config.yml#L23)给出了设置TensorRT的方式。
...@@ -33,6 +33,7 @@ tar -xzf bert_chinese_L-12_H-768_A-12.tar.gz ...@@ -33,6 +33,7 @@ tar -xzf bert_chinese_L-12_H-768_A-12.tar.gz
mv bert_chinese_L-12_H-768_A-12_model bert_seq128_model mv bert_chinese_L-12_H-768_A-12_model bert_seq128_model
mv bert_chinese_L-12_H-768_A-12_client bert_seq128_client mv bert_chinese_L-12_H-768_A-12_client bert_seq128_client
``` ```
if your model is bert_chinese_L-12_H-768_A-12_model, replace the 'bert_seq128_model' field in the following command with 'bert_chinese_L-12_H-768_A-12_model',replace 'bert_seq128_client' with 'bert_chinese_L-12_H-768_A-12_client'.
### Getting Dict and Sample Dataset ### Getting Dict and Sample Dataset
......
...@@ -28,6 +28,9 @@ tar -xzf bert_chinese_L-12_H-768_A-12.tar.gz ...@@ -28,6 +28,9 @@ tar -xzf bert_chinese_L-12_H-768_A-12.tar.gz
mv bert_chinese_L-12_H-768_A-12_model bert_seq128_model mv bert_chinese_L-12_H-768_A-12_model bert_seq128_model
mv bert_chinese_L-12_H-768_A-12_client bert_seq128_client mv bert_chinese_L-12_H-768_A-12_client bert_seq128_client
``` ```
若使用bert_chinese_L-12_H-768_A-12_model模型,将下面命令中的bert_seq128_model字段替换为bert_chinese_L-12_H-768_A-12_model,bert_seq128_client字段替换为bert_chinese_L-12_H-768_A-12_client.
### 获取词典和样例数据 ### 获取词典和样例数据
......
...@@ -14,7 +14,10 @@ Paddle Detection provides a large number of [Model Zoo](https://github.com/Paddl ...@@ -14,7 +14,10 @@ Paddle Detection provides a large number of [Model Zoo](https://github.com/Paddl
Several examples of PaddleDetection models used in Serving are given in this folder Several examples of PaddleDetection models used in Serving are given in this folder
All examples support TensorRT. All examples support TensorRT.
-[Faster RCNN](./faster_rcnn_r50_fpn_1x_coco) - [Faster RCNN](./faster_rcnn_r50_fpn_1x_coco)
-[PPYOLO](./ppyolo_r50vd_dcn_1x_coco) - [PPYOLO](./ppyolo_r50vd_dcn_1x_coco)
-[TTFNet](./ttfnet_darknet53_1x_coco) - [TTFNet](./ttfnet_darknet53_1x_coco)
-[YOLOv3](./yolov3_darknet53_270e_coco) - [YOLOv3](./yolov3_darknet53_270e_coco)
- [HRNet](./faster_rcnn_hrnetv2p_w18_1x)
- [Fcos](./fcos_dcn_r50_fpn_1x_coco)
- [SSD](./ssd_vgg16_300_240e_voc/)
...@@ -19,4 +19,6 @@ Paddle Detection提供了大量的[模型库](https://github.com/PaddlePaddle/Pa ...@@ -19,4 +19,6 @@ Paddle Detection提供了大量的[模型库](https://github.com/PaddlePaddle/Pa
- [PPYOLO](./ppyolo_r50vd_dcn_1x_coco) - [PPYOLO](./ppyolo_r50vd_dcn_1x_coco)
- [TTFNet](./ttfnet_darknet53_1x_coco) - [TTFNet](./ttfnet_darknet53_1x_coco)
- [YOLOv3](./yolov3_darknet53_270e_coco) - [YOLOv3](./yolov3_darknet53_270e_coco)
- [HRNet](./faster_rcnn_hrnetv2p_w18_1x)
- [Fcos](./fcos_dcn_r50_fpn_1x_coco)
- [SSD](./ssd_vgg16_300_240e_voc/)
# Faster RCNN HRNet model on Paddle Serving
([简体中文](./README_CN.md)|English)
### Get The Faster RCNN HRNet Model
```
wget --no-check-certificate https://paddle-serving.bj.bcebos.com/pddet_demo/2.0/faster_rcnn_hrnetv2p_w18_1x.tar
```
### Start the service
```
tar xf faster_rcnn_hrnetv2p_w18_1x.tar
python -m paddle_serving_server_gpu.serve --model serving_server --port 9494 --gpu_ids 0
```
This model support TensorRT, if you want a faster inference, please use `--use_trt`.
### Prediction
```
python test_client.py 000000570688.jpg
```
# 使用Paddle Serving部署Faster RCNN HRNet模型
(简体中文|[English](./README.md))
## 获得Faster RCNN HRNet模型
```
wget --no-check-certificate https://paddle-serving.bj.bcebos.com/pddet_demo/2.0/faster_rcnn_hrnetv2p_w18_1x.tar
```
### 启动服务
```
tar xf faster_rcnn_hrnetv2p_w18_1x.tar
python -m paddle_serving_server_gpu.serve --model serving_server --port 9494 --gpu_ids 0
```
该模型支持TensorRT,如果想要更快的预测速度,可以开启`--use_trt`选项。
### 执行预测
```
python test_client.py 000000570688.jpg
```
person
bicycle
car
motorcycle
airplane
bus
train
truck
boat
traffic light
fire hydrant
stop sign
parking meter
bench
bird
cat
dog
horse
sheep
cow
elephant
bear
zebra
giraffe
backpack
umbrella
handbag
tie
suitcase
frisbee
skis
snowboard
sports ball
kite
baseball bat
baseball glove
skateboard
surfboard
tennis racket
bottle
wine glass
cup
fork
knife
spoon
bowl
banana
apple
sandwich
orange
broccoli
carrot
hot dog
pizza
donut
cake
chair
couch
potted plant
bed
dining table
toilet
tv
laptop
mouse
remote
keyboard
cell phone
microwave
oven
toaster
sink
refrigerator
book
clock
vase
scissors
teddy bear
hair drier
toothbrush
from paddle_serving_client import Client
from paddle_serving_app.reader import *
import sys
import numpy as np
preprocess = Sequential([
File2Image(), BGR2RGB(), Div(255.0),
Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], False),
Resize(640, 640), Transpose((2, 0, 1))
])
postprocess = RCNNPostprocess("label_list.txt", "output")
client = Client()
client.load_client_config("serving_client/serving_client_conf.prototxt")
client.connect(['127.0.0.1:9494'])
im = preprocess(sys.argv[1])
fetch_map = client.predict(
feed={
"image": im,
"im_info": np.array(list(im.shape[1:]) + [1.0]),
"im_shape": np.array(list(im.shape[1:]) + [1.0])
},
fetch=["multiclass_nms_0.tmp_0"],
batch=False)
print(fetch_map)
# FCOS model on Paddle Serving
([简体中文](./README_CN.md)|English)
### Get Model
```
wget --no-check-certificate https://paddle-serving.bj.bcebos.com/pddet_demo/2.0/fcos_dcn_r50_fpn_1x_coco.tar
```
### Start the service
```
tar xf fcos_dcn_r50_fpn_1x_coco.tar
python -m paddle_serving_server_gpu.serve --model serving_server --port 9494 --gpu_ids 0
```
This model support TensorRT, if you want a faster inference, please use `--use_trt`.
### Perform prediction
```
python test_client.py 000000570688.jpg
```
# 使用Paddle Serving部署FCOS模型
(简体中文|[English](./README.md))
## 获得模型
```
wget --no-check-certificate https://paddle-serving.bj.bcebos.com/pddet_demo/2.0/fcos_dcn_r50_fpn_1x_coco.tar
```
### 启动服务
```
tar xf fcos_dcn_r50_fpn_1x_coco.tar
python -m paddle_serving_server_gpu.serve --model serving_server --port 9494 --gpu_ids 0
```
该模型支持TensorRT,如果想要更快的预测速度,可以开启`--use_trt`选项。
### 执行预测
```
python test_client.py 000000570688.jpg
```
person
bicycle
car
motorcycle
airplane
bus
train
truck
boat
traffic light
fire hydrant
stop sign
parking meter
bench
bird
cat
dog
horse
sheep
cow
elephant
bear
zebra
giraffe
backpack
umbrella
handbag
tie
suitcase
frisbee
skis
snowboard
sports ball
kite
baseball bat
baseball glove
skateboard
surfboard
tennis racket
bottle
wine glass
cup
fork
knife
spoon
bowl
banana
apple
sandwich
orange
broccoli
carrot
hot dog
pizza
donut
cake
chair
couch
potted plant
bed
dining table
toilet
tv
laptop
mouse
remote
keyboard
cell phone
microwave
oven
toaster
sink
refrigerator
book
clock
vase
scissors
teddy bear
hair drier
toothbrush
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_serving_client import Client
from paddle_serving_app.reader import *
import sys
import numpy as np
preprocess = Sequential([
File2Image(), BGR2RGB(), Div(255.0),
Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], False),
Resize(640, 640), Transpose((2, 0, 1))
])
postprocess = RCNNPostprocess("label_list.txt", "output")
client = Client()
client.load_client_config("serving_client/serving_client_conf.prototxt")
client.connect(['127.0.0.1:9494'])
im = preprocess(sys.argv[1])
fetch_map = client.predict(
feed={
"image": im,
"scale_factor": np.array([1.0, 1.0]).reshape(-1),
},
fetch=["save_infer_model/scale_0.tmp_1"],
batch=False)
print(fetch_map)
# SSD model on Paddle Serving
([简体中文](./README_CN.md)|English)
### Get Model
```
wget --no-check-certificate https://paddle-serving.bj.bcebos.com/pddet_demo/2.0/ssd_vgg16_300_240e_voc.tar
```
### Start the service
```
tar xf ssd_vgg16_300_240e_voc.tar
python -m paddle_serving_server_gpu.serve --model serving_server --port 9494 --gpu_ids 0
```
This model support TensorRT, if you want a faster inference, please use `--use_trt`.
### Perform prediction
```
python test_client.py 000000570688.jpg
```
# 使用Paddle Serving部署SSD模型
(简体中文|[English](./README.md))
## 获得模型
```
wget --no-check-certificate https://paddle-serving.bj.bcebos.com/pddet_demo/2.0/ssd_vgg16_300_240e_voc.tar
```
### 启动服务
```
tar xf ssd_vgg16_300_240e_voc.tar
python -m paddle_serving_server_gpu.serve --model serving_server --port 9494 --gpu_ids 0
```
该模型支持TensorRT,如果想要更快的预测速度,可以开启`--use_trt`选项。
### 执行预测
```
python test_client.py 000000570688.jpg
```
person
bicycle
car
motorcycle
airplane
bus
train
truck
boat
traffic light
fire hydrant
stop sign
parking meter
bench
bird
cat
dog
horse
sheep
cow
elephant
bear
zebra
giraffe
backpack
umbrella
handbag
tie
suitcase
frisbee
skis
snowboard
sports ball
kite
baseball bat
baseball glove
skateboard
surfboard
tennis racket
bottle
wine glass
cup
fork
knife
spoon
bowl
banana
apple
sandwich
orange
broccoli
carrot
hot dog
pizza
donut
cake
chair
couch
potted plant
bed
dining table
toilet
tv
laptop
mouse
remote
keyboard
cell phone
microwave
oven
toaster
sink
refrigerator
book
clock
vase
scissors
teddy bear
hair drier
toothbrush
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_serving_client import Client
from paddle_serving_app.reader import *
import sys
import numpy as np
preprocess = Sequential([
File2Image(), BGR2RGB(),
Normalize([123.675, 116.28, 103.53], [58.395, 57.12, 57.375], False),
Resize((512, 512)), Transpose((2, 0, 1))
])
postprocess = RCNNPostprocess("label_list.txt", "output")
client = Client()
client.load_client_config("serving_client/serving_client_conf.prototxt")
client.connect(['127.0.0.1:9494'])
im = preprocess(sys.argv[1])
fetch_map = client.predict(
feed={
"image": im,
"scale_factor": np.array([1.0, 1.0]).reshape(-1),
},
fetch=["save_infer_model/scale_0.tmp_1"],
batch=False)
print(fetch_map)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import unittest
import sys
import numpy as np
import cv2
from paddle_serving_app.reader import Sequential, Resize, File2Image
import libgpupreprocess as pp
import libhwextract
class TestOperators(unittest.TestCase):
"""
test all operators, e.g. Div, Normalize
"""
def test_div(self):
height = 4
width = 5
channels = 3
value = 255.0
img = np.arange(height * width * channels).reshape(
[height, width, channels])
seq = Sequential(
[pp.Image2Gpubuffer(), pp.Div(value), pp.Gpubuffer2Image()])
result = seq(img).reshape(-1)
for i in range(0, result.size):
self.assertAlmostEqual(i / value, result[i], 5)
def test_sub(self):
height = 4
width = 5
channels = 3
img = np.arange(height * width * channels).reshape(
[height, width, channels])
# input size is an int
value = 10.0
seq = Sequential(
[pp.Image2Gpubuffer(), pp.Sub(value), pp.Gpubuffer2Image()])
result = seq(img).reshape(-1)
for i in range(0, result.size):
self.assertEqual(i - value, result[i])
# input size is a sequence
values = (9, 4, 2)
seq = Sequential(
[pp.Image2Gpubuffer(), pp.Sub(values), pp.Gpubuffer2Image()])
result = seq(img)
for i in range(0, result.shape[0]):
for j in range(0, result.shape[1]):
for k in range(0, result.shape[2]):
self.assertEqual(result[i][j][k], img[i][j][k] - values[k])
def test_normalize(self):
height = 4
width = 5
channels = 3
img = np.random.rand(height, width, channels)
mean = [5.0, 5.0, 5.0]
std = [2.0, 2.0, 2.0]
seq = Sequential([
pp.Image2Gpubuffer(), pp.Normalize(mean, std), pp.Gpubuffer2Image()
])
result = seq(img)
for i in range(0, height):
for j in range(0, width):
for k in range(0, channels):
self.assertAlmostEqual((img[i][j][k] - mean[k]) / std[k],
result[i][j][k], 5)
def test_center_crop(self):
height = 9
width = 7
channels = 3
img = np.arange(height * width * channels).reshape(
[height, width, channels])
new_size = 5
seq = Sequential([
pp.Image2Gpubuffer(), pp.CenterCrop(new_size), pp.Gpubuffer2Image()
])
result = seq(img)
self.assertEqual(result.shape[0], new_size)
self.assertEqual(result.shape[1], new_size)
self.assertEqual(result.shape[2], channels)
def test_resize(self):
height = 9
width = 5
channels = 3
img = np.arange(height * width).reshape([height, width, 1]) * np.ones(
(1, channels))
# input size is an int
for new_size in [3, 10]:
seq_gpu = Sequential([
pp.Image2Gpubuffer(), pp.Resize(new_size), pp.Gpubuffer2Image()
])
seq_paddle = Sequential([Resize(new_size)])
result_gpu = seq_gpu(img)
result_paddle = seq_paddle(img)
self.assertEqual(result_gpu.shape, result_paddle.shape)
for i in range(0, result_gpu.shape[0]):
for j in range(0, result_gpu.shape[1]):
for k in range(0, result_gpu.shape[2]):
self.assertAlmostEqual(result_gpu[i][j][k],
result_paddle[i][j][k], 5)
# input size is a sequence
for new_height, new_width in [(7, 3), (15, 10)]:
seq_gpu = Sequential([
pp.Image2Gpubuffer(), pp.Resize((new_width, new_height)),
pp.Gpubuffer2Image()
])
seq_paddle = Sequential([Resize((new_width, new_height))])
result_gpu = seq_gpu(img)
result_paddle = seq_paddle(img)
self.assertEqual(result_gpu.shape, result_paddle.shape)
for i in range(0, result_gpu.shape[0]):
for j in range(0, result_gpu.shape[1]):
for k in range(0, result_gpu.shape[2]):
self.assertAlmostEqual(result_gpu[i][j][k],
result_paddle[i][j][k], 5)
def test_resize_fixed_point(self):
new_height = 256
new_width = 256 * 4 / 3
seq = Sequential([
File2Image(), pp.Image2Gpubuffer(), pp.Resize(
(new_width, new_height), use_fixed_point=True),
pp.Gpubuffer2Image()
])
img = seq("./capture_16.bmp")
img = np.resize(img, (new_height, new_width * 3))
img_vis = np.loadtxt("./cap_resize_16.raw")
img_resize_diff = img_vis - img
self.assertEqual(np.all(img_resize_diff == 0), True)
if __name__ == '__main__':
unittest.main()
...@@ -11,3 +11,6 @@ ...@@ -11,3 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# pylint: disable=doc-string-missing
from .arr2image import Arr2Image
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import cv2
import yaml
class Arr2Image(object):
"""
from numpy array image(jpeg) to cv::Mat image
"""
def __init__(self):
pass
def __call__(self, img_arr):
img = cv2.imdecode(img_arr, cv2.IMREAD_COLOR)
return img
def __repr__(self):
return self.__class__.__name__ + "()"
...@@ -16,6 +16,7 @@ from __future__ import absolute_import ...@@ -16,6 +16,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import platform
import os import os
from setuptools import setup, Distribution, Extension from setuptools import setup, Distribution, Extension
...@@ -23,14 +24,22 @@ from setuptools import find_packages ...@@ -23,14 +24,22 @@ from setuptools import find_packages
from setuptools import setup from setuptools import setup
from paddle_serving_app.version import serving_app_version from paddle_serving_app.version import serving_app_version
from pkg_resources import DistributionNotFound, get_distribution from pkg_resources import DistributionNotFound, get_distribution
import util
max_version, mid_version, min_version = util.python_version() def python_version():
return [int(v) for v in platform.python_version().split(".")]
def find_package(pkgname):
try:
get_distribution(pkgname)
return True
except DistributionNotFound:
return False
max_version, mid_version, min_version = python_version()
if '${PACK}' == 'ON': if '${PACK}' == 'ON':
copy_lib() copy_lib()
REQUIRED_PACKAGES = [ REQUIRED_PACKAGES = [
'six >= 1.10.0', 'sentencepiece<=0.1.83', 'opencv-python<=4.2.0.32', 'pillow', 'six >= 1.10.0', 'sentencepiece<=0.1.83', 'opencv-python<=4.2.0.32', 'pillow',
'pyclipper', 'shapely' 'pyclipper', 'shapely'
...@@ -43,7 +52,11 @@ packages=['paddle_serving_app', ...@@ -43,7 +52,11 @@ packages=['paddle_serving_app',
'paddle_serving_app.models', 'paddle_serving_app.models',
'paddle_serving_app.reader.pddet'] 'paddle_serving_app.reader.pddet']
package_data={} if os.path.exists('../core/preprocess/hwvideoframe/libgpupreprocess.so'):
package_data={'paddle_serving_app': ['../core/preprocess/hwvideoframe/libgpupreprocess.so'],}
else:
package_data={}
package_dir={'paddle_serving_app': package_dir={'paddle_serving_app':
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_app', '${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_app',
'paddle_serving_app.proto': 'paddle_serving_app.proto':
...@@ -55,7 +68,8 @@ package_dir={'paddle_serving_app': ...@@ -55,7 +68,8 @@ package_dir={'paddle_serving_app':
'paddle_serving_app.models': 'paddle_serving_app.models':
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_app/models', '${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_app/models',
'paddle_serving_app.reader.pddet': 'paddle_serving_app.reader.pddet':
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_app/reader/pddet',} '${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_app/reader/pddet',
}
setup( setup(
name='paddle-serving-app', name='paddle-serving-app',
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# pylint: disable=doc-string-missing
from pkg_resources import DistributionNotFound, get_distribution from pkg_resources import DistributionNotFound, get_distribution
from grpc_tools import protoc from grpc_tools import protoc
......
...@@ -932,7 +932,6 @@ function python_app_api_test(){ ...@@ -932,7 +932,6 @@ function python_app_api_test(){
cd imagenet cd imagenet
case $TYPE in case $TYPE in
CPU) CPU)
check_cmd "python test_image_reader.py"
;; ;;
GPU) GPU)
echo "no implement for cpu type" echo "no implement for cpu type"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册