diff --git a/README.md b/README.md index 76e46b968aa77dac12ce241ebf66d53eebda5462..0464cf3dc8d85da1f96d7328ac2e4c03962c09cf 100644 --- a/README.md +++ b/README.md @@ -85,14 +85,14 @@ We **highly recommend** you to **run Paddle Serving in Docker**, please visit [R ``` # Run CPU Docker docker pull registry.baidubce.com/paddlepaddle/serving:0.5.0-devel -docker run -p 9292:9292 --name test -dit registry.baidubce.com/paddlepaddle/serving:0.5.0-devel +docker run -p 9292:9292 --name test -dit registry.baidubce.com/paddlepaddle/serving:0.5.0-devel bash docker exec -it test bash git clone https://github.com/PaddlePaddle/Serving ``` ``` # Run GPU Docker nvidia-docker pull registry.baidubce.com/paddlepaddle/serving:0.5.0-cuda10.2-cudnn8-devel -nvidia-docker run -p 9292:9292 --name test -dit registry.baidubce.com/paddlepaddle/serving:0.5.0-cuda10.2-cudnn8-devel +nvidia-docker run -p 9292:9292 --name test -dit registry.baidubce.com/paddlepaddle/serving:0.5.0-cuda10.2-cudnn8-devel bash nvidia-docker exec -it test bash git clone https://github.com/PaddlePaddle/Serving ``` diff --git a/README_CN.md b/README_CN.md index 2cae7b525833f9c60411ea6c7f48f3860e22a10b..ad8479ee6026eb42b128bd0d12395c4546b2c050 100644 --- a/README_CN.md +++ b/README_CN.md @@ -86,14 +86,14 @@ Paddle Serving开发者为您提供了简单易用的[AIStudio教程-Paddle Serv ``` # 启动 CPU Docker docker pull registry.baidubce.com/paddlepaddle/serving:0.5.0-devel -docker run -p 9292:9292 --name test -dit registry.baidubce.com/paddlepaddle/serving:0.5.0-devel +docker run -p 9292:9292 --name test -dit registry.baidubce.com/paddlepaddle/serving:0.5.0-devel bash docker exec -it test bash git clone https://github.com/PaddlePaddle/Serving ``` ``` # 启动 GPU Docker nvidia-docker pull registry.baidubce.com/paddlepaddle/serving:0.5.0-cuda10.2-cudnn8-devel -nvidia-docker run -p 9292:9292 --name test -dit registry.baidubce.com/paddlepaddle/serving:0.5.0-cuda10.2-cudnn8-devel +nvidia-docker run -p 9292:9292 --name test -dit registry.baidubce.com/paddlepaddle/serving:0.5.0-cuda10.2-cudnn8-devel bash nvidia-docker exec -it test bash git clone https://github.com/PaddlePaddle/Serving ``` diff --git a/core/CMakeLists.txt b/core/CMakeLists.txt index 56296b53319fb185c772ffa10e8b31c8203862fb..a174bbe1a35064f99ba73c10dea834ae7bc17e4e 100644 --- a/core/CMakeLists.txt +++ b/core/CMakeLists.txt @@ -25,6 +25,7 @@ endif() if (APP) add_subdirectory(configure) +add_subdirectory(preprocess) endif() diff --git a/core/general-client/CMakeLists.txt b/core/general-client/CMakeLists.txt index ddacb8d53d141e242fe6222a837ec8997608382b..d6079317a75d3f45b61920836e6695bd6b31d951 100644 --- a/core/general-client/CMakeLists.txt +++ b/core/general-client/CMakeLists.txt @@ -1,6 +1,5 @@ if(CLIENT) add_subdirectory(pybind11) pybind11_add_module(serving_client src/general_model.cpp src/pybind_general_model.cpp) -add_dependencies(serving_client sdk_cpp) target_link_libraries(serving_client PRIVATE -Wl,--whole-archive utils sdk-cpp pybind python -Wl,--no-whole-archive -lpthread -lcrypto -lm -lrt -lssl -ldl -lz -Wl,-rpath,'$ORIGIN'/lib) endif() diff --git a/core/preprocess/CMakeLists.txt b/core/preprocess/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..961a09c4598482af98b7102f1ae7999cab4a18e4 --- /dev/null +++ b/core/preprocess/CMakeLists.txt @@ -0,0 +1,12 @@ +if(NOT WITH_GPU) + message("GPU preprocess will not be compiled.") + return() +endif() + +message(STATUS "CUDA detected: " ${CUDA_VERSION}) +if (${CUDA_VERSION} LESS 10.0) + message("CUDA version should be 10.0.") + return() +elseif (${CUDA_VERSION} LESS 10.1) # CUDA 10.0 + add_subdirectory(hwvideoframe) +endif() diff --git a/core/preprocess/hwvideoframe/CMakeLists.txt b/core/preprocess/hwvideoframe/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..2f57cbf3b5673e98668cedf30d8368440edd242e --- /dev/null +++ b/core/preprocess/hwvideoframe/CMakeLists.txt @@ -0,0 +1,27 @@ +cmake_minimum_required(VERSION 3.2) +project(gpupreprocess) +include(cuda) +include(configure) + +#C flags +set(CMAKE_C_FLAGS " -g -pipe -W -Wall -fPIC") + +#C++ flags. +set(CMAKE_CXX_FLAGS " -g -pipe -W -Wall -fPIC -std=c++11") + +add_subdirectory(cuda) +set(PYTHON_SO {PYTHON_LIBRARY}) +set(EXTRA_LIBS ${EXTRA_LIBS} gpu) +file(GLOB SOURCE_FILES pybind/*.cpp src/*.cpp) + +include_directories("./include") +include_directories(${CUDA_INCLUDE_DIRS}) +include_directories(/usr${PYTHON_INCLUDE_DIR}) + +link_directories("-L${CUDA_TOOLKIT_ROOT_DIR}/lib64 -lcudart -lnppidei_static -lnppial_static -lnpps_static -lnppc_static -lculibos") +link_directories(${PYTHON_SO}) + +#.so +add_library(gpupreprocess SHARED ${SOURCE_FILES}) +target_link_libraries(gpupreprocess ${EXTRA_LIBS}) +target_link_libraries(gpupreprocess ${CUDA_LIBRARIES}) diff --git a/core/preprocess/hwvideoframe/README.md b/core/preprocess/hwvideoframe/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5d561c75fc6b891f9752a303630b3914dce7998b --- /dev/null +++ b/core/preprocess/hwvideoframe/README.md @@ -0,0 +1,61 @@ +# hwvideoframe + +Hwvideoframe is a CV preprocessing library based on cuda. The project uses GPU for image preprocessing operations. It speeds up the processing speed while increasing the utilization rate of the GPU. + +## Preprocess API + +Hwvideoframe provides a variety of data preprocessing methods for photo preprocess: + +- class Image2Gpubuffer + + - `__call__(img)` + - img(np.array):Image data. + +- class Gpubuffer2Image + + - `__call__(img)` + - img(np.array):Image data. + +- class Div + + - `__init__(value)` + - value(float):Constant value to be divided. + - `__call__(img)` + - img(np.array):Image data. + +- class Sub + + - `__init__(subtractor)` + - subtractor(list/float):Three 32-bit floating point channel image subtract constant. When the input is a list type, length of list must be three. + - `__call__(img)` + - img(np.array):Image data in (C,H,W) channels. + +- class Normalize + + - `__init__(mean,std)` + - mean(list):Mean. Length of list must be three. + - std(list):Variance. Length of list must be three. + - `__call__(img)` + - img(np.array):Image data in (C,H,W) channels. + +- class CenterCrop + + - `__init__(size)` + - size(int):Crops the given Image at the center while the size must not bigger than any inputs' height and width. + - `__call__(img)` + - img(np.array):Image data in (C,H,W) channels. + +- class Resize + + - `__init__(size, max_size=2147483647, interpolation=None)` + - size(list/int):The expected image size, when the input is a list type, it needs to contain the expected length and width. When the input is int type, the short side will be set to the length of size, and the long side will be scaled proportionally. + - `__call__(img)` + - img(numpy array):Image data in (C,H,W) channels. + +## Quick start + +[After compiling from code](https://github.com/PaddlePaddle/Serving/blob/develop/doc/COMPILE.md),this project will be stored in reader。 + +## How to Test + +Test file:Serving/python/paddle_serving_app/reader/test_preprocess.py diff --git a/core/preprocess/hwvideoframe/README_ZN.md b/core/preprocess/hwvideoframe/README_ZN.md new file mode 100644 index 0000000000000000000000000000000000000000..aed080b8c1762b0bbc61cb3ed075b17308c826f9 --- /dev/null +++ b/core/preprocess/hwvideoframe/README_ZN.md @@ -0,0 +1,61 @@ +# hwvideoframe + +hwvideoframe是一个基于cuda的图片预处理库。项目利用GPU进行图片预处理操作,在加快处理速度的同时,提高GPU的使用率。 + +## 项目结构 + +hwvideoframe目前提供以下图片预处理功能: + +- class Image2Gpubuffer + + - `__call__(img)` + - img(np.array):输入图像。 + +- class Gpubuffer2Image + + - `__call__(img)` + - img(np.array):输入图像。 + +- class Div + + - `__init__(value)` + - value(float):根据固定值切割图像。 + - `__call__(img)` + - img(np.array):输入图像。 + +- class Sub + + - `__init__(subtractor)` + - subtractor(list/float):list的长度必须为3。 + - `__call__(img)` + - img(np.array):(C,H,W)排列的图像数据。 + +- class Normalize + + - `__init__(mean,std)` + - mean(list):均值。 list长度必须为3。 + - std(list):方差。 list长度必须为3。 + - `__call__(img)` + - img(np.array):(C,H,W)排列的图像数据。 + +- class CenterCrop + + - `__init__(size)` + - size(int):预期的裁剪后的大小,list类型时需要包含预期的长和宽,int类型时会返回边长为size的正方形图片。size不能大于原始图片大小。 + - `__call__(img)` + - img(np.array):(C,H,W)排列的图像数据 + +- class Resize + + - `__init__(size, max_size=2147483647, interpolation=None)` + - size(list/int):预期的图像大小,短边会设置为size的长度,长边按比例缩放. + - `__call__(img)` + - img(numpy array):(C,H,W)排列的图像数据 + +## 快速开始 + +按照Paddle Serving文档编译,编译结果在reader中。 + +## 测试 + +测试文件:Serving/python/paddle_serving_app/reader/test_preprocess.py diff --git a/core/preprocess/hwvideoframe/cuda/CMakeLists.txt b/core/preprocess/hwvideoframe/cuda/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..dddfea97ab56f08705432eaa7c60955cd2e0cdb2 --- /dev/null +++ b/core/preprocess/hwvideoframe/cuda/CMakeLists.txt @@ -0,0 +1,18 @@ +cmake_minimum_required(VERSION 3.2) +project(gpu) +FIND_PACKAGE(CUDA ${CUDA_VERSION} REQUIRED) +SET(CUDA_TARGET_INCLUDE ${CUDA_TOOLKIT_ROOT_DIR}-${CUDA_VERSION}/targets/${CMAKE_HOST_SYSTEM_PROCESSOR}-${LOWER_SYSTEM_NAME}/include) + +file(GLOB_RECURSE CURRENT_HEADERS *.h *.hpp *.cuh) +file(GLOB CURRENT_SOURCES *.cpp *.cu) + +source_group("Include" FILES ${CURRENT_HEADERS}) +source_group("Source" FILES ${CURRENT_SOURCES}) + +set(CMAKE_CUDA_FLAGS "-ccbin /opt/compiler/gcc-4.8.2/bin -Xcompiler -fPIC --std=c++11") +set(CUDA_NVCC_FLAGS "-L/opt/compiler/gcc-4.8.2/bin -Xcompiler -fPIC --std=c++11") + +include_directories(${CUDA_INCLUDE_DIRS}) + +cuda_add_library(gpu SHARED ${CURRENT_HEADERS} ${CURRENT_SOURCES}) +target_link_libraries(gpu ${CUDA_LIBS}) diff --git a/core/preprocess/hwvideoframe/cuda/resize.cu b/core/preprocess/hwvideoframe/cuda/resize.cu new file mode 100644 index 0000000000000000000000000000000000000000..a20e2aa3288ddc30ccd76998464bc64c4f4a9ea1 --- /dev/null +++ b/core/preprocess/hwvideoframe/cuda/resize.cu @@ -0,0 +1,284 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "./cuda_runtime.h" + +#define clip(x, a, b) x >= a ? (x < b ? x : b - 1) : a; + +const int INTER_RESIZE_COEF_BITS = 11; +const int INTER_RESIZE_COEF_SCALE = 1 << INTER_RESIZE_COEF_BITS; + +__global__ void resizeCudaKernel(const float* input, + float* output, + const int inputWidth, + const int inputHeight, + const int outputWidth, + const int outputHeight, + const int inputChannels) { + // 2D Index of current thread + const int dx = blockIdx.x * blockDim.x + threadIdx.x; + const int dy = blockIdx.y * blockDim.y + threadIdx.y; + + if ((dx < outputWidth) && (dy < outputHeight)) { + if (inputChannels == 1) { // grayscale image + // TODO(Zelda): support grayscale + } else if (inputChannels == 3) { // RGB image + double scale_x = static_cast(inputWidth / outputWidth); + double scale_y = static_cast(inputHeight / outputHeight); + + int xmax = outputWidth; + + float fx = static_cast((dx + 0.5) * scale_x - 0.5); + int sx = floorf(fx); + fx = fx - sx; + + int isx1 = sx; + if (isx1 < 0) { + fx = 0.0; + isx1 = 0; + } + if (isx1 >= (inputWidth - 1)) { + xmax = ::min(xmax, dx); + fx = 0; + isx1 = inputWidth - 1; + } + + float2 cbufx; + cbufx.x = (1.f - fx); + cbufx.y = fx; + + float fy = static_cast((dy + 0.5) * scale_y - 0.5); + int sy = floorf(fy); + fy = fy - sy; + + int isy1 = clip(sy + 0, 0, inputHeight); + int isy2 = clip(sy + 1, 0, inputHeight); + + float2 cbufy; + cbufy.x = (1.f - fy); + cbufy.y = fy; + + int isx2 = isx1 + 1; + + float3 d0; + + float3 s11 = + make_float3(input[(isy1 * inputWidth + isx1) * inputChannels + 0], + input[(isy1 * inputWidth + isx1) * inputChannels + 1], + input[(isy1 * inputWidth + isx1) * inputChannels + 2]); + float3 s12 = + make_float3(input[(isy1 * inputWidth + isx2) * inputChannels + 0], + input[(isy1 * inputWidth + isx2) * inputChannels + 1], + input[(isy1 * inputWidth + isx2) * inputChannels + 2]); + float3 s21 = + make_float3(input[(isy2 * inputWidth + isx1) * inputChannels + 0], + input[(isy2 * inputWidth + isx1) * inputChannels + 1], + input[(isy2 * inputWidth + isx1) * inputChannels + 2]); + float3 s22 = + make_float3(input[(isy2 * inputWidth + isx2) * inputChannels + 0], + input[(isy2 * inputWidth + isx2) * inputChannels + 1], + input[(isy2 * inputWidth + isx2) * inputChannels + 2]); + + float h_rst00, h_rst01; + // B + if (dx > xmax - 1) { + h_rst00 = s11.x; + h_rst01 = s21.x; + } else { + h_rst00 = s11.x * cbufx.x + s12.x * cbufx.y; + h_rst01 = s21.x * cbufx.x + s22.x * cbufx.y; + } + d0.x = h_rst00 * cbufy.x + h_rst01 * cbufy.y; + + // G + if (dx > xmax - 1) { + h_rst00 = s11.y; + h_rst01 = s21.y; + } else { + h_rst00 = s11.y * cbufx.x + s12.y * cbufx.y; + h_rst01 = s21.y * cbufx.x + s22.y * cbufx.y; + } + d0.y = h_rst00 * cbufy.x + h_rst01 * cbufy.y; + // R + if (dx > xmax - 1) { + h_rst00 = s11.z; + h_rst01 = s21.z; + } else { + h_rst00 = s11.z * cbufx.x + s12.z * cbufx.y; + h_rst01 = s21.z * cbufx.x + s22.z * cbufx.y; + } + d0.z = h_rst00 * cbufy.x + h_rst01 * cbufy.y; + + output[(dy * outputWidth + dx) * 3 + 0] = (d0.x); // R + output[(dy * outputWidth + dx) * 3 + 1] = (d0.y); // G + output[(dy * outputWidth + dx) * 3 + 2] = (d0.z); // B + } else { + // TODO(Zelda): support alpha channel + } + } +} + +__global__ void resizeCudaKernel_fixpt(const float* input, + float* output, + const int inputWidth, + const int inputHeight, + const int outputWidth, + const int outputHeight, + const int inputChannels) { + // 2D Index of current thread + const int dx = blockIdx.x * blockDim.x + threadIdx.x; + const int dy = blockIdx.y * blockDim.y + threadIdx.y; + + if ((dx < outputWidth) && (dy < outputHeight)) { + if (inputChannels == 1) { // grayscale image + // TODO(Zelda): support grayscale + } else if (inputChannels == 3) { // RGB image + double scale_x = static_cast(inputWidth / outputWidth); + double scale_y = static_cast(inputHeight / outputHeight); + + int xmax = outputWidth; + + float fx = static_cast((dx + 0.5) * scale_x - 0.5); + int sx = floorf(fx); + fx = fx - sx; + + int isx1 = sx; + if (isx1 < 0) { + fx = 0.0; + isx1 = 0; + } + if (isx1 >= (inputWidth - 1)) { + xmax = ::min(xmax, dx); + fx = 0; + isx1 = inputWidth - 1; + } + + short2 cbufx; + cbufx.x = lrintf((1.f - fx) * INTER_RESIZE_COEF_SCALE); + cbufx.y = lrintf(fx * INTER_RESIZE_COEF_SCALE); + + float fy = static_cast((dy + 0.5) * scale_y - 0.5); + int sy = floorf(fy); + fy = fy - sy; + + int isy1 = clip(sy + 0, 0, inputHeight); + int isy2 = clip(sy + 1, 0, inputHeight); + + short2 cbufy; + cbufy.x = lrintf((1.f - fy) * INTER_RESIZE_COEF_SCALE); + cbufy.y = lrintf(fy * INTER_RESIZE_COEF_SCALE); + + int isx2 = isx1 + 1; + + uchar3 d0; + + int3 s11 = + make_int3(input[(isy1 * inputWidth + isx1) * inputChannels + 0], + input[(isy1 * inputWidth + isx1) * inputChannels + 1], + input[(isy1 * inputWidth + isx1) * inputChannels + 2]); + int3 s12 = + make_int3(input[(isy1 * inputWidth + isx2) * inputChannels + 0], + input[(isy1 * inputWidth + isx2) * inputChannels + 1], + input[(isy1 * inputWidth + isx2) * inputChannels + 2]); + int3 s21 = + make_int3(input[(isy2 * inputWidth + isx1) * inputChannels + 0], + input[(isy2 * inputWidth + isx1) * inputChannels + 1], + input[(isy2 * inputWidth + isx1) * inputChannels + 2]); + int3 s22 = + make_int3(input[(isy2 * inputWidth + isx2) * inputChannels + 0], + input[(isy2 * inputWidth + isx2) * inputChannels + 1], + input[(isy2 * inputWidth + isx2) * inputChannels + 2]); + + int h_rst00, h_rst01; + // B + if (dx > xmax - 1) { + h_rst00 = s11.x * INTER_RESIZE_COEF_SCALE; + h_rst01 = s21.x * INTER_RESIZE_COEF_SCALE; + } else { + h_rst00 = s11.x * cbufx.x + s12.x * cbufx.y; + h_rst01 = s21.x * cbufx.x + s22.x * cbufx.y; + } + d0.x = (unsigned char)((((cbufy.x * (h_rst00 >> 4)) >> 16) + + ((cbufy.y * (h_rst01 >> 4)) >> 16) + 2) >> + 2); + + // G + if (dx > xmax - 1) { + h_rst00 = s11.y * INTER_RESIZE_COEF_SCALE; + h_rst01 = s21.y * INTER_RESIZE_COEF_SCALE; + } else { + h_rst00 = s11.y * cbufx.x + s12.y * cbufx.y; + h_rst01 = s21.y * cbufx.x + s22.y * cbufx.y; + } + d0.y = (unsigned char)((((cbufy.x * (h_rst00 >> 4)) >> 16) + + ((cbufy.y * (h_rst01 >> 4)) >> 16) + 2) >> + 2); + // R + if (dx > xmax - 1) { + h_rst00 = s11.z * INTER_RESIZE_COEF_SCALE; + h_rst01 = s21.z * INTER_RESIZE_COEF_SCALE; + } else { + h_rst00 = s11.z * cbufx.x + s12.z * cbufx.y; + h_rst01 = s21.z * cbufx.x + s22.z * cbufx.y; + } + d0.z = (unsigned char)((((cbufy.x * (h_rst00 >> 4)) >> 16) + + ((cbufy.y * (h_rst01 >> 4)) >> 16) + 2) >> + 2); + + output[(dy * outputWidth + dx) * 3 + 0] = (d0.x); // R + output[(dy * outputWidth + dx) * 3 + 1] = (d0.y); // G + output[(dy * outputWidth + dx) * 3 + 2] = (d0.z); // B + } else { + // TODO(Zelda): support alpha channel + } + } +} + +extern "C" cudaError_t resize_linear(const float* input, + float* output, + const int inputWidth, + const int inputHeight, + const int outputWidth, + const int outputHeight, + const int inputChannels, + const bool use_fixed_point) { + // Specify a reasonable block size + const dim3 block(16, 16); + + // Calculate grid size to cover the whole image + const dim3 grid((outputWidth + block.x - 1) / block.x, + (outputHeight + block.y - 1) / block.y); + + // Launch the size conversion kernel + if (use_fixed_point) { + resizeCudaKernel_fixpt<<>>(input, + output, + inputWidth, + inputHeight, + outputWidth, + outputHeight, + inputChannels); + } else { + resizeCudaKernel<<>>(input, + output, + inputWidth, + inputHeight, + outputWidth, + outputHeight, + inputChannels); + } + + // Synchronize to check for any kernel launch errors + return cudaDeviceSynchronize(); +} diff --git a/core/preprocess/hwvideoframe/include/center_crop.h b/core/preprocess/hwvideoframe/include/center_crop.h new file mode 100644 index 0000000000000000000000000000000000000000..11830fdaa9e0a9aa7b5b002d6bd2decd8fd10fd2 --- /dev/null +++ b/core/preprocess/hwvideoframe/include/center_crop.h @@ -0,0 +1,33 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_CENTER_CROP_H_ +#define CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_CENTER_CROP_H_ + +#include +#include +#include "./op_context.h" + +// Crops the given Image at the center. +// the size must not bigger than any inputs' height and width +class CenterCrop { + public: + explicit CenterCrop(int size) : _size(size) {} + std::shared_ptr operator()(std::shared_ptr input); + + private: + int _size; +}; + +#endif // CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_CENTER_CROP_H_ diff --git a/core/preprocess/hwvideoframe/include/div.h b/core/preprocess/hwvideoframe/include/div.h new file mode 100644 index 0000000000000000000000000000000000000000..67482d00ca9300a6c6312b8c0d80b5b684e76da0 --- /dev/null +++ b/core/preprocess/hwvideoframe/include/div.h @@ -0,0 +1,32 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_DIV_H_ +#define CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_DIV_H_ + +#include +#include +#include "core/preprocess/hwvideoframe/include/op_context.h" + +// divide by some float number for all pixel +class Div { + public: + explicit Div(float value); + std::shared_ptr operator()(std::shared_ptr input); + + private: + Npp32f _divisor; +}; + +#endif // CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_DIV_H_ diff --git a/core/preprocess/hwvideoframe/include/image_io.h b/core/preprocess/hwvideoframe/include/image_io.h new file mode 100644 index 0000000000000000000000000000000000000000..bbc15ef39125275b1830b939c800146b0779fb80 --- /dev/null +++ b/core/preprocess/hwvideoframe/include/image_io.h @@ -0,0 +1,36 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_IMAGE_IO_H_ +#define CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_IMAGE_IO_H_ + +#include +#include +#include + +#include "core/preprocess/hwvideoframe/include/op_context.h" + +// Input operator that copy numpy data to gpu buffer +class Image2Gpubuffer { + public: + std::shared_ptr operator()(pybind11::array_t array); +}; + +// Output operator that copy gpu buffer data to numpy +class Gpubuffer2Image { + public: + pybind11::array_t operator()(std::shared_ptr input); +}; + +#endif // CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_IMAGE_IO_H_ diff --git a/core/preprocess/hwvideoframe/include/normalize.h b/core/preprocess/hwvideoframe/include/normalize.h new file mode 100644 index 0000000000000000000000000000000000000000..db28efa7ef34329b1837e8b9e39e5d80fc97110c --- /dev/null +++ b/core/preprocess/hwvideoframe/include/normalize.h @@ -0,0 +1,39 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_NORMALIZE_H_ +#define CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_NORMALIZE_H_ + +#include +#include +#include + +#include "core/preprocess/hwvideoframe/include/op_context.h" + +// utilize normalize operator on gpu +class Normalize { + public: + Normalize(const std::vector &mean, + const std::vector &std, + bool channel_first = false); + std::shared_ptr operator()(std::shared_ptr input); + + private: + Npp32f _mean[CHANNEL_SIZE]; + Npp32f _std[CHANNEL_SIZE]; + bool _channel_first; // indicate whether the channel is dimension 0, + // unsupported +}; + +#endif // CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_NORMALIZE_H_ diff --git a/core/preprocess/hwvideoframe/include/op_context.h b/core/preprocess/hwvideoframe/include/op_context.h new file mode 100644 index 0000000000000000000000000000000000000000..45fd3351fafe6682bb78c018b5e6dff18118cfb3 --- /dev/null +++ b/core/preprocess/hwvideoframe/include/op_context.h @@ -0,0 +1,67 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_OP_CONTEXT_H_ +#define CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_OP_CONTEXT_H_ + +#include + +const size_t CHANNEL_SIZE = 3; + +// The context as input/ouput of all operators +// contains pointer to raw data on gpu, frame size +class OpContext { + public: + OpContext() { + _step = 0; + _size = 0; + _p_frame = nullptr; + } + // constructor to apply gpu memory of image raw data + OpContext(int height, int width) { + _step = sizeof(Npp32f) * width * CHANNEL_SIZE; + _length = height * width * CHANNEL_SIZE; + _size = _step * height; + _nppi_size.height = height; + _nppi_size.width = width; + cudaMalloc(reinterpret_cast(&_p_frame), _size); + } + virtual ~OpContext() { free_memory(); } + + public: + Npp32f* p_frame() const { return _p_frame; } + int step() const { return _step; } + int length() const { return _length; } + int size() const { return _size; } + NppiSize& nppi_size() { return _nppi_size; } + void free_memory() { + if (_p_frame != nullptr) { + cudaFree(_p_frame); + _p_frame = nullptr; + } + _nppi_size.height = 0; + _nppi_size.width = 0; + _step = 0; + _size = 0; + } + + private: + Npp32f* _p_frame; // pointer to raw data on gpu + int _step; // number of bytes in a row + int _length; // length of _p_frame, _size = _length * sizeof(Npp32f) + int _size; // number of bytes of the image + NppiSize _nppi_size; // contains height and width +}; + +#endif // CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_OP_CONTEXT_H_ diff --git a/core/preprocess/hwvideoframe/include/resize.h b/core/preprocess/hwvideoframe/include/resize.h new file mode 100644 index 0000000000000000000000000000000000000000..808f812983fa4d539282d495fb8bc02c6f8ec817 --- /dev/null +++ b/core/preprocess/hwvideoframe/include/resize.h @@ -0,0 +1,69 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_RESIZE_H_ +#define CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_RESIZE_H_ + +#include +#include +#include + +#include "core/preprocess/hwvideoframe/include/op_context.h" + +extern "C" cudaError_t resize_linear(const float* input, + float* output, + const int inputWidth, + const int inputHeight, + const int outputWidth, + const int outputHeight, + const int inputChannels, + const bool use_fixed_point); + +// Resize the input numpy array Image to the given size. +// only support linear interpolation +// only support RGB channels +class Resize { + public: + // size is an int, smaller edge of the image will be matched to this number. + Resize(int size, + int max_size = 214748364, + bool use_fixed_point = false, + int interpolation = 0) + : _size(size), + _max_size(max_size), + _use_fixed_point(use_fixed_point), + _interpolation(interpolation) {} + // size is a sequence like (w, h), output size will be matched to this + Resize(std::vector size, + int max_size = 214748364, + bool use_fixed_point = false, + int interpolation = 0) + : _size(-1), + _max_size(max_size), + _use_fixed_point(use_fixed_point), + _interpolation(interpolation) { + _target_size[0] = size[0]; + _target_size[1] = size[1]; + } + std::shared_ptr operator()(std::shared_ptr input); + + private: + int _size; // target of smaller edge + int _target_size[2]; // target size sequence (w, h) + int _max_size; + bool _use_fixed_point; + int _interpolation; // unused +}; + +#endif // CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_RESIZE_H_ diff --git a/core/preprocess/hwvideoframe/include/resize_by_factor.h b/core/preprocess/hwvideoframe/include/resize_by_factor.h new file mode 100644 index 0000000000000000000000000000000000000000..1247edddf929d0fff7fa07125b18b0960aa6aecf --- /dev/null +++ b/core/preprocess/hwvideoframe/include/resize_by_factor.h @@ -0,0 +1,59 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_RESIZE_BY_FACTOR_H_ +#define CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_RESIZE_BY_FACTOR_H_ + +#include + +#include +#include + +#include "core/preprocess/hwvideoframe/include/op_context.h" + +extern "C" cudaError_t resize_linear(const float *input, + float *output, + const int inputWidth, + const int inputHeight, + const int outputWidth, + const int outputHeight, + const int inputChannels, + const bool use_fixed_point); + +// Resize the input numpy array Image to a size multiple of factor which is +// usually required by a network +// only support linear interpolation +// only support RGB channels +class ResizeByFactor { + public: + // Resize factor. make width and height multiple factor of the value of + // factor. Default is 32 + ResizeByFactor(int factor = 32, + int max_side_len = 2400, + bool use_fixed_point = false, + int interpolation = 0) + : _factor(factor), + _max_side_len(max_side_len), + _use_fixed_point(use_fixed_point), + _interpolation(interpolation) {} + std::shared_ptr operator()(std::shared_ptr input); + + private: + int _factor; // target of smaller edge + int _max_side_len; + bool _use_fixed_point; + int _interpolation; // unused +}; + +#endif // CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_RESIZE_BY_FACTOR_H_ diff --git a/core/preprocess/hwvideoframe/include/rgb_swap.h b/core/preprocess/hwvideoframe/include/rgb_swap.h new file mode 100644 index 0000000000000000000000000000000000000000..2da58a6ae28ff38bd06593fdb6ba1df262796bc0 --- /dev/null +++ b/core/preprocess/hwvideoframe/include/rgb_swap.h @@ -0,0 +1,37 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_RGB_SWAP_H_ +#define CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_RGB_SWAP_H_ + +#include +#include +#include "core/preprocess/hwvideoframe/include/op_context.h" + +// swap channel 0 and channel 2 for every pixel +// both RGB2BGR and BGR2RGB use this operator +class SwapChannel { + public: + SwapChannel() {} + std::shared_ptr operator()(std::shared_ptr input); + + private: + static const int + _ORDER[CHANNEL_SIZE]; // describing how channel values are permutated +}; + +class RGB2BGR : public SwapChannel {}; +class BGR2RGB : public SwapChannel {}; + +#endif // CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_RGB_SWAP_H_ diff --git a/core/preprocess/hwvideoframe/include/sub.h b/core/preprocess/hwvideoframe/include/sub.h new file mode 100644 index 0000000000000000000000000000000000000000..4b6be2de5a24c9166155680985a4c70ad2d733ff --- /dev/null +++ b/core/preprocess/hwvideoframe/include/sub.h @@ -0,0 +1,34 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_SUB_H_ +#define CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_SUB_H_ + +#include +#include +#include +#include "core/preprocess/hwvideoframe/include/op_context.h" + +// subtract by some float numbers +class Sub { + public: + explicit Sub(float subtractor); + explicit Sub(const std::vector &subtractors); + std::shared_ptr operator()(std::shared_ptr input); + + private: + Npp32f _subtractors[CHANNEL_SIZE]; +}; + +#endif // CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_SUB_H_ diff --git a/core/preprocess/hwvideoframe/include/utils.h b/core/preprocess/hwvideoframe/include/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..811317933bc6da3dfce0f554d9051541aa35def2 --- /dev/null +++ b/core/preprocess/hwvideoframe/include/utils.h @@ -0,0 +1,30 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_UTILS_H_ +#define CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_UTILS_H_ + +#include + +#include + +// verify return value of npp function +// throw an exception if failed +void verify_npp_ret(const std::string& function_name, NppStatus ret); + +// verify return value of cuda runtime function +// throw an exception if failed +void verify_cuda_ret(const std::string& function_name, cudaError_t ret); + +#endif // CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_UTILS_H_ diff --git a/core/preprocess/hwvideoframe/pybind/pybind_gpu_preprocess.cpp b/core/preprocess/hwvideoframe/pybind/pybind_gpu_preprocess.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ab3723e0afd2f67e131de5accdcceb80f0712815 --- /dev/null +++ b/core/preprocess/hwvideoframe/pybind/pybind_gpu_preprocess.cpp @@ -0,0 +1,75 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "core/preprocess/hwvideoframe/include/center_crop.h" +#include "core/preprocess/hwvideoframe/include/div.h" +#include "core/preprocess/hwvideoframe/include/image_io.h" +#include "core/preprocess/hwvideoframe/include/normalize.h" +#include "core/preprocess/hwvideoframe/include/resize.h" +#include "core/preprocess/hwvideoframe/include/resize_by_factor.h" +#include "core/preprocess/hwvideoframe/include/rgb_swap.h" +#include "core/preprocess/hwvideoframe/include/sub.h" + +PYBIND11_MODULE(libgpupreprocess, m) { + pybind11::class_>(m, "OpContext"); + pybind11::class_(m, "Image2Gpubuffer") + .def(pybind11::init<>()) + .def("__call__", &Image2Gpubuffer::operator()); + pybind11::class_(m, "Gpubuffer2Image") + .def(pybind11::init<>()) + .def("__call__", &Gpubuffer2Image::operator()); + pybind11::class_(m, "RGB2BGR") + .def(pybind11::init<>()) + .def("__call__", &RGB2BGR::operator()); + pybind11::class_(m, "BGR2RGB") + .def(pybind11::init<>()) + .def("__call__", &BGR2RGB::operator()); + pybind11::class_
(m, "Div") + .def(pybind11::init()) + .def("__call__", &Div::operator()); + pybind11::class_(m, "Sub") + .def(pybind11::init()) + .def(pybind11::init&>()) + .def("__call__", &Sub::operator()); + pybind11::class_(m, "Normalize") + .def(pybind11::init&, + const std::vector&, + bool>(), + pybind11::arg("mean"), + pybind11::arg("std"), + pybind11::arg("channel_first") = false) + .def("__call__", &Normalize::operator()); + pybind11::class_(m, "CenterCrop") + .def(pybind11::init()) + .def("__call__", &CenterCrop::operator()); + pybind11::class_(m, "Resize") + .def(pybind11::init(), + pybind11::arg("size"), + pybind11::arg("max_size") = 214748364, + pybind11::arg("use_fixed_point") = false) + .def(pybind11::init&, int, bool>(), + pybind11::arg("target_size"), + pybind11::arg("max_size") = 214748364, + pybind11::arg("use_fixed_point") = false) + .def("__call__", &Resize::operator()); + pybind11::class_(m, "ResizeByFactor") + .def(pybind11::init(), + pybind11::arg("factor") = 32, + pybind11::arg("max_side_len") = 2400, + pybind11::arg("use_fixed_point") = false) + .def("__call__", &ResizeByFactor::operator()); +} diff --git a/core/preprocess/hwvideoframe/src/center_crop.cpp b/core/preprocess/hwvideoframe/src/center_crop.cpp new file mode 100644 index 0000000000000000000000000000000000000000..872433c5ac6bccc4b46d9f21c5f2f2b5007b57f6 --- /dev/null +++ b/core/preprocess/hwvideoframe/src/center_crop.cpp @@ -0,0 +1,39 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include + +#include "core/preprocess/hwvideoframe/include/center_crop.h" +#include "core/preprocess/hwvideoframe/include/utils.h" + +std::shared_ptr CenterCrop::operator()( + std::shared_ptr input) { + int new_width = std::min(_size, input->nppi_size().width); + int new_height = std::min(_size, input->nppi_size().height); + auto output = std::make_shared(new_height, new_width); + int x_start = (input->nppi_size().width - new_width) / 2; + int y_start = (input->nppi_size().height - new_height) / 2; + Npp32f* p_src = input->p_frame() + + y_start * input->nppi_size().width * CHANNEL_SIZE + + x_start * CHANNEL_SIZE; + NppStatus ret = nppiCopy_32f_C3R(p_src, + input->step(), + output->p_frame(), + output->step(), + output->nppi_size()); + verify_npp_ret("nppiCopy_32f_C3R", ret); + return output; +} diff --git a/core/preprocess/hwvideoframe/src/div.cpp b/core/preprocess/hwvideoframe/src/div.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0711cb2f4babc16416f8c9f49426efda0fdcacae --- /dev/null +++ b/core/preprocess/hwvideoframe/src/div.cpp @@ -0,0 +1,27 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "core/preprocess/hwvideoframe/include/div.h" +#include "core/preprocess/hwvideoframe/include/utils.h" + +Div::Div(float value) { _divisor = value; } + +std::shared_ptr Div::operator()(std::shared_ptr input) { + NppStatus ret = nppsDivC_32f_I(_divisor, input->p_frame(), input->length()); + verify_npp_ret("nppsDivC_32f_I", ret); + return input; +} diff --git a/core/preprocess/hwvideoframe/src/image_io.cpp b/core/preprocess/hwvideoframe/src/image_io.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0b02d5aa8761ed58e335fbff48c7f6d4f8dcd8e5 --- /dev/null +++ b/core/preprocess/hwvideoframe/src/image_io.cpp @@ -0,0 +1,56 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include + +#include "core/preprocess/hwvideoframe/include/image_io.h" +#include "core/preprocess/hwvideoframe/include/utils.h" + +std::shared_ptr Image2Gpubuffer::operator()( + pybind11::array_t input) { + pybind11::buffer_info buf = input.request(); + if (buf.format != pybind11::format_descriptor::format()) { + throw std::runtime_error("Incompatible format: expected a float numpy!"); + } + if (buf.ndim != 3) { + throw std::runtime_error("Number of dimensions must be three"); + } + if (buf.shape[2] != CHANNEL_SIZE) { + throw std::runtime_error("Number of channels must be three"); + } + auto result = std::make_shared(buf.shape[0], buf.shape[1]); + auto ret = cudaMemcpy(result->p_frame(), + static_cast(buf.ptr), + result->size(), + cudaMemcpyHostToDevice); + verify_cuda_ret("cudaMemcpy", ret); + return result; +} + +pybind11::array_t Gpubuffer2Image::operator()( + std::shared_ptr input) { + auto result = pybind11::array_t({input->nppi_size().height, + input->nppi_size().width, + static_cast(CHANNEL_SIZE)}); + pybind11::buffer_info buf = result.request(); + auto ret = cudaMemcpy(static_cast(buf.ptr), + input->p_frame(), + input->size(), + cudaMemcpyDeviceToHost); + verify_cuda_ret("cudaMemcpy", ret); + return result; +} diff --git a/core/preprocess/hwvideoframe/src/normalize.cpp b/core/preprocess/hwvideoframe/src/normalize.cpp new file mode 100644 index 0000000000000000000000000000000000000000..23042fb6bb731f0874133da970aac01f341bffc7 --- /dev/null +++ b/core/preprocess/hwvideoframe/src/normalize.cpp @@ -0,0 +1,46 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "core/preprocess/hwvideoframe/include/normalize.h" +#include "core/preprocess/hwvideoframe/include/utils.h" + +Normalize::Normalize(const std::vector &mean, + const std::vector &std, + bool channel_first) { + if (mean.size() != CHANNEL_SIZE) { + throw std::runtime_error("size of mean must be three"); + } + if (std.size() != CHANNEL_SIZE) { + throw std::runtime_error("size of std must be three"); + } + for (size_t i = 0; i < CHANNEL_SIZE; i++) { + _mean[i] = mean[i]; + _std[i] = std[i]; + } + _channel_first = channel_first; +} + +std::shared_ptr Normalize::operator()( + std::shared_ptr input) { + NppStatus ret = nppiSubC_32f_C3IR( + _mean, input->p_frame(), input->step(), input->nppi_size()); + verify_npp_ret("nppiSubC_32f_C3IR", ret); + ret = nppiDivC_32f_C3IR( + _std, input->p_frame(), input->step(), input->nppi_size()); + verify_npp_ret("nppiDivC_32f_C3IR", ret); + return input; +} diff --git a/core/preprocess/hwvideoframe/src/resize.cpp b/core/preprocess/hwvideoframe/src/resize.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d75cf66c4cf50d244f2bda237ff942fcda122dd0 --- /dev/null +++ b/core/preprocess/hwvideoframe/src/resize.cpp @@ -0,0 +1,52 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "core/preprocess/hwvideoframe/include/resize.h" + +#include +#include + +#include "core/preprocess/hwvideoframe/include/utils.h" + +std::shared_ptr Resize::operator()( + std::shared_ptr input) { + int resized_width = 0, resized_height = 0; + if (_size == -1) { + resized_width = std::min(_target_size[0], _max_size); + resized_height = std::min(_target_size[1], _max_size); + } else { + int im_max_size = + std::max(input->nppi_size().height, input->nppi_size().width); + float percent = + static_cast(_size) / + std::min(input->nppi_size().height, input->nppi_size().width); + if (round(percent * im_max_size) > _max_size) { + percent = static_cast(_max_size) / static_cast(im_max_size); + } + resized_width = static_cast(round(input->nppi_size().width * percent)); + resized_height = + static_cast(round(input->nppi_size().height * percent)); + } + auto output = std::make_shared(resized_height, resized_width); + auto ret = resize_linear(input->p_frame(), + output->p_frame(), + input->nppi_size().width, + input->nppi_size().height, + output->nppi_size().width, + output->nppi_size().height, + CHANNEL_SIZE, + _use_fixed_point); + verify_cuda_ret("resize_linear", ret); + return output; +} diff --git a/core/preprocess/hwvideoframe/src/resize_by_factor.cpp b/core/preprocess/hwvideoframe/src/resize_by_factor.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cd59e0ac2ccb0a1ce6ddf53a188082be6e44db08 --- /dev/null +++ b/core/preprocess/hwvideoframe/src/resize_by_factor.cpp @@ -0,0 +1,67 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "core/preprocess/hwvideoframe/include/resize.h" +#include "core/preprocess/hwvideoframe/include/resize_by_factor.h" +#include "core/preprocess/hwvideoframe/include/utils.h" + +std::shared_ptr ResizeByFactor::operator()( + std::shared_ptr input) { + int resized_width = input->nppi_size().width, + resized_height = input->nppi_size().height; + float radio = 0; + if (std::max(resized_width, resized_height) > _max_side_len) { + if (resized_width > resized_height) { + radio = static_cast(_max_side_len / resized_width); + } else { + radio = static_cast(_max_side_len / resized_height); + } + } else { + radio = 1; + } + resized_width = static_cast(resized_width * radio); + resized_height = static_cast(resized_height * radio); + if (resized_height % _factor == 0) { + resized_height = resized_height; + } else if (floor(resized_height / _factor) <= 1) { + resized_height = _factor; + } else { + resized_height = (floor(resized_height / 32) - 1) * 32; + } + if (resized_width % _factor == 0) { + resized_width = resized_width; + } else if (floor(resized_width / _factor) <= 1) { + resized_width = _factor; + } else { + resized_width = (floor(resized_width / 32) - 1) * _factor; + } + if (static_cast(resized_width) <= 0 || + static_cast(resized_height) <= 0) { + return NULL; + } + auto output = std::make_shared(resized_height, resized_width); + auto ret = resize_linear(input->p_frame(), + output->p_frame(), + input->nppi_size().width, + input->nppi_size().height, + output->nppi_size().width, + output->nppi_size().height, + CHANNEL_SIZE, + _use_fixed_point); + verify_cuda_ret("resize_linear", ret); + return output; +} diff --git a/core/preprocess/hwvideoframe/src/rgb_swap.cpp b/core/preprocess/hwvideoframe/src/rgb_swap.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a123a2323c5621c2875eec29254629f9fb12e35b --- /dev/null +++ b/core/preprocess/hwvideoframe/src/rgb_swap.cpp @@ -0,0 +1,28 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "core/preprocess/hwvideoframe/include/rgb_swap.h" +#include "core/preprocess/hwvideoframe/include/utils.h" + +const int SwapChannel::_ORDER[CHANNEL_SIZE] = {2, 1, 0}; + +std::shared_ptr SwapChannel::operator()( + std::shared_ptr input) { + NppStatus ret = nppiSwapChannels_32f_C3IR( + input->p_frame(), input->step(), input->nppi_size(), _ORDER); + verify_npp_ret("nppiSwapChannels_32f_C3IR", ret); + return input; +} diff --git a/core/preprocess/hwvideoframe/src/sub.cpp b/core/preprocess/hwvideoframe/src/sub.cpp new file mode 100644 index 0000000000000000000000000000000000000000..529d13d46727ace1219b24dea6e8ebfe026ca846 --- /dev/null +++ b/core/preprocess/hwvideoframe/src/sub.cpp @@ -0,0 +1,41 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "core/preprocess/hwvideoframe/include/sub.h" +#include "core/preprocess/hwvideoframe/include/utils.h" + +Sub::Sub(float subtractor) { + for (size_t i = 0; i < CHANNEL_SIZE; i++) { + _subtractors[i] = subtractor; + } +} + +Sub::Sub(const std::vector &subtractors) { + if (subtractors.size() != CHANNEL_SIZE) { + throw std::runtime_error("size of subtractors must be three"); + } + for (size_t i = 0; i < CHANNEL_SIZE; i++) { + _subtractors[i] = subtractors[i]; + } +} + +std::shared_ptr Sub::operator()(std::shared_ptr input) { + NppStatus ret = nppiSubC_32f_C3IR( + _subtractors, input->p_frame(), input->step(), input->nppi_size()); + verify_npp_ret("nppiSubC_32f_C3IR", ret); + return input; +} diff --git a/core/preprocess/hwvideoframe/src/utlis.cpp b/core/preprocess/hwvideoframe/src/utlis.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5ec5d18e769f85c7f2a7be5c6d91daa8b8e3f1e0 --- /dev/null +++ b/core/preprocess/hwvideoframe/src/utlis.cpp @@ -0,0 +1,35 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "core/preprocess/hwvideoframe/include/utils.h" + +void verify_npp_ret(const std::string& function_name, NppStatus ret) { + if (ret != NPP_SUCCESS) { + std::ostringstream ss; + ss << function_name << ", ret: " << ret; + throw std::runtime_error(ss.str()); + } +} + +void verify_cuda_ret(const std::string& function_name, cudaError_t ret) { + if (ret != cudaSuccess) { + std::ostringstream ss; + ss << function_name << ", ret: " << ret; + throw std::runtime_error(ss.str()); + } +} diff --git a/doc/COMPILE.md b/doc/COMPILE.md index ec1900a89b9b2f0112f0d44adc539b138438bba7..60c7203de19663eaa43a2d5d29f48d90fe27f969 100644 --- a/doc/COMPILE.md +++ b/doc/COMPILE.md @@ -77,7 +77,7 @@ export PYTHON_EXECUTABLE=$PYTHONROOT/bin/python3.8 ## Install Python dependencies ```shell -pip install -r python/requirements.txt +pip install -r python/requirements.txt -i https://mirror.baidu.com/pypi/simple ``` If you use other Python version, please use the right `pip` accordingly. @@ -123,14 +123,13 @@ Compared with CPU environment, GPU environment needs to refer to the following t **It should be noted that the following table is used as a reference for non-Docker compilation environment. The Docker compilation environment has been configured with relevant parameters and does not need to be specified in cmake process. ** | cmake environment variable | meaning | GPU environment considerations | whether Docker environment is needed | -|-----------------------|------------------------- ------------|-------------------------------|----- ---------------| -| CUDA_TOOLKIT_ROOT_DIR | cuda installation path, usually /usr/local/cuda | Required for all environments | No -(/usr/local/cuda) | +|-----------------------|-------------------------------------|-------------------------------|--------------------| +| CUDA_TOOLKIT_ROOT_DIR | cuda installation path, usually /usr/local/cuda | Required for all environments | No (/usr/local/cuda) | | CUDNN_LIBRARY | The directory where libcudnn.so.* is located, usually /usr/local/cuda/lib64/ | Required for all environments | No (/usr/local/cuda/lib64/) | | CUDA_CUDART_LIBRARY | The directory where libcudart.so.* is located, usually /usr/local/cuda/lib64/ | Required for all environments | No (/usr/local/cuda/lib64/) | | TENSORRT_ROOT | The upper level directory of the directory where libnvinfer.so.* is located, depends on the TensorRT installation directory | Cuda 9.0/10.0 does not need, other needs | No (/usr) | -If not in Docker environment, users can refer to the following execution methods. The specific path is subject to the current environment, and the code is only for reference. +If not in Docker environment, users can refer to the following execution methods. The specific path is subject to the current environment, and the code is only for reference.TENSORRT_LIBRARY_PATH is related to the TensorRT version and should be set according to the actual situation。For example, in the cuda10.1 environment, the TensorRT version is 6.0 (/usr/local/TensorRT-6.0.1.5/targets/x86_64-linux-gnu/),In the cuda10.2 environment, the TensorRT version is 7.1 (/usr/local/TensorRT-7.1.3.4/targets/x86_64-linux-gnu/). ``` shell export CUDA_PATH='/usr/local/cuda' @@ -145,7 +144,7 @@ cmake -DPYTHON_INCLUDE_DIR=$PYTHON_INCLUDE_DIR \ -DCUDA_TOOLKIT_ROOT_DIR=${CUDA_PATH} \ -DCUDNN_LIBRARY=${CUDNN_LIBRARY} \ -DCUDA_CUDART_LIBRARY=${CUDA_CUDART_LIBRARY} \ - -DTENSORRT_ROOT=${TENSORRT_LIBRARY_PATH} + -DTENSORRT_ROOT=${TENSORRT_LIBRARY_PATH} \ -DSERVER=ON \ -DWITH_GPU=ON .. make -j10 diff --git a/doc/COMPILE_CN.md b/doc/COMPILE_CN.md index 740a33028c2c1fff7364e3d771360d4a579e3ae8..ec8a482e7cecb9a5d412e7759d688aa875b9aeef 100644 --- a/doc/COMPILE_CN.md +++ b/doc/COMPILE_CN.md @@ -76,7 +76,7 @@ export PYTHON_EXECUTABLE=$PYTHONROOT/bin/python3.8 ## 安装Python依赖 ```shell -pip install -r python/requirements.txt +pip install -r python/requirements.txt -i https://mirror.baidu.com/pypi/simple ``` 如果使用其他Python版本,请使用对应版本的`pip`。 @@ -128,7 +128,7 @@ make -j10 | CUDA_CUDART_LIBRARY | libcudart.so.*所在目录,通常为/usr/local/cuda/lib64/ | 全部环境都需要 | 否(/usr/local/cuda/lib64/) | | TENSORRT_ROOT | libnvinfer.so.*所在目录的上一级目录,取决于TensorRT安装目录 | Cuda 9.0/10.0不需要,其他需要 | 否(/usr) | -非Docker环境下,用户可以参考如下执行方式,具体的路径以当时环境为准,代码仅作为参考。 +非Docker环境下,用户可以参考如下执行方式,具体的路径以当时环境为准,代码仅作为参考。TENSORRT_LIBRARY_PATH和TensorRT版本有关,要根据实际情况设置。例如在cuda10.1环境下TensorRT版本是6.0(/usr/local/TensorRT-6.0.1.5/targets/x86_64-linux-gnu/),在cuda10.2环境下TensorRT版本是7.1(/usr/local/TensorRT-7.1.3.4/targets/x86_64-linux-gnu/)。 ``` shell export CUDA_PATH='/usr/local/cuda' @@ -143,7 +143,7 @@ cmake -DPYTHON_INCLUDE_DIR=$PYTHON_INCLUDE_DIR \ -DCUDA_TOOLKIT_ROOT_DIR=${CUDA_PATH} \ -DCUDNN_LIBRARY=${CUDNN_LIBRARY} \ -DCUDA_CUDART_LIBRARY=${CUDA_CUDART_LIBRARY} \ - -DTENSORRT_ROOT=${TENSORRT_LIBRARY_PATH} + -DTENSORRT_ROOT=${TENSORRT_LIBRARY_PATH} \ -DSERVER=ON \ -DWITH_GPU=ON .. make -j10 @@ -159,7 +159,7 @@ make -j10 mkdir client-build && cd client-build cmake -DPYTHON_INCLUDE_DIR=$PYTHON_INCLUDE_DIR \ -DPYTHON_LIBRARIES=$PYTHON_LIBRARIES \ - -DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \ + -DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \ -DCLIENT=ON .. make -j10 ``` diff --git a/python/examples/bert/README.md b/python/examples/bert/README.md index 1313ad4151ac9cfae2a9ad516a3ca34bebbad3bd..1fde6d46625af8513ee244ab9c0865cccfe05a20 100644 --- a/python/examples/bert/README.md +++ b/python/examples/bert/README.md @@ -33,6 +33,7 @@ tar -xzf bert_chinese_L-12_H-768_A-12.tar.gz mv bert_chinese_L-12_H-768_A-12_model bert_seq128_model mv bert_chinese_L-12_H-768_A-12_client bert_seq128_client ``` +if your model is bert_chinese_L-12_H-768_A-12_model, replace the 'bert_seq128_model' field in the following command with 'bert_chinese_L-12_H-768_A-12_model',replace 'bert_seq128_client' with 'bert_chinese_L-12_H-768_A-12_client'. ### Getting Dict and Sample Dataset diff --git a/python/examples/bert/README_CN.md b/python/examples/bert/README_CN.md index 4fa42c78a913742873556290fbd79ef6b97a7d01..060c5579af6d2772ed666fda6f023245bf881213 100644 --- a/python/examples/bert/README_CN.md +++ b/python/examples/bert/README_CN.md @@ -28,6 +28,9 @@ tar -xzf bert_chinese_L-12_H-768_A-12.tar.gz mv bert_chinese_L-12_H-768_A-12_model bert_seq128_model mv bert_chinese_L-12_H-768_A-12_client bert_seq128_client ``` +若使用bert_chinese_L-12_H-768_A-12_model模型,将下面命令中的bert_seq128_model字段替换为bert_chinese_L-12_H-768_A-12_model,bert_seq128_client字段替换为bert_chinese_L-12_H-768_A-12_client. + + ### 获取词典和样例数据 diff --git a/python/examples/detection/README.md b/python/examples/detection/README.md index 83f6157c0d29fc3b1c672c06473487e9e4efe3f0..c42b3863c312e68b444385d59d905b850a23f697 100644 --- a/python/examples/detection/README.md +++ b/python/examples/detection/README.md @@ -18,3 +18,6 @@ All examples support TensorRT. - [PPYOLO](./ppyolo_r50vd_dcn_1x_coco) - [TTFNet](./ttfnet_darknet53_1x_coco) - [YOLOv3](./yolov3_darknet53_270e_coco) +- [HRNet](./faster_rcnn_hrnetv2p_w18_1x) +- [Fcos](./fcos_dcn_r50_fpn_1x_coco) +- [SSD](./ssd_vgg16_300_240e_voc/) diff --git a/python/examples/detection/README_CN.md b/python/examples/detection/README_CN.md index a04fe40ab3e116375a04f84132dd0a0a35eb2ef7..3a50afd81a87a59440d561d6849a6bd493f4012c 100644 --- a/python/examples/detection/README_CN.md +++ b/python/examples/detection/README_CN.md @@ -19,4 +19,6 @@ Paddle Detection提供了大量的[模型库](https://github.com/PaddlePaddle/Pa - [PPYOLO](./ppyolo_r50vd_dcn_1x_coco) - [TTFNet](./ttfnet_darknet53_1x_coco) - [YOLOv3](./yolov3_darknet53_270e_coco) - +- [HRNet](./faster_rcnn_hrnetv2p_w18_1x) +- [Fcos](./fcos_dcn_r50_fpn_1x_coco) +- [SSD](./ssd_vgg16_300_240e_voc/) diff --git a/python/examples/detection/faster_rcnn_hrnetv2p_w18_1x/000000570688.jpg b/python/examples/detection/faster_rcnn_hrnetv2p_w18_1x/000000570688.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cb304bd56c4010c08611a30dcca58ea9140cea54 Binary files /dev/null and b/python/examples/detection/faster_rcnn_hrnetv2p_w18_1x/000000570688.jpg differ diff --git a/python/examples/detection/faster_rcnn_hrnetv2p_w18_1x/README.md b/python/examples/detection/faster_rcnn_hrnetv2p_w18_1x/README.md new file mode 100644 index 0000000000000000000000000000000000000000..21ee05809042f5e6ee2e496306975f3ae18ed158 --- /dev/null +++ b/python/examples/detection/faster_rcnn_hrnetv2p_w18_1x/README.md @@ -0,0 +1,22 @@ +# Faster RCNN HRNet model on Paddle Serving + +([简体中文](./README_CN.md)|English) + +### Get The Faster RCNN HRNet Model +``` +wget --no-check-certificate https://paddle-serving.bj.bcebos.com/pddet_demo/2.0/faster_rcnn_hrnetv2p_w18_1x.tar +``` + +### Start the service +``` +tar xf faster_rcnn_hrnetv2p_w18_1x.tar +python -m paddle_serving_server_gpu.serve --model serving_server --port 9494 --gpu_ids 0 +``` + +This model support TensorRT, if you want a faster inference, please use `--use_trt`. + + +### Prediction +``` +python test_client.py 000000570688.jpg +``` diff --git a/python/examples/detection/faster_rcnn_hrnetv2p_w18_1x/README_CN.md b/python/examples/detection/faster_rcnn_hrnetv2p_w18_1x/README_CN.md new file mode 100644 index 0000000000000000000000000000000000000000..30c455500c0d2a9cc5a68976e261292fee200c75 --- /dev/null +++ b/python/examples/detection/faster_rcnn_hrnetv2p_w18_1x/README_CN.md @@ -0,0 +1,21 @@ +# 使用Paddle Serving部署Faster RCNN HRNet模型 + +(简体中文|[English](./README.md)) + +## 获得Faster RCNN HRNet模型 +``` +wget --no-check-certificate https://paddle-serving.bj.bcebos.com/pddet_demo/2.0/faster_rcnn_hrnetv2p_w18_1x.tar +``` + + +### 启动服务 +``` +tar xf faster_rcnn_hrnetv2p_w18_1x.tar +python -m paddle_serving_server_gpu.serve --model serving_server --port 9494 --gpu_ids 0 +``` +该模型支持TensorRT,如果想要更快的预测速度,可以开启`--use_trt`选项。 + +### 执行预测 +``` +python test_client.py 000000570688.jpg +``` diff --git a/python/examples/detection/faster_rcnn_hrnetv2p_w18_1x/label_list.txt b/python/examples/detection/faster_rcnn_hrnetv2p_w18_1x/label_list.txt new file mode 100644 index 0000000000000000000000000000000000000000..941cb4e1392266f6a6c09b1fdc5f79503b2e5df6 --- /dev/null +++ b/python/examples/detection/faster_rcnn_hrnetv2p_w18_1x/label_list.txt @@ -0,0 +1,80 @@ +person +bicycle +car +motorcycle +airplane +bus +train +truck +boat +traffic light +fire hydrant +stop sign +parking meter +bench +bird +cat +dog +horse +sheep +cow +elephant +bear +zebra +giraffe +backpack +umbrella +handbag +tie +suitcase +frisbee +skis +snowboard +sports ball +kite +baseball bat +baseball glove +skateboard +surfboard +tennis racket +bottle +wine glass +cup +fork +knife +spoon +bowl +banana +apple +sandwich +orange +broccoli +carrot +hot dog +pizza +donut +cake +chair +couch +potted plant +bed +dining table +toilet +tv +laptop +mouse +remote +keyboard +cell phone +microwave +oven +toaster +sink +refrigerator +book +clock +vase +scissors +teddy bear +hair drier +toothbrush diff --git a/python/examples/detection/faster_rcnn_hrnetv2p_w18_1x/test_client.py b/python/examples/detection/faster_rcnn_hrnetv2p_w18_1x/test_client.py new file mode 100644 index 0000000000000000000000000000000000000000..da21478a5de0feef816a4dedb9e3aab7cd011719 --- /dev/null +++ b/python/examples/detection/faster_rcnn_hrnetv2p_w18_1x/test_client.py @@ -0,0 +1,27 @@ +from paddle_serving_client import Client +from paddle_serving_app.reader import * +import sys +import numpy as np + +preprocess = Sequential([ + File2Image(), BGR2RGB(), Div(255.0), + Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], False), + Resize(640, 640), Transpose((2, 0, 1)) +]) + +postprocess = RCNNPostprocess("label_list.txt", "output") +client = Client() + +client.load_client_config("serving_client/serving_client_conf.prototxt") +client.connect(['127.0.0.1:9494']) + +im = preprocess(sys.argv[1]) +fetch_map = client.predict( + feed={ + "image": im, + "im_info": np.array(list(im.shape[1:]) + [1.0]), + "im_shape": np.array(list(im.shape[1:]) + [1.0]) + }, + fetch=["multiclass_nms_0.tmp_0"], + batch=False) +print(fetch_map) diff --git a/python/examples/detection/fcos_dcn_r50_fpn_1x_coco/000000570688.jpg b/python/examples/detection/fcos_dcn_r50_fpn_1x_coco/000000570688.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cb304bd56c4010c08611a30dcca58ea9140cea54 Binary files /dev/null and b/python/examples/detection/fcos_dcn_r50_fpn_1x_coco/000000570688.jpg differ diff --git a/python/examples/detection/fcos_dcn_r50_fpn_1x_coco/README.md b/python/examples/detection/fcos_dcn_r50_fpn_1x_coco/README.md new file mode 100644 index 0000000000000000000000000000000000000000..251732c76e1996493eb7d785c721cd478f3b060b --- /dev/null +++ b/python/examples/detection/fcos_dcn_r50_fpn_1x_coco/README.md @@ -0,0 +1,21 @@ +# FCOS model on Paddle Serving + +([简体中文](./README_CN.md)|English) + +### Get Model +``` +wget --no-check-certificate https://paddle-serving.bj.bcebos.com/pddet_demo/2.0/fcos_dcn_r50_fpn_1x_coco.tar +``` + +### Start the service +``` +tar xf fcos_dcn_r50_fpn_1x_coco.tar +python -m paddle_serving_server_gpu.serve --model serving_server --port 9494 --gpu_ids 0 +``` +This model support TensorRT, if you want a faster inference, please use `--use_trt`. + +### Perform prediction +``` +python test_client.py 000000570688.jpg +``` + diff --git a/python/examples/detection/fcos_dcn_r50_fpn_1x_coco/README_CN.md b/python/examples/detection/fcos_dcn_r50_fpn_1x_coco/README_CN.md new file mode 100644 index 0000000000000000000000000000000000000000..d4a65e072a069efcfe93053b92bdb764f5cbcc32 --- /dev/null +++ b/python/examples/detection/fcos_dcn_r50_fpn_1x_coco/README_CN.md @@ -0,0 +1,23 @@ +# 使用Paddle Serving部署FCOS模型 + +(简体中文|[English](./README.md)) + +## 获得模型 +``` +wget --no-check-certificate https://paddle-serving.bj.bcebos.com/pddet_demo/2.0/fcos_dcn_r50_fpn_1x_coco.tar +``` + + +### 启动服务 +``` +tar xf fcos_dcn_r50_fpn_1x_coco.tar +python -m paddle_serving_server_gpu.serve --model serving_server --port 9494 --gpu_ids 0 +``` + +该模型支持TensorRT,如果想要更快的预测速度,可以开启`--use_trt`选项。 + +### 执行预测 +``` +python test_client.py 000000570688.jpg +``` + diff --git a/python/examples/detection/fcos_dcn_r50_fpn_1x_coco/label_list.txt b/python/examples/detection/fcos_dcn_r50_fpn_1x_coco/label_list.txt new file mode 100644 index 0000000000000000000000000000000000000000..941cb4e1392266f6a6c09b1fdc5f79503b2e5df6 --- /dev/null +++ b/python/examples/detection/fcos_dcn_r50_fpn_1x_coco/label_list.txt @@ -0,0 +1,80 @@ +person +bicycle +car +motorcycle +airplane +bus +train +truck +boat +traffic light +fire hydrant +stop sign +parking meter +bench +bird +cat +dog +horse +sheep +cow +elephant +bear +zebra +giraffe +backpack +umbrella +handbag +tie +suitcase +frisbee +skis +snowboard +sports ball +kite +baseball bat +baseball glove +skateboard +surfboard +tennis racket +bottle +wine glass +cup +fork +knife +spoon +bowl +banana +apple +sandwich +orange +broccoli +carrot +hot dog +pizza +donut +cake +chair +couch +potted plant +bed +dining table +toilet +tv +laptop +mouse +remote +keyboard +cell phone +microwave +oven +toaster +sink +refrigerator +book +clock +vase +scissors +teddy bear +hair drier +toothbrush diff --git a/python/examples/detection/fcos_dcn_r50_fpn_1x_coco/test_client.py b/python/examples/detection/fcos_dcn_r50_fpn_1x_coco/test_client.py new file mode 100755 index 0000000000000000000000000000000000000000..bf5504105a61df7912e6c34037287610a1939479 --- /dev/null +++ b/python/examples/detection/fcos_dcn_r50_fpn_1x_coco/test_client.py @@ -0,0 +1,40 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle_serving_client import Client +from paddle_serving_app.reader import * +import sys +import numpy as np + +preprocess = Sequential([ + File2Image(), BGR2RGB(), Div(255.0), + Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], False), + Resize(640, 640), Transpose((2, 0, 1)) +]) + +postprocess = RCNNPostprocess("label_list.txt", "output") +client = Client() + +client.load_client_config("serving_client/serving_client_conf.prototxt") +client.connect(['127.0.0.1:9494']) + +im = preprocess(sys.argv[1]) +fetch_map = client.predict( + feed={ + "image": im, + "scale_factor": np.array([1.0, 1.0]).reshape(-1), + }, + fetch=["save_infer_model/scale_0.tmp_1"], + batch=False) +print(fetch_map) diff --git a/python/examples/detection/ssd_vgg16_300_240e_voc/000000570688.jpg b/python/examples/detection/ssd_vgg16_300_240e_voc/000000570688.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cb304bd56c4010c08611a30dcca58ea9140cea54 Binary files /dev/null and b/python/examples/detection/ssd_vgg16_300_240e_voc/000000570688.jpg differ diff --git a/python/examples/detection/ssd_vgg16_300_240e_voc/README.md b/python/examples/detection/ssd_vgg16_300_240e_voc/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f449dc45da135b353fdb60a591a306a6ef3d40c3 --- /dev/null +++ b/python/examples/detection/ssd_vgg16_300_240e_voc/README.md @@ -0,0 +1,21 @@ +# SSD model on Paddle Serving + +([简体中文](./README_CN.md)|English) + +### Get Model +``` +wget --no-check-certificate https://paddle-serving.bj.bcebos.com/pddet_demo/2.0/ssd_vgg16_300_240e_voc.tar +``` + +### Start the service +``` +tar xf ssd_vgg16_300_240e_voc.tar +python -m paddle_serving_server_gpu.serve --model serving_server --port 9494 --gpu_ids 0 +``` +This model support TensorRT, if you want a faster inference, please use `--use_trt`. + +### Perform prediction +``` +python test_client.py 000000570688.jpg +``` + diff --git a/python/examples/detection/ssd_vgg16_300_240e_voc/README_CN.md b/python/examples/detection/ssd_vgg16_300_240e_voc/README_CN.md new file mode 100644 index 0000000000000000000000000000000000000000..332937cacc2f3bdf948a670de91dd20276473abc --- /dev/null +++ b/python/examples/detection/ssd_vgg16_300_240e_voc/README_CN.md @@ -0,0 +1,23 @@ +# 使用Paddle Serving部署SSD模型 + +(简体中文|[English](./README.md)) + +## 获得模型 +``` +wget --no-check-certificate https://paddle-serving.bj.bcebos.com/pddet_demo/2.0/ssd_vgg16_300_240e_voc.tar +``` + + +### 启动服务 +``` +tar xf ssd_vgg16_300_240e_voc.tar +python -m paddle_serving_server_gpu.serve --model serving_server --port 9494 --gpu_ids 0 +``` + +该模型支持TensorRT,如果想要更快的预测速度,可以开启`--use_trt`选项。 + +### 执行预测 +``` +python test_client.py 000000570688.jpg +``` + diff --git a/python/examples/detection/ssd_vgg16_300_240e_voc/label_list.txt b/python/examples/detection/ssd_vgg16_300_240e_voc/label_list.txt new file mode 100644 index 0000000000000000000000000000000000000000..941cb4e1392266f6a6c09b1fdc5f79503b2e5df6 --- /dev/null +++ b/python/examples/detection/ssd_vgg16_300_240e_voc/label_list.txt @@ -0,0 +1,80 @@ +person +bicycle +car +motorcycle +airplane +bus +train +truck +boat +traffic light +fire hydrant +stop sign +parking meter +bench +bird +cat +dog +horse +sheep +cow +elephant +bear +zebra +giraffe +backpack +umbrella +handbag +tie +suitcase +frisbee +skis +snowboard +sports ball +kite +baseball bat +baseball glove +skateboard +surfboard +tennis racket +bottle +wine glass +cup +fork +knife +spoon +bowl +banana +apple +sandwich +orange +broccoli +carrot +hot dog +pizza +donut +cake +chair +couch +potted plant +bed +dining table +toilet +tv +laptop +mouse +remote +keyboard +cell phone +microwave +oven +toaster +sink +refrigerator +book +clock +vase +scissors +teddy bear +hair drier +toothbrush diff --git a/python/examples/detection/ssd_vgg16_300_240e_voc/test_client.py b/python/examples/detection/ssd_vgg16_300_240e_voc/test_client.py new file mode 100755 index 0000000000000000000000000000000000000000..59024d010a27c1569b5a07afd4508ad19894d89e --- /dev/null +++ b/python/examples/detection/ssd_vgg16_300_240e_voc/test_client.py @@ -0,0 +1,40 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle_serving_client import Client +from paddle_serving_app.reader import * +import sys +import numpy as np + +preprocess = Sequential([ + File2Image(), BGR2RGB(), + Normalize([123.675, 116.28, 103.53], [58.395, 57.12, 57.375], False), + Resize((512, 512)), Transpose((2, 0, 1)) +]) + +postprocess = RCNNPostprocess("label_list.txt", "output") +client = Client() + +client.load_client_config("serving_client/serving_client_conf.prototxt") +client.connect(['127.0.0.1:9494']) + +im = preprocess(sys.argv[1]) +fetch_map = client.predict( + feed={ + "image": im, + "scale_factor": np.array([1.0, 1.0]).reshape(-1), + }, + fetch=["save_infer_model/scale_0.tmp_1"], + batch=False) +print(fetch_map) diff --git a/python/paddle_serving_app/local_predict.py b/python/paddle_serving_app/local_predict.py index 1c49f01f22cbc23cfecb70fb36d3a72ff0991e5f..3aa74bb804a89904735ab6b085db0b49453e79fe 100644 --- a/python/paddle_serving_app/local_predict.py +++ b/python/paddle_serving_app/local_predict.py @@ -82,7 +82,10 @@ class LocalPredictor(object): f = open(client_config, 'r') model_conf = google.protobuf.text_format.Merge( str(f.read()), model_conf) - config = AnalysisConfig(model_path) + if os.path.exists(os.path.join(model_path, "__params__")): + config = AnalysisConfig(os.path.join(model_path, "__model__"), os.path.join(model_path, "__params__")) + else: + config = AnalysisConfig(model_path) logger.info("load_model_config params: model_path:{}, use_gpu:{},\ gpu_id:{}, use_profile:{}, thread_num:{}, mem_optim:{}, ir_optim:{},\ use_trt:{}, use_lite:{}, use_xpu: {}, use_feed_fetch_ops:{}".format( diff --git a/python/paddle_serving_app/reader/test_preprocess.py b/python/paddle_serving_app/reader/test_preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..47c244fa7103753341b21462b3a498153526e5e7 --- /dev/null +++ b/python/paddle_serving_app/reader/test_preprocess.py @@ -0,0 +1,154 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# pylint: disable=doc-string-missing + +import unittest +import sys +import numpy as np +import cv2 +from paddle_serving_app.reader import Sequential, Resize, File2Image +import libgpupreprocess as pp +import libhwextract + + +class TestOperators(unittest.TestCase): + """ + test all operators, e.g. Div, Normalize + """ + + def test_div(self): + height = 4 + width = 5 + channels = 3 + value = 255.0 + img = np.arange(height * width * channels).reshape( + [height, width, channels]) + seq = Sequential( + [pp.Image2Gpubuffer(), pp.Div(value), pp.Gpubuffer2Image()]) + result = seq(img).reshape(-1) + for i in range(0, result.size): + self.assertAlmostEqual(i / value, result[i], 5) + + def test_sub(self): + height = 4 + width = 5 + channels = 3 + img = np.arange(height * width * channels).reshape( + [height, width, channels]) + + # input size is an int + value = 10.0 + seq = Sequential( + [pp.Image2Gpubuffer(), pp.Sub(value), pp.Gpubuffer2Image()]) + result = seq(img).reshape(-1) + for i in range(0, result.size): + self.assertEqual(i - value, result[i]) + + # input size is a sequence + values = (9, 4, 2) + seq = Sequential( + [pp.Image2Gpubuffer(), pp.Sub(values), pp.Gpubuffer2Image()]) + result = seq(img) + for i in range(0, result.shape[0]): + for j in range(0, result.shape[1]): + for k in range(0, result.shape[2]): + self.assertEqual(result[i][j][k], img[i][j][k] - values[k]) + + def test_normalize(self): + height = 4 + width = 5 + channels = 3 + img = np.random.rand(height, width, channels) + mean = [5.0, 5.0, 5.0] + std = [2.0, 2.0, 2.0] + seq = Sequential([ + pp.Image2Gpubuffer(), pp.Normalize(mean, std), pp.Gpubuffer2Image() + ]) + result = seq(img) + for i in range(0, height): + for j in range(0, width): + for k in range(0, channels): + self.assertAlmostEqual((img[i][j][k] - mean[k]) / std[k], + result[i][j][k], 5) + + def test_center_crop(self): + height = 9 + width = 7 + channels = 3 + img = np.arange(height * width * channels).reshape( + [height, width, channels]) + new_size = 5 + seq = Sequential([ + pp.Image2Gpubuffer(), pp.CenterCrop(new_size), pp.Gpubuffer2Image() + ]) + result = seq(img) + self.assertEqual(result.shape[0], new_size) + self.assertEqual(result.shape[1], new_size) + self.assertEqual(result.shape[2], channels) + + def test_resize(self): + height = 9 + width = 5 + channels = 3 + img = np.arange(height * width).reshape([height, width, 1]) * np.ones( + (1, channels)) + + # input size is an int + for new_size in [3, 10]: + seq_gpu = Sequential([ + pp.Image2Gpubuffer(), pp.Resize(new_size), pp.Gpubuffer2Image() + ]) + seq_paddle = Sequential([Resize(new_size)]) + result_gpu = seq_gpu(img) + result_paddle = seq_paddle(img) + self.assertEqual(result_gpu.shape, result_paddle.shape) + for i in range(0, result_gpu.shape[0]): + for j in range(0, result_gpu.shape[1]): + for k in range(0, result_gpu.shape[2]): + self.assertAlmostEqual(result_gpu[i][j][k], + result_paddle[i][j][k], 5) + + # input size is a sequence + for new_height, new_width in [(7, 3), (15, 10)]: + seq_gpu = Sequential([ + pp.Image2Gpubuffer(), pp.Resize((new_width, new_height)), + pp.Gpubuffer2Image() + ]) + seq_paddle = Sequential([Resize((new_width, new_height))]) + result_gpu = seq_gpu(img) + result_paddle = seq_paddle(img) + self.assertEqual(result_gpu.shape, result_paddle.shape) + for i in range(0, result_gpu.shape[0]): + for j in range(0, result_gpu.shape[1]): + for k in range(0, result_gpu.shape[2]): + self.assertAlmostEqual(result_gpu[i][j][k], + result_paddle[i][j][k], 5) + + def test_resize_fixed_point(self): + new_height = 256 + new_width = 256 * 4 / 3 + seq = Sequential([ + File2Image(), pp.Image2Gpubuffer(), pp.Resize( + (new_width, new_height), use_fixed_point=True), + pp.Gpubuffer2Image() + ]) + img = seq("./capture_16.bmp") + img = np.resize(img, (new_height, new_width * 3)) + img_vis = np.loadtxt("./cap_resize_16.raw") + img_resize_diff = img_vis - img + self.assertEqual(np.all(img_resize_diff == 0), True) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle_serving_app/utils/__init__.py b/python/paddle_serving_app/utils/__init__.py index 847ddc47ac89114f2012bc6b9990a69abfe39fb3..d8e70508ecb76d149d145c1774d69c4b7e79338e 100644 --- a/python/paddle_serving_app/utils/__init__.py +++ b/python/paddle_serving_app/utils/__init__.py @@ -11,3 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# pylint: disable=doc-string-missing + +from .arr2image import Arr2Image diff --git a/python/paddle_serving_app/utils/arr2image.py b/python/paddle_serving_app/utils/arr2image.py new file mode 100644 index 0000000000000000000000000000000000000000..295ed5cc784b3c3403f3fb0edb029314d09659e0 --- /dev/null +++ b/python/paddle_serving_app/utils/arr2image.py @@ -0,0 +1,33 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# pylint: disable=doc-string-missing + +import cv2 +import yaml + + +class Arr2Image(object): + """ + from numpy array image(jpeg) to cv::Mat image + """ + + def __init__(self): + pass + + def __call__(self, img_arr): + img = cv2.imdecode(img_arr, cv2.IMREAD_COLOR) + return img + + def __repr__(self): + return self.__class__.__name__ + "()" diff --git a/python/setup.py.app.in b/python/setup.py.app.in index a9b58a11877a410815c2159ddbce6afbf311cc3b..c5ccc0eecf47db503b8efc847ab7db2cab2a122a 100644 --- a/python/setup.py.app.in +++ b/python/setup.py.app.in @@ -16,6 +16,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import platform import os from setuptools import setup, Distribution, Extension @@ -23,14 +24,22 @@ from setuptools import find_packages from setuptools import setup from paddle_serving_app.version import serving_app_version from pkg_resources import DistributionNotFound, get_distribution -import util -max_version, mid_version, min_version = util.python_version() +def python_version(): + return [int(v) for v in platform.python_version().split(".")] + +def find_package(pkgname): + try: + get_distribution(pkgname) + return True + except DistributionNotFound: + return False + +max_version, mid_version, min_version = python_version() if '${PACK}' == 'ON': copy_lib() - REQUIRED_PACKAGES = [ 'six >= 1.10.0', 'pillow', @@ -48,7 +57,11 @@ packages=['paddle_serving_app', 'paddle_serving_app.models', 'paddle_serving_app.reader.pddet'] -package_data={} +if os.path.exists('../core/preprocess/hwvideoframe/libgpupreprocess.so'): + package_data={'paddle_serving_app': ['../core/preprocess/hwvideoframe/libgpupreprocess.so'],} +else: + package_data={} + package_dir={'paddle_serving_app': '${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_app', 'paddle_serving_app.proto': @@ -60,7 +73,8 @@ package_dir={'paddle_serving_app': 'paddle_serving_app.models': '${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_app/models', 'paddle_serving_app.reader.pddet': - '${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_app/reader/pddet',} + '${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_app/reader/pddet', + } setup( name='paddle-serving-app', diff --git a/python/util.py b/python/util.py index 32dc2993077d1a73b880620549d924b54c1c3bf8..ef7cd6632855266096decc3139ca72a46ead6d2b 100644 --- a/python/util.py +++ b/python/util.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# pylint: disable=doc-string-missing from pkg_resources import DistributionNotFound, get_distribution from grpc_tools import protoc diff --git a/tools/scripts/ipipe_py2.sh b/tools/scripts/ipipe_py2.sh new file mode 100644 index 0000000000000000000000000000000000000000..ce3aa91c3b4c1f31841342e5463effc7460013ed --- /dev/null +++ b/tools/scripts/ipipe_py2.sh @@ -0,0 +1,684 @@ +#!/bin/bash +echo "################################################################" +echo "# #" +echo "# #" +echo "# #" +echo "# Paddle Serving begin run with python2.7.15!! #" +echo "# #" +echo "# #" +echo "# #" +echo "################################################################" + +export GOPATH=$HOME/go +export PATH=$PATH:$GOROOT/bin:$GOPATH/bin +export CUDA_INCLUDE_DIRS=/usr/local/cuda-10.2/include +export PYTHONROOT=/usr/local/python2.7.15/ + +go env -w GO111MODULE=on +go env -w GOPROXY=https://goproxy.cn,direct +go get -u github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway@v1.15.2 +go get -u github.com/grpc-ecosystem/grpc-gateway/protoc-gen-swagger@v1.15.2 +go get -u github.com/golang/protobuf/protoc-gen-go@v1.4.3 +go get -u google.golang.org/grpc@v1.33.0 + +build_path=/workspace/Serving +build_whl_list=(build_gpu_server build_client build_cpu_server build_app) +rpc_model_list=(grpc_impl pipeline_imagenet bert_rpc_gpu bert_rpc_cpu faster_rcnn_model_rpc ResNet50_rpc lac_rpc \ +cnn_rpc bow_rpc lstm_rpc fit_a_line_rpc cascade_rcnn_rpc deeplabv3_rpc mobilenet_rpc unet_rpc resnetv2_rpc \ +ocr_rpc criteo_ctr_rpc_cpu criteo_ctr_rpc_gpu yolov4_rpc_gpu) +http_model_list=(fit_a_line_http lac_http cnn_http bow_http lstm_http ResNet50_http bert_http) + +function setproxy(){ + export http_proxy=${proxy} + export https_proxy=${proxy} +} + +function unsetproxy(){ + unset http_proxy + unset https_proxy +} + +function kill_server_process(){ + kill `ps -ef|grep serving|awk '{print $2}'` +} + +function check() { + cd ${build_path} + if [ ! -f paddle_serving_app* ]; then + echo "paddle_serving_app is compiled failed, please check your pull request" + exit 1 + elif [ ! -f paddle_serving_server-* ]; then + echo "paddle_serving_server-cpu is compiled failed, please check your pull request" + exit 1 + elif [ ! -f paddle_serving_server_* ]; then + echo "paddle_serving_server_gpu is compiled failed, please check your pull request" + exit 1 + elif [ ! -f paddle_serving_client* ]; then + echo "paddle_serving_server_client is compiled failed, please check your pull request" + exit 1 + else + echo "paddle serving Build Passed" + fi +} + +function check_result() { + if [ $? -ne 0 ];then + echo -e "\033[4;31;42m$1 model runs failed, please check your pull request or modify test case! \033[0m" + exit 1 + else + echo -e "\033[4;37;42m$1 model runs successfully, congratulations! \033[0m" + fi +} + +function before_hook(){ + setproxy + cd ${build_path}/python + pip2.7 install --upgrade pip + pip2.7 install opencv-python==4.2.0.32 requests + pip2.7 install -r requirements.txt + echo "before hook configuration is successful.... " +} + +function run_env(){ + setproxy + pip2.7 install --upgrade nltk==3.4 + pip2.7 install --upgrade scipy==1.2.1 + pip2.7 install --upgrade setuptools + pip2.7 install paddlehub ujson paddlepaddle==2.0.0 + echo "run env configuration is successful.... " +} + +function run_gpu_env(){ + cd ${build_path} + export LD_LIBRARY_PATH=/usr/local/python2.7.15/lib/python2.7/site-packages/paddle/libs/:$LD_LIBRARY_PATH + export LD_LIBRARY_PATH=/workspace/Serving/build_gpu/third_party/install/Paddle/lib/:/workspace/Serving/build_gpu/third_party/Paddle/src/extern_paddle/third_party/install/mklml/lib/:/workspace/Serving/build_gpu/third_party/Paddle/src/extern_paddle/third_party/install/mkldnn/lib/:$LD_LIBRARY_PATH + export SERVING_BIN=${build_path}/build_gpu/core/general-server/serving + echo "run gpu env configuration is successful.... " +} + +function run_cpu_env(){ + cd ${build_path} + export LD_LIBRARY_PATH=/usr/local/python2.7.15/lib/python2.7/site-packages/paddle/libs/:$LD_LIBRARY_PATH + export LD_LIBRARY_PATH=/workspace/Serving/build_cpu/third_party/install/Paddle/lib/:$LD_LIBRARY_PATH + export SERVING_BIN=${build_path}/build_cpu/core/general-server/serving + echo "run cpu env configuration is successful.... " +} + +function build_gpu_server() { + setproxy + cd ${build_path} + git submodule update --init --recursive + if [ -d build_gpu ];then + rm -rf build_gpu + fi + if [ -d build ];then + rm -rf build + fi + mkdir build && cd build + cmake -DPYTHON_INCLUDE_DIR=$PYTHONROOT/include/python2.7/ \ + -DPYTHON_LIBRARIES=$PYTHONROOT/lib64/libpython2.7.so \ + -DPYTHON_EXECUTABLE=$PYTHONROOT/bin/python \ + -DSERVER=ON \ + -DTENSORRT_ROOT=/usr \ + -DWITH_GPU=ON .. + make -j18 + make -j18 + make install -j18 + pip2.7 uninstall paddle-serving-server-gpu -y + pip2.7 install ${build_path}/build/python/dist/* + cp ${build_path}/build/python/dist/* ../ + cp -r ${build_path}/build/ ${build_path}/build_gpu +} + +function build_client() { + setproxy + cd ${build_path} + if [ -d build ];then + rm -rf build + fi + mkdir build && cd build + cmake -DPYTHON_INCLUDE_DIR=$PYTHONROOT/include/python2.7/ \ + -DPYTHON_LIBRARIES=$PYTHONROOT/lib64/libpython2.7.so \ + -DPYTHON_EXECUTABLE=$PYTHONROOT/bin/python \ + -DCLIENT=ON .. + make -j18 + make -j18 + cp ${build_path}/build/python/dist/* ../ + pip2.7 uninstall paddle-serving-client -y + pip2.7 install ${build_path}/build/python/dist/* +} + +function build_cpu_server(){ + setproxy + cd ${build_path} + if [ -d build_cpu ];then + rm -rf build_cpu + fi + if [ -d build ];then + rm -rf build + fi + mkdir build && cd build + pwd + cmake -DPYTHON_INCLUDE_DIR=$PYTHONROOT/include/python2.7/ \ + -DPYTHON_LIBRARIES=$PYTHONROOT/lib64/libpython2.7.so \ + -DPYTHON_EXECUTABLE=$PYTHONROOT/bin/python \ + -DWITH_GPU=OFF \ + -DSERVER=ON .. + make -j18 + make -j18 + make install -j18 + cp ${build_path}/build/python/dist/* ../ + pip2.7 uninstall paddle-serving-server -y + pip2.7 install ${build_path}/build/python/dist/* + cp -r ${build_path}/build/ ${build_path}/build_cpu +} + +function build_app() { + setproxy + cd ${build_path} + if [ -d build ];then + rm -rf build + fi + mkdir build && cd build + cmake -DPYTHON_INCLUDE_DIR=$PYTHONROOT/include/python2.7/ \ + -DPYTHON_LIBRARIES=$PYTHONROOT/lib/libpython2.7.so \ + -DPYTHON_EXECUTABLE=$PYTHONROOT/bin/python \ + -DCMAKE_INSTALL_PREFIX=./output -DAPP=ON .. + make + cp ${build_path}/build/python/dist/* ../ + pip2.7 uninstall paddle-serving-app -y + pip2.7 install ${build_path}/build/python/dist/* +} + +function bert_rpc_gpu(){ + run_gpu_env + setproxy + cd ${build_path}/python/examples/bert + sh get_data.sh >/dev/null 2>&1 + sed -i 's/9292/8860/g' bert_client.py + sed -i '$aprint(result)' bert_client.py + cp -r /root/.cache/dist_data/serving/bert/bert_seq128_* ./ + ls -hlst + python2.7 -m paddle_serving_server_gpu.serve --model bert_seq128_model/ --port 8860 --gpu_ids 0 > bert_rpc_gpu 2>&1 & + sleep 15 + head data-c.txt | python2.7 bert_client.py --model bert_seq128_client/serving_client_conf.prototxt + cat bert_rpc_gpu + check_result $FUNCNAME + kill_server_process +} + +function bert_rpc_cpu(){ + run_cpu_env + setproxy + cd ${build_path}/python/examples/bert + sed -i 's/8860/8861/g' bert_client.py + python2.7 -m paddle_serving_server.serve --model bert_seq128_model/ --port 8861 > bert_rpc_cpu 2>&1 & + sleep 3 + cp data-c.txt.1 data-c.txt + head data-c.txt | python2.7 bert_client.py --model bert_seq128_client/serving_client_conf.prototxt + cat bert_rpc_cpu + check_result $FUNCNAME + kill_server_process +} + +function criteo_ctr_with_cube_rpc(){ + setproxy + run_cpu_env + cd ${build_path}/python/examples/criteo_ctr_with_cube + ln -s /root/.cache/dist_data/serving/criteo_ctr_with_cube/raw_data ./ + sed -i "s/9292/8888/g" test_server.py + sed -i "s/9292/8888/g" test_client.py + wget https://paddle-serving.bj.bcebos.com/unittest/ctr_cube_unittest.tar.gz >/dev/null 2>&1 + tar xf ctr_cube_unittest.tar.gz + mv models/ctr_client_conf ./ + mv models/ctr_serving_model_kv ./ + mv models/data ./cube/ + wget https://paddle-serving.bj.bcebos.com/others/cube_app.tar.gz >/dev/null 2>&1 + tar xf cube_app.tar.gz + mv cube_app/cube* ./cube/ + sh cube_prepare.sh > haha 2>&1 & + sleep 5 + python2.7 test_server.py ctr_serving_model_kv > criteo_ctr_rpc 2>&1 & + sleep 5 + python2.7 test_client.py ctr_client_conf/serving_client_conf.prototxt ./raw_data + cat criteo_ctr_rpc + check_result $FUNCNAME + kill `ps -ef|grep cube|awk '{print $2}'` + kill_server_process +} + +function pipeline_imagenet(){ + run_gpu_env + setproxy + cd ${build_path}/python/examples/pipeline/imagenet + cp -r /root/.cache/dist_data/serving/imagenet/* ./ + ls -a + python2.7 resnet50_web_service.py > pipelog 2>&1 & + sleep 5 + python2.7 pipeline_rpc_client.py + # check_result $FUNCNAME + kill_server_process +} + +function ResNet50_rpc(){ + run_gpu_env + setproxy + cd ${build_path}/python/examples/imagenet + cp -r /root/.cache/dist_data/serving/imagenet/* ./ + sed -i 's/9696/8863/g' resnet50_rpc_client.py + python2.7 -m paddle_serving_server_gpu.serve --model ResNet50_vd_model --port 8863 --gpu_ids 0 > ResNet50_rpc 2>&1 & + sleep 5 + python2.7 resnet50_rpc_client.py ResNet50_vd_client_config/serving_client_conf.prototxt + tail ResNet50_rpc + check_result $FUNCNAME + kill_server_process + sleep 5 +} + +function ResNet101_rpc(){ + run_gpu_env + setproxy + cd ${build_path}/python/examples/imagenet + sed -i 's/9292/8864/g' image_rpc_client.py + python2.7 -m paddle_serving_server_gpu.serve --model ResNet101_vd_model --port 8864 --gpu_ids 0 > ResNet101_rpc 2>&1 & + sleep 5 + python2.7 image_rpc_client.py ResNet101_vd_client_config/serving_client_conf.prototxt + tail ResNet101_rpc + kill_server_process + check_result $FUNCNAME +} + +function cnn_rpc(){ + setproxy + run_cpu_env + cd ${build_path}/python/examples/imdb + cp -r /root/.cache/dist_data/serving/imdb/* ./ + tar xf imdb_model.tar.gz && tar xf text_classification_data.tar.gz + sed -i 's/9292/8865/g' test_client.py + python2.7 -m paddle_serving_server.serve --model imdb_cnn_model/ --port 8865 > cnn_rpc 2>&1 & + sleep 5 + head test_data/part-0 | python2.7 test_client.py imdb_cnn_client_conf/serving_client_conf.prototxt imdb.vocab + tail cnn_rpc + check_result $FUNCNAME + kill_server_process +} + +function bow_rpc(){ + setproxy + run_cpu_env + cd ${build_path}/python/examples/imdb + sed -i 's/8865/8866/g' test_client.py + python2.7 -m paddle_serving_server.serve --model imdb_bow_model/ --port 8866 > bow_rpc 2>&1 & + sleep 5 + head test_data/part-0 | python2.7 test_client.py imdb_bow_client_conf/serving_client_conf.prototxt imdb.vocab + tail bow_rpc + check_result $FUNCNAME + kill_server_process + sleep 5 +} + +function lstm_rpc(){ + setproxy + run_cpu_env + cd ${build_path}/python/examples/imdb + sed -i 's/8866/8867/g' test_client.py + python2.7 -m paddle_serving_server.serve --model imdb_lstm_model/ --port 8867 > lstm_rpc 2>&1 & + sleep 5 + head test_data/part-0 | python2.7 test_client.py imdb_lstm_client_conf/serving_client_conf.prototxt imdb.vocab + tail lstm_rpc + check_result $FUNCNAME + kill_server_process + kill `ps -ef|grep imdb|awk '{print $2}'` +} + +function lac_rpc(){ + setproxy + run_cpu_env + cd ${build_path}/python/examples/lac + python2.7 -m paddle_serving_app.package --get_model lac >/dev/null 2>&1 + tar xf lac.tar.gz + sed -i 's/9292/8868/g' lac_client.py + python2.7 -m paddle_serving_server.serve --model lac_model/ --port 8868 > lac_rpc 2>&1 & + sleep 5 + echo "我爱北京天安门" | python2.7 lac_client.py lac_client/serving_client_conf.prototxt lac_dict/ + tail lac_rpc + check_result $FUNCNAME + kill_server_process +} + +function fit_a_line_rpc(){ + setproxy + run_cpu_env + cd ${build_path}/python/examples/fit_a_line + sh get_data.sh >/dev/null 2>&1 + sed -i 's/9393/8869/g' test_client.py + python2.7 -m paddle_serving_server.serve --model uci_housing_model --port 8869 > line_rpc 2>&1 & + sleep 5 + python2.7 test_client.py uci_housing_client/serving_client_conf.prototxt + tail line_rpc + check_result $FUNCNAME + kill_server_process +} + +function faster_rcnn_model_rpc(){ + run_gpu_env + setproxy + cd ${build_path}/python/examples/faster_rcnn_model + cp -r /root/.cache/dist_data/serving/faster_rcnn/faster_rcnn_model.tar.gz ./ + tar xf faster_rcnn_model.tar.gz + wget https://paddle-serving.bj.bcebos.com/pddet_demo/infer_cfg.yml >/dev/null 2>&1 + mv faster_rcnn_model/pddet* ./ + sed -i 's/9494/8870/g' test_client.py + python2.7 -m paddle_serving_server_gpu.serve --model pddet_serving_model --port 8870 --gpu_id 0 > faster_rcnn_rpc 2>&1 & + sleep 3 + python2.7 test_client.py pddet_client_conf/serving_client_conf.prototxt infer_cfg.yml 000000570688.jpg + tail faster_rcnn_rpc + check_result $FUNCNAME + kill_server_process +} + +function cascade_rcnn_rpc(){ + setproxy + run_gpu_env + cd ${build_path}/python/examples/cascade_rcnn + cp -r /root/.cache/dist_data/serving/cascade_rcnn/cascade_rcnn_r50_fpx_1x_serving.tar.gz ./ + tar xf cascade_rcnn_r50_fpx_1x_serving.tar.gz + sed -i "s/9292/8879/g" test_client.py + python2.7 -m paddle_serving_server_gpu.serve --model serving_server --port 8879 --gpu_id 0 > rcnn_rpc 2>&1 & + ls -hlst + sleep 5 + python2.7 test_client.py + tail rcnn_rpc + check_result $FUNCNAME + kill_server_process +} + +function deeplabv3_rpc() { + setproxy + run_gpu_env + cd ${build_path}/python/examples/deeplabv3 + cp -r /root/.cache/dist_data/serving/deeplabv3/deeplabv3.tar.gz ./ + tar xf deeplabv3.tar.gz + sed -i "s/9494/8880/g" deeplabv3_client.py + python2.7 -m paddle_serving_server_gpu.serve --model deeplabv3_server --gpu_ids 0 --port 8880 > deeplab_rpc 2>&1 & + sleep 5 + python2.7 deeplabv3_client.py + tail deeplab_rpc + check_result $FUNCNAME + kill_server_process +} + +function mobilenet_rpc() { + setproxy + run_gpu_env + cd ${build_path}/python/examples/mobilenet + python2.7 -m paddle_serving_app.package --get_model mobilenet_v2_imagenet >/dev/null 2>&1 + tar xf mobilenet_v2_imagenet.tar.gz + sed -i "s/9393/8881/g" mobilenet_tutorial.py + python2.7 -m paddle_serving_server_gpu.serve --model mobilenet_v2_imagenet_model --gpu_ids 0 --port 8881 > mobilenet_rpc 2>&1 & + sleep 5 + python2.7 mobilenet_tutorial.py + tail mobilenet_rpc + check_result $FUNCNAME + kill_server_process +} + +function unet_rpc() { + setproxy + run_gpu_env + cd ${build_path}/python/examples/unet_for_image_seg + python2.7 -m paddle_serving_app.package --get_model unet >/dev/null 2>&1 + tar xf unet.tar.gz + sed -i "s/9494/8882/g" seg_client.py + python2.7 -m paddle_serving_server_gpu.serve --model unet_model --gpu_ids 0 --port 8882 > unet_rpc 2>&1 & + sleep 5 + python2.7 seg_client.py + tail unet_rpc + check_result $FUNCNAME + kill_server_process +} + +function resnetv2_rpc() { + setproxy + run_gpu_env + cd ${build_path}/python/examples/resnet_v2_50 + cp /root/.cache/dist_data/serving/resnet_v2_50/resnet_v2_50_imagenet.tar.gz ./ + tar xf resnet_v2_50_imagenet.tar.gz + sed -i 's/9393/8883/g' resnet50_v2_tutorial.py + python2.7 -m paddle_serving_server_gpu.serve --model resnet_v2_50_imagenet_model --gpu_ids 0 --port 8883 > v2_log 2>&1 & + sleep 10 + python2.7 resnet50_v2_tutorial.py + tail v2_log + check_result $FUNCNAME + kill_server_process +} + +function ocr_rpc() { + setproxy + run_cpu_env + cd ${build_path}/python/examples/ocr + cp -r /root/.cache/dist_data/serving/ocr/test_imgs ./ + python2.7 -m paddle_serving_app.package --get_model ocr_rec >/dev/null 2>&1 + tar xf ocr_rec.tar.gz + sed -i 's/9292/8884/g' test_ocr_rec_client.py + python2.7 -m paddle_serving_server.serve --model ocr_rec_model --port 8884 > ocr_rpc 2>&1 & + sleep 5 + python2.7 test_ocr_rec_client.py + tail ocr_rpc + check_result $FUNCNAME + kill_server_process +} + +function criteo_ctr_rpc_cpu() { + setproxy + run_cpu_env + cd ${build_path}/python/examples/criteo_ctr + sed -i "s/9292/8885/g" test_client.py + ln -s /root/.cache/dist_data/serving/criteo_ctr_with_cube/raw_data ./ + wget https://paddle-serving.bj.bcebos.com/criteo_ctr_example/criteo_ctr_demo_model.tar.gz >/dev/null 2>&1 + tar xf criteo_ctr_demo_model.tar.gz + mv models/ctr_client_conf . + mv models/ctr_serving_model . + python2.7 -m paddle_serving_server.serve --model ctr_serving_model/ --port 8885 > criteo_ctr_cpu_rpc 2>&1 & + sleep 5 + python2.7 test_client.py ctr_client_conf/serving_client_conf.prototxt raw_data/part-0 + tail criteo_ctr_cpu_rpc + check_result $FUNCNAME + kill_server_process +} + +function criteo_ctr_rpc_gpu() { + setproxy + run_gpu_env + cd ${build_path}/python/examples/criteo_ctr + sed -i "s/8885/8886/g" test_client.py + python2.7 -m paddle_serving_server_gpu.serve --model ctr_serving_model/ --port 8886 --gpu_ids 0 > criteo_ctr_gpu_rpc 2>&1 & + sleep 5 + python2.7 test_client.py ctr_client_conf/serving_client_conf.prototxt raw_data/ + tail criteo_ctr_gpu_rpc + check_result $FUNCNAME + kill_server_process +} + +function yolov4_rpc_gpu() { + setproxy + run_gpu_env + cd ${build_path}/python/examples/yolov4 + sed -i "s/9393/8887/g" test_client.py + cp -r /root/.cache/dist_data/serving/yolov4/yolov4.tar.gz ./ + tar xf yolov4.tar.gz + python2.7 -m paddle_serving_server_gpu.serve --model yolov4_model --port 8887 --gpu_ids 0 > yolov4_rpc_log 2>&1 & + sleep 5 + python2.7 test_client.py 000000570688.jpg + tail yolov4_rpc_log +# check_result $FUNCNAME + kill_server_process +} + +function senta_rpc_cpu() { + setproxy + run_gpu_env + cd ${build_path}/python/examples/senta + sed -i "s/9393/8887/g" test_client.py + cp -r /data/.cache/dist_data/serving/yolov4/yolov4.tar.gz ./ + tar xf yolov4.tar.gz + python2.7 -m paddle_serving_server_gpu.serve --model yolov4_model --port 8887 --gpu_ids 0 > yolov4_rpc_log 2>&1 & + sleep 5 + python2.7 test_client.py 000000570688.jpg + tail yolov4_rpc_log + check_result $FUNCNAME + kill_server_process +} + +function fit_a_line_http() { + unsetproxy + run_cpu_env + cd ${build_path}/python/examples/fit_a_line + sed -i "s/9292/8871/g" test_server.py + python2.7 test_server.py > http_log2 2>&1 & + sleep 10 + curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"x": [0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283, 0.4919, 0.1856, 0.0795, -0.0332]}], "fetch":["price"]}' http://${host}:8871/uci/prediction + check_result $FUNCNAME + kill_server_process +} + +function lac_http() { + unsetproxy + run_cpu_env + cd ${build_path}/python/examples/lac + python2.7 lac_web_service.py lac_model/ lac_workdir 8872 > http_lac_log2 2>&1 & + sleep 10 + curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"words": "我爱北京天安门"}], "fetch":["word_seg"]}' http://${host}:8872/lac/prediction + check_result $FUNCNAME + kill_server_process +} + +function cnn_http() { + unsetproxy + run_cpu_env + cd ${build_path}/python/examples/imdb + python2.7 text_classify_service.py imdb_cnn_model/ workdir/ 8873 imdb.vocab > cnn_http 2>&1 & + sleep 10 + curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"words": "i am very sad | 0"}], "fetch":["prediction"]}' http://${host}:8873/imdb/prediction + check_result $FUNCNAME + kill_server_process +} + +function bow_http() { + unsetproxy + run_cpu_env + cd ${build_path}/python/examples/imdb + python2.7 text_classify_service.py imdb_bow_model/ workdir/ 8874 imdb.vocab > bow_http 2>&1 & + sleep 10 + curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"words": "i am very sad | 0"}], "fetch":["prediction"]}' http://${host}:8874/imdb/prediction + check_result $FUNCNAME + kill_server_process +} + +function lstm_http() { + unsetproxy + run_cpu_env + cd ${build_path}/python/examples/imdb + python2.7 text_classify_service.py imdb_bow_model/ workdir/ 8875 imdb.vocab > bow_http 2>&1 & + sleep 10 + curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"words": "i am very sad | 0"}], "fetch":["prediction"]}' http://${host}:8875/imdb/prediction + check_result $FUNCNAME + kill_server_process +} + +function ResNet50_http() { + unsetproxy + run_gpu_env + cd ${build_path}/python2.7/examples/imagenet + python2.7 resnet50_web_service.py ResNet50_vd_model gpu 8876 > resnet50_http 2>&1 & + sleep 10 + curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"image": "https://paddle-serving.bj.bcebos.com/imagenet-example/daisy.jpg"}], "fetch": ["score"]}' http://${host}:8876/image/prediction + check_result $FUNCNAME + kill_server_process +} + +bert_http(){ + run_gpu_env + unsetproxy + cd ${build_path}/python/examples/bert + cp data-c.txt.1 data-c.txt + cp vocab.txt.1 vocab.txt + export CUDA_VISIBLE_DEVICES=0 + python2.7 bert_web_service.py bert_seq128_model/ 8878 > bert_http 2>&1 & + sleep 10 + curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"words": "hello"}], "fetch":["pooled_output"]}' http://${host}:8878/bert/prediction + check_result $FUNCNAME + kill_server_process +} + +grpc_impl(){ + run_gpu_env + cd ${build_path}/python/examples/grpc_impl_example/fit_a_line + sh get_data.sh >/dev/null 2>&1 + python2.7 test_server.py uci_housing_model/ > grpclog 2>&1 & + sleep 5 + echo "sync predict" + python2.7 test_sync_client.py + echo "async predict" + python2.7 test_asyn_client.py + echo "batch predict" + python2.7 test_batch_client.py + echo "timeout predict" + python2.7 test_timeout_client.py + # check_result $FUNCNAME + kill_server_process +} + + +function build_all_whl(){ + for whl in ${build_whl_list[@]} + do + echo "===========${whl} begin build===========" + $whl + sleep 3 + echo "===========${whl} build over ===========" + done +} + +function run_rpc_models(){ + for model in ${rpc_model_list[@]} + do + echo "===========${model} run begin===========" + $model + sleep 3 + echo "===========${model} run end ===========" + done +} + +function run_http_models(){ + for model in ${http_model_list[@]} + do + echo "===========${model} run begin===========" + $model + sleep 3 + echo "===========${model} run end ===========" + done +} +function end_hook(){ + cd ${build_path} + kill_server_process + kill `ps -ef|grep python|awk '{print $2}'` + sleep 5 + echo "===========files===========" + ls -hlst + echo "=========== end ===========" +} + +function main() { + before_hook + build_all_whl + check + run_env + run_rpc_models +# run_http_models + end_hook + +} + + +main$@ diff --git a/tools/scripts/ipipe_py3.sh b/tools/scripts/ipipe_py3.sh new file mode 100644 index 0000000000000000000000000000000000000000..156b32d4acdc5fa63acd6e2c1d467dccb680d36b --- /dev/null +++ b/tools/scripts/ipipe_py3.sh @@ -0,0 +1,696 @@ +#!/bin/bash +echo "################################################################" +echo "# #" +echo "# #" +echo "# #" +echo "# Paddle Serving begin run with python3.6.8! #" +echo "# #" +echo "# #" +echo "# #" +echo "################################################################" + +export GOPATH=$HOME/go +export PATH=$PATH:$GOROOT/bin:$GOPATH/bin +export CUDA_INCLUDE_DIRS=/usr/local/cuda-10.2/include +export PYTHONROOT=/usr/local + +go env -w GO111MODULE=on +go env -w GOPROXY=https://goproxy.cn,direct +go get -u github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway@v1.15.2 +go get -u github.com/grpc-ecosystem/grpc-gateway/protoc-gen-swagger@v1.15.2 +go get -u github.com/golang/protobuf/protoc-gen-go@v1.4.3 +go get -u google.golang.org/grpc@v1.33.0 + +build_path=/workspace/Serving/ +build_whl_list=(build_gpu_server build_client build_cpu_server build_app) +rpc_model_list=(grpc_impl pipeline_imagenet bert_rpc_gpu bert_rpc_cpu ResNet50_rpc lac_rpc \ +cnn_rpc bow_rpc lstm_rpc fit_a_line_rpc deeplabv3_rpc mobilenet_rpc unet_rpc resnetv2_rpc \ +criteo_ctr_rpc_cpu criteo_ctr_rpc_gpu ocr_rpc yolov4_rpc_gpu) +http_model_list=(fit_a_line_http lac_http cnn_http bow_http lstm_http ResNet50_http bert_http) + + +function setproxy(){ + export http_proxy=${proxy} + export https_proxy=${proxy} +} + +function unsetproxy(){ + unset http_proxy + unset https_proxy +} + +function kill_server_process(){ + kill `ps -ef|grep $1 |awk '{print $2}'` + kill `ps -ef|grep serving |awk '{print $2}'` +} + +function check() { + cd ${build_path} + if [ ! -f paddle_serving_app* ]; then + echo "paddle_serving_app is compiled failed, please check your pull request" + exit 1 + elif [ ! -f paddle_serving_server-* ]; then + echo "paddle_serving_server-cpu is compiled failed, please check your pull request" + exit 1 + elif [ ! -f paddle_serving_server_* ]; then + echo "paddle_serving_server_gpu is compiled failed, please check your pull request" + exit 1 + elif [ ! -f paddle_serving_client* ]; then + echo "paddle_serving_server_client is compiled failed, please check your pull request" + exit 1 + else + echo "paddle serving build passed" + fi +} + +function check_result() { + if [ $? -ne 0 ];then + echo -e "\033[4;31;42m$1 model runs failed, please check your pull request or modify test case! \033[0m" + exit 1 + else + echo -e "\033[4;37;42m$1 model runs successfully, congratulations! \033[0m" + fi +} + +function before_hook(){ + setproxy + cd ${build_path}/python + pip3.6 install --upgrade pip + pip3.6 install requests + pip3.6 install -r requirements.txt + pip3.6 install numpy==1.16.4 + echo "before hook configuration is successful.... " +} + +function run_env(){ + setproxy + pip3.6 install --upgrade nltk==3.4 + pip3.6 install --upgrade scipy==1.2.1 + pip3.6 install --upgrade setuptools==41.0.0 + pip3.6 install paddlehub ujson paddlepaddle==2.0.0 + echo "run env configuration is successful.... " +} + +function run_gpu_env(){ + cd ${build_path} + export LD_LIBRARY_PATH=/usr/local/lib64/python3.6/site-packages/paddle/libs/:$LD_LIBRARY_PATH + export LD_LIBRARY_PATH=/workspace/Serving/build_gpu/third_party/install/Paddle/lib/:/workspace/Serving/build_gpu/third_party/Paddle/src/extern_paddle/third_party/install/mklml/lib/:/workspace/Serving/build_gpu/third_party/Paddle/src/extern_paddle/third_party/install/mkldnn/lib/:$LD_LIBRARY_PATH + export SERVING_BIN=${build_path}/build_gpu/core/general-server/serving + echo "run gpu env configuration is successful.... " +} + +function run_cpu_env(){ + cd ${build_path} + export LD_LIBRARY_PATH=/usr/local/lib64/python3.6/site-packages/paddle/libs/:$LD_LIBRARY_PATH + export LD_LIBRARY_PATH=/workspace/Serving/build_cpu/third_party/install/Paddle/lib/:$LD_LIBRARY_PATH + export SERVING_BIN=${build_path}/build_cpu/core/general-server/serving + echo "run cpu env configuration is successful.... " +} + +function build_gpu_server() { + setproxy + cd ${build_path} + git submodule update --init --recursive + if [ -d build_gpu ];then + rm -rf build_gpu + fi + if [ -d build ];then + rm -rf build + fi + mkdir build && cd build + cmake -DPYTHON_INCLUDE_DIR=$PYTHONROOT/include/python3.6m/ \ + -DPYTHON_LIBRARIES=$PYTHONROOT/lib64/libpython3.6.so \ + -DPYTHON_EXECUTABLE=$PYTHONROOT/bin/python3.6 \ + -DSERVER=ON \ + -DTENSORRT_ROOT=/usr \ + -DWITH_GPU=ON .. + make -j18 + make -j18 + make install -j18 + pip3.6 uninstall paddle-serving-server-gpu -y + pip3.6 install ${build_path}/build/python/dist/* + cp ${build_path}/build/python/dist/* ../ + cp -r ${build_path}/build/ ${build_path}/build_gpu +} + +function build_client() { + setproxy + cd ${build_path} + if [ -d build ];then + rm -rf build + fi + mkdir build && cd build + cmake -DPYTHON_INCLUDE_DIR=$PYTHONROOT/include/python3.6m/ \ + -DPYTHON_LIBRARIES=$PYTHONROOT/lib64/libpython3.6.so \ + -DPYTHON_EXECUTABLE=$PYTHONROOT/bin/python3.6 \ + -DCLIENT=ON .. + make -j18 + make -j18 + cp ${build_path}/build/python/dist/* ../ + pip3.6 uninstall paddle-serving-client -y + pip3.6 install ${build_path}/build/python/dist/* +} + +function build_cpu_server(){ + setproxy + cd ${build_path} + if [ -d build_cpu ];then + rm -rf build_cpu + fi + if [ -d build ];then + rm -rf build + fi + mkdir build && cd build + cmake -DPYTHON_INCLUDE_DIR=$PYTHONROOT/include/python3.6m/ \ + -DPYTHON_LIBRARIES=$PYTHONROOT/lib64/libpython3.6.so \ + -DPYTHON_EXECUTABLE=$PYTHONROOT/bin/python3.6 \ + -DWITH_GPU=OFF \ + -DSERVER=ON .. + make -j18 + make -j18 + make install -j18 + cp ${build_path}/build/python/dist/* ../ + pip3.6 uninstall paddle-serving-server -y + pip3.6 install ${build_path}/build/python/dist/* + cp -r ${build_path}/build/ ${build_path}/build_cpu +} + +function build_app() { + setproxy + pip3.6 install paddlehub ujson Pillow + pip3.6 install paddlepaddle==2.0.0 + cd ${build_path} + if [ -d build ];then + rm -rf build + fi + mkdir build && cd build + cmake -DPYTHON_INCLUDE_DIR=$PYTHONROOT/include/python3.6m/ \ + -DPYTHON_LIBRARIES=$PYTHONROOT/lib/libpython3.6.so \ + -DPYTHON_EXECUTABLE=$PYTHONROOT/bin/python3.6 \ + -DCMAKE_INSTALL_PREFIX=./output -DAPP=ON .. + make + cp ${build_path}/build/python/dist/* ../ + pip3.6 uninstall paddle-serving-app -y + pip3.6 install ${build_path}/build/python/dist/* +} + +function bert_rpc_gpu(){ + run_gpu_env + unsetproxy + cd ${build_path}/python/examples/bert + sh get_data.sh >/dev/null 2>&1 + sed -i 's/9292/8860/g' bert_client.py + sed -i '$aprint(result)' bert_client.py + cp -r /root/.cache/dist_data/serving/bert/bert_seq128_* ./ + ls -hlst + python3.6 -m paddle_serving_server_gpu.serve --model bert_seq128_model/ --port 8860 --gpu_ids 0 & + sleep 15 + nvidia-smi + head data-c.txt | python3.6 bert_client.py --model bert_seq128_client/serving_client_conf.prototxt + nvidia-smi + check_result $FUNCNAME + kill_server_process serving +} + +function bert_rpc_cpu(){ + run_cpu_env + unsetproxy + cd ${build_path}/python/examples/bert + sed -i 's/8860/8861/g' bert_client.py + python3.6 -m paddle_serving_server.serve --model bert_seq128_model/ --port 8861 & + sleep 3 + cp data-c.txt.1 data-c.txt + head data-c.txt | python3.6 bert_client.py --model bert_seq128_client/serving_client_conf.prototxt + check_result $FUNCNAME + kill_server_process serving +} + +function criteo_ctr_with_cube_rpc(){ + unsetproxy + run_cpu_env + cd ${build_path}/python/examples/criteo_ctr_with_cube + ln -s /root/.cache/dist_data/serving/criteo_ctr_with_cube/raw_data ./ + sed -i "s/9292/8888/g" test_server.py + sed -i "s/9292/8888/g" test_client.py + wget https://paddle-serving.bj.bcebos.com/unittest/ctr_cube_unittest.tar.gz >/dev/null 2>&1 + tar xf ctr_cube_unittest.tar.gz + mv models/ctr_client_conf ./ + mv models/ctr_serving_model_kv ./ + mv models/data ./cube/ + wget https://paddle-serving.bj.bcebos.com/others/cube_app.tar.gz >/dev/null 2>&1 + tar xf cube_app.tar.gz + mv cube_app/cube* ./cube/ + sh cube_prepare.sh > haha 2>&1 & + sleep 5 + python3.6 test_server.py ctr_serving_model_kv & + sleep 5 + python3.6 test_client.py ctr_client_conf/serving_client_conf.prototxt ./raw_data + check_result $FUNCNAME + kill `ps -ef|grep cube|awk '{print $2}'` + kill_server_process test_server +} + +function pipeline_imagenet(){ + run_gpu_env + unsetproxy + cd ${build_path}/python/examples/pipeline/imagenet + cp -r /root/.cache/dist_data/serving/imagenet/* ./ + ls -a + python3.6 resnet50_web_service.py & + sleep 5 + nvidia-smi + python3.6 pipeline_rpc_client.py + nvidia-smi + # check_result $FUNCNAME + kill_server_process resnet50_web_service +} + +function ResNet50_rpc(){ + run_gpu_env + unsetproxy + cd ${build_path}/python/examples/imagenet + cp -r /root/.cache/dist_data/serving/imagenet/* ./ + sed -i 's/9696/8863/g' resnet50_rpc_client.py + python3.6 -m paddle_serving_server_gpu.serve --model ResNet50_vd_model --port 8863 --gpu_ids 0 & + sleep 5 + nvidia-smi + python3.6 resnet50_rpc_client.py ResNet50_vd_client_config/serving_client_conf.prototxt + nvidia-smi + check_result $FUNCNAME + kill_server_process serving +} + +function ResNet101_rpc(){ + run_gpu_env + unsetproxy + cd ${build_path}/python/examples/imagenet + sed -i "22cclient.connect(['${host}:8864'])" image_rpc_client.py + python3.6 -m paddle_serving_server_gpu.serve --model ResNet101_vd_model --port 8864 --gpu_ids 0 & + sleep 5 + nvidia-smi + python3.6 image_rpc_client.py ResNet101_vd_client_config/serving_client_conf.prototxt + nvidia-smi + check_result $FUNCNAME + kill_server_process serving + sleep 5 +} + +function cnn_rpc(){ + unsetproxy + run_cpu_env + cd ${build_path}/python/examples/imdb + cp -r /root/.cache/dist_data/serving/imdb/* ./ + tar xf imdb_model.tar.gz && tar xf text_classification_data.tar.gz + sed -i 's/9292/8865/g' test_client.py + python3.6 -m paddle_serving_server.serve --model imdb_cnn_model/ --port 8865 & + sleep 5 + head test_data/part-0 | python3.6 test_client.py imdb_cnn_client_conf/serving_client_conf.prototxt imdb.vocab + check_result $FUNCNAME + kill_server_process serving +} + +function bow_rpc(){ + unsetproxy + run_cpu_env + cd ${build_path}/python/examples/imdb + sed -i 's/8865/8866/g' test_client.py + python3.6 -m paddle_serving_server.serve --model imdb_bow_model/ --port 8866 & + sleep 5 + head test_data/part-0 | python3.6 test_client.py imdb_bow_client_conf/serving_client_conf.prototxt imdb.vocab + check_result $FUNCNAME + kill_server_process serving +} + +function lstm_rpc(){ + unsetproxy + run_cpu_env + cd ${build_path}/python/examples/imdb + sed -i 's/8866/8867/g' test_client.py + python3.6 -m paddle_serving_server.serve --model imdb_lstm_model/ --port 8867 & + sleep 5 + head test_data/part-0 | python3.6 test_client.py imdb_lstm_client_conf/serving_client_conf.prototxt imdb.vocab + check_result $FUNCNAME + kill_server_process serving +} + +function lac_rpc(){ + unsetproxy + run_cpu_env + cd ${build_path}/python/examples/lac + python3.6 -m paddle_serving_app.package --get_model lac >/dev/null 2>&1 + tar xf lac.tar.gz + sed -i 's/9292/8868/g' lac_client.py + python3.6 -m paddle_serving_server.serve --model lac_model/ --port 8868 & + sleep 5 + echo "我爱北京天安门" | python3.6 lac_client.py lac_client/serving_client_conf.prototxt lac_dict/ + check_result $FUNCNAME + kill_server_process serving +} + +function fit_a_line_rpc(){ + unsetproxy + run_cpu_env + cd ${build_path}/python/examples/fit_a_line + sh get_data.sh >/dev/null 2>&1 + sed -i 's/9393/8869/g' test_client.py + python3.6 -m paddle_serving_server.serve --model uci_housing_model --port 8869 & + sleep 5 + python3.6 test_client.py uci_housing_client/serving_client_conf.prototxt + check_result $FUNCNAME + kill_server_process serving +} + +function faster_rcnn_model_rpc(){ + unsetproxy + run_gpu_env + cd ${build_path}/python/examples/faster_rcnn + cp -r /root/.cache/dist_data/serving/faster_rcnn/faster_rcnn_model.tar.gz ./ + tar xf faster_rcnn_model.tar.gz + wget https://paddle-serving.bj.bcebos.com/pddet_demo/infer_cfg.yml >/dev/null 2>&1 + mv faster_rcnn_model/pddet* ./ + sed -i 's/9494/8870/g' test_client.py + python3.6 -m paddle_serving_server_gpu.serve --model pddet_serving_model --port 8870 --gpu_id 0 --thread 2 & + echo "faster rcnn running ..." + nvidia-smi + sleep 5 + python3.6 test_client.py pddet_client_conf/serving_client_conf.prototxt infer_cfg.yml 000000570688.jpg + nvidia-smi + check_result $FUNCNAME + kill_server_process serving +} + +function cascade_rcnn_rpc(){ + unsetproxy + run_gpu_env + cd ${build_path}/python/examples/cascade_rcnn + cp -r /root/.cache/dist_data/serving/cascade_rcnn/cascade_rcnn_r50_fpx_1x_serving.tar.gz ./ + tar xf cascade_rcnn_r50_fpx_1x_serving.tar.gz + sed -i "s/9292/8879/g" test_client.py + python3.6 -m paddle_serving_server_gpu.serve --model serving_server --port 8879 --gpu_id 0 --thread 2 & + sleep 5 + nvidia-smi + python3.6 test_client.py + nvidia-smi + check_result $FUNCNAME + kill_server_process serving +} + +function deeplabv3_rpc() { + unsetproxy + run_gpu_env + cd ${build_path}/python/examples/deeplabv3 + cp -r /root/.cache/dist_data/serving/deeplabv3/deeplabv3.tar.gz ./ + tar xf deeplabv3.tar.gz + sed -i "s/9494/8880/g" deeplabv3_client.py + python3.6 -m paddle_serving_server_gpu.serve --model deeplabv3_server --gpu_ids 0 --port 8880 --thread 2 & + sleep 5 + nvidia-smi + python3.6 deeplabv3_client.py + nvidia-smi + check_result $FUNCNAME + kill_server_process serving +} + +function mobilenet_rpc() { + unsetproxy + run_gpu_env + cd ${build_path}/python/examples/mobilenet + python3.6 -m paddle_serving_app.package --get_model mobilenet_v2_imagenet >/dev/null 2>&1 + tar xf mobilenet_v2_imagenet.tar.gz + sed -i "s/9393/8881/g" mobilenet_tutorial.py + python3.6 -m paddle_serving_server_gpu.serve --model mobilenet_v2_imagenet_model --gpu_ids 0 --port 8881 & + sleep 5 + nvidia-smi + python3.6 mobilenet_tutorial.py + nvidia-smi + check_result $FUNCNAME + kill_server_process serving +} + +function unet_rpc() { + unsetproxy + run_gpu_env + cd ${build_path}/python/examples/unet_for_image_seg + python3.6 -m paddle_serving_app.package --get_model unet >/dev/null 2>&1 + tar xf unet.tar.gz + sed -i "s/9494/8882/g" seg_client.py + python3.6 -m paddle_serving_server_gpu.serve --model unet_model --gpu_ids 0 --port 8882 & + sleep 5 + nvidia-smi + python3.6 seg_client.py + nvidia-smi + check_result $FUNCNAME + kill_server_process serving +} + +function resnetv2_rpc() { + unsetproxy + run_gpu_env + cd ${build_path}/python/examples/resnet_v2_50 + cp /root/.cache/dist_data/serving/resnet_v2_50/resnet_v2_50_imagenet.tar.gz ./ + tar xf resnet_v2_50_imagenet.tar.gz + sed -i 's/9393/8883/g' resnet50_v2_tutorial.py + python3.6 -m paddle_serving_server_gpu.serve --model resnet_v2_50_imagenet_model --gpu_ids 0 --port 8883 & + sleep 10 + nvidia-smi + python3.6 resnet50_v2_tutorial.py + nvidia-smi + check_result $FUNCNAME + kill_server_process serving +} + +function ocr_rpc() { + unsetproxy + run_cpu_env + cd ${build_path}/python/examples/ocr + cp -r /root/.cache/dist_data/serving/ocr/test_imgs ./ + python3.6 -m paddle_serving_app.package --get_model ocr_rec >/dev/null 2>&1 + tar xf ocr_rec.tar.gz + sed -i 's/9292/8884/g' test_ocr_rec_client.py + python3.6 -m paddle_serving_server.serve --model ocr_rec_model --port 8884 & + sleep 5 + python3.6 test_ocr_rec_client.py + # check_result $FUNCNAME + kill_server_process serving +} + +function criteo_ctr_rpc_cpu() { + unsetproxy + run_cpu_env + cd ${build_path}/python/examples/criteo_ctr + sed -i "s/9292/8885/g" test_client.py + ln -s /root/.cache/dist_data/serving/criteo_ctr_with_cube/raw_data ./ + wget https://paddle-serving.bj.bcebos.com/criteo_ctr_example/criteo_ctr_demo_model.tar.gz >/dev/null 2>&1 + tar xf criteo_ctr_demo_model.tar.gz + mv models/ctr_client_conf . + mv models/ctr_serving_model . + python3.6 -m paddle_serving_server.serve --model ctr_serving_model/ --port 8885 & + sleep 5 + python3.6 test_client.py ctr_client_conf/serving_client_conf.prototxt raw_data/part-0 + check_result $FUNCNAME + kill_server_process serving +} + +function criteo_ctr_rpc_gpu() { + unsetproxy + run_gpu_env + cd ${build_path}/python/examples/criteo_ctr + sed -i "s/8885/8886/g" test_client.py + wget https://paddle-serving.bj.bcebos.com/criteo_ctr_example/criteo_ctr_demo_model.tar.gz >/dev/null 2>&1 + python3.6 -m paddle_serving_server_gpu.serve --model ctr_serving_model/ --port 8886 --gpu_ids 0 & + sleep 5 + nvidia-smi + python3.6 test_client.py ctr_client_conf/serving_client_conf.prototxt raw_data/ + nvidia-smi + check_result $FUNCNAME + kill `ps -ef|grep ctr|awk '{print $2}'` + kill_server_process serving +} + +function yolov4_rpc_gpu() { + unsetproxy + run_gpu_env + cd ${build_path}/python/examples/yolov4 + sed -i "s/9393/8887/g" test_client.py + cp -r /root/.cache/dist_data/serving/yolov4/yolov4.tar.gz ./ + tar xf yolov4.tar.gz + python3.6 -m paddle_serving_server_gpu.serve --model yolov4_model --port 8887 --gpu_ids 0 & + nvidia-smi + sleep 5 + python3.6 test_client.py 000000570688.jpg + nvidia-smi + # check_result $FUNCNAME + kill_server_process serving +} + +function senta_rpc_cpu() { + unsetproxy + run_gpu_env + cd ${build_path}/python/examples/senta + sed -i "s/9393/8887/g" test_client.py + cp -r /data/.cache/dist_data/serving/yolov4/yolov4.tar.gz ./ + tar xf yolov4.tar.gz + python3.6 -m paddle_serving_server_gpu.serve --model yolov4_model --port 8887 --gpu_ids 0 & + nvidia-smi + sleep 5 + python3.6 test_client.py 000000570688.jpg + nvidia-smi + check_result $FUNCNAME + kill_server_process serving +} + + +function fit_a_line_http() { + unsetproxy + run_cpu_env + cd ${build_path}/python/examples/fit_a_line + sed -i "s/9292/8871/g" test_server.py + python3.6 test_server.py & + sleep 10 + curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"x": [0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283, 0.4919, 0.1856, 0.0795, -0.0332]}], "fetch":["price"]}' http://${host}:8871/uci/prediction + check_result $FUNCNAME + kill_server_process test_server +} + +function lac_http() { + unsetproxy + run_cpu_env + cd ${build_path}/python/examples/lac + python3.6 lac_web_service.py lac_model/ lac_workdir 8872 & + sleep 10 + curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"words": "我爱北京天安门"}], "fetch":["word_seg"]}' http://${host}:8872/lac/prediction + check_result $FUNCNAME + kill_server_process lac_web_service +} + +function cnn_http() { + unsetproxy + run_cpu_env + cd ${build_path}/python/examples/imdb + python3.6 text_classify_service.py imdb_cnn_model/ workdir/ 8873 imdb.vocab & + sleep 10 + curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"words": "i am very sad | 0"}], "fetch":["prediction"]}' http://${host}:8873/imdb/prediction + check_result $FUNCNAME + kill_server_process text_classify_service +} + +function bow_http() { + unsetproxy + run_cpu_env + cd ${build_path}/python/examples/imdb + python3.6 text_classify_service.py imdb_bow_model/ workdir/ 8874 imdb.vocab & + sleep 10 + curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"words": "i am very sad | 0"}], "fetch":["prediction"]}' http://${host}:8874/imdb/prediction + check_result $FUNCNAME + kill_server_process text_classify_service +} + +function lstm_http() { + unsetproxy + run_cpu_env + cd ${build_path}/python/examples/imdb + python3.6 text_classify_service.py imdb_bow_model/ workdir/ 8875 imdb.vocab & + sleep 10 + curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"words": "i am very sad | 0"}], "fetch":["prediction"]}' http://${host}:8875/imdb/prediction + check_result $FUNCNAME + kill `ps -ef|grep imdb|awk '{print $2}'` + kill_server_process text_classify_service +} + +function ResNet50_http() { + unsetproxy + run_gpu_env + cd ${build_path}/python/examples/imagenet + python3.6 resnet50_web_service.py ResNet50_vd_model gpu 8876 & + sleep 10 + curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"image": "https://paddle-serving.bj.bcebos.com/imagenet-example/daisy.jpg"}], "fetch": ["score"]}' http://${host}:8876/image/prediction + check_result $FUNCNAME + kill_server_process resnet50_web_service +} + +function bert_http(){ + unsetproxy + run_gpu_env + cd ${build_path}/python/examples/bert + cp data-c.txt.1 data-c.txt + cp vocab.txt.1 vocab.txt + export CUDA_VISIBLE_DEVICES=0 + python3.6 bert_web_service.py bert_seq128_model/ 8878 & + sleep 5 + curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"words": "hello"}], "fetch":["pooled_output"]}' http://127.0.0.1:8878/bert/prediction + check_result $FUNCNAME + kill_server_process bert_web_service +} + +grpc_impl(){ + unsetproxy + run_gpu_env + cd ${build_path}/python/examples/grpc_impl_example/fit_a_line + sh get_data.sh >/dev/null 2>&1 + python3.6 test_server.py uci_housing_model/ & + sleep 5 + echo "sync predict" + python3.6 test_sync_client.py + echo "async predict" + python3.6 test_asyn_client.py + echo "batch predict" + python3.6 test_batch_client.py + echo "timeout predict" + python3.6 test_timeout_client.py +# check_result $FUNCNAME + kill_server_process test_server +} + +function build_all_whl(){ + for whl in ${build_whl_list[@]} + do + echo "===========${whl} begin build===========" + $whl + sleep 3 + echo "===========${whl} build over ===========" + done +} + +function run_rpc_models(){ + for model in ${rpc_model_list[@]} + do + echo "===========${model} run begin===========" + $model + sleep 3 + echo "===========${model} run end ===========" + done +} + +function run_http_models(){ + for model in ${http_model_list[@]} + do + echo "===========${model} run begin===========" + $model + sleep 3 + echo "===========${model} run end ===========" + done +} + +function end_hook(){ + cd ${build_path} + kill_server_process + kill `ps -ef|grep python|awk '{print $2}'` + sleep 5 + echo "===========files===========" + ls -hlst + echo "=========== end ===========" + +} + +function main() { + before_hook + build_all_whl + check + run_env + run_rpc_models +# run_http_models + end_hook +} + + +main$@ diff --git a/tools/serving_build.sh b/tools/serving_build.sh index 5d5abaf64c575d6d3728d1c1fdea281f94e3c2d6..fb0b75ae51a3716c55079d5aab4ddcc7b168680b 100644 --- a/tools/serving_build.sh +++ b/tools/serving_build.sh @@ -932,7 +932,6 @@ function python_app_api_test(){ cd imagenet case $TYPE in CPU) - check_cmd "python test_image_reader.py" ;; GPU) echo "no implement for cpu type"