提交 0ade1bc5 编写于 作者: H HappyAngel 提交者: Yuan Shuai

[LITE][ARM]add cv image process (#2402)

* add cv image process

* fix arm liunx build error

* add LITE_WITH_CV defien to make cv, test=develop

* fix cv format, annd add describe in utils/cv

* delete some Meaningless comments, test=develop

* set LITE_WITH_CV=OFF in build.sh, test=develop

* delete cv_enum.h in utils/cv, push the contents in cv_ennum.h to paddle_image_preprocess.h, test=develop

* according to reviews to redefine paddle_image_preprocess.h, test=develop

* add detailed note of flipParam, test=develop

* fix format in paddle_image_preprocess.h, test=develop

* fix error when build x86. test=develop

* lite_with_X86 does not contain lite_with_cv
上级 26470600
......@@ -72,6 +72,9 @@ lite_option(LITE_ON_MODEL_OPTIMIZE_TOOL "Build the model optimize tool" OFF)
# publish options
lite_option(LITE_BUILD_EXTRA "Enable extra algorithm support in Lite, both kernels and operators" OFF)
lite_option(LITE_BUILD_TAILOR "Enable tailoring library according to model" OFF)
# cv build options
lite_option(LITE_WITH_CV "Enable build cv image in lite" OFF IF NOT LITE_WITH_ARM)
# TODO(Superjomn) Remove WITH_ANAKIN option if not needed latter.
if(ANDROID OR IOS OR ARMLINUX)
......@@ -181,7 +184,7 @@ include(external/xxhash) # download install xxhash needed for x86 jit
include(cudnn)
include(configure) # add paddle env configuration
if(LITE_WITH_CUDA)
if(LITE_WITH_CUDA)
include(cuda)
endif()
......
......@@ -117,8 +117,12 @@ endif()
if (LITE_WITH_ARM)
add_definitions("-DLITE_WITH_ARM")
if (LITE_WITH_CV)
add_definitions("-DLITE_WITH_CV")
endif()
endif()
if (WITH_ARM_DOTPROD)
add_definitions("-DWITH_ARM_DOTPROD")
endif()
......
......@@ -43,6 +43,11 @@ function (lite_deps TARGET)
foreach(var ${lite_deps_ARM_DEPS})
set(deps ${deps} ${var})
endforeach(var)
if(LITE_WITH_CV)
foreach(var ${lite_cv_deps})
set(deps ${deps} ${var})
endforeach(var)
endif()
endif()
if(LITE_WITH_PROFILE)
......@@ -341,7 +346,7 @@ function(add_kernel TARGET device level)
file(APPEND ${kernels_src_list} "${CMAKE_CURRENT_SOURCE_DIR}/${src}\n")
endforeach()
nv_library(${TARGET} SRCS ${args_SRCS} DEPS ${args_DEPS})
return()
return()
endif()
# the source list will collect for paddle_use_kernel.h code generation.
......
......@@ -9,6 +9,7 @@ message(STATUS "LITE_WITH_NPU:\t${LITE_WITH_NPU}")
message(STATUS "LITE_WITH_XPU:\t${LITE_WITH_XPU}")
message(STATUS "LITE_WITH_FPGA:\t${LITE_WITH_FPGA}")
message(STATUS "LITE_WITH_PROFILE:\t${LITE_WITH_PROFILE}")
message(STATUS "LITE_WITH_CV:\t${LITE_WITH_CV}")
set(LITE_MODEL_DIR "${THIRD_PARTY_PATH}/install")
set(LITE_ON_MOBILE ${LITE_WITH_LIGHT_WEIGHT_FRAMEWORK})
......@@ -129,6 +130,7 @@ if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM)
#COMMAND cp "${CMAKE_BINARY_DIR}/lite/api/model_optimize_tool" "${INFER_LITE_PUBLISH_ROOT}/bin"
COMMAND cp "${CMAKE_BINARY_DIR}/lite/gen_code/paddle_code_generator" "${INFER_LITE_PUBLISH_ROOT}/bin"
COMMAND cp "${CMAKE_BINARY_DIR}/lite/api/test_model_bin" "${INFER_LITE_PUBLISH_ROOT}/bin"
COMMAND cp "${CMAKE_SOURCE_DIR}/lite/utils/cv/paddle_*.h" "${INFER_LITE_PUBLISH_ROOT}/cxx/include"
)
if(NOT IOS)
#add_dependencies(publish_inference_cxx_lib model_optimize_tool)
......@@ -136,10 +138,10 @@ if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM)
add_dependencies(publish_inference_cxx_lib bundle_full_api)
add_dependencies(publish_inference_cxx_lib bundle_light_api)
add_dependencies(publish_inference_cxx_lib test_model_bin)
if (ARM_TARGET_OS STREQUAL "android" OR ARM_TARGET_OS STREQUAL "armlinux")
if (ARM_TARGET_OS STREQUAL "android" OR ARM_TARGET_OS STREQUAL "armlinux")
add_dependencies(publish_inference_cxx_lib paddle_full_api_shared)
add_dependencies(publish_inference paddle_light_api_shared)
add_custom_command(TARGET publish_inference_cxx_lib
add_custom_command(TARGET publish_inference_cxx_lib
COMMAND cp ${CMAKE_BINARY_DIR}/lite/api/*.so ${INFER_LITE_PUBLISH_ROOT}/cxx/lib)
endif()
add_dependencies(publish_inference publish_inference_cxx_lib)
......@@ -155,6 +157,7 @@ if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM)
COMMAND mkdir -p "${INFER_LITE_PUBLISH_ROOT}/include"
COMMAND cp "${CMAKE_SOURCE_DIR}/lite/api/paddle_*.h" "${INFER_LITE_PUBLISH_ROOT}/include"
COMMAND cp "${CMAKE_BINARY_DIR}/libpaddle_api_light_bundled.a" "${INFER_LITE_PUBLISH_ROOT}/lib"
COMMAND cp "${CMAKE_SOURCE_DIR}/lite/utils/cv/paddle_*.h" "${INFER_LITE_PUBLISH_ROOT}/cxx/include"
)
add_dependencies(tiny_publish_lib bundle_light_api)
add_dependencies(publish_inference tiny_publish_lib)
......@@ -166,6 +169,7 @@ if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM)
COMMAND mkdir -p "${INFER_LITE_PUBLISH_ROOT}/cxx/lib"
COMMAND cp "${CMAKE_SOURCE_DIR}/lite/api/paddle_*.h" "${INFER_LITE_PUBLISH_ROOT}/cxx/include"
COMMAND cp "${CMAKE_BINARY_DIR}/lite/api/libpaddle_light_api_shared.so" "${INFER_LITE_PUBLISH_ROOT}/cxx/lib"
COMMAND cp "${CMAKE_SOURCE_DIR}/lite/utils/cv/paddle_*.h" "${INFER_LITE_PUBLISH_ROOT}/cxx/include"
)
add_dependencies(tiny_publish_cxx_lib paddle_light_api_shared)
add_dependencies(publish_inference tiny_publish_cxx_lib)
......
add_subdirectory(kernels)
add_subdirectory(math)
add_subdirectory(cv)
if(LITE_WITH_CV AND (NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND LITE_WITH_ARM)
lite_cc_test(image_convert_test SRCS image_convert_test.cc DEPS paddle_cv_arm paddle_api_light ${lite_cv_deps} ${arm_kernels} ${lite_ops} ${host_kernels})
endif()
此差异已折叠。
此差异已折叠。
......@@ -19,6 +19,7 @@ BUILD_PYTHON=OFF
BUILD_DIR=$(pwd)
OPTMODEL_DIR=""
BUILD_TAILOR=OFF
BUILD_CV=OFF
readonly THIRDPARTY_TAR=https://paddle-inference-dist.bj.bcebos.com/PaddleLite/third-party-05b862.tar.gz
......@@ -96,6 +97,7 @@ function make_tiny_publish_so {
-DLITE_ON_TINY_PUBLISH=ON \
-DANDROID_STL_TYPE=$android_stl \
-DLITE_BUILD_EXTRA=$BUILD_EXTRA \
-DLITE_WITH_CV=$BUILD_CV \
-DLITE_BUILD_TAILOR=$BUILD_TAILOR \
-DLITE_OPTMODEL_DIR=$OPTMODEL_DIR \
-DARM_TARGET_OS=${os} -DARM_TARGET_ARCH_ABI=${abi} -DARM_TARGET_LANG=${lang}
......@@ -122,7 +124,7 @@ function make_full_publish_so {
fi
mkdir -p $build_directory
cd $build_directory
if [ ${os} == "armlinux" ]; then
BUILD_JAVA=OFF
fi
......@@ -137,6 +139,7 @@ function make_full_publish_so {
-DLITE_SHUTDOWN_LOG=ON \
-DANDROID_STL_TYPE=$android_stl \
-DLITE_BUILD_EXTRA=$BUILD_EXTRA \
-DLITE_WITH_CV=$BUILD_CV \
-DLITE_BUILD_TAILOR=$BUILD_TAILOR \
-DLITE_OPTMODEL_DIR=$OPTMODEL_DIR \
-DARM_TARGET_OS=${os} -DARM_TARGET_ARCH_ABI=${abi} -DARM_TARGET_LANG=${lang}
......@@ -166,6 +169,7 @@ function make_all_tests {
${CMAKE_COMMON_OPTIONS} \
-DWITH_TESTING=ON \
-DLITE_BUILD_EXTRA=$BUILD_EXTRA \
-DLITE_WITH_CV=$BUILD_CV \
-DARM_TARGET_OS=${os} -DARM_TARGET_ARCH_ABI=${abi} -DARM_TARGET_LANG=${lang}
make lite_compile_deps -j$NUM_PROC
......@@ -201,6 +205,7 @@ function make_ios {
-DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK=ON \
-DARM_TARGET_ARCH_ABI=$abi \
-DLITE_BUILD_EXTRA=$BUILD_EXTRA \
-DLITE_WITH_CV=$BUILD_CV \
-DARM_TARGET_OS=$os
make -j4 publish_inference
......@@ -362,11 +367,11 @@ function main {
shift
;;
tiny_publish)
make_tiny_publish_so $ARM_OS $ARM_ABI $ARM_LANG $ANDROID_STL
make_tiny_publish_so $ARM_OS $ARM_ABI $ARM_LANG $ANDROID_STL
shift
;;
full_publish)
make_full_publish_so $ARM_OS $ARM_ABI $ARM_LANG $ANDROID_STL
make_full_publish_so $ARM_OS $ARM_ABI $ARM_LANG $ANDROID_STL
shift
;;
test)
......@@ -382,7 +387,7 @@ function main {
shift
;;
cuda)
make_cuda
make_cuda
shift
;;
x86)
......
......@@ -24,3 +24,5 @@ if(LITE_ON_TINY_PUBLISH OR LITE_ON_MODEL_OPTIMIZE_TOOL)
else()
lite_cc_library(utils SRCS string.cc DEPS ${utils_DEPS} any)
endif()
add_subdirectory(cv)
if(LITE_WITH_CV AND (NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND LITE_WITH_ARM)
set(lite_cv_deps)
lite_cc_library(paddle_cv_arm SRCS
image_convert.cc
paddle_image_preprocess.cc
image2tensor.cc
image_flip.cc
image_rotate.cc
image_resize.cc
DEPS ${lite_cv_deps} paddle_api_light)
endif()
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/utils/cv/image2tensor.h"
#include <arm_neon.h>
namespace paddle {
namespace lite {
namespace utils {
namespace cv {
void bgr_to_tensor_chw(const uint8_t* src,
float* output,
int width,
int height,
float* means,
float* scales);
void bgra_to_tensor_chw(const uint8_t* src,
float* output,
int width,
int height,
float* means,
float* scales);
void bgr_to_tensor_hwc(const uint8_t* src,
float* output,
int width,
int height,
float* means,
float* scales);
void bgra_to_tensor_hwc(const uint8_t* src,
float* output,
int width,
int height,
float* means,
float* scales);
/*
* change image data to tensor data
* support image format is BGR(RGB) and BGRA(RGBA), Data layout is NHWC and
* NCHW
* param src: input image data
* param dstTensor: output tensor data
* param srcFormat: input image format, support BGR(GRB) and BGRA(RGBA)
* param srcw: input image width
* param srch: input image height
* param layout: output tensor layout,support NHWC and NCHW
* param means: means of image
* param scales: scales of image
*/
void Image2Tensor::choose(const uint8_t* src,
Tensor* dst,
ImageFormat srcFormat,
LayoutType layout,
int srcw,
int srch,
float* means,
float* scales) {
float* output = dst->mutable_data<float>();
if (layout == LayoutType::kNCHW && (srcFormat == BGR || srcFormat == RGB)) {
impl_ = bgr_to_tensor_chw;
} else if (layout == LayoutType::kNHWC &&
(srcFormat == BGR || srcFormat == RGB)) {
impl_ = bgr_to_tensor_hwc;
} else if (layout == LayoutType::kNCHW &&
(srcFormat == BGRA || srcFormat == RGBA)) {
impl_ = bgra_to_tensor_chw;
} else if (layout == LayoutType::kNHWC &&
(srcFormat == BGRA || srcFormat == RGBA)) {
impl_ = bgra_to_tensor_hwc;
} else {
printf("this layout: %d or image format: %d not support \n",
static_cast<int>(layout),
srcFormat);
return;
}
impl_(src, output, srcw, srch, means, scales);
}
void bgr_to_tensor_chw(const uint8_t* src,
float* output,
int width,
int height,
float* means,
float* scales) {
int size = width * height;
float b_means = means[0];
float g_means = means[1];
float r_means = means[2];
float b_scales = scales[0];
float g_scales = scales[1];
float r_scales = scales[2];
float* ptr_b = output;
float* ptr_g = ptr_b + size;
float* ptr_r = ptr_g + size;
int dim8 = width >> 3;
int remain = width % 8;
float32x4_t vbmean = vdupq_n_f32(b_means);
float32x4_t vgmean = vdupq_n_f32(g_means);
float32x4_t vrmean = vdupq_n_f32(r_means);
float32x4_t vbscale = vdupq_n_f32(b_scales);
float32x4_t vgscale = vdupq_n_f32(g_scales);
float32x4_t vrscale = vdupq_n_f32(r_scales);
#pragma omp parallel for
for (int i = 0; i < height; i += 1) {
const uint8_t* din_ptr = src + i * 3 * width;
float* ptr_b_h = ptr_b + i * width;
float* ptr_g_h = ptr_g + i * width;
float* ptr_r_h = ptr_r + i * width;
int cnt = dim8;
if (cnt > 0) {
#ifdef __aarch64__
asm volatile(
"prfm pldl1keep, [%[inptr0]] \n"
"prfm pldl1keep, [%[inptr0], #64] \n"
"prfm pldl1keep, [%[inptr0], #128] \n"
"prfm pldl1keep, [%[inptr0], #192] \n"
"1: \n"
"ld3 {v0.8b, v1.8b, v2.8b}, [%[inptr0]], #24 \n" // d8 = y0y3y6y9..
// d9 = y1y4y7..."
// 8->16
"ushll v3.8h, v0.8b, #0 \n"
"ushll v4.8h, v1.8b, #0 \n"
"ushll v5.8h, v2.8b, #0 \n"
// 16->32
"ushll v6.4s, v3.4h, #0 \n"
"ushll2 v7.4s, v3.8h, #0 \n"
"ushll v8.4s, v4.4h, #0 \n"
"ushll2 v9.4s, v4.8h, #0 \n"
"ushll v10.4s, v5.4h, #0 \n"
"ushll2 v11.4s, v5.8h, #0 \n"
// int32->fp32
"ucvtf v12.4s, v6.4s \n"
"ucvtf v13.4s, v7.4s \n"
"ucvtf v14.4s, v8.4s \n"
"ucvtf v15.4s, v9.4s \n"
"ucvtf v16.4s, v10.4s \n"
"ucvtf v17.4s, v11.4s \n"
// sub -mean
"fsub v12.4s, v12.4s, %w[vbmean].4s \n"
"fsub v13.4s, v13.4s, %w[vbmean].4s \n"
"fsub v14.4s, v14.4s, %w[vgmean].4s \n"
"fsub v15.4s, v15.4s, %w[vgmean].4s \n"
"fsub v16.4s, v16.4s, %w[vrmean].4s \n"
"fsub v17.4s, v17.4s, %w[vrmean].4s \n"
// mul * scale
"fmul v6.4s, v12.4s, %w[vbscale].4s \n"
"fmul v7.4s, v13.4s, %w[vbscale].4s \n"
"fmul v8.4s, v14.4s, %w[vgscale].4s \n"
"fmul v9.4s, v15.4s, %w[vgscale].4s \n"
"fmul v10.4s, v16.4s, %w[vrscale].4s \n"
"fmul v11.4s, v17.4s, %w[vrscale].4s \n"
// store
"st1 {v6.4s}, [%[outr0]], #16 \n"
"st1 {v8.4s}, [%[outr1]], #16 \n"
"st1 {v10.4s}, [%[outr2]], #16 \n"
"subs %w[cnt], %w[cnt], #1 \n"
"st1 {v7.4s}, [%[outr0]], #16 \n"
"st1 {v9.4s}, [%[outr1]], #16 \n"
"st1 {v11.4s}, [%[outr2]], #16 \n"
"bne 1b \n"
: [inptr0] "+r"(din_ptr),
[outr0] "+r"(ptr_b_h),
[outr1] "+r"(ptr_g_h),
[outr2] "+r"(ptr_r_h),
[cnt] "+r"(cnt)
: [vbmean] "w"(vbmean),
[vgmean] "w"(vgmean),
[vrmean] "w"(vrmean),
[vbscale] "w"(vbscale),
[vgscale] "w"(vgscale),
[vrscale] "w"(vrscale)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20");
#else
asm volatile(
"pld [%[inptr0]] @ preload a, 64byte\n"
"pld [%[inptr0], #64] @ preload a, 64byte\n"
"pld [%[inptr0], #128] @ preload a, 64byte\n"
"pld [%[inptr0], #192] @ preload a, 64byte\n"
"1: \n"
"vld3.8 {d12, d13, d14}, [%[inptr0]]! \n"
// 8->16
"vmovl.u8 q8, d12 \n"
"vmovl.u8 q9, d13 \n"
"vmovl.u8 q10, d14 \n"
// 16->32
"vmovl.u16 q11, d16 \n"
"vmovl.u16 q12, d17 \n"
"vmovl.u16 q13, d18 \n"
"vmovl.u16 q14, d19 \n"
"vmovl.u16 q15, d20 \n"
"vmovl.u16 q6, d21 \n"
// int32->fp32
"vcvt.f32.u32 q7, q11 \n"
"vcvt.f32.u32 q8, q12 \n"
"vcvt.f32.u32 q9, q13 \n"
"vcvt.f32.u32 q10, q14 \n"
"vcvt.f32.u32 q11, q15 \n"
"vcvt.f32.u32 q12, q6 \n"
// sub -mean
"vsub.f32 q7, q7, %q[vbmean] \n"
"vsub.f32 q8, q8, %q[vbmean] \n"
"vsub.f32 q9, q9, %q[vgmean] \n"
"vsub.f32 q10, q10, %q[vgmean] \n"
"vsub.f32 q11, q11, %q[vrmean] \n"
"vsub.f32 q12, q12, %q[vrmean] \n"
// mul *scale
"vmul.f32 q13, q7, %q[vbscale] \n"
"vmul.f32 q14, q8, %q[vbscale] \n"
"vmul.f32 q15, q9, %q[vgscale] \n"
"vmul.f32 q6, q10, %q[vgscale] \n"
"vmul.f32 q7, q11, %q[vrscale] \n"
"vmul.f32 q8, q12, %q[vrscale] \n"
// store
"vst1.32 {d26 - d27}, [%[outr0]]! \n"
"vst1.32 {d30 - d31}, [%[outr1]]! \n"
"vst1.32 {d14 - d15}, [%[outr2]]! \n"
"subs %[cnt], #1 \n"
"vst1.32 {d28 - d29}, [%[outr0]]! \n"
"vst1.32 {d12 - d13}, [%[outr1]]! \n"
"vst1.32 {d16 - d17}, [%[outr2]]! \n"
"bne 1b"
: [inptr0] "+r"(din_ptr),
[outr0] "+r"(ptr_b_h),
[outr1] "+r"(ptr_g_h),
[outr2] "+r"(ptr_r_h),
[cnt] "+r"(cnt)
: [vbmean] "w"(vbmean),
[vgmean] "w"(vgmean),
[vrmean] "w"(vrmean),
[vbscale] "w"(vbscale),
[vgscale] "w"(vgscale),
[vrscale] "w"(vrscale)
: "cc",
"memory",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif
}
for (int j = 0; j < remain; j++) {
*ptr_b_h++ = (*din_ptr - b_means) * b_scales;
din_ptr++;
*ptr_g_h++ = (*din_ptr - g_means) * g_scales;
din_ptr++;
*ptr_r_h++ = (*din_ptr - r_means) * r_scales;
din_ptr++;
}
}
}
void bgra_to_tensor_chw(const uint8_t* src,
float* output,
int width,
int height,
float* means,
float* scales) {
int size = width * height;
float b_means = means[0];
float g_means = means[1];
float r_means = means[2];
float b_scales = scales[0];
float g_scales = scales[1];
float r_scales = scales[2];
float* ptr_b = output;
float* ptr_g = ptr_b + size;
float* ptr_r = ptr_g + size;
int dim8 = width >> 3;
int remain = width % 8;
float32x4_t vbmean = vdupq_n_f32(b_means);
float32x4_t vgmean = vdupq_n_f32(g_means);
float32x4_t vrmean = vdupq_n_f32(r_means);
float32x4_t vbscale = vdupq_n_f32(b_scales);
float32x4_t vgscale = vdupq_n_f32(g_scales);
float32x4_t vrscale = vdupq_n_f32(r_scales);
#pragma omp parallel for
for (int i = 0; i < height; i += 1) {
const uint8_t* din_ptr = src + i * 4 * width;
float* ptr_b_h = ptr_b + i * width;
float* ptr_g_h = ptr_g + i * width;
float* ptr_r_h = ptr_r + i * width;
for (int j = 0; j < dim8; j++) {
uint8x8x4_t v_bgr = vld4_u8(din_ptr);
uint16x8_t vb_16 = vmovl_u8(v_bgr.val[0]);
uint16x8_t vg_16 = vmovl_u8(v_bgr.val[1]);
uint16x8_t vr_16 = vmovl_u8(v_bgr.val[2]);
uint32x4_t vb_low_32 = vmovl_u16(vget_low_u16(vb_16));
uint32x4_t vg_low_32 = vmovl_u16(vget_low_u16(vg_16));
uint32x4_t vr_low_32 = vmovl_u16(vget_low_u16(vr_16));
uint32x4_t vb_high_32 = vmovl_u16(vget_high_u16(vb_16));
uint32x4_t vg_high_32 = vmovl_u16(vget_high_u16(vg_16));
uint32x4_t vr_high_32 = vmovl_u16(vget_high_u16(vr_16));
float32x4_t vb_low_f32 = vcvtq_f32_u32(vb_low_32);
float32x4_t vr_low_f32 = vcvtq_f32_u32(vr_low_32);
float32x4_t vg_low_f32 = vcvtq_f32_u32(vg_low_32);
float32x4_t vb_high_f32 = vcvtq_f32_u32(vb_high_32);
float32x4_t vg_high_f32 = vcvtq_f32_u32(vg_high_32);
float32x4_t vr_high_f32 = vcvtq_f32_u32(vr_high_32);
vb_low_f32 = vsubq_f32(vb_low_f32, vbmean);
vg_low_f32 = vsubq_f32(vg_low_f32, vgmean);
vr_low_f32 = vsubq_f32(vr_low_f32, vrmean);
vb_high_f32 = vsubq_f32(vb_high_f32, vbmean);
vg_high_f32 = vsubq_f32(vg_high_f32, vgmean);
vr_high_f32 = vsubq_f32(vr_high_f32, vrmean);
vb_low_f32 = vmulq_f32(vb_low_f32, vbscale);
vg_low_f32 = vmulq_f32(vg_low_f32, vgscale);
vr_low_f32 = vmulq_f32(vr_low_f32, vrscale);
vb_high_f32 = vmulq_f32(vb_high_f32, vbscale);
vg_high_f32 = vmulq_f32(vg_high_f32, vgscale);
vr_high_f32 = vmulq_f32(vr_high_f32, vrscale);
vst1q_f32(ptr_b_h, vb_low_f32);
vst1q_f32(ptr_g_h, vg_low_f32);
vst1q_f32(ptr_r_h, vr_low_f32);
din_ptr += 32;
vst1q_f32(ptr_b_h + 4, vb_high_f32);
vst1q_f32(ptr_g_h + 4, vg_high_f32);
vst1q_f32(ptr_r_h + 4, vr_high_f32);
ptr_b_h += 8;
ptr_g_h += 8;
ptr_r_h += 8;
}
for (int j = 0; j < remain; j++) {
*ptr_b_h++ = (*din_ptr - b_means) * b_scales;
din_ptr++;
*ptr_g_h++ = (*din_ptr - g_means) * g_scales;
din_ptr++;
*ptr_r_h++ = (*din_ptr - r_means) * r_scales;
din_ptr++;
din_ptr++; // a
}
}
}
void bgr_to_tensor_hwc(const uint8_t* src,
float* output,
int width,
int height,
float* means,
float* scales) {
int size = width * height;
float b_means = means[0];
float g_means = means[1];
float r_means = means[2];
float b_scales = scales[0];
float g_scales = scales[1];
float r_scales = scales[2];
float* dout = output;
int dim8 = width >> 3;
int remain = width % 8;
float32x4_t vbmean = vdupq_n_f32(b_means);
float32x4_t vgmean = vdupq_n_f32(g_means);
float32x4_t vrmean = vdupq_n_f32(r_means);
float32x4_t vbscale = vdupq_n_f32(b_scales);
float32x4_t vgscale = vdupq_n_f32(g_scales);
float32x4_t vrscale = vdupq_n_f32(r_scales);
#pragma omp parallel for
for (int i = 0; i < height; i += 1) {
const uint8_t* din_ptr = src + i * 3 * width;
float* dout_ptr = dout + i * 3 * width;
for (int j = 0; j < dim8; j++) {
uint8x8x3_t v_bgr = vld3_u8(din_ptr);
uint16x8_t vb_16 = vmovl_u8(v_bgr.val[0]);
uint16x8_t vg_16 = vmovl_u8(v_bgr.val[1]);
uint16x8_t vr_16 = vmovl_u8(v_bgr.val[2]);
uint32x4_t vb_low_32 = vmovl_u16(vget_low_u16(vb_16));
uint32x4_t vg_low_32 = vmovl_u16(vget_low_u16(vg_16));
uint32x4_t vr_low_32 = vmovl_u16(vget_low_u16(vr_16));
uint32x4_t vb_high_32 = vmovl_u16(vget_high_u16(vb_16));
uint32x4_t vg_high_32 = vmovl_u16(vget_high_u16(vg_16));
uint32x4_t vr_high_32 = vmovl_u16(vget_high_u16(vr_16));
float32x4_t vb_low_f32 = vcvtq_f32_u32(vb_low_32);
float32x4_t vr_low_f32 = vcvtq_f32_u32(vr_low_32);
float32x4_t vg_low_f32 = vcvtq_f32_u32(vg_low_32);
float32x4_t vb_high_f32 = vcvtq_f32_u32(vb_high_32);
float32x4_t vg_high_f32 = vcvtq_f32_u32(vg_high_32);
float32x4_t vr_high_f32 = vcvtq_f32_u32(vr_high_32);
vb_low_f32 = vsubq_f32(vb_low_f32, vbmean);
vg_low_f32 = vsubq_f32(vg_low_f32, vgmean);
vr_low_f32 = vsubq_f32(vr_low_f32, vrmean);
vb_high_f32 = vsubq_f32(vb_high_f32, vbmean);
vg_high_f32 = vsubq_f32(vg_high_f32, vgmean);
vr_high_f32 = vsubq_f32(vr_high_f32, vrmean);
vb_low_f32 = vmulq_f32(vb_low_f32, vbscale);
vg_low_f32 = vmulq_f32(vg_low_f32, vgscale);
vr_low_f32 = vmulq_f32(vr_low_f32, vrscale);
vb_high_f32 = vmulq_f32(vb_high_f32, vbscale);
vg_high_f32 = vmulq_f32(vg_high_f32, vgscale);
vr_high_f32 = vmulq_f32(vr_high_f32, vrscale);
float32x4x3_t val;
val.val[0] = vb_low_f32;
val.val[1] = vg_low_f32;
val.val[2] = vr_low_f32;
vst3q_f32(dout_ptr, val);
din_ptr += 24;
dout_ptr += 12;
val.val[0] = vb_high_f32;
val.val[1] = vg_high_f32;
val.val[2] = vr_high_f32;
vst3q_f32(dout_ptr, val);
dout_ptr += 12;
}
for (int j = 0; j < remain; j++) {
*dout_ptr++ = (*din_ptr - b_means) * b_scales;
din_ptr++;
*dout_ptr++ = (*din_ptr - g_means) * g_scales;
din_ptr++;
*dout_ptr++ = (*din_ptr - r_means) * r_scales;
din_ptr++;
}
}
}
void bgra_to_tensor_hwc(const uint8_t* src,
float* output,
int width,
int height,
float* means,
float* scales) {
int size = width * height;
float b_means = means[0];
float g_means = means[1];
float r_means = means[2];
float b_scales = scales[0];
float g_scales = scales[1];
float r_scales = scales[2];
float* dout = output;
int dim8 = width >> 3;
int remain = width % 8;
float32x4_t vbmean = vdupq_n_f32(b_means);
float32x4_t vgmean = vdupq_n_f32(g_means);
float32x4_t vrmean = vdupq_n_f32(r_means);
float32x4_t vbscale = vdupq_n_f32(b_scales);
float32x4_t vgscale = vdupq_n_f32(g_scales);
float32x4_t vrscale = vdupq_n_f32(r_scales);
#pragma omp parallel for
for (int i = 0; i < height; i += 1) {
const uint8_t* din_ptr = src + i * 4 * width;
float* dout_ptr = dout + i * 3 * width;
for (int j = 0; j < dim8; j++) {
uint8x8x4_t v_bgr = vld4_u8(din_ptr);
uint16x8_t vb_16 = vmovl_u8(v_bgr.val[0]);
uint16x8_t vg_16 = vmovl_u8(v_bgr.val[1]);
uint16x8_t vr_16 = vmovl_u8(v_bgr.val[2]);
// uint16x8_t va_16 = vmovl_u8(v_bgr.val[3]);
uint32x4_t vb_low_32 = vmovl_u16(vget_low_u16(vb_16));
uint32x4_t vg_low_32 = vmovl_u16(vget_low_u16(vg_16));
uint32x4_t vr_low_32 = vmovl_u16(vget_low_u16(vr_16));
uint32x4_t vb_high_32 = vmovl_u16(vget_high_u16(vb_16));
uint32x4_t vg_high_32 = vmovl_u16(vget_high_u16(vg_16));
uint32x4_t vr_high_32 = vmovl_u16(vget_high_u16(vr_16));
float32x4_t vb_low_f32 = vcvtq_f32_u32(vb_low_32);
float32x4_t vr_low_f32 = vcvtq_f32_u32(vr_low_32);
float32x4_t vg_low_f32 = vcvtq_f32_u32(vg_low_32);
float32x4_t vb_high_f32 = vcvtq_f32_u32(vb_high_32);
float32x4_t vg_high_f32 = vcvtq_f32_u32(vg_high_32);
float32x4_t vr_high_f32 = vcvtq_f32_u32(vr_high_32);
vb_low_f32 = vsubq_f32(vb_low_f32, vbmean);
vg_low_f32 = vsubq_f32(vg_low_f32, vgmean);
vr_low_f32 = vsubq_f32(vr_low_f32, vrmean);
vb_high_f32 = vsubq_f32(vb_high_f32, vbmean);
vg_high_f32 = vsubq_f32(vg_high_f32, vgmean);
vr_high_f32 = vsubq_f32(vr_high_f32, vrmean);
vb_low_f32 = vmulq_f32(vb_low_f32, vbscale);
vg_low_f32 = vmulq_f32(vg_low_f32, vgscale);
vr_low_f32 = vmulq_f32(vr_low_f32, vrscale);
vb_high_f32 = vmulq_f32(vb_high_f32, vbscale);
vg_high_f32 = vmulq_f32(vg_high_f32, vgscale);
vr_high_f32 = vmulq_f32(vr_high_f32, vrscale);
float32x4x3_t val;
val.val[0] = vb_low_f32;
val.val[1] = vg_low_f32;
val.val[2] = vr_low_f32;
// val.val[3] = num_a;
vst3q_f32(dout_ptr, val);
din_ptr += 32;
dout_ptr += 12;
val.val[0] = vb_high_f32;
val.val[1] = vg_high_f32;
val.val[2] = vr_high_f32;
vst3q_f32(dout_ptr, val);
dout_ptr += 12;
}
for (int j = 0; j < remain; j++) {
*dout_ptr++ = (*din_ptr - b_means) * b_scales;
din_ptr++;
*dout_ptr++ = (*din_ptr - g_means) * g_scales;
din_ptr++;
*dout_ptr++ = (*din_ptr - r_means) * r_scales;
din_ptr++;
din_ptr++; // a
// *dout_ptr++ = 255;
}
}
}
} // namespace cv
} // namespace utils
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdint.h>
#include "lite/utils/cv/paddle_image_preprocess.h"
namespace paddle {
namespace lite {
namespace utils {
namespace cv {
typedef void (*tensor_func)(const uint8_t* src,
float* dst,
int srcw,
int srch,
float* means,
float* scales);
class Image2Tensor {
public:
void choose(const uint8_t* src,
Tensor* dst,
ImageFormat srcFormat,
LayoutType layout,
int srcw,
int srch,
float* means,
float* scales);
private:
tensor_func impl_{nullptr};
};
} // namespace cv
} // namespace utils
} // namespace lite
} // namespace paddle
此差异已折叠。
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdint.h>
#include <stdio.h>
#include "lite/utils/cv/paddle_image_preprocess.h"
namespace paddle {
namespace lite {
namespace utils {
namespace cv {
typedef void (*convert_func)(const uint8_t* src,
uint8_t* dst,
int srcw,
int srch);
class ImageConvert {
public:
void choose(const uint8_t* src,
uint8_t* dst,
ImageFormat srcFormat,
ImageFormat dstFormat,
int srcw,
int srch);
private:
convert_func impl_{nullptr};
};
} // namespace cv
} // namespace utils
} // namespace lite
} // namespace paddle
此差异已折叠。
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdint.h>
#include <vector>
#include "lite/utils/cv/paddle_image_preprocess.h"
namespace paddle {
namespace lite {
namespace utils {
namespace cv {
void flip_hwc1(
const uint8_t* src, uint8_t* dst, int srcw, int srch, FlipParam flip_param);
void flip_hwc3(
const uint8_t* src, uint8_t* dst, int srcw, int srch, FlipParam flip_param);
void flip_hwc4(
const uint8_t* src, uint8_t* dst, int srcw, int srch, FlipParam flip_param);
} // namespace cv
} // namespace utils
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// ncnn license
// Tencent is pleased to support the open source community by making ncnn
// available.
//
// Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this
// file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software
// distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.
#include "lite/utils/cv/image_resize.h"
#include <arm_neon.h>
#include <math.h>
#include <algorithm>
namespace paddle {
namespace lite {
namespace utils {
namespace cv {
void compute_xy(int srcw,
int srch,
int dstw,
int dsth,
double scale_x,
double scale_y,
int* xofs,
int* yofs,
int16_t* ialpha,
int16_t* ibeta);
// use bilinear method to resize
void resize(const uint8_t* src,
uint8_t* dst,
ImageFormat srcFormat,
int srcw,
int srch,
int dstw,
int dsth) {
int size = srcw * srch;
if (srcw == dstw && srch == dsth) {
if (srcFormat == NV12 || srcFormat == NV21) {
size = srcw * (floor(1.5 * srch));
} else if (srcFormat == BGR || srcFormat == RGB) {
size = 3 * srcw * srch;
} else if (srcFormat == BGRA || srcFormat == RGBA) {
size = 4 * srcw * srch;
}
memcpy(dst, src, sizeof(uint8_t) * size);
return;
}
double scale_x = static_cast<double>(srcw / dstw);
double scale_y = static_cast<double>(srch / dsth);
int* buf = new int[dstw * 2 + dsth * 2];
int* xofs = buf;
int* yofs = buf + dstw;
int16_t* ialpha = reinterpret_cast<int16_t*>(buf + dstw + dsth);
int16_t* ibeta = reinterpret_cast<int16_t*>(buf + 2 * dstw + dsth);
compute_xy(
srcw, srch, dstw, dsth, scale_x, scale_y, xofs, yofs, ialpha, ibeta);
int w_out = dstw;
int w_in = srcw;
int num = 1;
int orih = dsth;
if (srcFormat == GRAY) {
num = 1;
} else if (srcFormat == NV12 || srcFormat == NV21) {
num = 1;
int hout = static_cast<int>(0.5 * dsth);
dsth += hout;
} else if (srcFormat == BGR || srcFormat == RGB) {
w_in = srcw * 3;
w_out = dstw * 3;
num = 3;
} else if (srcFormat == BGRA || srcFormat == RGBA) {
w_in = srcw * 4;
w_out = dstw * 4;
num = 4;
}
int* xofs1 = nullptr;
int* yofs1 = nullptr;
int16_t* ialpha1 = nullptr;
if (orih < dsth) { // uv
int tmp = dsth - orih;
int w = dstw / 2;
xofs1 = new int[w];
yofs1 = new int[tmp];
ialpha1 = new int16_t[srcw];
compute_xy(srcw / 2,
srch / 2,
w,
tmp,
scale_x,
scale_y,
xofs1,
yofs1,
ialpha1,
ibeta + orih);
}
int cnt = w_out >> 3;
int remain = w_out % 8;
int32x4_t _v2 = vdupq_n_s32(2);
#pragma omp parallel for
for (int dy = 0; dy < dsth; dy++) {
int16_t* rowsbuf0 = new int16_t[w_out];
int16_t* rowsbuf1 = new int16_t[w_out];
int sy = yofs[dy];
if (dy >= orih) {
xofs = xofs1;
yofs = yofs1;
ialpha = ialpha1;
}
if (sy < 0) {
memset(rowsbuf0, 0, sizeof(uint16_t) * w_out);
const uint8_t* S1 = src + srcw * (sy + 1);
const int16_t* ialphap = ialpha;
int16_t* rows1p = rowsbuf1;
for (int dx = 0; dx < dstw; dx++) {
int sx = xofs[dx] * num; // num = 4
int16_t a0 = ialphap[0];
int16_t a1 = ialphap[1];
const uint8_t* S1pl = S1 + sx;
const uint8_t* S1pr = S1 + sx + num;
if (sx < 0) {
S1pl = S1;
}
for (int i = 0; i < num; i++) {
if (sx < 0) {
*rows1p++ = ((*S1pl++) * a1) >> 4;
} else {
*rows1p++ = ((*S1pl++) * a0 + (*S1pr++) * a1) >> 4;
}
}
ialphap += 2;
}
} else {
// hresize two rows
const uint8_t* S0 = src + w_in * (sy);
const uint8_t* S1 = src + w_in * (sy + 1);
const int16_t* ialphap = ialpha;
int16_t* rows0p = rowsbuf0;
int16_t* rows1p = rowsbuf1;
for (int dx = 0; dx < dstw; dx++) {
int sx = xofs[dx] * num; // num = 4
int16_t a0 = ialphap[0];
int16_t a1 = ialphap[1];
const uint8_t* S0pl = S0 + sx;
const uint8_t* S0pr = S0 + sx + num;
const uint8_t* S1pl = S1 + sx;
const uint8_t* S1pr = S1 + sx + num;
if (sx < 0) {
S0pl = S0;
S1pl = S1;
}
for (int i = 0; i < num; i++) {
if (sx < 0) {
*rows0p = ((*S0pl++) * a1) >> 4;
*rows1p = ((*S1pl++) * a1) >> 4;
rows0p++;
rows1p++;
} else {
*rows0p++ = ((*S0pl++) * a0 + (*S0pr++) * a1) >> 4;
*rows1p++ = ((*S1pl++) * a0 + (*S1pr++) * a1) >> 4;
}
}
ialphap += 2;
}
}
int ind = dy * 2;
int16_t b0 = ibeta[ind];
int16_t b1 = ibeta[ind + 1];
int16x8_t _b0 = vdupq_n_s16(b0);
int16x8_t _b1 = vdupq_n_s16(b1);
uint8_t* dp_ptr = dst + dy * w_out;
int16_t* rows0p = rowsbuf0;
int16_t* rows1p = rowsbuf1;
int re_cnt = cnt;
if (re_cnt > 0) {
#ifdef __aarch64__
asm volatile(
"1: \n"
"ld1 {v0.8h}, [%[rows0p]], #16 \n"
"ld1 {v1.8h}, [%[rows1p]], #16 \n"
"orr v6.16b, %w[_v2].16b, %w[_v2].16b \n"
"orr v7.16b, %w[_v2].16b, %w[_v2].16b \n"
"smull v2.4s, v0.4h, %w[_b0].4h \n"
"smull2 v4.4s, v0.8h, %w[_b0].8h \n"
"smull v3.4s, v1.4h, %w[_b1].4h \n"
"smull2 v5.4s, v1.8h, %w[_b1].8h \n"
"ssra v6.4s, v2.4s, #16 \n"
"ssra v7.4s, v4.4s, #16 \n"
"ssra v6.4s, v3.4s, #16 \n"
"ssra v7.4s, v5.4s, #16 \n"
"shrn v0.4h, v6.4s, #2 \n"
"shrn2 v0.8h, v7.4s, #2 \n"
"subs %w[cnt], %w[cnt], #1 \n"
"sqxtun v1.8b, v0.8h \n"
"st1 {v1.8b}, [%[dp]], #8 \n"
"bne 1b \n"
: [rows0p] "+r"(rows0p),
[rows1p] "+r"(rows1p),
[cnt] "+r"(re_cnt),
[dp] "+r"(dp_ptr)
: [_b0] "w"(_b0), [_b1] "w"(_b1), [_v2] "w"(_v2)
: "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7");
#else
asm volatile(
"mov r4, #2 \n"
"vdup.s32 q12, r4 \n"
"0: \n"
"vld1.s16 {d2-d3}, [%[rows0p]]!\n"
"vld1.s16 {d6-d7}, [%[rows1p]]!\n"
"vorr.s32 q10, q12, q12 \n"
"vorr.s32 q11, q12, q12 \n"
"vmull.s16 q0, d2, %[_b0] \n"
"vmull.s16 q1, d3, %[_b0] \n"
"vmull.s16 q2, d6, %[_b1] \n"
"vmull.s16 q3, d7, %[_b1] \n"
"vsra.s32 q10, q0, #16 \n"
"vsra.s32 q11, q1, #16 \n"
"vsra.s32 q10, q2, #16 \n"
"vsra.s32 q11, q3, #16 \n"
"vshrn.s32 d20, q10, #2 \n"
"vshrn.s32 d21, q11, #2 \n"
"subs %[cnt], #1 \n"
"vqmovun.s16 d20, q10 \n"
"vst1.8 {d20}, [%[dp]]! \n"
"bne 0b \n"
: [rows0p] "+r"(rows0p),
[rows1p] "+r"(rows1p),
[cnt] "+r"(re_cnt),
[dp] "+r"(dp_ptr)
: [_b0] "w"(_b0), [_b1] "w"(_b1)
: "cc",
"memory",
"r4",
"q0",
"q1",
"q2",
"q3",
"q8",
"q9",
"q10",
"q11",
"q12");
#endif // __aarch64__
}
for (int i = 0; i < remain; i++) {
// D[x] = (rows0[x]*b0 + rows1[x]*b1) >>
// INTER_RESIZE_COEF_BITS;
*dp_ptr++ =
(uint8_t)(((int16_t)((b0 * (int16_t)(*rows0p++)) >> 16) +
(int16_t)((b1 * (int16_t)(*rows1p++)) >> 16) + 2) >>
2);
}
}
delete[] buf;
}
// compute xofs, yofs, alpha, beta
void compute_xy(int srcw,
int srch,
int dstw,
int dsth,
double scale_x,
double scale_y,
int* xofs,
int* yofs,
int16_t* ialpha,
int16_t* ibeta) {
float fy = 0.f;
float fx = 0.f;
int sy = 0;
int sx = 0;
const int resize_coef_bits = 11;
const int resize_coef_scale = 1 << resize_coef_bits;
#define SATURATE_CAST_SHORT(X) \
(int16_t)::std::min( \
::std::max(static_cast<int>(X + (X >= 0.f ? 0.5f : -0.5f)), SHRT_MIN), \
SHRT_MAX);
for (int dx = 0; dx < dstw; dx++) {
fx = static_cast<float>((dx + 0.5) * scale_x - 0.5);
sx = floor(fx);
fx -= sx;
if (sx < 0) {
sx = 0;
fx = 0.f;
}
if (sx >= srcw - 1) {
sx = srcw - 2;
fx = 1.f;
}
xofs[dx] = sx;
float a0 = (1.f - fx) * resize_coef_scale;
float a1 = fx * resize_coef_scale;
ialpha[dx * 2] = SATURATE_CAST_SHORT(a0);
ialpha[dx * 2 + 1] = SATURATE_CAST_SHORT(a1);
}
for (int dy = 0; dy < dsth; dy++) {
fy = static_cast<float>((dy + 0.5) * scale_y - 0.5);
sy = floor(fy);
fy -= sy;
if (sy < 0) {
sy = 0;
fy = 0.f;
}
if (sy >= srch - 1) {
sy = srch - 2;
fy = 1.f;
}
yofs[dy] = sy;
float b0 = (1.f - fy) * resize_coef_scale;
float b1 = fy * resize_coef_scale;
ibeta[dy * 2] = SATURATE_CAST_SHORT(b0);
ibeta[dy * 2 + 1] = SATURATE_CAST_SHORT(b1);
}
#undef SATURATE_CAST_SHORT
}
} // namespace cv
} // namespace utils
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// ncnn license
// Tencent is pleased to support the open source community by making ncnn
// available.
//
// Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this
// file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software
// distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.
#pragma once
#include <math.h>
#include <stdint.h>
#include "lite/utils/cv/paddle_image_preprocess.h"
namespace paddle {
namespace lite {
namespace utils {
namespace cv {
void resize(const uint8_t* src,
uint8_t* dst,
ImageFormat srcFormat,
int srcw,
int srch,
int dstw,
int dsth);
} // namespace cv
} // namespace utils
} // namespace lite
} // namespace paddle
此差异已折叠。
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdint.h>
#include <vector>
namespace paddle {
namespace lite {
namespace utils {
namespace cv {
void rotate_hwc1(
const uint8_t* src, uint8_t* dst, int srcw, int srch, float degree);
void rotate_hwc3(
const uint8_t* src, uint8_t* dst, int srcw, int srch, float degree);
void rotate_hwc4(
const uint8_t* src, uint8_t* dst, int srcw, int srch, float degree);
} // namespace cv
} // namespace utils
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/utils/cv/paddle_image_preprocess.h"
#include <math.h>
#include <algorithm>
#include <climits>
#include "lite/utils/cv/image2tensor.h"
#include "lite/utils/cv/image_convert.h"
#include "lite/utils/cv/image_flip.h"
#include "lite/utils/cv/image_resize.h"
#include "lite/utils/cv/image_rotate.h"
namespace paddle {
namespace lite {
namespace utils {
namespace cv {
#define PI 3.14159265f
#define Degrees2Radians(degrees) ((degrees) * (SK_ScalarPI / 180))
#define Radians2Degrees(radians) ((radians) * (180 / SK_ScalarPI))
#define ScalarNearlyZero (1.0f / (1 << 12))
// init
ImagePreprocess::ImagePreprocess(ImageFormat srcFormat,
ImageFormat dstFormat,
TransParam param) {
this->srcFormat_ = srcFormat;
this->dstFormat_ = dstFormat;
this->transParam_ = param;
}
void ImagePreprocess::imageCovert(const uint8_t* src, uint8_t* dst) {
ImageConvert img_convert;
img_convert.choose(src,
dst,
this->srcFormat_,
this->dstFormat_,
this->transParam_.iw,
this->transParam_.ih);
}
void ImagePreprocess::imageCovert(const uint8_t* src,
uint8_t* dst,
ImageFormat srcFormat,
ImageFormat dstFormat) {
ImageConvert img_convert;
img_convert.choose(src,
dst,
srcFormat,
dstFormat,
this->transParam_.iw,
this->transParam_.ih);
}
void ImagePreprocess::imageResize(const uint8_t* src,
uint8_t* dst,
ImageFormat srcFormat,
int srcw,
int srch,
int dstw,
int dsth) {
resize(src, dst, srcFormat, srcw, srch, dstw, dsth);
/*
int size = srcw * srch;
if (srcw == dstw && srch == dsth) {
if (srcFormat == NV12 || srcFormat == NV21) {
size = srcw * (floor(1.5 * srch));
} else if (srcFormat == BGR || srcFormat == RGB) {
size = 3 * srcw * srch;
} else if (srcFormat == BGRA || srcFormat == RGBA) {
size = 4 * srcw * srch;
}
memcpy(dst, src, sizeof(uint8_t) * size);
return;
}
double scale_x = static_cast<double>(srcw / dstw);
double scale_y = static_cast<double>(srch / dsth);
int* buf = new int[dstw * 2 + dsth * 2];
int* xofs = buf;
int* yofs = buf + dstw;
int16_t* ialpha = reinterpret_cast<int16_t*>(buf + dstw + dsth);
int16_t* ibeta = reinterpret_cast<int16_t*>(buf + 2 * dstw + dsth);
compute_xy(
srcw, srch, dstw, dsth, scale_x, scale_y, xofs, yofs, ialpha, ibeta);
int w_out = dstw;
int w_in = srcw;
int num = 1;
int orih = dsth;
if (srcFormat == GRAY) {
num = 1;
} else if (srcFormat == NV12 || srcFormat == NV21) {
num = 1;
int hout = static_cast<int>(0.5 * dsth);
dsth += hout;
} else if (srcFormat == BGR || srcFormat == RGB) {
w_in = srcw * 3;
w_out = dstw * 3;
num = 3;
} else if (srcFormat == BGRA || srcFormat == RGBA) {
w_in = srcw * 4;
w_out = dstw * 4;
num = 4;
}
int* xofs1 = nullptr;
int* yofs1 = nullptr;
int16_t* ialpha1 = nullptr;
if (orih < dsth) { // uv
int tmp = dsth - orih;
int w = dstw / 2;
xofs1 = new int[w];
yofs1 = new int[tmp];
ialpha1 = new int16_t[srcw];
compute_xy(srcw / 2,
srch / 2,
w,
tmp,
scale_x,
scale_y,
xofs1,
yofs1,
ialpha1,
ibeta + orih);
}
int cnt = w_out >> 3;
int remain = w_out % 8;
int32x4_t _v2 = vdupq_n_s32(2);
#pragma omp parallel for
for (int dy = 0; dy < dsth; dy++) {
int16_t* rowsbuf0 = new int16_t[w_out];
int16_t* rowsbuf1 = new int16_t[w_out];
int sy = yofs[dy];
if (dy >= orih) {
xofs = xofs1;
yofs = yofs1;
ialpha = ialpha1;
}
if (sy < 0) {
memset(rowsbuf0, 0, sizeof(uint16_t) * w_out);
const uint8_t* S1 = src + srcw * (sy + 1);
const int16_t* ialphap = ialpha;
int16_t* rows1p = rowsbuf1;
for (int dx = 0; dx < dstw; dx++) {
int sx = xofs[dx] * num; // num = 4
int16_t a0 = ialphap[0];
int16_t a1 = ialphap[1];
const uint8_t* S1pl = S1 + sx;
const uint8_t* S1pr = S1 + sx + num;
if (sx < 0) {
S1pl = S1;
}
for (int i = 0; i < num; i++) {
if (sx < 0) {
*rows1p++ = ((*S1pl++) * a1) >> 4;
} else {
*rows1p++ = ((*S1pl++) * a0 + (*S1pr++) * a1) >> 4;
}
}
ialphap += 2;
}
} else {
// hresize two rows
const uint8_t* S0 = src + w_in * (sy);
const uint8_t* S1 = src + w_in * (sy + 1);
const int16_t* ialphap = ialpha;
int16_t* rows0p = rowsbuf0;
int16_t* rows1p = rowsbuf1;
for (int dx = 0; dx < dstw; dx++) {
int sx = xofs[dx] * num; // num = 4
int16_t a0 = ialphap[0];
int16_t a1 = ialphap[1];
const uint8_t* S0pl = S0 + sx;
const uint8_t* S0pr = S0 + sx + num;
const uint8_t* S1pl = S1 + sx;
const uint8_t* S1pr = S1 + sx + num;
if (sx < 0) {
S0pl = S0;
S1pl = S1;
}
for (int i = 0; i < num; i++) {
if (sx < 0) {
*rows0p = ((*S0pl++) * a1) >> 4;
*rows1p = ((*S1pl++) * a1) >> 4;
rows0p++;
rows1p++;
} else {
*rows0p++ = ((*S0pl++) * a0 + (*S0pr++) * a1) >> 4;
*rows1p++ = ((*S1pl++) * a0 + (*S1pr++) * a1) >> 4;
}
}
ialphap += 2;
}
}
int ind = dy * 2;
int16_t b0 = ibeta[ind];
int16_t b1 = ibeta[ind + 1];
int16x8_t _b0 = vdupq_n_s16(b0);
int16x8_t _b1 = vdupq_n_s16(b1);
uint8_t* dp_ptr = dst + dy * w_out;
int16_t* rows0p = rowsbuf0;
int16_t* rows1p = rowsbuf1;
int re_cnt = cnt;
if (re_cnt > 0) {
#ifdef __aarch64__
asm volatile(
"1: \n"
"ld1 {v0.8h}, [%[rows0p]], #16 \n"
"ld1 {v1.8h}, [%[rows1p]], #16 \n"
"orr v6.16b, %w[_v2].16b, %w[_v2].16b \n"
"orr v7.16b, %w[_v2].16b, %w[_v2].16b \n"
"smull v2.4s, v0.4h, %w[_b0].4h \n"
"smull2 v4.4s, v0.8h, %w[_b0].8h \n"
"smull v3.4s, v1.4h, %w[_b1].4h \n"
"smull2 v5.4s, v1.8h, %w[_b1].8h \n"
"ssra v6.4s, v2.4s, #16 \n"
"ssra v7.4s, v4.4s, #16 \n"
"ssra v6.4s, v3.4s, #16 \n"
"ssra v7.4s, v5.4s, #16 \n"
"shrn v0.4h, v6.4s, #2 \n"
"shrn2 v0.8h, v7.4s, #2 \n"
"subs %w[cnt], %w[cnt], #1 \n"
"sqxtun v1.8b, v0.8h \n"
"st1 {v1.8b}, [%[dp]], #8 \n"
"bne 1b \n"
: [rows0p] "+r"(rows0p),
[rows1p] "+r"(rows1p),
[cnt] "+r"(re_cnt),
[dp] "+r"(dp_ptr)
: [_b0] "w"(_b0), [_b1] "w"(_b1), [_v2] "w"(_v2)
: "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7");
#else
asm volatile(
"mov r4, #2 \n"
"vdup.s32 q12, r4 \n"
"0: \n"
"vld1.s16 {d2-d3}, [%[rows0p]]!\n"
"vld1.s16 {d6-d7}, [%[rows1p]]!\n"
"vorr.s32 q10, q12, q12 \n"
"vorr.s32 q11, q12, q12 \n"
"vmull.s16 q0, d2, %[_b0] \n"
"vmull.s16 q1, d3, %[_b0] \n"
"vmull.s16 q2, d6, %[_b1] \n"
"vmull.s16 q3, d7, %[_b1] \n"
"vsra.s32 q10, q0, #16 \n"
"vsra.s32 q11, q1, #16 \n"
"vsra.s32 q10, q2, #16 \n"
"vsra.s32 q11, q3, #16 \n"
"vshrn.s32 d20, q10, #2 \n"
"vshrn.s32 d21, q11, #2 \n"
"subs %[cnt], #1 \n"
"vqmovun.s16 d20, q10 \n"
"vst1.8 {d20}, [%[dp]]! \n"
"bne 0b \n"
: [rows0p] "+r"(rows0p),
[rows1p] "+r"(rows1p),
[cnt] "+r"(re_cnt),
[dp] "+r"(dp_ptr)
: [_b0] "w"(_b0), [_b1] "w"(_b1)
: "cc",
"memory",
"r4",
"q0",
"q1",
"q2",
"q3",
"q8",
"q9",
"q10",
"q11",
"q12");
#endif // __aarch64__
}
for (int i = 0; i < remain; i++) {
// D[x] = (rows0[x]*b0 + rows1[x]*b1) >>
// INTER_RESIZE_COEF_BITS;
*dp_ptr++ =
(uint8_t)(((int16_t)((b0 * (int16_t)(*rows0p++)) >> 16) +
(int16_t)((b1 * (int16_t)(*rows1p++)) >> 16) + 2) >>
2);
}
}
delete[] buf;
*/
}
void ImagePreprocess::imageResize(const uint8_t* src, uint8_t* dst) {
int srcw = this->transParam_.iw;
int srch = this->transParam_.ih;
int dstw = this->transParam_.ow;
int dsth = this->transParam_.oh;
auto srcFormat = this->dstFormat_;
resize(src, dst, srcFormat, srcw, srch, dstw, dsth);
}
void ImagePreprocess::imageRotate(const uint8_t* src,
uint8_t* dst,
ImageFormat srcFormat,
int srcw,
int srch,
float degree) {
if (degree != 90 && degree != 180 && degree != 270) {
printf("this degree: %f not support \n", degree);
}
if (srcFormat == GRAY) {
rotate_hwc1(src, dst, srcw, srch, degree);
} else if (srcFormat == BGR || srcFormat == RGB) {
rotate_hwc3(src, dst, srcw, srch, degree);
} else if (srcFormat == BGRA || srcFormat == RGBA) {
rotate_hwc4(src, dst, srcw, srch, degree);
} else {
printf("this srcFormat: %d does not support! \n", srcFormat);
return;
}
}
void ImagePreprocess::imageRotate(const uint8_t* src, uint8_t* dst) {
auto srcw = this->transParam_.ow;
auto srch = this->transParam_.oh;
auto srcFormat = this->dstFormat_;
auto degree = this->transParam_.rotate_param;
if (degree != 90 && degree != 180 && degree != 270) {
printf("this degree: %f not support \n", degree);
}
ImagePreprocess::imageRotate(src, dst, srcFormat, srcw, srch, degree);
}
void ImagePreprocess::imageFlip(const uint8_t* src,
uint8_t* dst,
ImageFormat srcFormat,
int srcw,
int srch,
FlipParam flip_param) {
if (srcFormat == GRAY) {
flip_hwc1(src, dst, srcw, srch, flip_param);
} else if (srcFormat == BGR || srcFormat == RGB) {
flip_hwc3(src, dst, srcw, srch, flip_param);
} else if (srcFormat == BGRA || srcFormat == RGBA) {
flip_hwc4(src, dst, srcw, srch, flip_param);
} else {
printf("this srcFormat: %d does not support! \n", srcFormat);
return;
}
}
void ImagePreprocess::imageFlip(const uint8_t* src, uint8_t* dst) {
auto srcw = this->transParam_.ow;
auto srch = this->transParam_.oh;
auto srcFormat = this->dstFormat_;
auto flip_param = this->transParam_.flip_param;
ImagePreprocess::imageFlip(src, dst, srcFormat, srcw, srch, flip_param);
}
void ImagePreprocess::image2Tensor(const uint8_t* src,
Tensor* dstTensor,
ImageFormat srcFormat,
int srcw,
int srch,
LayoutType layout,
float* means,
float* scales) {
Image2Tensor img2tensor;
img2tensor.choose(
src, dstTensor, srcFormat, layout, srcw, srch, means, scales);
}
void ImagePreprocess::image2Tensor(const uint8_t* src,
Tensor* dstTensor,
LayoutType layout,
float* means,
float* scales) {
Image2Tensor img2tensor;
img2tensor.choose(src,
dstTensor,
this->dstFormat_,
layout,
this->transParam_.ow,
this->transParam_.oh,
means,
scales);
}
} // namespace cv
} // namespace utils
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdint.h>
#include <stdio.h>
#include <vector>
#include "lite/api/paddle_api.h"
#include "lite/api/paddle_place.h"
namespace paddle {
namespace lite {
namespace utils {
namespace cv {
typedef paddle::lite_api::Tensor Tensor;
typedef paddle::lite_api::DataLayoutType LayoutType;
// color enum
enum ImageFormat {
RGBA = 0,
BGRA,
RGB,
BGR,
GRAY,
NV21 = 11,
NV12,
};
// flip enum
enum FlipParam {
X = 0, // flip along the X axis
Y, // flip along the Y axis
XY // flip along the XY axis
};
// transform param
typedef struct {
int ih; // input height
int iw; // input width
int oh; // outpu theight
int ow; // output width
FlipParam flip_param; // flip, support x, y, xy
float rotate_param; // rotate, support 90, 180, 270
} TransParam;
class ImagePreprocess {
public:
/*
* init
* param srcFormat: input image color
* param dstFormat: output image color
* param param: input image parameter, egs: input size
*/
ImagePreprocess(ImageFormat srcFormat,
ImageFormat dstFormat,
TransParam param);
/*
* image color convert
* support NV12/NV21_to_BGR(RGB), NV12/NV21_to_BGRA(RGBA),
* BGR(RGB)and BGRA(RGBA) transform,
* BGR(RGB)and RGB(BGR) transform,
* BGR(RGB)and RGBA(BGRA) transform,
* BGR(RGB)and GRAY transform,
* param src: input image data
* param dst: output image data
*/
void imageCovert(const uint8_t* src, uint8_t* dst);
/*
* image color convert
* support NV12/NV21_to_BGR(RGB), NV12/NV21_to_BGRA(RGBA),
* BGR(RGB)and BGRA(RGBA) transform,
* BGR(RGB)and RGB(BGR) transform,
* BGR(RGB)and RGBA(BGRA) transform,
* BGR(RGB)and GRAY transform,
* param src: input image data
* param dst: output image data
* param srcFormat: input image image format support: GRAY, NV12(NV21),
* BGR(RGB) and BGRA(RGBA)
* param dstFormat: output image image format, support GRAY, BGR(RGB) and
* BGRA(RGBA)
*/
void imageCovert(const uint8_t* src,
uint8_t* dst,
ImageFormat srcFormat,
ImageFormat dstFormat);
/*
* image resize, use bilinear method
* support image format: 1-channel image (egs: GRAY, 2-channel image (egs:
* NV12, NV21), 3-channel(egs: BGR), 4-channel(egs: BGRA)
* param src: input image data
* param dst: output image data
*/
void imageResize(const uint8_t* src, uint8_t* dst);
/*
image resize, use bilinear method
* support image format: 1-channel image (egs: GRAY, 2-channel image (egs:
NV12, NV21), 3-channel image(egs: BGR), 4-channel image(egs: BGRA)
* param src: input image data
* param dst: output image data
* param srcw: input image width
* param srch: input image height
* param dstw: output image width
* param dsth: output image height
*/
void imageResize(const uint8_t* src,
uint8_t* dst,
ImageFormat srcFormat,
int srcw,
int srch,
int dstw,
int dsth);
/*
* image Rotate
* support 90, 180 and 270 Rotate process
* color format support 1-channel image, 3-channel image and 4-channel image
* param src: input image data
* param dst: output image data
*/
void imageRotate(const uint8_t* src, uint8_t* dst);
/*
* image Rotate
* support 90, 180 and 270 Rotate process
* color format support 1-channel image, 3-channel image and 4-channel image
* param src: input image data
* param dst: output image data
* param srcFormat: input image format, support GRAY, BGR(GRB) and BGRA(RGBA)
* param srcw: input image width
* param srch: input image height
* param degree: Rotate degree, support 90, 180 and 270
*/
void imageRotate(const uint8_t* src,
uint8_t* dst,
ImageFormat srcFormat,
int srcw,
int srch,
float degree);
/*
* image Flip
* support X, Y and XY flip process
* color format support 1-channel image, 3-channel image and 4-channel image
* param src: input image data
* param dst: output image data
*/
void imageFlip(const uint8_t* src, uint8_t* dst);
/*
* image Flip
* support X, Y and XY flip process
* color format support 1-channel image, 3-channel image and 4-channel image
* param src: input image data
* param dst: output image data
* param srcFormat: input image format, support GRAY, BGR(GRB) and BGRA(RGBA)
* param srcw: input image width
* param srch: input image height
* param flip_param: flip parameter, support X, Y and XY
*/
void imageFlip(const uint8_t* src,
uint8_t* dst,
ImageFormat srcFormat,
int srcw,
int srch,
FlipParam flip_param);
/*
* change image data to tensor data
* support image format is BGR(RGB) and BGRA(RGBA), Data layout is NHWC and
* NCHW
* param src: input image data
* param dstTensor: output tensor data
* param layout: output tensor layout,support NHWC and NCHW
* param means: means of image
* param scales: scales of image
*/
void image2Tensor(const uint8_t* src,
Tensor* dstTensor,
LayoutType layout,
float* means,
float* scales);
/*
* change image data to tensor data
* support image format is BGR(RGB) and BGRA(RGBA), Data layout is NHWC and
* NCHW
* param src: input image data
* param dstTensor: output tensor data
* param srcFormat: input image format, support BGR(GRB) and BGRA(RGBA)
* param srcw: input image width
* param srch: input image height
* param layout: output tensor layout,support NHWC and NCHW
* param means: means of image
* param scales: scales of image
*/
void image2Tensor(const uint8_t* src,
Tensor* dstTensor,
ImageFormat srcFormat,
int srcw,
int srch,
LayoutType layout,
float* means,
float* scales);
private:
ImageFormat srcFormat_;
ImageFormat dstFormat_;
TransParam transParam_;
};
} // namespace cv
} // namespace utils
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册