未验证 提交 4fc7e121 编写于 作者: J Jiawei Wang 提交者: GitHub

Merge pull request #853 from BohaoWu/Zelda

[Zelda] GPU 预处理代码
......@@ -25,6 +25,7 @@ endif()
if (APP)
add_subdirectory(configure)
add_subdirectory(preprocess)
endif()
......
if(NOT WITH_GPU)
message("GPU preprocess will not be compiled.")
return()
endif()
message(STATUS "CUDA detected: " ${CUDA_VERSION})
if (${CUDA_VERSION} LESS 10.0)
message("CUDA version should be 10.0.")
return()
elseif (${CUDA_VERSION} LESS 10.1) # CUDA 10.0
add_subdirectory(hwvideoframe)
endif()
cmake_minimum_required(VERSION 3.2)
project(gpupreprocess)
include(cuda)
include(configure)
#C flags
set(CMAKE_C_FLAGS " -g -pipe -W -Wall -fPIC")
#C++ flags.
set(CMAKE_CXX_FLAGS " -g -pipe -W -Wall -fPIC -std=c++11")
add_subdirectory(cuda)
set(PYTHON_SO {PYTHON_LIBRARY})
set(EXTRA_LIBS ${EXTRA_LIBS} gpu)
file(GLOB SOURCE_FILES pybind/*.cpp src/*.cpp)
include_directories("./include")
include_directories(${CUDA_INCLUDE_DIRS})
include_directories(/usr${PYTHON_INCLUDE_DIR})
link_directories("-L${CUDA_TOOLKIT_ROOT_DIR}/lib64 -lcudart -lnppidei_static -lnppial_static -lnpps_static -lnppc_static -lculibos")
link_directories(${PYTHON_SO})
#.so
add_library(gpupreprocess SHARED ${SOURCE_FILES})
target_link_libraries(gpupreprocess ${EXTRA_LIBS})
target_link_libraries(gpupreprocess ${CUDA_LIBRARIES})
# hwvideoframe
Hwvideoframe is a CV preprocessing library based on cuda. The project uses GPU for image preprocessing operations. It speeds up the processing speed while increasing the utilization rate of the GPU.
## Preprocess API
Hwvideoframe provides a variety of data preprocessing methods for photo preprocess:
- class Image2Gpubuffer
- `__call__(img)`
- img(np.array):Image data.
- class Gpubuffer2Image
- `__call__(img)`
- img(np.array):Image data.
- class Div
- `__init__(value)`
- value(float):Constant value to be divided.
- `__call__(img)`
- img(np.array):Image data.
- class Sub
- `__init__(subtractor)`
- subtractor(list/float):Three 32-bit floating point channel image subtract constant. When the input is a list type, length of list must be three.
- `__call__(img)`
- img(np.array):Image data in (C,H,W) channels.
- class Normalize
- `__init__(mean,std)`
- mean(list):Mean. Length of list must be three.
- std(list):Variance. Length of list must be three.
- `__call__(img)`
- img(np.array):Image data in (C,H,W) channels.
- class CenterCrop
- `__init__(size)`
- size(int):Crops the given Image at the center while the size must not bigger than any inputs' height and width.
- `__call__(img)`
- img(np.array):Image data in (C,H,W) channels.
- class Resize
- `__init__(size, max_size=2147483647, interpolation=None)`
- size(list/int):The expected image size, when the input is a list type, it needs to contain the expected length and width. When the input is int type, the short side will be set to the length of size, and the long side will be scaled proportionally.
- `__call__(img)`
- img(numpy array):Image data in (C,H,W) channels.
## Quick start
[After compiling from code](https://github.com/PaddlePaddle/Serving/blob/develop/doc/COMPILE.md),this project will be stored in reader。
## How to Test
Test file:Serving/python/paddle_serving_app/reader/test_preprocess.py
# hwvideoframe
hwvideoframe是一个基于cuda的图片预处理库。项目利用GPU进行图片预处理操作,在加快处理速度的同时,提高GPU的使用率。
## 项目结构
hwvideoframe目前提供以下图片预处理功能:
- class Image2Gpubuffer
- `__call__(img)`
- img(np.array):输入图像。
- class Gpubuffer2Image
- `__call__(img)`
- img(np.array):输入图像。
- class Div
- `__init__(value)`
- value(float):根据固定值切割图像。
- `__call__(img)`
- img(np.array):输入图像。
- class Sub
- `__init__(subtractor)`
- subtractor(list/float):list的长度必须为3。
- `__call__(img)`
- img(np.array):(C,H,W)排列的图像数据。
- class Normalize
- `__init__(mean,std)`
- mean(list):均值。 list长度必须为3。
- std(list):方差。 list长度必须为3。
- `__call__(img)`
- img(np.array):(C,H,W)排列的图像数据。
- class CenterCrop
- `__init__(size)`
- size(int):预期的裁剪后的大小,list类型时需要包含预期的长和宽,int类型时会返回边长为size的正方形图片。size不能大于原始图片大小。
- `__call__(img)`
- img(np.array):(C,H,W)排列的图像数据
- class Resize
- `__init__(size, max_size=2147483647, interpolation=None)`
- size(list/int):预期的图像大小,短边会设置为size的长度,长边按比例缩放.
- `__call__(img)`
- img(numpy array):(C,H,W)排列的图像数据
## 快速开始
按照Paddle Serving文档编译,编译结果在reader中。
## 测试
测试文件:Serving/python/paddle_serving_app/reader/test_preprocess.py
cmake_minimum_required(VERSION 3.2)
project(gpu)
FIND_PACKAGE(CUDA ${CUDA_VERSION} REQUIRED)
SET(CUDA_TARGET_INCLUDE ${CUDA_TOOLKIT_ROOT_DIR}-${CUDA_VERSION}/targets/${CMAKE_HOST_SYSTEM_PROCESSOR}-${LOWER_SYSTEM_NAME}/include)
file(GLOB_RECURSE CURRENT_HEADERS *.h *.hpp *.cuh)
file(GLOB CURRENT_SOURCES *.cpp *.cu)
source_group("Include" FILES ${CURRENT_HEADERS})
source_group("Source" FILES ${CURRENT_SOURCES})
set(CMAKE_CUDA_FLAGS "-ccbin /opt/compiler/gcc-4.8.2/bin -Xcompiler -fPIC --std=c++11")
set(CUDA_NVCC_FLAGS "-L/opt/compiler/gcc-4.8.2/bin -Xcompiler -fPIC --std=c++11")
include_directories(${CUDA_INCLUDE_DIRS})
cuda_add_library(gpu SHARED ${CURRENT_HEADERS} ${CURRENT_SOURCES})
target_link_libraries(gpu ${CUDA_LIBS})
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "./cuda_runtime.h"
#define clip(x, a, b) x >= a ? (x < b ? x : b - 1) : a;
const int INTER_RESIZE_COEF_BITS = 11;
const int INTER_RESIZE_COEF_SCALE = 1 << INTER_RESIZE_COEF_BITS;
__global__ void resizeCudaKernel(const float* input,
float* output,
const int inputWidth,
const int inputHeight,
const int outputWidth,
const int outputHeight,
const int inputChannels) {
// 2D Index of current thread
const int dx = blockIdx.x * blockDim.x + threadIdx.x;
const int dy = blockIdx.y * blockDim.y + threadIdx.y;
if ((dx < outputWidth) && (dy < outputHeight)) {
if (inputChannels == 1) { // grayscale image
// TODO(Zelda): support grayscale
} else if (inputChannels == 3) { // RGB image
double scale_x = static_cast<double>(inputWidth / outputWidth);
double scale_y = static_cast<double>(inputHeight / outputHeight);
int xmax = outputWidth;
float fx = static_cast<float>((dx + 0.5) * scale_x - 0.5);
int sx = floorf(fx);
fx = fx - sx;
int isx1 = sx;
if (isx1 < 0) {
fx = 0.0;
isx1 = 0;
}
if (isx1 >= (inputWidth - 1)) {
xmax = ::min(xmax, dx);
fx = 0;
isx1 = inputWidth - 1;
}
float2 cbufx;
cbufx.x = (1.f - fx);
cbufx.y = fx;
float fy = static_cast<float>((dy + 0.5) * scale_y - 0.5);
int sy = floorf(fy);
fy = fy - sy;
int isy1 = clip(sy + 0, 0, inputHeight);
int isy2 = clip(sy + 1, 0, inputHeight);
float2 cbufy;
cbufy.x = (1.f - fy);
cbufy.y = fy;
int isx2 = isx1 + 1;
float3 d0;
float3 s11 =
make_float3(input[(isy1 * inputWidth + isx1) * inputChannels + 0],
input[(isy1 * inputWidth + isx1) * inputChannels + 1],
input[(isy1 * inputWidth + isx1) * inputChannels + 2]);
float3 s12 =
make_float3(input[(isy1 * inputWidth + isx2) * inputChannels + 0],
input[(isy1 * inputWidth + isx2) * inputChannels + 1],
input[(isy1 * inputWidth + isx2) * inputChannels + 2]);
float3 s21 =
make_float3(input[(isy2 * inputWidth + isx1) * inputChannels + 0],
input[(isy2 * inputWidth + isx1) * inputChannels + 1],
input[(isy2 * inputWidth + isx1) * inputChannels + 2]);
float3 s22 =
make_float3(input[(isy2 * inputWidth + isx2) * inputChannels + 0],
input[(isy2 * inputWidth + isx2) * inputChannels + 1],
input[(isy2 * inputWidth + isx2) * inputChannels + 2]);
float h_rst00, h_rst01;
// B
if (dx > xmax - 1) {
h_rst00 = s11.x;
h_rst01 = s21.x;
} else {
h_rst00 = s11.x * cbufx.x + s12.x * cbufx.y;
h_rst01 = s21.x * cbufx.x + s22.x * cbufx.y;
}
d0.x = h_rst00 * cbufy.x + h_rst01 * cbufy.y;
// G
if (dx > xmax - 1) {
h_rst00 = s11.y;
h_rst01 = s21.y;
} else {
h_rst00 = s11.y * cbufx.x + s12.y * cbufx.y;
h_rst01 = s21.y * cbufx.x + s22.y * cbufx.y;
}
d0.y = h_rst00 * cbufy.x + h_rst01 * cbufy.y;
// R
if (dx > xmax - 1) {
h_rst00 = s11.z;
h_rst01 = s21.z;
} else {
h_rst00 = s11.z * cbufx.x + s12.z * cbufx.y;
h_rst01 = s21.z * cbufx.x + s22.z * cbufx.y;
}
d0.z = h_rst00 * cbufy.x + h_rst01 * cbufy.y;
output[(dy * outputWidth + dx) * 3 + 0] = (d0.x); // R
output[(dy * outputWidth + dx) * 3 + 1] = (d0.y); // G
output[(dy * outputWidth + dx) * 3 + 2] = (d0.z); // B
} else {
// TODO(Zelda): support alpha channel
}
}
}
__global__ void resizeCudaKernel_fixpt(const float* input,
float* output,
const int inputWidth,
const int inputHeight,
const int outputWidth,
const int outputHeight,
const int inputChannels) {
// 2D Index of current thread
const int dx = blockIdx.x * blockDim.x + threadIdx.x;
const int dy = blockIdx.y * blockDim.y + threadIdx.y;
if ((dx < outputWidth) && (dy < outputHeight)) {
if (inputChannels == 1) { // grayscale image
// TODO(Zelda): support grayscale
} else if (inputChannels == 3) { // RGB image
double scale_x = static_cast<double>(inputWidth / outputWidth);
double scale_y = static_cast<double>(inputHeight / outputHeight);
int xmax = outputWidth;
float fx = static_cast<float>((dx + 0.5) * scale_x - 0.5);
int sx = floorf(fx);
fx = fx - sx;
int isx1 = sx;
if (isx1 < 0) {
fx = 0.0;
isx1 = 0;
}
if (isx1 >= (inputWidth - 1)) {
xmax = ::min(xmax, dx);
fx = 0;
isx1 = inputWidth - 1;
}
short2 cbufx;
cbufx.x = lrintf((1.f - fx) * INTER_RESIZE_COEF_SCALE);
cbufx.y = lrintf(fx * INTER_RESIZE_COEF_SCALE);
float fy = static_cast<float>((dy + 0.5) * scale_y - 0.5);
int sy = floorf(fy);
fy = fy - sy;
int isy1 = clip(sy + 0, 0, inputHeight);
int isy2 = clip(sy + 1, 0, inputHeight);
short2 cbufy;
cbufy.x = lrintf((1.f - fy) * INTER_RESIZE_COEF_SCALE);
cbufy.y = lrintf(fy * INTER_RESIZE_COEF_SCALE);
int isx2 = isx1 + 1;
uchar3 d0;
int3 s11 =
make_int3(input[(isy1 * inputWidth + isx1) * inputChannels + 0],
input[(isy1 * inputWidth + isx1) * inputChannels + 1],
input[(isy1 * inputWidth + isx1) * inputChannels + 2]);
int3 s12 =
make_int3(input[(isy1 * inputWidth + isx2) * inputChannels + 0],
input[(isy1 * inputWidth + isx2) * inputChannels + 1],
input[(isy1 * inputWidth + isx2) * inputChannels + 2]);
int3 s21 =
make_int3(input[(isy2 * inputWidth + isx1) * inputChannels + 0],
input[(isy2 * inputWidth + isx1) * inputChannels + 1],
input[(isy2 * inputWidth + isx1) * inputChannels + 2]);
int3 s22 =
make_int3(input[(isy2 * inputWidth + isx2) * inputChannels + 0],
input[(isy2 * inputWidth + isx2) * inputChannels + 1],
input[(isy2 * inputWidth + isx2) * inputChannels + 2]);
int h_rst00, h_rst01;
// B
if (dx > xmax - 1) {
h_rst00 = s11.x * INTER_RESIZE_COEF_SCALE;
h_rst01 = s21.x * INTER_RESIZE_COEF_SCALE;
} else {
h_rst00 = s11.x * cbufx.x + s12.x * cbufx.y;
h_rst01 = s21.x * cbufx.x + s22.x * cbufx.y;
}
d0.x = (unsigned char)((((cbufy.x * (h_rst00 >> 4)) >> 16) +
((cbufy.y * (h_rst01 >> 4)) >> 16) + 2) >>
2);
// G
if (dx > xmax - 1) {
h_rst00 = s11.y * INTER_RESIZE_COEF_SCALE;
h_rst01 = s21.y * INTER_RESIZE_COEF_SCALE;
} else {
h_rst00 = s11.y * cbufx.x + s12.y * cbufx.y;
h_rst01 = s21.y * cbufx.x + s22.y * cbufx.y;
}
d0.y = (unsigned char)((((cbufy.x * (h_rst00 >> 4)) >> 16) +
((cbufy.y * (h_rst01 >> 4)) >> 16) + 2) >>
2);
// R
if (dx > xmax - 1) {
h_rst00 = s11.z * INTER_RESIZE_COEF_SCALE;
h_rst01 = s21.z * INTER_RESIZE_COEF_SCALE;
} else {
h_rst00 = s11.z * cbufx.x + s12.z * cbufx.y;
h_rst01 = s21.z * cbufx.x + s22.z * cbufx.y;
}
d0.z = (unsigned char)((((cbufy.x * (h_rst00 >> 4)) >> 16) +
((cbufy.y * (h_rst01 >> 4)) >> 16) + 2) >>
2);
output[(dy * outputWidth + dx) * 3 + 0] = (d0.x); // R
output[(dy * outputWidth + dx) * 3 + 1] = (d0.y); // G
output[(dy * outputWidth + dx) * 3 + 2] = (d0.z); // B
} else {
// TODO(Zelda): support alpha channel
}
}
}
extern "C" cudaError_t resize_linear(const float* input,
float* output,
const int inputWidth,
const int inputHeight,
const int outputWidth,
const int outputHeight,
const int inputChannels,
const bool use_fixed_point) {
// Specify a reasonable block size
const dim3 block(16, 16);
// Calculate grid size to cover the whole image
const dim3 grid((outputWidth + block.x - 1) / block.x,
(outputHeight + block.y - 1) / block.y);
// Launch the size conversion kernel
if (use_fixed_point) {
resizeCudaKernel_fixpt<<<grid, block>>>(input,
output,
inputWidth,
inputHeight,
outputWidth,
outputHeight,
inputChannels);
} else {
resizeCudaKernel<<<grid, block>>>(input,
output,
inputWidth,
inputHeight,
outputWidth,
outputHeight,
inputChannels);
}
// Synchronize to check for any kernel launch errors
return cudaDeviceSynchronize();
}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_CENTER_CROP_H_
#define CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_CENTER_CROP_H_
#include <npp.h>
#include <memory>
#include "./op_context.h"
// Crops the given Image at the center.
// the size must not bigger than any inputs' height and width
class CenterCrop {
public:
explicit CenterCrop(int size) : _size(size) {}
std::shared_ptr<OpContext> operator()(std::shared_ptr<OpContext> input);
private:
int _size;
};
#endif // CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_CENTER_CROP_H_
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_DIV_H_
#define CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_DIV_H_
#include <npp.h>
#include <memory>
#include "core/preprocess/hwvideoframe/include/op_context.h"
// divide by some float number for all pixel
class Div {
public:
explicit Div(float value);
std::shared_ptr<OpContext> operator()(std::shared_ptr<OpContext> input);
private:
Npp32f _divisor;
};
#endif // CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_DIV_H_
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_IMAGE_IO_H_
#define CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_IMAGE_IO_H_
#include <npp.h>
#include <pybind11/numpy.h>
#include <memory>
#include "core/preprocess/hwvideoframe/include/op_context.h"
// Input operator that copy numpy data to gpu buffer
class Image2Gpubuffer {
public:
std::shared_ptr<OpContext> operator()(pybind11::array_t<float> array);
};
// Output operator that copy gpu buffer data to numpy
class Gpubuffer2Image {
public:
pybind11::array_t<float> operator()(std::shared_ptr<OpContext> input);
};
#endif // CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_IMAGE_IO_H_
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_NORMALIZE_H_
#define CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_NORMALIZE_H_
#include <npp.h>
#include <memory>
#include <vector>
#include "core/preprocess/hwvideoframe/include/op_context.h"
// utilize normalize operator on gpu
class Normalize {
public:
Normalize(const std::vector<float> &mean,
const std::vector<float> &std,
bool channel_first = false);
std::shared_ptr<OpContext> operator()(std::shared_ptr<OpContext> input);
private:
Npp32f _mean[CHANNEL_SIZE];
Npp32f _std[CHANNEL_SIZE];
bool _channel_first; // indicate whether the channel is dimension 0,
// unsupported
};
#endif // CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_NORMALIZE_H_
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_OP_CONTEXT_H_
#define CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_OP_CONTEXT_H_
#include <npp.h>
const size_t CHANNEL_SIZE = 3;
// The context as input/ouput of all operators
// contains pointer to raw data on gpu, frame size
class OpContext {
public:
OpContext() {
_step = 0;
_size = 0;
_p_frame = nullptr;
}
// constructor to apply gpu memory of image raw data
OpContext(int height, int width) {
_step = sizeof(Npp32f) * width * CHANNEL_SIZE;
_length = height * width * CHANNEL_SIZE;
_size = _step * height;
_nppi_size.height = height;
_nppi_size.width = width;
cudaMalloc(reinterpret_cast<void**>(&_p_frame), _size);
}
virtual ~OpContext() { free_memory(); }
public:
Npp32f* p_frame() const { return _p_frame; }
int step() const { return _step; }
int length() const { return _length; }
int size() const { return _size; }
NppiSize& nppi_size() { return _nppi_size; }
void free_memory() {
if (_p_frame != nullptr) {
cudaFree(_p_frame);
_p_frame = nullptr;
}
_nppi_size.height = 0;
_nppi_size.width = 0;
_step = 0;
_size = 0;
}
private:
Npp32f* _p_frame; // pointer to raw data on gpu
int _step; // number of bytes in a row
int _length; // length of _p_frame, _size = _length * sizeof(Npp32f)
int _size; // number of bytes of the image
NppiSize _nppi_size; // contains height and width
};
#endif // CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_OP_CONTEXT_H_
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_RESIZE_H_
#define CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_RESIZE_H_
#include <npp.h>
#include <memory>
#include <vector>
#include "core/preprocess/hwvideoframe/include/op_context.h"
extern "C" cudaError_t resize_linear(const float* input,
float* output,
const int inputWidth,
const int inputHeight,
const int outputWidth,
const int outputHeight,
const int inputChannels,
const bool use_fixed_point);
// Resize the input numpy array Image to the given size.
// only support linear interpolation
// only support RGB channels
class Resize {
public:
// size is an int, smaller edge of the image will be matched to this number.
Resize(int size,
int max_size = 214748364,
bool use_fixed_point = false,
int interpolation = 0)
: _size(size),
_max_size(max_size),
_use_fixed_point(use_fixed_point),
_interpolation(interpolation) {}
// size is a sequence like (w, h), output size will be matched to this
Resize(std::vector<int> size,
int max_size = 214748364,
bool use_fixed_point = false,
int interpolation = 0)
: _size(-1),
_max_size(max_size),
_use_fixed_point(use_fixed_point),
_interpolation(interpolation) {
_target_size[0] = size[0];
_target_size[1] = size[1];
}
std::shared_ptr<OpContext> operator()(std::shared_ptr<OpContext> input);
private:
int _size; // target of smaller edge
int _target_size[2]; // target size sequence (w, h)
int _max_size;
bool _use_fixed_point;
int _interpolation; // unused
};
#endif // CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_RESIZE_H_
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_RESIZE_BY_FACTOR_H_
#define CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_RESIZE_BY_FACTOR_H_
#include <npp.h>
#include <memory>
#include <vector>
#include "core/preprocess/hwvideoframe/include/op_context.h"
extern "C" cudaError_t resize_linear(const float *input,
float *output,
const int inputWidth,
const int inputHeight,
const int outputWidth,
const int outputHeight,
const int inputChannels,
const bool use_fixed_point);
// Resize the input numpy array Image to a size multiple of factor which is
// usually required by a network
// only support linear interpolation
// only support RGB channels
class ResizeByFactor {
public:
// Resize factor. make width and height multiple factor of the value of
// factor. Default is 32
ResizeByFactor(int factor = 32,
int max_side_len = 2400,
bool use_fixed_point = false,
int interpolation = 0)
: _factor(factor),
_max_side_len(max_side_len),
_use_fixed_point(use_fixed_point),
_interpolation(interpolation) {}
std::shared_ptr<OpContext> operator()(std::shared_ptr<OpContext> input);
private:
int _factor; // target of smaller edge
int _max_side_len;
bool _use_fixed_point;
int _interpolation; // unused
};
#endif // CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_RESIZE_BY_FACTOR_H_
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_RGB_SWAP_H_
#define CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_RGB_SWAP_H_
#include <npp.h>
#include <memory>
#include "core/preprocess/hwvideoframe/include/op_context.h"
// swap channel 0 and channel 2 for every pixel
// both RGB2BGR and BGR2RGB use this operator
class SwapChannel {
public:
SwapChannel() {}
std::shared_ptr<OpContext> operator()(std::shared_ptr<OpContext> input);
private:
static const int
_ORDER[CHANNEL_SIZE]; // describing how channel values are permutated
};
class RGB2BGR : public SwapChannel {};
class BGR2RGB : public SwapChannel {};
#endif // CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_RGB_SWAP_H_
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_SUB_H_
#define CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_SUB_H_
#include <npp.h>
#include <memory>
#include <vector>
#include "core/preprocess/hwvideoframe/include/op_context.h"
// subtract by some float numbers
class Sub {
public:
explicit Sub(float subtractor);
explicit Sub(const std::vector<float> &subtractors);
std::shared_ptr<OpContext> operator()(std::shared_ptr<OpContext> input);
private:
Npp32f _subtractors[CHANNEL_SIZE];
};
#endif // CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_SUB_H_
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_UTILS_H_
#define CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_UTILS_H_
#include <npp.h>
#include <string>
// verify return value of npp function
// throw an exception if failed
void verify_npp_ret(const std::string& function_name, NppStatus ret);
// verify return value of cuda runtime function
// throw an exception if failed
void verify_cuda_ret(const std::string& function_name, cudaError_t ret);
#endif // CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_UTILS_H_
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "core/preprocess/hwvideoframe/include/center_crop.h"
#include "core/preprocess/hwvideoframe/include/div.h"
#include "core/preprocess/hwvideoframe/include/image_io.h"
#include "core/preprocess/hwvideoframe/include/normalize.h"
#include "core/preprocess/hwvideoframe/include/resize.h"
#include "core/preprocess/hwvideoframe/include/resize_by_factor.h"
#include "core/preprocess/hwvideoframe/include/rgb_swap.h"
#include "core/preprocess/hwvideoframe/include/sub.h"
PYBIND11_MODULE(libgpupreprocess, m) {
pybind11::class_<OpContext, std::shared_ptr<OpContext>>(m, "OpContext");
pybind11::class_<Image2Gpubuffer>(m, "Image2Gpubuffer")
.def(pybind11::init<>())
.def("__call__", &Image2Gpubuffer::operator());
pybind11::class_<Gpubuffer2Image>(m, "Gpubuffer2Image")
.def(pybind11::init<>())
.def("__call__", &Gpubuffer2Image::operator());
pybind11::class_<RGB2BGR>(m, "RGB2BGR")
.def(pybind11::init<>())
.def("__call__", &RGB2BGR::operator());
pybind11::class_<BGR2RGB>(m, "BGR2RGB")
.def(pybind11::init<>())
.def("__call__", &BGR2RGB::operator());
pybind11::class_<Div>(m, "Div")
.def(pybind11::init<float>())
.def("__call__", &Div::operator());
pybind11::class_<Sub>(m, "Sub")
.def(pybind11::init<float>())
.def(pybind11::init<const std::vector<float>&>())
.def("__call__", &Sub::operator());
pybind11::class_<Normalize>(m, "Normalize")
.def(pybind11::init<const std::vector<float>&,
const std::vector<float>&,
bool>(),
pybind11::arg("mean"),
pybind11::arg("std"),
pybind11::arg("channel_first") = false)
.def("__call__", &Normalize::operator());
pybind11::class_<CenterCrop>(m, "CenterCrop")
.def(pybind11::init<int>())
.def("__call__", &CenterCrop::operator());
pybind11::class_<Resize>(m, "Resize")
.def(pybind11::init<int, int, bool>(),
pybind11::arg("size"),
pybind11::arg("max_size") = 214748364,
pybind11::arg("use_fixed_point") = false)
.def(pybind11::init<const std::vector<int>&, int, bool>(),
pybind11::arg("target_size"),
pybind11::arg("max_size") = 214748364,
pybind11::arg("use_fixed_point") = false)
.def("__call__", &Resize::operator());
pybind11::class_<ResizeByFactor>(m, "ResizeByFactor")
.def(pybind11::init<int, int, bool>(),
pybind11::arg("factor") = 32,
pybind11::arg("max_side_len") = 2400,
pybind11::arg("use_fixed_point") = false)
.def("__call__", &ResizeByFactor::operator());
}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <npp.h>
#include <algorithm>
#include "core/preprocess/hwvideoframe/include/center_crop.h"
#include "core/preprocess/hwvideoframe/include/utils.h"
std::shared_ptr<OpContext> CenterCrop::operator()(
std::shared_ptr<OpContext> input) {
int new_width = std::min(_size, input->nppi_size().width);
int new_height = std::min(_size, input->nppi_size().height);
auto output = std::make_shared<OpContext>(new_height, new_width);
int x_start = (input->nppi_size().width - new_width) / 2;
int y_start = (input->nppi_size().height - new_height) / 2;
Npp32f* p_src = input->p_frame() +
y_start * input->nppi_size().width * CHANNEL_SIZE +
x_start * CHANNEL_SIZE;
NppStatus ret = nppiCopy_32f_C3R(p_src,
input->step(),
output->p_frame(),
output->step(),
output->nppi_size());
verify_npp_ret("nppiCopy_32f_C3R", ret);
return output;
}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <npp.h>
#include <stdexcept>
#include "core/preprocess/hwvideoframe/include/div.h"
#include "core/preprocess/hwvideoframe/include/utils.h"
Div::Div(float value) { _divisor = value; }
std::shared_ptr<OpContext> Div::operator()(std::shared_ptr<OpContext> input) {
NppStatus ret = nppsDivC_32f_I(_divisor, input->p_frame(), input->length());
verify_npp_ret("nppsDivC_32f_I", ret);
return input;
}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cuda_runtime.h>
#include <pybind11/numpy.h>
#include <stdexcept>
#include "core/preprocess/hwvideoframe/include/image_io.h"
#include "core/preprocess/hwvideoframe/include/utils.h"
std::shared_ptr<OpContext> Image2Gpubuffer::operator()(
pybind11::array_t<float> input) {
pybind11::buffer_info buf = input.request();
if (buf.format != pybind11::format_descriptor<float>::format()) {
throw std::runtime_error("Incompatible format: expected a float numpy!");
}
if (buf.ndim != 3) {
throw std::runtime_error("Number of dimensions must be three");
}
if (buf.shape[2] != CHANNEL_SIZE) {
throw std::runtime_error("Number of channels must be three");
}
auto result = std::make_shared<OpContext>(buf.shape[0], buf.shape[1]);
auto ret = cudaMemcpy(result->p_frame(),
static_cast<float*>(buf.ptr),
result->size(),
cudaMemcpyHostToDevice);
verify_cuda_ret("cudaMemcpy", ret);
return result;
}
pybind11::array_t<float> Gpubuffer2Image::operator()(
std::shared_ptr<OpContext> input) {
auto result = pybind11::array_t<float>({input->nppi_size().height,
input->nppi_size().width,
static_cast<int>(CHANNEL_SIZE)});
pybind11::buffer_info buf = result.request();
auto ret = cudaMemcpy(static_cast<float*>(buf.ptr),
input->p_frame(),
input->size(),
cudaMemcpyDeviceToHost);
verify_cuda_ret("cudaMemcpy", ret);
return result;
}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <npp.h>
#include <stdexcept>
#include "core/preprocess/hwvideoframe/include/normalize.h"
#include "core/preprocess/hwvideoframe/include/utils.h"
Normalize::Normalize(const std::vector<float> &mean,
const std::vector<float> &std,
bool channel_first) {
if (mean.size() != CHANNEL_SIZE) {
throw std::runtime_error("size of mean must be three");
}
if (std.size() != CHANNEL_SIZE) {
throw std::runtime_error("size of std must be three");
}
for (size_t i = 0; i < CHANNEL_SIZE; i++) {
_mean[i] = mean[i];
_std[i] = std[i];
}
_channel_first = channel_first;
}
std::shared_ptr<OpContext> Normalize::operator()(
std::shared_ptr<OpContext> input) {
NppStatus ret = nppiSubC_32f_C3IR(
_mean, input->p_frame(), input->step(), input->nppi_size());
verify_npp_ret("nppiSubC_32f_C3IR", ret);
ret = nppiDivC_32f_C3IR(
_std, input->p_frame(), input->step(), input->nppi_size());
verify_npp_ret("nppiDivC_32f_C3IR", ret);
return input;
}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "core/preprocess/hwvideoframe/include/resize.h"
#include <math.h>
#include <sstream>
#include "core/preprocess/hwvideoframe/include/utils.h"
std::shared_ptr<OpContext> Resize::operator()(
std::shared_ptr<OpContext> input) {
int resized_width = 0, resized_height = 0;
if (_size == -1) {
resized_width = std::min(_target_size[0], _max_size);
resized_height = std::min(_target_size[1], _max_size);
} else {
int im_max_size =
std::max(input->nppi_size().height, input->nppi_size().width);
float percent =
static_cast<float>(_size) /
std::min(input->nppi_size().height, input->nppi_size().width);
if (round(percent * im_max_size) > _max_size) {
percent = static_cast<float>(_max_size) / static_cast<float>(im_max_size);
}
resized_width = static_cast<int>(round(input->nppi_size().width * percent));
resized_height =
static_cast<int>(round(input->nppi_size().height * percent));
}
auto output = std::make_shared<OpContext>(resized_height, resized_width);
auto ret = resize_linear(input->p_frame(),
output->p_frame(),
input->nppi_size().width,
input->nppi_size().height,
output->nppi_size().width,
output->nppi_size().height,
CHANNEL_SIZE,
_use_fixed_point);
verify_cuda_ret("resize_linear", ret);
return output;
}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <math.h>
#include <sstream>
#include "core/preprocess/hwvideoframe/include/resize.h"
#include "core/preprocess/hwvideoframe/include/resize_by_factor.h"
#include "core/preprocess/hwvideoframe/include/utils.h"
std::shared_ptr<OpContext> ResizeByFactor::operator()(
std::shared_ptr<OpContext> input) {
int resized_width = input->nppi_size().width,
resized_height = input->nppi_size().height;
float radio = 0;
if (std::max(resized_width, resized_height) > _max_side_len) {
if (resized_width > resized_height) {
radio = static_cast<float>(_max_side_len / resized_width);
} else {
radio = static_cast<float>(_max_side_len / resized_height);
}
} else {
radio = 1;
}
resized_width = static_cast<int>(resized_width * radio);
resized_height = static_cast<int>(resized_height * radio);
if (resized_height % _factor == 0) {
resized_height = resized_height;
} else if (floor(resized_height / _factor) <= 1) {
resized_height = _factor;
} else {
resized_height = (floor(resized_height / 32) - 1) * 32;
}
if (resized_width % _factor == 0) {
resized_width = resized_width;
} else if (floor(resized_width / _factor) <= 1) {
resized_width = _factor;
} else {
resized_width = (floor(resized_width / 32) - 1) * _factor;
}
if (static_cast<int>(resized_width) <= 0 ||
static_cast<int>(resized_height) <= 0) {
return NULL;
}
auto output = std::make_shared<OpContext>(resized_height, resized_width);
auto ret = resize_linear(input->p_frame(),
output->p_frame(),
input->nppi_size().width,
input->nppi_size().height,
output->nppi_size().width,
output->nppi_size().height,
CHANNEL_SIZE,
_use_fixed_point);
verify_cuda_ret("resize_linear", ret);
return output;
}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <npp.h>
#include "core/preprocess/hwvideoframe/include/rgb_swap.h"
#include "core/preprocess/hwvideoframe/include/utils.h"
const int SwapChannel::_ORDER[CHANNEL_SIZE] = {2, 1, 0};
std::shared_ptr<OpContext> SwapChannel::operator()(
std::shared_ptr<OpContext> input) {
NppStatus ret = nppiSwapChannels_32f_C3IR(
input->p_frame(), input->step(), input->nppi_size(), _ORDER);
verify_npp_ret("nppiSwapChannels_32f_C3IR", ret);
return input;
}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <npp.h>
#include <stdexcept>
#include "core/preprocess/hwvideoframe/include/sub.h"
#include "core/preprocess/hwvideoframe/include/utils.h"
Sub::Sub(float subtractor) {
for (size_t i = 0; i < CHANNEL_SIZE; i++) {
_subtractors[i] = subtractor;
}
}
Sub::Sub(const std::vector<float> &subtractors) {
if (subtractors.size() != CHANNEL_SIZE) {
throw std::runtime_error("size of subtractors must be three");
}
for (size_t i = 0; i < CHANNEL_SIZE; i++) {
_subtractors[i] = subtractors[i];
}
}
std::shared_ptr<OpContext> Sub::operator()(std::shared_ptr<OpContext> input) {
NppStatus ret = nppiSubC_32f_C3IR(
_subtractors, input->p_frame(), input->step(), input->nppi_size());
verify_npp_ret("nppiSubC_32f_C3IR", ret);
return input;
}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <npp.h>
#include <sstream>
#include <stdexcept>
#include "core/preprocess/hwvideoframe/include/utils.h"
void verify_npp_ret(const std::string& function_name, NppStatus ret) {
if (ret != NPP_SUCCESS) {
std::ostringstream ss;
ss << function_name << ", ret: " << ret;
throw std::runtime_error(ss.str());
}
}
void verify_cuda_ret(const std::string& function_name, cudaError_t ret) {
if (ret != cudaSuccess) {
std::ostringstream ss;
ss << function_name << ", ret: " << ret;
throw std::runtime_error(ss.str());
}
}
......@@ -33,6 +33,7 @@ tar -xzf bert_chinese_L-12_H-768_A-12.tar.gz
mv bert_chinese_L-12_H-768_A-12_model bert_seq128_model
mv bert_chinese_L-12_H-768_A-12_client bert_seq128_client
```
if your model is bert_chinese_L-12_H-768_A-12_model, replace the 'bert_seq128_model' field in the following command with 'bert_chinese_L-12_H-768_A-12_model',replace 'bert_seq128_client' with 'bert_chinese_L-12_H-768_A-12_client'.
### Getting Dict and Sample Dataset
......
......@@ -28,6 +28,9 @@ tar -xzf bert_chinese_L-12_H-768_A-12.tar.gz
mv bert_chinese_L-12_H-768_A-12_model bert_seq128_model
mv bert_chinese_L-12_H-768_A-12_client bert_seq128_client
```
若使用bert_chinese_L-12_H-768_A-12_model模型,将下面命令中的bert_seq128_model字段替换为bert_chinese_L-12_H-768_A-12_model,bert_seq128_client字段替换为bert_chinese_L-12_H-768_A-12_client.
### 获取词典和样例数据
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import unittest
import sys
import numpy as np
import cv2
from paddle_serving_app.reader import Sequential, Resize, File2Image
import libgpupreprocess as pp
import libhwextract
class TestOperators(unittest.TestCase):
"""
test all operators, e.g. Div, Normalize
"""
def test_div(self):
height = 4
width = 5
channels = 3
value = 255.0
img = np.arange(height * width * channels).reshape(
[height, width, channels])
seq = Sequential(
[pp.Image2Gpubuffer(), pp.Div(value), pp.Gpubuffer2Image()])
result = seq(img).reshape(-1)
for i in range(0, result.size):
self.assertAlmostEqual(i / value, result[i], 5)
def test_sub(self):
height = 4
width = 5
channels = 3
img = np.arange(height * width * channels).reshape(
[height, width, channels])
# input size is an int
value = 10.0
seq = Sequential(
[pp.Image2Gpubuffer(), pp.Sub(value), pp.Gpubuffer2Image()])
result = seq(img).reshape(-1)
for i in range(0, result.size):
self.assertEqual(i - value, result[i])
# input size is a sequence
values = (9, 4, 2)
seq = Sequential(
[pp.Image2Gpubuffer(), pp.Sub(values), pp.Gpubuffer2Image()])
result = seq(img)
for i in range(0, result.shape[0]):
for j in range(0, result.shape[1]):
for k in range(0, result.shape[2]):
self.assertEqual(result[i][j][k], img[i][j][k] - values[k])
def test_normalize(self):
height = 4
width = 5
channels = 3
img = np.random.rand(height, width, channels)
mean = [5.0, 5.0, 5.0]
std = [2.0, 2.0, 2.0]
seq = Sequential([
pp.Image2Gpubuffer(), pp.Normalize(mean, std), pp.Gpubuffer2Image()
])
result = seq(img)
for i in range(0, height):
for j in range(0, width):
for k in range(0, channels):
self.assertAlmostEqual((img[i][j][k] - mean[k]) / std[k],
result[i][j][k], 5)
def test_center_crop(self):
height = 9
width = 7
channels = 3
img = np.arange(height * width * channels).reshape(
[height, width, channels])
new_size = 5
seq = Sequential([
pp.Image2Gpubuffer(), pp.CenterCrop(new_size), pp.Gpubuffer2Image()
])
result = seq(img)
self.assertEqual(result.shape[0], new_size)
self.assertEqual(result.shape[1], new_size)
self.assertEqual(result.shape[2], channels)
def test_resize(self):
height = 9
width = 5
channels = 3
img = np.arange(height * width).reshape([height, width, 1]) * np.ones(
(1, channels))
# input size is an int
for new_size in [3, 10]:
seq_gpu = Sequential([
pp.Image2Gpubuffer(), pp.Resize(new_size), pp.Gpubuffer2Image()
])
seq_paddle = Sequential([Resize(new_size)])
result_gpu = seq_gpu(img)
result_paddle = seq_paddle(img)
self.assertEqual(result_gpu.shape, result_paddle.shape)
for i in range(0, result_gpu.shape[0]):
for j in range(0, result_gpu.shape[1]):
for k in range(0, result_gpu.shape[2]):
self.assertAlmostEqual(result_gpu[i][j][k],
result_paddle[i][j][k], 5)
# input size is a sequence
for new_height, new_width in [(7, 3), (15, 10)]:
seq_gpu = Sequential([
pp.Image2Gpubuffer(), pp.Resize((new_width, new_height)),
pp.Gpubuffer2Image()
])
seq_paddle = Sequential([Resize((new_width, new_height))])
result_gpu = seq_gpu(img)
result_paddle = seq_paddle(img)
self.assertEqual(result_gpu.shape, result_paddle.shape)
for i in range(0, result_gpu.shape[0]):
for j in range(0, result_gpu.shape[1]):
for k in range(0, result_gpu.shape[2]):
self.assertAlmostEqual(result_gpu[i][j][k],
result_paddle[i][j][k], 5)
def test_resize_fixed_point(self):
new_height = 256
new_width = 256 * 4 / 3
seq = Sequential([
File2Image(), pp.Image2Gpubuffer(), pp.Resize(
(new_width, new_height), use_fixed_point=True),
pp.Gpubuffer2Image()
])
img = seq("./capture_16.bmp")
img = np.resize(img, (new_height, new_width * 3))
img_vis = np.loadtxt("./cap_resize_16.raw")
img_resize_diff = img_vis - img
self.assertEqual(np.all(img_resize_diff == 0), True)
if __name__ == '__main__':
unittest.main()
......@@ -11,3 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
from .arr2image import Arr2Image
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import cv2
import yaml
class Arr2Image(object):
"""
from numpy array image(jpeg) to cv::Mat image
"""
def __init__(self):
pass
def __call__(self, img_arr):
img = cv2.imdecode(img_arr, cv2.IMREAD_COLOR)
return img
def __repr__(self):
return self.__class__.__name__ + "()"
......@@ -16,6 +16,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import platform
import os
from setuptools import setup, Distribution, Extension
......@@ -23,14 +24,22 @@ from setuptools import find_packages
from setuptools import setup
from paddle_serving_app.version import serving_app_version
from pkg_resources import DistributionNotFound, get_distribution
import util
max_version, mid_version, min_version = util.python_version()
def python_version():
return [int(v) for v in platform.python_version().split(".")]
def find_package(pkgname):
try:
get_distribution(pkgname)
return True
except DistributionNotFound:
return False
max_version, mid_version, min_version = python_version()
if '${PACK}' == 'ON':
copy_lib()
REQUIRED_PACKAGES = [
'six >= 1.10.0', 'sentencepiece<=0.1.83', 'opencv-python<=4.2.0.32', 'pillow',
'pyclipper', 'shapely'
......@@ -43,7 +52,11 @@ packages=['paddle_serving_app',
'paddle_serving_app.models',
'paddle_serving_app.reader.pddet']
package_data={}
if os.path.exists('../core/preprocess/hwvideoframe/libgpupreprocess.so'):
package_data={'paddle_serving_app': ['../core/preprocess/hwvideoframe/libgpupreprocess.so'],}
else:
package_data={}
package_dir={'paddle_serving_app':
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_app',
'paddle_serving_app.proto':
......@@ -55,7 +68,8 @@ package_dir={'paddle_serving_app':
'paddle_serving_app.models':
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_app/models',
'paddle_serving_app.reader.pddet':
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_app/reader/pddet',}
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_app/reader/pddet',
}
setup(
name='paddle-serving-app',
......
......@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
from pkg_resources import DistributionNotFound, get_distribution
from grpc_tools import protoc
......
......@@ -932,7 +932,6 @@ function python_app_api_test(){
cd imagenet
case $TYPE in
CPU)
check_cmd "python test_image_reader.py"
;;
GPU)
echo "no implement for cpu type"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册