提交 50f1bbfd 编写于 作者: B BohaoWu

Adjust code Style.

上级 4c2614f5
cmake_minimum_required(VERSION 3.2)
project(hw-frame-extract)
# SET(CUDA_VERSION 10.1)
#gcc version
#GCC('gcc482')
#CUDA("10.1")
set(global_cflags_str "-g -pipe -W -Wall -fPIC")
set(CMAKE_C_FLAGS ${global_cflags_str})
#C++ flags.
set(global_cxxflags_str "-g -pipe -W -Wall -fPIC -std=c++11")
set(CMAKE_CXX_FLAGS ${global_cxxflags_str})
add_subdirectory(cuda)
add_subdirectory(pybind11)
set (EXTRA_LIBS ${EXTRA_LIBS} gpu)
message(${CMAKE_CURRENT_SOURCE_DIR})
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/include")
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/pybind11/include")
include_directories("/opt/compiler/cuda-10.1/include")
include_directories("/home/work/wubohao/baidu/third-party/python/include/python2.7")
file(GLOB SOURCE_FILES src/*.cpp pybind11/*.cpp )
link_directories("-L/opt/compiler/cuda-10.1/lib64 -lcudart -lnppidei_static -lnppial_static -lnpps_static -lnppc_static -lculibos")
# link_directories("/home/work/wubohao/baidu/third-party/python/lib")
#.so
add_library(gpupreprocess SHARED ${SOURCE_FILES})
target_link_libraries (gpupreprocess ${EXTRA_LIBS})
# hwvideoframe
简要说明
## 快速开始
如何构建、安装、运行
## 测试
如何执行自动化测试
## 如何贡献
贡献patch流程、质量要求
## 讨论
百度Hi讨论群:XXXX
cmake_minimum_required(VERSION 3.2)
project(gpu)
FIND_PACKAGE(CUDA ${CUDA_VERSION} REQUIRED)
SET(CUDA_TARGET_INCLUDE ${CUDA_TOOLKIT_ROOT_DIR}-${CUDA_VERSION}/targets/${CMAKE_HOST_SYSTEM_PROCESSOR}-${LOWER_SYSTEM_NAME}/include)
file(GLOB_RECURSE CURRENT_HEADERS *.h *.hpp *.cuh)
file(GLOB CURRENT_SOURCES *.cpp *.cu)
file(GLOB CUDA_LIBS /opt/compiler/cuda-10.1/lib64/*.so)
source_group("Include" FILES ${CURRENT_HEADERS})
source_group("Source" FILES ${CURRENT_SOURCES})
set(CMAKE_CUDA_FLAGS "-ccbin /opt/compiler/gcc-4.8.2/bin -Xcompiler -fPIC --std=c++11")
set(CUDA_NVCC_FLAGS "-L/opt/compiler/gcc-4.8.2/bin -Xcompiler -fPIC --std=c++11")
include_directories("/opt/compiler/cuda-10.1/include")
#cuda_add_library(gpu SHARED ${CURRENT_HEADERS} ${CURRENT_SOURCES})
cuda_add_library(gpu SHARED ${CURRENT_HEADERS} ${CURRENT_SOURCES})
target_link_libraries(gpu ${CUDA_LIBS})
# import libgpupreprocess as pp
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "./cuda_runtime.h"
#define clip(x, a, b) x >= a ? (x < b ? x : b - 1) : a;
const int INTER_RESIZE_COEF_BITS = 11;
const int INTER_RESIZE_COEF_SCALE = 1 << INTER_RESIZE_COEF_BITS;
__global__ void resizeCudaKernel(const float* input,
float* output,
const int inputWidth,
const int inputHeight,
const int outputWidth,
const int outputHeight,
const int inputChannels) {
// 2D Index of current thread
const int dx = blockIdx.x * blockDim.x + threadIdx.x;
const int dy = blockIdx.y * blockDim.y + threadIdx.y;
if ((dx < outputWidth) && (dy < outputHeight)) {
if (inputChannels == 1) { // grayscale image
// TODO(Zelda): support grayscale
} else if (inputChannels == 3) { // RGB image
double scale_x = static_cast<double> inputWidth / outputWidth;
double scale_y = static_cast<double> inputHeight / outputHeight;
int xmax = outputWidth;
float fx = static_cast<float>((dx + 0.5) * scale_x - 0.5);
int sx = floorf(fx);
fx = fx - sx;
int isx1 = sx;
if (isx1 < 0) {
fx = 0.0;
isx1 = 0;
}
if (isx1 >= (inputWidth - 1)) {
xmax = ::min(xmax, dx);
fx = 0;
isx1 = inputWidth - 1;
}
float2 cbufx;
cbufx.x = (1.f - fx);
cbufx.y = fx;
float fy = static_cast<float>((dy + 0.5) * scale_y - 0.5);
int sy = floorf(fy);
fy = fy - sy;
int isy1 = clip(sy + 0, 0, inputHeight);
int isy2 = clip(sy + 1, 0, inputHeight);
float2 cbufy;
cbufy.x = (1.f - fy);
cbufy.y = fy;
int isx2 = isx1 + 1;
float3 d0;
float3 s11 =
make_float3(input[(isy1 * inputWidth + isx1) * inputChannels + 0],
input[(isy1 * inputWidth + isx1) * inputChannels + 1],
input[(isy1 * inputWidth + isx1) * inputChannels + 2]);
float3 s12 =
make_float3(input[(isy1 * inputWidth + isx2) * inputChannels + 0],
input[(isy1 * inputWidth + isx2) * inputChannels + 1],
input[(isy1 * inputWidth + isx2) * inputChannels + 2]);
float3 s21 =
make_float3(input[(isy2 * inputWidth + isx1) * inputChannels + 0],
input[(isy2 * inputWidth + isx1) * inputChannels + 1],
input[(isy2 * inputWidth + isx1) * inputChannels + 2]);
float3 s22 =
make_float3(input[(isy2 * inputWidth + isx2) * inputChannels + 0],
input[(isy2 * inputWidth + isx2) * inputChannels + 1],
input[(isy2 * inputWidth + isx2) * inputChannels + 2]);
float h_rst00, h_rst01;
// B
if (dx > xmax - 1) {
h_rst00 = s11.x;
h_rst01 = s21.x;
} else {
h_rst00 = s11.x * cbufx.x + s12.x * cbufx.y;
h_rst01 = s21.x * cbufx.x + s22.x * cbufx.y;
}
d0.x = h_rst00 * cbufy.x + h_rst01 * cbufy.y;
// G
if (dx > xmax - 1) {
h_rst00 = s11.y;
h_rst01 = s21.y;
} else {
h_rst00 = s11.y * cbufx.x + s12.y * cbufx.y;
h_rst01 = s21.y * cbufx.x + s22.y * cbufx.y;
}
d0.y = h_rst00 * cbufy.x + h_rst01 * cbufy.y;
// R
if (dx > xmax - 1) {
h_rst00 = s11.z;
h_rst01 = s21.z;
} else {
h_rst00 = s11.z * cbufx.x + s12.z * cbufx.y;
h_rst01 = s21.z * cbufx.x + s22.z * cbufx.y;
}
d0.z = h_rst00 * cbufy.x + h_rst01 * cbufy.y;
output[(dy * outputWidth + dx) * 3 + 0] = (d0.x); // R
output[(dy * outputWidth + dx) * 3 + 1] = (d0.y); // G
output[(dy * outputWidth + dx) * 3 + 2] = (d0.z); // B
} else {
// TODO(Zelda): support alpha channel
}
}
}
__global__ void resizeCudaKernel_fixpt(const float* input,
float* output,
const int inputWidth,
const int inputHeight,
const int outputWidth,
const int outputHeight,
const int inputChannels) {
// 2D Index of current thread
const int dx = blockIdx.x * blockDim.x + threadIdx.x;
const int dy = blockIdx.y * blockDim.y + threadIdx.y;
if ((dx < outputWidth) && (dy < outputHeight)) {
if (inputChannels == 1) { // grayscale image
// TODO(Zelda): support grayscale
} else if (inputChannels == 3) { // RGB image
double scale_x = static_cast<double> inputWidth / outputWidth;
double scale_y = static_cast<double> inputHeight / outputHeight;
int xmax = outputWidth;
float fx = static_cast<float>((dx + 0.5) * scale_x - 0.5);
int sx = floorf(fx);
fx = fx - sx;
int isx1 = sx;
if (isx1 < 0) {
fx = 0.0;
isx1 = 0;
}
if (isx1 >= (inputWidth - 1)) {
xmax = ::min(xmax, dx);
fx = 0;
isx1 = inputWidth - 1;
}
short2 cbufx;
cbufx.x = lrintf((1.f - fx) * INTER_RESIZE_COEF_SCALE);
cbufx.y = lrintf(fx * INTER_RESIZE_COEF_SCALE);
float fy = static_cast<float>((dy + 0.5) * scale_y - 0.5);
int sy = floorf(fy);
fy = fy - sy;
int isy1 = clip(sy + 0, 0, inputHeight);
int isy2 = clip(sy + 1, 0, inputHeight);
short2 cbufy;
cbufy.x = lrintf((1.f - fy) * INTER_RESIZE_COEF_SCALE);
cbufy.y = lrintf(fy * INTER_RESIZE_COEF_SCALE);
int isx2 = isx1 + 1;
uchar3 d0;
int3 s11 =
make_int3(input[(isy1 * inputWidth + isx1) * inputChannels + 0],
input[(isy1 * inputWidth + isx1) * inputChannels + 1],
input[(isy1 * inputWidth + isx1) * inputChannels + 2]);
int3 s12 =
make_int3(input[(isy1 * inputWidth + isx2) * inputChannels + 0],
input[(isy1 * inputWidth + isx2) * inputChannels + 1],
input[(isy1 * inputWidth + isx2) * inputChannels + 2]);
int3 s21 =
make_int3(input[(isy2 * inputWidth + isx1) * inputChannels + 0],
input[(isy2 * inputWidth + isx1) * inputChannels + 1],
input[(isy2 * inputWidth + isx1) * inputChannels + 2]);
int3 s22 =
make_int3(input[(isy2 * inputWidth + isx2) * inputChannels + 0],
input[(isy2 * inputWidth + isx2) * inputChannels + 1],
input[(isy2 * inputWidth + isx2) * inputChannels + 2]);
int h_rst00, h_rst01;
// B
if (dx > xmax - 1) {
h_rst00 = s11.x * INTER_RESIZE_COEF_SCALE;
h_rst01 = s21.x * INTER_RESIZE_COEF_SCALE;
} else {
h_rst00 = s11.x * cbufx.x + s12.x * cbufx.y;
h_rst01 = s21.x * cbufx.x + s22.x * cbufx.y;
}
d0.x = (unsigned char)((((cbufy.x * (h_rst00 >> 4)) >> 16) +
((cbufy.y * (h_rst01 >> 4)) >> 16) + 2) >>
2);
// G
if (dx > xmax - 1) {
h_rst00 = s11.y * INTER_RESIZE_COEF_SCALE;
h_rst01 = s21.y * INTER_RESIZE_COEF_SCALE;
} else {
h_rst00 = s11.y * cbufx.x + s12.y * cbufx.y;
h_rst01 = s21.y * cbufx.x + s22.y * cbufx.y;
}
d0.y = (unsigned char)((((cbufy.x * (h_rst00 >> 4)) >> 16) +
((cbufy.y * (h_rst01 >> 4)) >> 16) + 2) >>
2);
// R
if (dx > xmax - 1) {
h_rst00 = s11.z * INTER_RESIZE_COEF_SCALE;
h_rst01 = s21.z * INTER_RESIZE_COEF_SCALE;
} else {
h_rst00 = s11.z * cbufx.x + s12.z * cbufx.y;
h_rst01 = s21.z * cbufx.x + s22.z * cbufx.y;
}
d0.z = (unsigned char)((((cbufy.x * (h_rst00 >> 4)) >> 16) +
((cbufy.y * (h_rst01 >> 4)) >> 16) + 2) >>
2);
output[(dy * outputWidth + dx) * 3 + 0] = (d0.x); // R
output[(dy * outputWidth + dx) * 3 + 1] = (d0.y); // G
output[(dy * outputWidth + dx) * 3 + 2] = (d0.z); // B
} else {
// TODO(Zelda): support alpha channel
}
}
}
extern "C" cudaError_t resize_linear(const float* input,
float* output,
const int inputWidth,
const int inputHeight,
const int outputWidth,
const int outputHeight,
const int inputChannels,
const bool use_fixed_point) {
// Specify a reasonable block size
const dim3 block(16, 16);
// Calculate grid size to cover the whole image
const dim3 grid((outputWidth + block.x - 1) / block.x,
(outputHeight + block.y - 1) / block.y);
// Launch the size conversion kernel
if (use_fixed_point) {
resizeCudaKernel_fixpt<<<grid, block>>>(input,
output,
inputWidth,
inputHeight,
outputWidth,
outputHeight,
inputChannels);
} else {
resizeCudaKernel<<<grid, block>>>(input,
output,
inputWidth,
inputHeight,
outputWidth,
outputHeight,
inputChannels);
}
// Synchronize to check for any kernel launch errors
return cudaDeviceSynchronize();
}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_CENTER_CROP_H_
#define CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_CENTER_CROP_H_
#include <npp.h>
#include <memory>
#include "./op_context.h"
// Crops the given Image at the center.
// the size must not bigger than any inputs' height and width
class CenterCrop {
public:
explicit CenterCrop(int size) : _size(size) {}
std::shared_ptr<OpContext> operator()(std::shared_ptr<OpContext> input);
private:
int _size;
};
#endif // CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_CENTER_CROP_H_
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_DIV_H_
#define CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_DIV_H_
#include <npp.h>
#include <memory>
#include "core/preprocess/hwvideoframe/include/op_context.h"
// divide by some float number for all pixel
class Div {
public:
explicit Div(float value);
std::shared_ptr<OpContext> operator()(std::shared_ptr<OpContext> input);
private:
Npp32f _divisor;
};
#endif // CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_DIV_H_
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_IMAGE_IO_H_
#define CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_IMAGE_IO_H_
#include <npp.h>
#include <pybind11/numpy.h>
#include <memory>
#include "core/preprocess/hwvideoframe/include/op_context.h"
// Input operator that copy numpy data to gpu buffer
class Image2Gpubuffer {
public:
std::shared_ptr<OpContext> operator()(pybind11::array_t<float> array);
};
// Output operator that copy gpu buffer data to numpy
class Gpubuffer2Image {
public:
pybind11::array_t<float> operator()(std::shared_ptr<OpContext> input);
};
#endif // CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_IMAGE_IO_H_
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_NORMALIZE_H_
#define CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_NORMALIZE_H_
#include <npp.h>
#include <memory>
#include <vector>
#include "core/preprocess/hwvideoframe/include/op_context.h"
// utilize normalize operator on gpu
class Normalize {
public:
Normalize(const std::vector<float> &mean,
const std::vector<float> &std,
bool channel_first = false);
std::shared_ptr<OpContext> operator()(std::shared_ptr<OpContext> input);
private:
Npp32f _mean[CHANNEL_SIZE];
Npp32f _std[CHANNEL_SIZE];
bool _channel_first; // indicate whether the channel is dimension 0,
// unsupported
};
#endif // CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_NORMALIZE_H_
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_OP_CONTEXT_H_
#define CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_OP_CONTEXT_H_
#include <npp.h>
const size_t CHANNEL_SIZE = 3;
// The context as input/ouput of all operators
// contains pointer to raw data on gpu, frame size
class OpContext {
public:
OpContext() {
_step = 0;
_size = 0;
_p_frame = nullptr;
}
// constructor to apply gpu memory of image raw data
OpContext(int height, int width) {
_step = sizeof(Npp32f) * width * CHANNEL_SIZE;
_length = height * width * CHANNEL_SIZE;
_size = _step * height;
_nppi_size.height = height;
_nppi_size.width = width;
cudaMalloc(reinterpret_cast<void**>(&_p_frame), _size);
}
virtual ~OpContext() { free_memory(); }
public:
Npp32f* p_frame() const { return _p_frame; }
int step() const { return _step; }
int length() const { return _length; }
int size() const { return _size; }
NppiSize& nppi_size() { return _nppi_size; }
void free_memory() {
if (_p_frame != nullptr) {
cudaFree(_p_frame);
_p_frame = nullptr;
}
_nppi_size.height = 0;
_nppi_size.width = 0;
_step = 0;
_size = 0;
}
private:
Npp32f* _p_frame; // pointer to raw data on gpu
int _step; // number of bytes in a row
int _length; // length of _p_frame, _size = _length * sizeof(Npp32f)
int _size; // number of bytes of the image
NppiSize _nppi_size; // contains height and width
};
#endif // CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_OP_CONTEXT_H_
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_RESIZE_H_
#define CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_RESIZE_H_
#include <npp.h>
#include <memory>
#include <vector>
#include "core/preprocess/hwvideoframe/include/op_context.h"
extern "C" cudaError_t resize_linear(const float* input,
float* output,
const int inputWidth,
const int inputHeight,
const int outputWidth,
const int outputHeight,
const int inputChannels,
const bool use_fixed_point);
// Resize the input numpy array Image to the given size.
// only support linear interpolation
// only support RGB channels
class Resize {
public:
// size is an int, smaller edge of the image will be matched to this number.
Resize(int size,
int max_size = 214748364,
bool use_fixed_point = false,
int interpolation = 0)
: _size(size),
_max_size(max_size),
_use_fixed_point(use_fixed_point),
_interpolation(interpolation) {}
// size is a sequence like (w, h), output size will be matched to this
Resize(std::vector<int> size,
int max_size = 214748364,
bool use_fixed_point = false,
int interpolation = 0)
: _size(-1),
_max_size(max_size),
_use_fixed_point(use_fixed_point),
_interpolation(interpolation) {
_target_size[0] = size[0];
_target_size[1] = size[1];
}
std::shared_ptr<OpContext> operator()(std::shared_ptr<OpContext> input);
private:
int _size; // target of smaller edge
int _target_size[2]; // target size sequence (w, h)
int _max_size;
bool _use_fixed_point;
int _interpolation; // unused
};
#endif // CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_RESIZE_H_
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_RESIZE_BY_FACTOR_H_
#define CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_RESIZE_BY_FACTOR_H_
#include <npp.h>
#include <memory>
#include <vector>
#include "core/preprocess/hwvideoframe/include/op_context.h"
extern "C" cudaError_t resize_linear(const float *input,
float *output,
const int inputWidth,
const int inputHeight,
const int outputWidth,
const int outputHeight,
const int inputChannels,
const bool use_fixed_point);
// Resize the input numpy array Image to a size multiple of factor which is
// usually required by a network
// only support linear interpolation
// only support RGB channels
class ResizeByFactor {
public:
// Resize factor. make width and height multiple factor of the value of
// factor. Default is 32
ResizeByFactor(int factor = 32,
int max_side_len = 2400,
bool use_fixed_point = false,
int interpolation = 0)
: _factor(factor),
_max_side_len(max_side_len),
_use_fixed_point(use_fixed_point),
_interpolation(interpolation) {}
std::shared_ptr<OpContext> operator()(std::shared_ptr<OpContext> input);
private:
int _factor; // target of smaller edge
int _max_side_len;
bool _use_fixed_point;
int _interpolation; // unused
};
#endif // CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_RESIZE_BY_FACTOR_H_
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_RGB_SWAP_H_
#define CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_RGB_SWAP_H_
#include <npp.h>
#include <memory>
#include "core/preprocess/hwvideoframe/include/op_context.h"
// swap channel 0 and channel 2 for every pixel
// both RGB2BGR and BGR2RGB use this operator
class SwapChannel {
public:
SwapChannel() {}
std::shared_ptr<OpContext> operator()(std::shared_ptr<OpContext> input);
private:
static const int
_ORDER[CHANNEL_SIZE]; // describing how channel values are permutated
};
class RGB2BGR : public SwapChannel {};
class BGR2RGB : public SwapChannel {};
#endif // CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_RGB_SWAP_H_
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_SUB_H_
#define CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_SUB_H_
#include <npp.h>
#include <memory>
#include <vector>
#include "core/preprocess/hwvideoframe/include/op_context.h"
// subtract by some float numbers
class Sub {
public:
explicit Sub(float subtractor) {}
explicit Sub(const std::vector<float> &subtractors) {}
std::shared_ptr<OpContext> operator()(std::shared_ptr<OpContext> input);
private:
Npp32f _subtractors[CHANNEL_SIZE];
};
#endif // CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_SUB_H_
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_UTILS_H_
#define CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_UTILS_H_
#include <npp.h>
#include <string>
// verify return value of npp function
// throw an exception if failed
void verify_npp_ret(const std::string& function_name, NppStatus ret);
// verify return value of cuda runtime function
// throw an exception if failed
void verify_cuda_ret(const std::string& function_name, cudaError_t ret);
#endif // CORE_PREPROCESS_HWVIDEOFRAME_INCLUDE_UTILS_H_
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "core/preprocess/hwvideoframe/include/center_crop.h"
#include "core/preprocess/hwvideoframe/include/div.h"
#include "core/preprocess/hwvideoframe/include/image_io.h"
#include "core/preprocess/hwvideoframe/include/normalize.h"
#include "core/preprocess/hwvideoframe/include/resize.h"
#include "core/preprocess/hwvideoframe/include/resize_by_factor.h"
#include "core/preprocess/hwvideoframe/include/rgb_swap.h"
#include "core/preprocess/hwvideoframe/include/sub.h"
PYBIND11_MODULE(libgpupreprocess, m) {
pybind11::class_<OpContext, std::shared_ptr<OpContext>>(m, "OpContext");
pybind11::class_<Image2Gpubuffer>(m, "Image2Gpubuffer")
.def(pybind11::init<>())
.def("__call__", &Image2Gpubuffer::operator());
pybind11::class_<Gpubuffer2Image>(m, "Gpubuffer2Image")
.def(pybind11::init<>())
.def("__call__", &Gpubuffer2Image::operator());
pybind11::class_<RGB2BGR>(m, "RGB2BGR")
.def(pybind11::init<>())
.def("__call__", &RGB2BGR::operator());
pybind11::class_<BGR2RGB>(m, "BGR2RGB")
.def(pybind11::init<>())
.def("__call__", &BGR2RGB::operator());
pybind11::class_<Div>(m, "Div")
.def(pybind11::init<float>())
.def("__call__", &Div::operator());
pybind11::class_<Sub>(m, "Sub")
.def(pybind11::init<float>())
.def(pybind11::init<const std::vector<float>&>())
.def("__call__", &Sub::operator());
pybind11::class_<Normalize>(m, "Normalize")
.def(pybind11::init<const std::vector<float>&,
const std::vector<float>&,
bool>(),
pybind11::arg("mean"),
pybind11::arg("std"),
pybind11::arg("channel_first") = false)
.def("__call__", &Normalize::operator());
pybind11::class_<CenterCrop>(m, "CenterCrop")
.def(pybind11::init<int>())
.def("__call__", &CenterCrop::operator());
pybind11::class_<Resize>(m, "Resize")
.def(pybind11::init<int, int, bool>(),
pybind11::arg("size"),
pybind11::arg("max_size") = 214748364,
pybind11::arg("use_fixed_point") = false)
.def(pybind11::init<const std::vector<int>&, int, bool>(),
pybind11::arg("target_size"),
pybind11::arg("max_size") = 214748364,
pybind11::arg("use_fixed_point") = false)
.def("__call__", &Resize::operator());
pybind11::class_<ResizeByFactor>(m, "ResizeByFactor")
.def(pybind11::init<int, int, bool>(),
pybind11::arg("factor") = 32,
pybind11::arg("max_side_len") = 2400,
pybind11::arg("use_fixed_point") = false)
.def("__call__", &ResizeByFactor::operator());
}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <npp.h>
#include <algorithm>
#include "core/preprocess/hwvideoframe/include/center_crop.h"
#include "core/preprocess/hwvideoframe/include/utils.h"
std::shared_ptr<OpContext> CenterCrop::operator()(
std::shared_ptr<OpContext> input) {
int new_width = std::min(_size, input->nppi_size().width);
int new_height = std::min(_size, input->nppi_size().height);
auto output = std::make_shared<OpContext>(new_height, new_width);
int x_start = (input->nppi_size().width - new_width) / 2;
int y_start = (input->nppi_size().height - new_height) / 2;
Npp32f* p_src = input->p_frame() +
y_start * input->nppi_size().width * CHANNEL_SIZE +
x_start * CHANNEL_SIZE;
NppStatus ret = nppiCopy_32f_C3R(p_src,
input->step(),
output->p_frame(),
output->step(),
output->nppi_size());
verify_npp_ret("nppiCopy_32f_C3R", ret);
return output;
}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <npp.h>
#include <stdexcept>
#include "core/preprocess/hwvideoframe/include/div.h"
#include "core/preprocess/hwvideoframe/include/utils.h"
Div::Div(float value) { _divisor = value; }
std::shared_ptr<OpContext> Div::operator()(std::shared_ptr<OpContext> input) {
NppStatus ret = nppsDivC_32f_I(_divisor, input->p_frame(), input->length());
verify_npp_ret("nppsDivC_32f_I", ret);
return input;
}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cuda_runtime.h>
#include <pybind11/numpy.h>
#include <stdexcept>
#include "core/preprocess/hwvideoframe/include/image_io.h"
#include "core/preprocess/hwvideoframe/include/utils.h"
std::shared_ptr<OpContext> Image2Gpubuffer::operator()(
pybind11::array_t<float> input) {
pybind11::buffer_info buf = input.request();
if (buf.format != pybind11::format_descriptor<float>::format()) {
throw std::runtime_error("Incompatible format: expected a float numpy!");
}
if (buf.ndim != 3) {
throw std::runtime_error("Number of dimensions must be three");
}
if (buf.shape[2] != CHANNEL_SIZE) {
throw std::runtime_error("Number of channels must be three");
}
auto result = std::make_shared<OpContext>(buf.shape[0], buf.shape[1]);
auto ret = cudaMemcpy(result->p_frame(),
static_cast<float*>(buf.ptr),
result->size(),
cudaMemcpyHostToDevice);
verify_cuda_ret("cudaMemcpy", ret);
return result;
}
pybind11::array_t<float> Gpubuffer2Image::operator()(
std::shared_ptr<OpContext> input) {
auto result = pybind11::array_t<float>({input->nppi_size().height,
input->nppi_size().width,
static_cast<int> CHANNEL_SIZE});
pybind11::buffer_info buf = result.request();
auto ret = cudaMemcpy(static_cast<float*>(buf.ptr),
input->p_frame(),
input->size(),
cudaMemcpyDeviceToHost);
verify_cuda_ret("cudaMemcpy", ret);
return result;
}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <npp.h>
#include <stdexcept>
#include "core/preprocess/hwvideoframe/include/normalize.h"
#include "core/preprocess/hwvideoframe/include/utils.h"
Normalize::Normalize(const std::vector<float> &mean,
const std::vector<float> &std,
bool channel_first) {
if (mean.size() != CHANNEL_SIZE) {
throw std::runtime_error("size of mean must be three");
}
if (std.size() != CHANNEL_SIZE) {
throw std::runtime_error("size of std must be three");
}
for (size_t i = 0; i < CHANNEL_SIZE; i++) {
_mean[i] = mean[i];
_std[i] = std[i];
}
_channel_first = channel_first;
}
std::shared_ptr<OpContext> Normalize::operator()(
std::shared_ptr<OpContext> input) {
NppStatus ret = nppiSubC_32f_C3IR(
_mean, input->p_frame(), input->step(), input->nppi_size());
verify_npp_ret("nppiSubC_32f_C3IR", ret);
ret = nppiDivC_32f_C3IR(
_std, input->p_frame(), input->step(), input->nppi_size());
verify_npp_ret("nppiDivC_32f_C3IR", ret);
return input;
}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "core/preprocess/hwvideoframe/include/resize.h"
#include <math.h>
#include <sstream>
#include "core/preprocess/hwvideoframe/include/utils.h"
std::shared_ptr<OpContext> Resize::operator()(
std::shared_ptr<OpContext> input) {
int resized_width = 0, resized_height = 0;
if (_size == -1) {
resized_width = std::min(_target_size[0], _max_size);
resized_height = std::min(_target_size[1], _max_size);
} else {
int im_max_size =
std::max(input->nppi_size().height, input->nppi_size().width);
float percent =
static_cast<float>(_size) /
std::min(input->nppi_size().height, input->nppi_size().width);
if (round(percent * im_max_size) > _max_size) {
percent = static_cast<float>(_max_size) / static_cast<float>(im_max_size);
}
resized_width = tatic_cast<int>(round(input->nppi_size().width * percent));
resized_height =
tatic_cast<int>(round(input->nppi_size().height * percent));
}
auto output = std::make_shared<OpContext>(resized_height, resized_width);
auto ret = resize_linear(input->p_frame(),
output->p_frame(),
input->nppi_size().width,
input->nppi_size().height,
output->nppi_size().width,
output->nppi_size().height,
CHANNEL_SIZE,
_use_fixed_point);
verify_cuda_ret("resize_linear", ret);
return output;
}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <math.h>
#include <sstream>
#include "core/preprocess/hwvideoframe/include/resize.h"
#include "core/preprocess/hwvideoframe/include/resize_by_factor.h"
#include "core/preprocess/hwvideoframe/include/utils.h"
std::shared_ptr<OpContext> ResizeByFactor::operator()(
std::shared_ptr<OpContext> input) {
int resized_width = input->nppi_size().width,
resized_height = input->nppi_size().height;
float radio = 0;
if (std::max(resized_width, resized_height) > _max_side_len) {
if (resized_width > resized_height) {
radio = static_cast<float>(_max_side_len / resized_width);
} else {
radio = static_cast<float>(_max_side_len / resized_height);
}
} else {
radio = 1;
}
resized_width = static_cast<int>(resized_width * radio);
resized_height = static_cast<int>(resized_height * radio);
if (resized_height % _factor == 0) {
resized_height = resized_height;
} else if (floor(resized_height / _factor) <= 1) {
resized_height = _factor;
} else {
resized_height = (floor(resized_height / 32) - 1) * 32;
}
if (resized_width % _factor == 0) {
resized_width = resized_width;
} else if (floor(resized_width / _factor) <= 1) {
resized_width = _factor;
} else {
resized_width = (floor(resized_width / 32) - 1) * _factor;
}
if (tatic_cast<int>(resized_width) <= 0 ||
tatic_cast<int>(resized_height) <= 0) {
return NULL;
}
auto output = std::make_shared<OpContext>(resized_height, resized_width);
auto ret = resize_linear(input->p_frame(),
output->p_frame(),
input->nppi_size().width,
input->nppi_size().height,
output->nppi_size().width,
output->nppi_size().height,
CHANNEL_SIZE,
_use_fixed_point);
verify_cuda_ret("resize_linear", ret);
return output;
}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <npp.h>
#include "core/preprocess/hwvideoframe/include/utils.h"
#include "core/preprocess/hwvideoframe/src/rgb_swap.h"
const int SwapChannel::_ORDER[CHANNEL_SIZE] = {2, 1, 0};
std::shared_ptr<OpContext> SwapChannel::operator()(
std::shared_ptr<OpContext> input) {
NppStatus ret = nppiSwapChannels_32f_C3IR(
input->p_frame(), input->step(), input->nppi_size(), _ORDER);
verify_npp_ret("nppiSwapChannels_32f_C3IR", ret);
return input;
}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <npp.h>
#include <stdexcept>
#include "core/preprocess/hwvideoframe/include/sub.h"
#include "core/preprocess/hwvideoframe/include/utils.h"
Sub::Sub(float subtractor) {
for (size_t i = 0; i < CHANNEL_SIZE; i++) {
_subtractors[i] = subtractor;
}
}
Sub::Sub(const std::vector<float> &subtractors) {
if (subtractors.size() != CHANNEL_SIZE) {
throw std::runtime_error("size of subtractors must be three");
}
for (size_t i = 0; i < CHANNEL_SIZE; i++) {
_subtractors[i] = subtractors[i];
}
}
std::shared_ptr<OpContext> Sub::operator()(std::shared_ptr<OpContext> input) {
NppStatus ret = nppiSubC_32f_C3IR(
_subtractors, input->p_frame(), input->step(), input->nppi_size());
verify_npp_ret("nppiSubC_32f_C3IR", ret);
return input;
}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <npp.h>
#include <sstream>
#include <stdexcept>
#include "core/preprocess/hwvideoframe/src/utils.h"
void verify_npp_ret(const std::string& function_name, NppStatus ret) {
if (ret != NPP_SUCCESS) {
std::ostringstream ss;
ss << function_name << ", ret: " << ret;
throw std::runtime_error(ss.str());
}
}
void verify_cuda_ret(const std::string& function_name, cudaError_t ret) {
if (ret != cudaSuccess) {
std::ostringstream ss;
ss << function_name << ", ret: " << ret;
throw std::runtime_error(ss.str());
}
}
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import unittest
import sys
import numpy as np
from paddle_serving_app.reader import Sequential, Resize, File2Image
import libgpupreprocess as pp
class TestOperators(unittest.TestCase):
"""
test all operators, e.g. Div, Normalize
"""
def test_div(self):
height = 4
width = 5
channels = 3
value = 255.0
img = np.arange(height * width * channels).reshape(
[height, width, channels])
seq = Sequential(
[pp.Image2Gpubuffer(), pp.Div(value), pp.Gpubuffer2Image()])
result = seq(img).reshape(-1)
for i in range(0, result.size):
self.assertAlmostEqual(i / value, result[i], 5)
def test_sub(self):
height = 4
width = 5
channels = 3
img = np.arange(height * width * channels).reshape(
[height, width, channels])
# input size is an int
value = 10.0
seq = Sequential(
[pp.Image2Gpubuffer(), pp.Sub(value), pp.Gpubuffer2Image()])
result = seq(img).reshape(-1)
for i in range(0, result.size):
self.assertEqual(i - value, result[i])
# input size is a sequence
values = (9, 4, 2)
seq = Sequential(
[pp.Image2Gpubuffer(), pp.Sub(values), pp.Gpubuffer2Image()])
result = seq(img)
for i in range(0, result.shape[0]):
for j in range(0, result.shape[1]):
for k in range(0, result.shape[2]):
self.assertEqual(result[i][j][k], img[i][j][k] - values[k])
def test_normalize(self):
height = 4
width = 5
channels = 3
img = np.random.rand(height, width, channels)
mean = [5.0, 5.0, 5.0]
std = [2.0, 2.0, 2.0]
seq = Sequential([
pp.Image2Gpubuffer(), pp.Normalize(mean, std), pp.Gpubuffer2Image()
])
result = seq(img)
for i in range(0, height):
for j in range(0, width):
for k in range(0, channels):
self.assertAlmostEqual((img[i][j][k] - mean[k]) / std[k],
result[i][j][k], 5)
def test_center_crop(self):
height = 9
width = 7
channels = 3
img = np.arange(height * width * channels).reshape(
[height, width, channels])
new_size = 5
seq = Sequential([
pp.Image2Gpubuffer(), pp.CenterCrop(new_size), pp.Gpubuffer2Image()
])
result = seq(img)
self.assertEqual(result.shape[0], new_size)
self.assertEqual(result.shape[1], new_size)
self.assertEqual(result.shape[2], channels)
def test_resize(self):
height = 9
width = 5
channels = 3
img = np.arange(height * width).reshape([height, width, 1]) * np.ones(
(1, channels))
# input size is an int
for new_size in [3, 10]:
seq_gpu = Sequential([
pp.Image2Gpubuffer(), pp.Resize(new_size), pp.Gpubuffer2Image()
])
seq_paddle = Sequential([Resize(new_size)])
result_gpu = seq_gpu(img)
result_paddle = seq_paddle(img)
self.assertEqual(result_gpu.shape, result_paddle.shape)
for i in range(0, result_gpu.shape[0]):
for j in range(0, result_gpu.shape[1]):
for k in range(0, result_gpu.shape[2]):
self.assertAlmostEqual(result_gpu[i][j][k],
result_paddle[i][j][k], 5)
# input size is a sequence
for new_height, new_width in [(7, 3), (15, 10)]:
seq_gpu = Sequential([
pp.Image2Gpubuffer(), pp.Resize((new_width, new_height)),
pp.Gpubuffer2Image()
])
seq_paddle = Sequential([Resize((new_width, new_height))])
result_gpu = seq_gpu(img)
result_paddle = seq_paddle(img)
self.assertEqual(result_gpu.shape, result_paddle.shape)
for i in range(0, result_gpu.shape[0]):
for j in range(0, result_gpu.shape[1]):
for k in range(0, result_gpu.shape[2]):
self.assertAlmostEqual(result_gpu[i][j][k],
result_paddle[i][j][k], 5)
def test_resize_fixed_point(self):
new_height = 256
new_width = 256 * 4 / 3
seq = Sequential([
File2Image(), pp.Image2Gpubuffer(), pp.Resize(
(new_width, new_height), use_fixed_point=True),
pp.Gpubuffer2Image()
])
img = seq("./capture_16.bmp")
img = np.resize(img, (new_height, new_width * 3))
img_vis = np.loadtxt("./cap_resize_16.raw")
img_resize_diff = img_vis - img
self.assertEqual(np.all(img_resize_diff == 0), True)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册