diff --git a/.gitignore b/.gitignore index 9db2912c07bc2d6abb01c322a25519ac0ff158fa..ed131bdbbad6bd4dad500fa29f40a29fddeb7593 100644 --- a/.gitignore +++ b/.gitignore @@ -35,6 +35,7 @@ build/ build_fpga/ +docs/_build/ .idea/ diff --git a/.travis.yml b/.travis.yml index c902afef91b816390170f1b7e1c8e4b07c7b0645..bee77d08304881c718483b88e1ea7e55228483e2 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,7 +1,7 @@ language: cpp cache: ccache sudo: required -dist: trusty +dist: xenial os: - linux @@ -18,7 +18,7 @@ addons: - clang-format-3.8 before_install: - - sudo pip install cpplint pre-commit + - sudo pip install cpplint pre-commit==1.10.3 - sudo ln -s /usr/bin/clang-format-3.8 /usr/bin/clang-format # Download and install recent cmake diff --git a/CMakeLists.txt b/CMakeLists.txt index 77a94bea1efcdafaa67b4c078bfb0a756f7b1cec..e3f7a211d70920aa74765b976af6939d55a328ab 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -60,6 +60,7 @@ lite_option(LITE_WITH_X86 "Enable X86 in lite mode" ON) lite_option(LITE_WITH_ARM "Enable ARM in lite mode" OFF) lite_option(LITE_WITH_NPU "Enable NPU in lite mode" OFF) lite_option(LITE_WITH_XPU "Enable XPU in lite mode" OFF) +lite_option(LITE_WITH_BM "Enable BM in lite mode" OFF) lite_option(LITE_WITH_OPENMP "Enable OpenMP in lite framework" ON) lite_option(LITE_WITH_OPENCL "Enable OpenCL support in lite" OFF) lite_option(LITE_WITH_FPGA "Enable FPGA support in lite" OFF) @@ -73,8 +74,8 @@ lite_option(LITE_ON_MODEL_OPTIMIZE_TOOL "Build the model optimize tool" OFF) lite_option(LITE_BUILD_EXTRA "Enable extra algorithm support in Lite, both kernels and operators" OFF) lite_option(LITE_BUILD_TAILOR "Enable tailoring library according to model" OFF) # cv build options -lite_option(LITE_WITH_CV "Enable build cv image in lite" OFF IF NOT LITE_WITH_ARM) - +lite_option(LITE_WITH_CV "Enable build cv image in lite" OFF) +lite_option(LITE_WITH_STATIC_CUDA "Statically link cuda libraries." ON) # TODO(Superjomn) Remove WITH_ANAKIN option if not needed latter. if(ANDROID OR IOS OR ARMLINUX) @@ -169,6 +170,10 @@ endif() ######################################################################################## +if(LITE_WITH_XPU) + include(xpu) +endif() + include(external/mklml) # download mklml package include(external/xbyak) # download xbyak package include(external/libxsmm) # download, build, install libxsmm @@ -188,10 +193,9 @@ if(LITE_WITH_CUDA) include(cuda) endif() -if(LITE_WITH_XPU) - include(xpu) +if(LITE_WITH_BM) + include(bm) endif() - include(generic) # simplify cmake module include(ccache) # set ccache for compilation include(util) # set unittest and link libs diff --git a/README.md b/README.md index 23974beee9a8af5ee7e2c454575efff2e3d96ee2..22b84888294b5ef60c3d91d7a7909aef8f601d81 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,7 @@ Framework compatibility: In addition to models trained on PaddlePaddle, those tr Paddle Lite is designed to support a wide range of hardwares and devices, and it enables mixed execution of a single model on multiple devices, optimization on various phases, and leight-weighted applications on devices. -![img](https://github.com/Superjomn/_tmp_images/raw/master/images/paddle-lite-architecture.png) +![img](https://user-images.githubusercontent.com/45189361/70908123-6ce4fd00-2045-11ea-97e1-ad08446c5c86.png) As is shown in the figure above, analysis phase includes Machine IR module, and it enables optimizations like Op fusion and redundant computation pruning. Besides, excecution phase only involves Kernal exevution, so it can be deployed on its own to ensure maximized light-weighted deployment. diff --git a/README_cn.md b/README_cn.md index 99d38c47ffbbaa3b8593801701e3528167899f97..11d3967fe8ce88826ca982b71d96268c1a7e5c3a 100644 --- a/README_cn.md +++ b/README_cn.md @@ -34,7 +34,7 @@ Paddle Lite为Paddle-Mobile的升级版,定位支持包括手机移动端在 PaddleLite 的架构设计着重考虑了对多硬件和平台的支持,并且强化了多个硬件在一个模型中混合执行的能力,多个层面的性能优化处理,以及对端侧应用的轻量化设计。 -![](https://github.com/Superjomn/_tmp_images/raw/master/images/paddle-lite-architecture.png) +![](https://user-images.githubusercontent.com/45189361/70908123-6ce4fd00-2045-11ea-97e1-ad08446c5c86.png) 其中,Analysis Phase 包括了 MIR(Machine IR) 相关模块,能够对原有的模型的计算图针对具体的硬件列表进行算子融合、计算裁剪 在内的多种优化。Execution Phase 只涉及到Kernel 的执行,且可以单独部署,以支持极致的轻量级部署。 diff --git a/cmake/bm.cmake b/cmake/bm.cmake new file mode 100644 index 0000000000000000000000000000000000000000..3a3abb5966172ba00227e9fac7fabfe55bac7737 --- /dev/null +++ b/cmake/bm.cmake @@ -0,0 +1,80 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if(NOT LITE_WITH_BM) + return() +endif() + +if(NOT DEFINED BM_SDK_ROOT) + set(BM_SDK_ROOT $ENV{BM_SDK_ROOT}) + if(NOT BM_SDK_ROOT) + message(FATAL_ERROR "Must set BM_SDK_ROOT or env BM_SDK_ROOT when LITE_WITH_BM=ON") + endif() +endif() + +message(STATUS "BM_SDK_ROOT: ${BM_SDK_ROOT}") +find_path(BM_SDK_INC NAMES bmruntime_interface.h + PATHS ${BM_SDK_ROOT}/include/bmruntime NO_DEFAULT_PATH) +if(NOT BM_SDK_INC) + message(FATAL_ERROR "Can not find bmruntime_interface.h in ${BM_SDK_ROOT}/include") +endif() + +include_directories("${BM_SDK_ROOT}/include/bmruntime") +include_directories("${BM_SDK_ROOT}/include/bmlib") +include_directories("${BM_SDK_ROOT}/include/bmcompiler") +include_directories("${BM_SDK_ROOT}/include/bmcpu") +include_directories("${BM_SDK_ROOT}/include/bmlog") + +find_library(BM_SDK_RT_LIB NAMES bmrt + PATHS ${BM_SDK_ROOT}/lib/bmnn/pcie) +if(NOT BM_SDK_RT_LIB) + message(FATAL_ERROR "Can not find bmrt Library in ${BM_SDK_ROOT}") +else() + message(STATUS "Found bmrt Library: ${BM_SDK_RT_LIB}") + add_library(bmrt SHARED IMPORTED GLOBAL) + set_property(TARGET bmrt PROPERTY IMPORTED_LOCATION ${BM_SDK_RT_LIB}) +endif() + +find_library(BM_SDK_BM_LIB NAMES bmlib + PATHS ${BM_SDK_ROOT}/lib/bmnn/pcie) +if(NOT BM_SDK_BM_LIB) + message(FATAL_ERROR "Can not find bmlib Library in ${BM_SDK_ROOT}") +else() + message(STATUS "Found bmlib Library: ${BM_SDK_BM_LIB}") + add_library(bmlib SHARED IMPORTED GLOBAL) + set_property(TARGET bmlib PROPERTY IMPORTED_LOCATION ${BM_SDK_BM_LIB}) +endif() + +find_library(BM_SDK_COMPILER_LIB NAMES bmcompiler + PATHS ${BM_SDK_ROOT}/lib/bmcompiler) +if(NOT BM_SDK_COMPILER_LIB) + message(FATAL_ERROR "Can not find bmcompiler Library in ${BM_SDK_ROOT}") +else() + message(STATUS "Found bmcompiler Library: ${BM_SDK_COMPILER_LIB}") + add_library(bmcompiler SHARED IMPORTED GLOBAL) + set_property(TARGET bmcompiler PROPERTY IMPORTED_LOCATION ${BM_SDK_COMPILER_LIB}) +endif() + +find_library(BM_SDK_CPU_LIB NAMES bmcpu + PATHS ${BM_SDK_ROOT}/lib/bmnn/pcie) +if(NOT BM_SDK_CPU_LIB) + message(FATAL_ERROR "Can not find bmcpu Library in ${BM_SDK_ROOT}") +else() + message(STATUS "Found bmcpu Library: ${BM_SDK_CPU_LIB}") + add_library(bmcpu SHARED IMPORTED GLOBAL) + set_property(TARGET bmcpu PROPERTY IMPORTED_LOCATION ${BM_SDK_CPU_LIB}) +endif() + +set(bm_runtime_libs bmrt bmlib bmcompiler bmcpu CACHE INTERNAL "bm runtime libs") +set(bm_builder_libs bmrt bmlib bmcompiler bmcpu CACHE INTERNAL "bm builder libs") diff --git a/cmake/configure.cmake b/cmake/configure.cmake index bc055d3186c6bfd77ff6a5e9f979af5082fa34e3..752b22461d9d1c36b3ca6a0bfe472a5dcc3ab976 100644 --- a/cmake/configure.cmake +++ b/cmake/configure.cmake @@ -143,6 +143,10 @@ if (LITE_WITH_FPGA) add_definitions("-DLITE_WITH_FPGA") endif() +if (LITE_WITH_BM) +add_definitions("-DLITE_WITH_BM") +endif() + if (LITE_WITH_PROFILE) add_definitions("-DLITE_WITH_PROFILE") if (LITE_WITH_PRECISION_PROFILE) diff --git a/cmake/cross_compiling/ios.cmake b/cmake/cross_compiling/ios.cmake index 76f62765aff791594123d689341b0876b3d0184d..0597ef0cc4ba4c0bcec172c767d66d0f362e1459 100644 --- a/cmake/cross_compiling/ios.cmake +++ b/cmake/cross_compiling/ios.cmake @@ -120,6 +120,7 @@ # ## Lite settings +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -flto") if (ARM_TARGET_OS STREQUAL "ios") set(PLATFORM "OS") elseif(ARM_TARGET_OS STREQUAL "ios64") diff --git a/cmake/cross_compiling/npu.cmake b/cmake/cross_compiling/npu.cmake index 25aa4d2bc8c1c145e7a103c9164e1c9e231a8f9e..c22bb1db4fbf8a7370ff3e7c9aca40cc94d550a2 100644 --- a/cmake/cross_compiling/npu.cmake +++ b/cmake/cross_compiling/npu.cmake @@ -30,7 +30,7 @@ if(NOT NPU_DDK_INC) message(FATAL_ERROR "Can not find HiAiModelManagerService.h in ${NPU_DDK_ROOT}/include") endif() -include_directories("${NPU_DDK_ROOT}") +include_directories("${NPU_DDK_ROOT}/include") set(NPU_SUB_LIB_PATH "lib64") if(ARM_TARGET_ARCH_ABI STREQUAL "armv8") diff --git a/cmake/cross_compiling/postproject.cmake b/cmake/cross_compiling/postproject.cmake index 88ac3e101a686cb49ef5a4c3b1879c15b8f7b57b..7466b3e6d438277ad31020f76665bf689df436f5 100644 --- a/cmake/cross_compiling/postproject.cmake +++ b/cmake/cross_compiling/postproject.cmake @@ -63,7 +63,7 @@ if (LITE_ON_TINY_PUBLISH) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-exceptions") endif() set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -ffast-math -Ofast -Os -fomit-frame-pointer -fno-asynchronous-unwind-tables -fno-unwind-tables") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -flto -fvisibility=hidden -fvisibility-inlines-hidden -fdata-sections -ffunction-sections") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=hidden -fvisibility-inlines-hidden -ffunction-sections") check_linker_flag(-Wl,--gc-sections) endif() diff --git a/cmake/cuda.cmake b/cmake/cuda.cmake index 9ff908a4c87d55e87468a06ae0e6085ac165a1b1..cfbda63f6d784a55803e3d3a44b9ec6a987bd964 100644 --- a/cmake/cuda.cmake +++ b/cmake/cuda.cmake @@ -174,15 +174,44 @@ if(NOT WITH_DSO) endif(WIN32) endif(NOT WITH_DSO) -get_filename_component(CUDA_LIB_PATH ${CUDA_curand_LIBRARY} DIRECTORY) -function(import_static_library alias path) - add_library(${alias} STATIC IMPORTED GLOBAL) - set_property(TARGET ${alias} PROPERTY IMPORTED_LOCATION ${path}) +function(add_cuda_lib TARGET_NAME) + set(options STATIC SHARED) + set(oneValueArgs "NAME") + set(multiValueArgs "PATHS") + cmake_parse_arguments(add_cuda_lib "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + unset(ABS_PATH CACHE) + if (NOT add_cuda_lib_PATHS) + set(add_cuda_lib_PATHS CUDNN_CHECK_LIBRARY_DIRS) + endif() + find_library(ABS_PATH NAMES ${add_cuda_lib_NAME} PATHS ${${add_cuda_lib_PATHS}} NO_DEFAULT_PATH) + add_library(${TARGET_NAME} SHARED IMPORTED GLOBAL) + set_property(TARGET ${TARGET_NAME} PROPERTY IMPORTED_LOCATION ${ABS_PATH}) + set(CUDA_MODULES ${CUDA_MODULES} ${TARGET_NAME} PARENT_SCOPE) + if (NOT ABS_PATH) + message(FATAL_ERROR "Can not find CUDA library: ${add_cuda_lib_NAME}") + endif() endfunction() -import_static_library(cudart_static ${CUDA_LIB_PATH}/libcudart_static.a) -import_static_library(cublas_static ${CUDA_LIB_PATH}/libcublas_static.a) -import_static_library(curand_static ${CUDA_LIB_PATH}/libcurand_static.a) -import_static_library(culibos_static ${CUDA_LIB_PATH}/libculibos.a) + +if(LITE_WITH_STATIC_CUDA) + message(STATUS "Static link CUDA toolkit.") + add_cuda_lib(cudart_static STATIC NAME libcudart_static.a) + add_cuda_lib(cublas_static STATIC NAME libcublas_static.a) + add_cuda_lib(curand_static STATIC NAME libcurand_static.a) + add_cuda_lib(culibos_static STATIC NAME libculibos.a) + if(NOT ${CUDA_VERSION} LESS 10.1) + add_cuda_lib(cublasLt_static STATIC NAME libcublasLt_static.a) + endif() + set_property(GLOBAL PROPERTY CUDA_MODULES cudnn_static ${CUDA_MODULES}) +else() + message(STATUS "Dynamic Link CUDA toolkit.") + add_cuda_lib(cudart SHARED NAME libcudart.so) + add_cuda_lib(cublas SHARED NAME libcublas.so) + add_cuda_lib(curand SHARED NAME libcurand.so) + if(NOT ${CUDA_VERSION} LESS 10.1) + add_cuda_lib(cublasLt SHARED NAME libcublasLt.so) + endif() + set_property(GLOBAL PROPERTY CUDA_MODULES cudnn ${CUDA_MODULES}) +endif() # setting nvcc arch flags select_nvcc_arch_flags(NVCC_FLAGS_EXTRA) diff --git a/cmake/cudnn.cmake b/cmake/cudnn.cmake index c0cb4ccea67cd493a30a6be43ee6ee48f70c36bf..d1386a6c7db08d140648106479a4e37947255c80 100644 --- a/cmake/cudnn.cmake +++ b/cmake/cudnn.cmake @@ -32,9 +32,9 @@ list(APPEND CUDNN_CHECK_LIBRARY_DIRS $ENV{CUDNN_ROOT}/lib64 $ENV{CUDNN_ROOT}/lib /usr/lib - ${CUDA_TOOLKIT_ROOT_DIR} - ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64 - ) + ${CUDA_TOOLKIT_ROOT_DIR} + ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64 + ${CUDA_TOOLKIT_ROOT_DIR}/lib64) if((${CUDA_VERSION} GREATER 10.0) OR (${CUDA_VERSION} EQUAL 10.0)) find_library(CUBLAS_LIBRARY NAMES libcublas.so PATHS ${CUDNN_CHECK_LIBRARY_DIRS} NO_DEFAULT_PATH) @@ -69,9 +69,15 @@ if(CUDNN_FOUND) file(READ ${CUDNN_INCLUDE_DIR}/cudnn.h CUDNN_VERSION_FILE_CONTENTS) get_filename_component(CUDNN_LIB_PATH ${CUDNN_LIBRARY} DIRECTORY) - add_library(cudnn_static STATIC IMPORTED GLOBAL) - set_property(TARGET cudnn_static PROPERTY IMPORTED_LOCATION + if(LITE_WITH_STATIC_CUDA) + add_library(cudnn_static STATIC IMPORTED GLOBAL) + set_property(TARGET cudnn_static PROPERTY IMPORTED_LOCATION "${CUDNN_LIB_PATH}/libcudnn_static.a") + else() + add_library(cudnn SHARED IMPORTED GLOBAL) + set_property(TARGET cudnn PROPERTY IMPORTED_LOCATION + "${CUDNN_LIB_PATH}/libcudnn.so") + endif(LITE_WITH_STATIC_CUDA) string(REGEX MATCH "define CUDNN_VERSION +([0-9]+)" CUDNN_VERSION "${CUDNN_VERSION_FILE_CONTENTS}") diff --git a/cmake/external/eigen.cmake b/cmake/external/eigen.cmake index bd0d117a633824d93c403b8167ff49505160069b..599e7bba7eaf12da7506ce44e706bd9f50ec6998 100644 --- a/cmake/external/eigen.cmake +++ b/cmake/external/eigen.cmake @@ -1,5 +1,6 @@ INCLUDE(ExternalProject) +SET(EIGEN_SOURCECODE_DIR ${CMAKE_SOURCE_DIR}/third-party/eigen3) SET(EIGEN_SOURCE_DIR ${THIRD_PARTY_PATH}/eigen3) SET(EIGEN_INCLUDE_DIR ${EIGEN_SOURCE_DIR}/src/extern_eigen3) INCLUDE_DIRECTORIES(${EIGEN_INCLUDE_DIR}) @@ -16,9 +17,12 @@ if(WITH_AMD_GPU) ExternalProject_Add( extern_eigen3 ${EXTERNAL_PROJECT_LOG_ARGS} - GIT_REPOSITORY "https://github.com/sabreshao/hipeigen.git" - GIT_TAG 7cb2b6e5a4b4a1efe658abb215cd866c6fb2275e + GIT_TAG + URL http://paddle-inference-dist.bj.bcebos.com/PaddleLite_ThirdParty%2Fhipeigen-upstream-702834151eaebcf955fd09ed0ad83c06.zip + DOWNLOAD_DIR ${EIGEN_SOURCECODE_DIR} + DOWNLOAD_NO_PROGRESS 1 PREFIX ${EIGEN_SOURCE_DIR} + DOWNLOAD_NAME "hipeigen-upstream-702834151eaebcf955fd09ed0ad83c06.zip" UPDATE_COMMAND "" CONFIGURE_COMMAND "" BUILD_COMMAND "" @@ -29,12 +33,14 @@ else() ExternalProject_Add( extern_eigen3 ${EXTERNAL_PROJECT_LOG_ARGS} - GIT_REPOSITORY "https://github.com/eigenteam/eigen-git-mirror" # eigen on cuda9.1 missing header of math_funtions.hpp # https://stackoverflow.com/questions/43113508/math-functions-hpp-not-found-when-using-cuda-with-eigen - GIT_TAG 917060c364181f33a735dc023818d5a54f60e54c + GIT_TAG + URL http://paddle-inference-dist.bj.bcebos.com/PaddleLite_ThirdParty%2Feigen-git-mirror-master-9ab917e9db99f5907d086aa73d5f9103.zip + DOWNLOAD_DIR ${EIGEN_SOURCECODE_DIR} + DOWNLOAD_NO_PROGRESS 1 PREFIX ${EIGEN_SOURCE_DIR} - DOWNLOAD_NAME "eigen" + DOWNLOAD_NAME "eigen-git-mirror-master-9ab917e9db99f5907d086aa73d5f9103.zip" UPDATE_COMMAND "" CONFIGURE_COMMAND "" BUILD_COMMAND "" diff --git a/cmake/external/xbyak.cmake b/cmake/external/xbyak.cmake index 1d61154c0d45dea795902d6544deb796693db263..5166b494c489e25c970c7dbfe72fa1404302009f 100644 --- a/cmake/external/xbyak.cmake +++ b/cmake/external/xbyak.cmake @@ -20,6 +20,7 @@ endif() include(ExternalProject) +SET(XBYAK_SOURCECODE_DIR ${CMAKE_SOURCE_DIR}/third-party/xbyak) set(XBYAK_PROJECT extern_xbyak) set(XBYAK_PREFIX_DIR ${THIRD_PARTY_PATH}/xbyak) set(XBYAK_INSTALL_ROOT ${THIRD_PARTY_PATH}/install/xbyak) @@ -38,8 +39,11 @@ ExternalProject_Add( ${XBYAK_PROJECT} ${EXTERNAL_PROJECT_LOG_ARGS} DEPENDS "" - GIT_REPOSITORY "https://github.com/herumi/xbyak.git" GIT_TAG "v5.661" # Jul 26th + URL http://paddle-inference-dist.bj.bcebos.com/PaddleLite_ThirdParty%2Fxbyak-5.66.zip + DOWNLOAD_DIR ${XBYAK_SOURCECODE_DIR} + DOWNLOAD_NAME "xbyak-5.66.zip" + DOWNLOAD_NO_PROGRESS 1 PREFIX ${XBYAK_PREFIX_DIR} UPDATE_COMMAND "" CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${XBYAK_INSTALL_ROOT} diff --git a/cmake/external/xxhash.cmake b/cmake/external/xxhash.cmake index 23b1e02108642df561948a6faa3152effb7ca932..fdc20351e8bcdf5fe8e95db3516f4c6f607611db 100644 --- a/cmake/external/xxhash.cmake +++ b/cmake/external/xxhash.cmake @@ -1,5 +1,6 @@ INCLUDE(ExternalProject) +SET(XXHASH_SOURCECODE_DIR ${CMAKE_SOURCE_DIR}/third-party/xxhash) set(XXHASH_SOURCE_DIR ${THIRD_PARTY_PATH}/xxhash) set(XXHASH_INSTALL_DIR ${THIRD_PARTY_PATH}/install/xxhash) set(XXHASH_INCLUDE_DIR "${XXHASH_INSTALL_DIR}/include") @@ -18,10 +19,12 @@ if(WIN32) ExternalProject_Add( extern_xxhash ${EXTERNAL_PROJECT_LOG_ARGS} - GIT_REPOSITORY "https://github.com/Cyan4973/xxHash" GIT_TAG "v0.6.5" + URL http://paddle-inference-dist.bj.bcebos.com/PaddleLite_ThirdParty%2FxxHash-0.6.5.zip + DOWNLOAD_DIR ${XXHASH_SOURCECODE_DIR} + DOWNLOAD_NAME "xxHash-0.6.5.zip" + DOWNLOAD_NO_PROGRESS 1 PREFIX ${XXHASH_SOURCE_DIR} - DOWNLOAD_NAME "xxhash" UPDATE_COMMAND "" BUILD_IN_SOURCE 1 PATCH_COMMAND @@ -41,10 +44,12 @@ else() ExternalProject_Add( extern_xxhash ${EXTERNAL_PROJECT_LOG_ARGS} - GIT_REPOSITORY "https://github.com/Cyan4973/xxHash" GIT_TAG "v0.6.5" + URL http://paddle-inference-dist.bj.bcebos.com/PaddleLite_ThirdParty%2FxxHash-0.6.5.zip + DOWNLOAD_DIR ${XXHASH_SOURCECODE_DIR} + DOWNLOAD_NO_PROGRESS 1 PREFIX ${XXHASH_SOURCE_DIR} - DOWNLOAD_NAME "xxhash" + DOWNLOAD_NAME "xxHash-0.6.5.zip" UPDATE_COMMAND "" CONFIGURE_COMMAND "" BUILD_IN_SOURCE 1 diff --git a/cmake/lite.cmake b/cmake/lite.cmake index a095eea6d1cce9ba09ee631a50b8029e769f6d37..fd40fa437b52ff33089b55c6cfb7df6604a0530d 100644 --- a/cmake/lite.cmake +++ b/cmake/lite.cmake @@ -22,7 +22,7 @@ endfunction() function (lite_deps TARGET) set(options "") set(oneValueArgs "") - set(multiValueArgs DEPS X86_DEPS CUDA_DEPS ARM_DEPS PROFILE_DEPS LIGHT_DEPS HVY_DEPS CL_DEPS FPGA_DEPS NPU_DEPS XPU_DEPS ARGS) + set(multiValueArgs DEPS X86_DEPS CUDA_DEPS ARM_DEPS PROFILE_DEPS LIGHT_DEPS HVY_DEPS CL_DEPS FPGA_DEPS BM_DEPS NPU_DEPS XPU_DEPS CV_DEPS ARGS) cmake_parse_arguments(lite_deps "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) set(deps ${lite_deps_DEPS}) @@ -44,7 +44,7 @@ function (lite_deps TARGET) set(deps ${deps} ${var}) endforeach(var) if(LITE_WITH_CV) - foreach(var ${lite_cv_deps}) + foreach(var ${lite_deps_CV_DEPS}) set(deps ${deps} ${var}) endforeach(var) endif() @@ -94,6 +94,12 @@ function (lite_deps TARGET) endforeach(var) endif() + if (LITE_WITH_BM) + foreach(var ${lite_deps_BM_DEPS}) + set(deps ${deps} ${var}) + endforeach(var) + endif() + set(${TARGET} ${deps} PARENT_SCOPE) endfunction() @@ -115,10 +121,11 @@ file(WRITE ${offline_lib_registry_file} "") # clean # LIGHT_DEPS: LITE_WITH_LIGHT_WEIGHT_FRAMEWORK # HVY_DEPS: NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK # EXCLUDE_COMPILE_DEPS: TARGET will not be included in lite_compile_deps if this is not None +# CV_DEPS: LITE_WITH_CV function(lite_cc_library TARGET) set(options SHARED shared STATIC static MODULE module) set(oneValueArgs "") - set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS CL_DEPS NPU_DEPS XPU_DEPS ARM_DEPS FPGA_DEPS PROFILE_DEPS LIGHT_DEPS + set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS CL_DEPS ARM_DEPS FPGA_DEPS BM_DEPS NPU_DEPS XPU_DEPS CV_DEPS PROFILE_DEPS LIGHT_DEPS HVY_DEPS EXCLUDE_COMPILE_DEPS ARGS) cmake_parse_arguments(args "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) @@ -128,10 +135,12 @@ function(lite_cc_library TARGET) X86_DEPS ${args_X86_DEPS} CUDA_DEPS ${args_CUDA_DEPS} CL_DEPS ${args_CL_DEPS} - NPU_DEPS ${args_NPU_DEPS} - XPU_DEPS ${args_XPU_DEPS} + BM_DEPS ${args_BM_DEPS} ARM_DEPS ${args_ARM_DEPS} + CV_DEPS ${args_CV_DEPS} FPGA_DEPS ${args_FPGA_DEPS} + NPU_DEPS ${args_NPU_DEPS} + XPU_DEPS ${args_XPU_DEPS} PROFILE_DEPS ${args_PROFILE_DEPS} LIGHT_DEPS ${args_LIGHT_DEPS} HVY_DEPS ${args_HVY_DEPS} @@ -161,8 +170,8 @@ function(lite_cc_binary TARGET) set(options " -g ") endif() set(oneValueArgs "") - set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS CL_DEPS ARM_DEPS FPGA_DEPS PROFILE_DEPS - LIGHT_DEPS HVY_DEPS EXCLUDE_COMPILE_DEPS ARGS) + set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS CL_DEPS ARM_DEPS FPGA_DEPS BM_DEPS NPU_DEPS XPU_DEPS PROFILE_DEPS + LIGHT_DEPS HVY_DEPS EXCLUDE_COMPILE_DEPS CV_DEPS ARGS) cmake_parse_arguments(args "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) set(deps "") @@ -173,9 +182,13 @@ function(lite_cc_binary TARGET) CL_DEPS ${args_CL_DEPS} ARM_DEPS ${args_ARM_DEPS} FPGA_DEPS ${args_FPGA_DEPS} + NPU_DEPS ${args_NPU_DEPS} + XPU_DEPS ${args_XPU_DEPS} + BM_DEPS ${args_BM_DEPS} PROFILE_DEPS ${args_PROFILE_DEPS} LIGHT_DEPS ${args_LIGHT_DEPS} HVY_DEPS ${args_HVY_DEPS} + CV_DEPS ${CV_DEPS} ) cc_binary(${TARGET} SRCS ${args_SRCS} DEPS ${deps}) target_compile_options(${TARGET} BEFORE PRIVATE -Wno-ignored-qualifiers) @@ -205,8 +218,8 @@ function(lite_cc_test TARGET) endif() set(options "") set(oneValueArgs "") - set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS CL_DEPS ARM_DEPS FPGA_DEPS PROFILE_DEPS - LIGHT_DEPS HVY_DEPS EXCLUDE_COMPILE_DEPS + set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS CL_DEPS ARM_DEPS FPGA_DEPS BM_DEPS NPU_DEPS XPU_DEPS PROFILE_DEPS + LIGHT_DEPS HVY_DEPS EXCLUDE_COMPILE_DEPS CV_DEPS ARGS COMPILE_LEVEL # (basic|extra) ) @@ -225,9 +238,13 @@ function(lite_cc_test TARGET) CL_DEPS ${args_CL_DEPS} ARM_DEPS ${args_ARM_DEPS} FPGA_DEPS ${args_FPGA_DEPS} + NPU_DEPS ${args_NPU_DEPS} + XPU_DEPS ${args_XPU_DEPS} + BM_DEPS ${args_BM_DEPS} PROFILE_DEPS ${args_PROFILE_DEPS} LIGHT_DEPS ${args_LIGHT_DEPS} HVY_DEPS ${args_HVY_DEPS} + CV_DEPS ${args_CV_DEPS} ) _lite_cc_test(${TARGET} SRCS ${args_SRCS} DEPS ${deps} ARGS ${args_ARGS}) # strip binary target to reduce size @@ -252,6 +269,7 @@ set(cuda_kernels CACHE INTERNAL "cuda kernels") set(fpga_kernels CACHE INTERNAL "fpga kernels") set(npu_kernels CACHE INTERNAL "npu kernels") set(xpu_kernels CACHE INTERNAL "xpu kernels") +set(bm_kernels CACHE INTERNAL "bm kernels") set(opencl_kernels CACHE INTERNAL "opencl kernels") set(host_kernels CACHE INTERNAL "host kernels") @@ -262,12 +280,12 @@ if(LITE_BUILD_TAILOR) file(STRINGS ${tailored_kernels_list_path} tailored_kernels_list) endif() # add a kernel for some specific device -# device: one of (Host, ARM, X86, NPU, FPGA, OPENCL, CUDA) +# device: one of (Host, ARM, X86, NPU, FPGA, OPENCL, CUDA, BM) # level: one of (basic, extra) function(add_kernel TARGET device level) set(options "") set(oneValueArgs "") - set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS CL_DEPS ARM_DEPS FPGA_DEPS PROFILE_DEPS + set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS CL_DEPS ARM_DEPS FPGA_DEPS BM_DEPS NPU_DEPS XPU_DEPS PROFILE_DEPS LIGHT_DEPS HVY_DEPS EXCLUDE_COMPILE_DEPS ARGS) cmake_parse_arguments(args "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) @@ -333,6 +351,12 @@ function(add_kernel TARGET device level) endif() set(fpga_kernels "${fpga_kernels};${TARGET}" CACHE INTERNAL "") endif() + if ("${device}" STREQUAL "BM") + if (NOT LITE_WITH_BM) + return() + endif() + set(bm_kernels "${bm_kernels};${TARGET}" CACHE INTERNAL "") + endif() if ("${device}" STREQUAL "OPENCL") if (NOT LITE_WITH_OPENCL) return() @@ -360,11 +384,13 @@ function(add_kernel TARGET device level) lite_cc_library(${TARGET} SRCS ${args_SRCS} DEPS ${args_DEPS} X86_DEPS ${args_X86_DEPS} - XPU_DEPS ${args_XPU_DEPS} CUDA_DEPS ${args_CUDA_DEPS} CL_DEPS ${args_CL_DEPS} ARM_DEPS ${args_ARM_DEPS} FPGA_DEPS ${args_FPGA_DEPS} + NPU_DEPS ${args_NPU_DEPS} + XPU_DEPS ${args_XPU_DEPS} + BM_DEPS ${args_BM_DEPS} PROFILE_DEPS ${args_PROFILE_DEPS} LIGHT_DEPS ${args_LIGHT_DEPS} HVY_DEPS ${args_HVY_DEPS} @@ -383,7 +409,7 @@ endif() function(add_operator TARGET level) set(options "") set(oneValueArgs "") - set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS CL_DEPS ARM_DEPS FPGA_DEPS PROFILE_DEPS + set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS CL_DEPS ARM_DEPS FPGA_DEPS BM_DEPS NPU_DEPS XPU_DEPS PROFILE_DEPS LIGHT_DEPS HVY_DEPS EXCLUDE_COMPILE_DEPS ARGS) cmake_parse_arguments(args "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) @@ -409,11 +435,13 @@ function(add_operator TARGET level) lite_cc_library(${TARGET} SRCS ${args_SRCS} DEPS ${args_DEPS} X86_DEPS ${args_X86_DEPS} - XPU_DEPS ${args_XPU_DEPS} CUDA_DEPS ${args_CUDA_DEPS} CL_DEPS ${args_CL_DEPS} ARM_DEPS ${args_ARM_DEPS} FPGA_DEPS ${args_FPGA_DEPS} + NPU_DEPS ${args_NPU_DEPS} + XPU_DEPS ${args_XPU_DEPS} + BM_DEPS ${args_BM_DEPS} PROFILE_DEPS ${args_PROFILE_DEPS} LIGHT_DEPS ${args_LIGHT_DEPS} HVY_DEPS ${args_HVY_DEPS} diff --git a/cmake/xpu.cmake b/cmake/xpu.cmake index 8d99343c3041351102820cb20890031fa3f5807e..2112f6b658f5f89b20d63c957cd0b979299c350b 100644 --- a/cmake/xpu.cmake +++ b/cmake/xpu.cmake @@ -99,7 +99,7 @@ else() set_property(TARGET xpu_sdk_llvm PROPERTY IMPORTED_LOCATION ${XPU_SDK_LLVM_FILE}) endif() -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DDMLC_USE_GLOG=1") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DDMLC_USE_GLOG=1 -D_GLIBCXX_USE_CXX11_ABI=0") set(xpu_runtime_libs xpu_sdk_xtcl xpu_sdk_tvm xpu_sdk_xpu_api xpu_sdk_xpu_rt xpu_sdk_xpu_jitc xpu_sdk_llvm CACHE INTERNAL "xpu runtime libs") set(xpu_builder_libs xpu_sdk_xtcl xpu_sdk_tvm xpu_sdk_xpu_api xpu_sdk_xpu_rt xpu_sdk_xpu_jitc xpu_sdk_llvm CACHE INTERNAL "xpu builder libs") diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..298ea9e213e8c4c11f0431077510d4e325733c65 --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,19 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line. +SPHINXOPTS = +SPHINXBUILD = sphinx-build +SOURCEDIR = . +BUILDDIR = _build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) \ No newline at end of file diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 0000000000000000000000000000000000000000..66f9b291ba3b459a8d3a327f7a71d9bd2f7031e0 --- /dev/null +++ b/docs/README.md @@ -0,0 +1 @@ +请参考[PaddleLite文档开发规范](http://agroup.baidu.com/paddle-infer/md/article/2561104)。 diff --git a/docs/advanced_user_guides/add_layout.md b/docs/advanced_user_guides/add_layout.md new file mode 100644 index 0000000000000000000000000000000000000000..11e504f93c2b1bcaefaa06c0a5f51aea0995884e --- /dev/null +++ b/docs/advanced_user_guides/add_layout.md @@ -0,0 +1,184 @@ +# 如何增加Layout + +Paddle-Lite中Place包含了Target、Layout、Precision信息,用来注册和选择模型中的具体Kernel。下面以增加Place中的layout:`ImageDefault`、`ImageFolder`、`ImageNW`为例,讲解如何增加新Layout。 + +根据在`lite/core/`、`lite/api`目录下以`NHWC`为关键词检索代码,发现需要分别在以下的文件中加入Layout内容: + +1. lite/api/paddle_place.h +2. lite/api/paddle_place.cc +3. lite/api/python/pybind/pybind.cc +4. lite/core/op_registry.h +5. lite/core/op_registry.cc + +## 1. lite/api/paddle_place.h + +在`enum class DataLayoutType`中加入对应的Layout,注意已有的Layout不能改变值,增加新Layout递增即可: + +```cpp +enum class DataLayoutType : int { + kUnk = 0, + kNCHW = 1, + kNHWC = 3, + kImageDefault = 4, // for opencl image2d + kImageFolder = 5, // for opencl image2d + kImageNW = 6, // for opencl image2d + kAny = 2, // any data layout + NUM = 7, // number of fields. +}; +``` + +## 2. lite/api/paddle_place.cc + +本文件有3处修改,注意在` DataLayoutToStr`函数中加入对应Layout的字符串名,顺序为`lite/api/paddle_place.h`中枚举值的顺序: + +```cpp +// 该文件第1处 +const std::string& DataLayoutToStr(DataLayoutType layout) { + static const std::string datalayout2string[] = { + "unk", "NCHW", "any", "NHWC", "ImageDefault", "ImageFolder", "ImageNW"}; + auto x = static_cast(layout); + CHECK_LT(x, static_cast(DATALAYOUT(NUM))); + return datalayout2string[x]; +} + +// 该文件第2处 +const std::string& DataLayoutRepr(DataLayoutType layout) { + static const std::string datalayout2string[] = {"kUnk", + "kNCHW", + "kAny", + "kNHWC", + "kImageDefault", + "kImageFolder", + "kImageNW"}; + auto x = static_cast(layout); + CHECK_LT(x, static_cast(DATALAYOUT(NUM))); + return datalayout2string[x]; +} + +// 该文件第3处 +std::set ExpandValidLayouts(DataLayoutType layout) { + static const std::set valid_set({DATALAYOUT(kNCHW), + DATALAYOUT(kAny), + DATALAYOUT(kNHWC), + DATALAYOUT(kImageDefault), + DATALAYOUT(kImageFolder), + DATALAYOUT(kImageNW)}); + if (layout == DATALAYOUT(kAny)) { + return valid_set; + } + return std::set({layout}); +} +``` + +## 3. lite/api/python/pybind/pybind.cc + +```cpp + // DataLayoutType + py::enum_(*m, "DataLayoutType") + .value("NCHW", DataLayoutType::kNCHW) + .value("NHWC", DataLayoutType::kNHWC) + .value("ImageDefault", DataLayoutType::kImageDefault) + .value("ImageFolder", DataLayoutType::kImageFolder) + .value("ImageNW", DataLayoutType::kImageNW) + .value("Any", DataLayoutType::kAny); +``` + +## 4. lite/core/op_registry.h + +找到KernelRegister final中的`using any_kernel_registor_t =`,加入下面修改信息: + +```cpp +// 找到KernelRegister final中的`using any_kernel_registor_t =` +// 加入如下内容: + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // +``` + + +## 5. lite/core/op_registry.cc + +该文件有2处修改: + +```cpp +// 该文件第1处 +#define CREATE_KERNEL1(target__, precision__) \ + switch (layout) { \ + case DATALAYOUT(kNCHW): \ + return Create(op_type); \ + case DATALAYOUT(kAny): \ + return Create(op_type); \ + case DATALAYOUT(kNHWC): \ + return Create(op_type); \ + case DATALAYOUT(kImageDefault): \ + return Create(op_type); \ + case DATALAYOUT(kImageFolder): \ + return Create(op_type); \ + case DATALAYOUT(kImageNW): \ + return Create(op_type); \ + default: \ + LOG(FATAL) << "unsupported kernel layout " << DataLayoutToStr(layout); \ + } + +// 该文件第2处 +// 找到文件中的下面的函数 +KernelRegistry::KernelRegistry() + : registries_(static_cast(TARGET(NUM)) * + static_cast(PRECISION(NUM)) * + static_cast(DATALAYOUT(NUM))) + +// 在该函数中加入新增Layout的下面内容 + INIT_FOR(kOpenCL, kFP16, kNCHW); + INIT_FOR(kOpenCL, kFP16, kNHWC); + INIT_FOR(kOpenCL, kFP16, kImageDefault); + INIT_FOR(kOpenCL, kFP16, kImageFolder); + INIT_FOR(kOpenCL, kFP16, kImageNW); + INIT_FOR(kOpenCL, kFloat, kImageDefault); + INIT_FOR(kOpenCL, kFloat, kImageFolder); + INIT_FOR(kOpenCL, kFloat, kImageNW); + INIT_FOR(kOpenCL, kAny, kImageDefault); + INIT_FOR(kOpenCL, kAny, kImageFolder); + INIT_FOR(kOpenCL, kAny, kImageNW); +``` diff --git a/docs/advanced_user_guides/add_new_pass.md b/docs/advanced_user_guides/add_new_pass.md new file mode 100644 index 0000000000000000000000000000000000000000..93b27cd038642c702cd213adffcc378dc852a1b3 --- /dev/null +++ b/docs/advanced_user_guides/add_new_pass.md @@ -0,0 +1,437 @@ + +# 新增Pass方法 + +本文从三个方面介绍了`Lite`中的`Pass`结构:**Pass是什么**、**Pass的实现与接口**、**Pass的一般注册流程**。最后以`Fc_fuse_pass`为例介绍了`fusion_pass`的作用与注册方法。 + +## 前述:Pass是什么? + +**CxxPredictor加载模型后,在执行预测前会先优化模型。模型优化过程是通过Pass实现的。** +具体调用关系如下: +![图片](https://user-images.githubusercontent.com/45189361/69638690-20d21880-1096-11ea-8169-1d2c7e1a1609.png) + + - `CreatePredictor(CxxConfig)`函数调用了Predictor->Build(CxxConfig) + - CxxPredictor的构建过程(Build)分为两步: + - Predictor->LoadModel() 加载模型文件到program中 + - Predicotr->optimizer_.Run() 对Program中的原始图形结构进行优化 + - 对图结构的优化是通过调用 `Pass->Apply(const std::unique_ptr& graph)`方法实现的。 + + +**每一类Pass定义了一种优化过程**,包括:原模型中的kernel选取、OP融合、冗余OP去除、子图创建、内存优化、类型推导、类型转换等。 + + + + +## Pass的实现与接口 :Pass基类、PassManager和Pass注册 + +### 1、Pass基类:`paddle::lite::mir::Pass` +```c++ +class Pass { + public: + // Pass的类型,Pass按照作用的不同可以分为三种 + enum class Kind { //种类的作用不太清楚 + // 1. 修改模型中的图拓扑结构的Pass + kProgramWise = 0, + // 2. 不修改图结构,修改状态的Pass + kStmtWise, + // 3. 不修改 IR,用于搜集信息和可视化信息的Pass. + kDebug, + }; + + // 主要实现函数:Apply 函数定义了 Pass 运行时执行的操作 + virtual void Apply(const std::unique_ptr& graph) = 0; + + bool is_program_pass() const { return kind_ == Kind::kProgramWise; } + bool is_stmt_pass() const { return kind_ == Kind::kStmtWise; } + + virtual ~Pass() = default; + + private: + const Kind kind_; // pass 的种类 + std::string name_; // pass 的名称 + std::set bound_targets_; // 指定了Pass运行的硬件平台,模型优化过程会根据当前硬件平台是否匹配筛选Pass。 + std::unordered_map> bound_kernels_; // 绑定的kernel +}; + + +// Different kinds. +class ProgramPass : public Pass { + public: + ProgramPass() : Pass(Kind::kProgramWise) {} +}; +class StmtPass : public Pass { + public: + StmtPass() : Pass(Kind::kStmtWise) {} +}; + +class DebugPass : public Pass { + public: + DebugPass() : Pass(Kind::kDebug) {} +}; +``` +**代码位置**:`lite/core/mir/pass.h` +**主要类成员**: + `const Kind kind_` : Pass类型。pass 有三种基本基本类型 :修改图结构的`ProgramPass`、修改状态量的`StmtPass`和Debug过程采集信息与控制可视化的`DebugPass`。 + `std::string name_` :pass 的名称 + `std::set bound_targets_` : Pass运行的硬件平台,optimizer.Run()优化过程会根据硬件平台选择匹配的Pass。------根据硬件平台自动选择需要的pass + `std::unordered_map> bound_kernels_` : Pass 绑定的kernel (what's this used for) +**主要接口**: + `Pass::Apply(const std::unique_ptr& graph)` : Pass优化过程的具体操作,是新注册Pass需要实现的接口。输入为`SSAGraph`型指针,是对模型结构的拓扑表示。 + +### 2、Pass管理 `paddle::lite::mir::PassManager` + +```c++ +class PassManager { + public: + // 内部静态变量PassManager,用来存储使用的Pass和图优化操作 + static PassManager& Global() { + static PassManager x; + return x; + } + + // 执行所有的 Pass + void Run(const std::unique_ptr& graph) { + for (auto& pass : passes_) { + LOG(INFO) << "Running MIR pass " << pass->name(); + pass->Apply(graph); + } + + private: + std::list passes_; //存储所有的 Pass + std::map pass_map_; //使用map变量存储 PassName::Pass + + } + +``` +**代码位置**:`lite/core/mir/pass_manager.h` +**主要类成员**: +`std::list:unique_ptr> passes_;` : List类型,存储了所有已注册Pass。 +`std::map pass_map_; ` : Map类型,存储了所有"Pass名称-Pass类"键对,用于根据名称查找Pass。 + +**主要接口**: + `static PassManager& Global()` 返回PassManager全局静态变量,该变量存储了所有已注册的Pass +` bool AddNewPass(const std::string& name, Pass* pass)` 添加新的Pass到PassManager中 + + +### 3、 Pass 注册 `paddle::lite::mir::PassRegistry` +**代码位置**:`lite/core/mir/pass_registry.h` +**主要接口**: +`REGISTER_MIR_PASS(name__, class__)` :宏定义函数,用于注册Pass。注册Pass过程实现的是 `PassManager::Global().AddNewPass(name__, class__)`,将新注册Pass添加到全局变量`PassManager`中。 + + + +## Pass的一般注册流程与使用方法 + +### 1. Pass 注册流程 +在`lite/core/mir`或其子目录下继承`Pass基类`,实现`Pass::Apply`接口,并使用宏`REGISTER_MIR_PASS(name__, class__)`将Pass注册到`PassManager`即完成了新Pass注册。 + +**以新建 **`new_demo_pass`**为例**,具体流程如下: +(1)在`lite/core/mir`路径下新建`example_pass.cc` 和 `new_demo_pass.h` 文件 +(2)在`example_pass.h` 文件中继承Pass基类(ProgramPass、StmtPass或DebugPass)定义自己的Pass类。 +```c++ +#include "lite/core/mir/pass.h" + +namespace paddle { +namespace lite { +namespace mir { +class ExamplePass : public ProgramPass { + void Apply(const std::unique_ptr &graph) override {} + ... +}; +} // namespace mir +} // namespace lite +} // namespace paddle +``` +(3)在`example_pass.cc` 文件中实现`ExamplePass::Apply()`接口,并注册`ExamplePass` +```c++ +#include "lite/core/mir/pass_registry.h" +#include "lite/core/mir/example_pass.h" + +namespace paddle { +namespace lite { +namespace mir { +void ExamplePass::Apply(const std::unique_ptr& graph) { + ... +} +} // namespace mir +} // namespace lite +} // namespace paddle +REGISTER_MIR_PASS(example_pass, paddle::lite::mir::ExamplePass) + .BindTargets({TARGET(kARM)}); // Pass执行的目标硬件平台 + // .BindKernel("conv2d"); //Pass绑定的 kernel +``` + +(4)修改`lite/core/mir/CMakeLists.txt`文件,将`example_pass.cc` 编译到`mir_passes`库中 + +```cmake +lite_cc_library(mir_passes + SRCS + demo_pass.cc // 新建的Pass文件 + ... + memory_optimize_pass.cc + DEPS mir_pass types context ${mir_fusers} ${subgraph_passes}) +``` +### 2. Pass使用流程 + +将Pass注册到PassManager后不会自动生效。需要在`optimizer->run()` 函数中添加该Pass才会在模型优化过程中调用。 +(1)在`paddle_use_passes.h`文件中调用该Pass + +```cmake +#include "paddle_lite_factory_helper.h" // NOLINT + ... +USE_MIR_PASS(new_demo_pass); //调用 new_demo_pass +``` +(2)要想在优化模型时调用该Pass,需要在`optimizer->run()`函数中手动添加调用。 + +修改`lite/core/optimizer.h`文件,添加`new_demo_pass`到`Optimizer::Run()`函数; +```c++ + class Optimizer { + public: + void Run(...) { + ... + if (passes.empty()) { + RunPasses(std::vector{ + {"new_demo_pass" //将新注册的Pass添加在这里 + ... + } + ... + } +``` +(3)只有CxxPredictor才会在模型加载后根据Pass优化模型。 +```c++ + ... +#include "paddle_use_passes.h" // 引用Pass优化模型 +void RunModel() { + // 1. 创建 CxxConfig + CxxConfig config; + config.set_model_dir(FLAGS_model_dir); + config.set_valid_places(Place{TARGET(kARM), PRECISION(kFloat)}); + + // 2. 创建CxxPredictor,该过程包括加载模型和用Pass优化模型 + std::shared_ptr> predictor = + Creat(config); +} +``` + + + + +## Fusion Pass的定义与注册 + +`Fusion Pass`是一种常见图结构优化Pass,可将多个连续OP融合成单个等效OP,减少数据交换并简化图结构。Pass运行时调用`Fuser`自动查找并替换指定图结构,所以注册`FuserPass`时还需要实现对应的Fuser类。 + +下面以`fc_fuse_pass`为例,详细说明`FusionPass`的效果和注册方法。 + +### `fc_fuse_pass`的作用 +将相邻的`mul`算子和 `element_wise add `算子 融合成一个 `FC` 算子 +```c++ +mul(X) = X * W +elementwise_add( mul(x) ) = X * W + Bias +//----------> after fusion +FC(X) = X * W +Bias +``` + +Pass 运行效果如下: +![图片](https://user-images.githubusercontent.com/45189361/69639193-12383100-1097-11ea-9063-21f030414080.png) +mul和elementwise_add的原有参数映射到FC的参数上: +![图片](https://user-images.githubusercontent.com/45189361/69638836-74446680-1096-11ea-9cdc-a961fa995dfe.png) + +### `fc_fuse_pass`的注册方法 +#### 1、创建FcFuser +(1)在`lite/core/mir/fusion`路径下新建`fc_fuser.cc` 和 `fc_fuser.h` 文件 +(2)在`fc_fuser.h` 文件中继承`FuseBase`定义自己的Fuser类。 + +```c++ +#include "lite/core/mir/pattern_matcher_high_api.h" + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +class FcFuser : public FuseBase { + public: + void BuildPattern() override; + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; + + private: + cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; +}; + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle +``` +**主要接口**: +`FuseBase::BuildPattern` : 描述需要替换位置的图结构(pattern),Fuser运行时会自动查找并替换该pattern。 +`FuseBase::GenOpDesc` : 创建融合后的等效Fused_op。 +`FuseBase::InsertNewNode` :用Fused_op替换原始图结构(pattern)。 + +对于 `FcFuser`:BuildPattern描述的Pattern是`mul+elementwise add`,GenOpDesc创建的FC_op,InsertNewNode函数的效果是用新建的`FC_op`替换模型中的`mul+elementwise add` pattern。 + + +(3) 在`fc_fuser.cc`文件中实现 `BuildPattern()` 、`GenOpDesc()`、`InsertNewNode() `接口 + +下面以FcFuser为例介绍三种接口的实现: + +```c++ +// 1. BuildPattern函数,描述需要替换的图结构 +// FcFuser::BuildPattern() 描述了 mul + element_wise add 图结构 +void FcFuser::BuildPattern() { + // (1) 用OpNode描述和VarNode + // mul OP + auto* mul = OpNode("mul", "mul"); + // mul OP 的输入和输出 + auto* x = VarNode("x")->assert_is_op_input("mul", "X"); + auto* W = VarNode("W")->assert_is_op_input("mul", "Y"); + auto* mul_out = VarNode("mul_out"); + + // elementwise_add OP + auto* add = OpNode("add", "elementwise_add"); + //elementwise_add 的输入 + auto* b = VarNode("b")->assert_is_persistable_var(); + // elementwise_add OP的输出(最终输出) + auto* Out = VarNode("Out"); + + //(2) 描述拓扑连接 (Fuse之前mul 和elementwise_add的连接) + std::vector mul_inputs{W, x}; + std::vector add_inputs{mul_out, b}; + mul_inputs >> *mul >> *mul_out; + add_inputs >> *add >> *Out; + + + //(3) 声明新的拓扑结构中将会被移除的节点,包括被fuse的OP和OP之间的中间变量 + mul_out->AsIntermediate(); + mul->AsIntermediate(); + add->AsIntermediate(); +} + + +// 2. GenOpDesc函数新建等效 Fused_op +// FcFuser::GenOpDesc() 新建了Fc_op +cpp::OpDesc FcFuser::GenOpDesc(const key2nodes_t& matched) { + // (1) 得到第一个OP节点的 OpDesc ,并清空输入输出信息 + cpp::OpDesc op_desc = *matched.at("mul")->stmt()->op_info(); + op_desc.mutable_inputs()->clear(); + op_desc.mutable_outputs()->clear(); + // (2) 修改OpDesc , 将OpType设置为 "fc" (FC OP 的OP_type), + op_desc.SetType("fc"); + // (3) 设置OpDesc中的Input、Output、Attrbute。分别连接到BuildPattern()函数中创建的VarNode + op_desc.SetInput("Input", {matched.at("x")->arg()->name}); + op_desc.SetInput("W", {matched.at("W")->arg()->name}); + op_desc.SetInput("Bias", {matched.at("b")->arg()->name}); + op_desc.SetOutput("Out", {matched.at("Out")->arg()->name}); + op_desc.SetAttr( + "in_num_col_dims", + matched.at("mul")->stmt()->op_info()->GetAttr("x_num_col_dims")); + return op_desc; +} + +// 3. InsertNewNode函数用Fused OP 替换模型图中的原始 Pattern +// FcFuser::InsertNewNode() 用Fc_OP替换原始模型图中的 " mul + element_wise add " +void FcFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { + // (1) 创建FC OP的参数(OpDesc) + auto op_desc = GenOpDesc(matched); + // 创建一个 FC OP + auto fc_op = LiteOpRegistry::Global().Create("fc"); + + // 找到原拓扑结构中的scope (作用域)和 valid_places (可支持设备类型) + auto mul = matched.at("mul")->stmt()->op(); + auto* scope = mul->scope(); + auto& valid_places = mul->valid_places(); + + // (2) 将 FC OP的 scope和 valid_places设置与fuse前相同,并在图中创建该节点(node) + fc_op->Attach(op_desc, scope); + auto* new_op_node = graph->GraphCreateInstructNode(fc_op, valid_places); + + // (3) 将FC节点连接到输入输出(var_node) + IR_NODE_LINK_TO(matched.at("W"), new_op_node); + IR_NODE_LINK_TO(matched.at("x"), new_op_node); + IR_NODE_LINK_TO(matched.at("b"), new_op_node); + IR_NODE_LINK_TO(new_op_node, matched.at("Out")); +} +``` + +#### 2、注册fc_fuse_pass + +(1)在`lite/core/mir/fusion`路径下新建`fc_fuse_pass.cc` 和 `fc_fuse_pass.h` 文件 +(2)在`fc_fuse_pass.h` 文件中,继承`ProgramPass`定义`FcFusePass`。 + +```c++ +#include "lite/core/mir/pass.h" + +namespace paddle { +namespace lite { +namespace mir { +class FcFusePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override; namespace mir namespace lite namespace paddle +``` +(3)在`fc_fuse_pass.cc` 文件中实现`FcFusePass::Apply()`接口,并注册`FcFusePass` +```c++ +#include "lite/core/mir/pass_registry.h" +#include "lite/core/mir/example_pass.h" + +namespace paddle { +namespace lite { +namespace mir { +void FcFusePass::Apply(const std::unique_ptr& graph) { + fusion::FcFuser fuser; + fuser(graph.get());namespace mir +} // namespace lite +} // namespace paddle +REGISTER_MIR_PASS(lite_fc_fuse_pass, paddle::lite::mir::FcFusePass) + .BindTargets({TARGET(kAny)}) // FcFusePass 可以在任何硬件平台执行 + .BindKernel("fc"); // FcFusePass 绑定 fc_kernel +``` + +(4)修改`lite/core/mir/fusion/CMakeLists.txt`文件,将`fc_fuser.cc` 编译到`mir_fusers`库 + +```cmake +lite_cc_library(fuse_fc + SRCS fc_fuser.cc + DEPS pattern_matcher_high_api) + +set(mir_fusers + fuse_fc + ... + CACHE INTERNAL "fusers") +``` + +(5)修改`lite/core/mir/CMakeLists.txt`文件,将`fc_fuse_pass.cc` 编译到`mir_pass`库 +```cmake +lite_cc_library(mir_passes + SRCS + fusion/fc_fuse_pass.cc + ... + DEPS mir_pass types context ${mir_fusers} ${subgraph_passes}) +``` + +#### 3、使用 fc_fuse_pass + +(1) `lite/api/paddle_use_passes.h`使用`USE_LITE_PASS`宏来引入新加入的pass + +```c++ +USE_MIR_PASS(lite_fc_fuse_pass); +``` +(2) 在`lite/core/optimizer.h`文件的`Optimizer::Run()`函数中添加新注册的pass +```C++ +class Optimizer { + public: + void Run(Program&& program, + const std::vector& valid_places, + core::KernelPickFactor kernel_pick_factor, + const std::vector& passes = {}) { + ... + if (passes.empty()) { + RunPasses(std::vector{ + {"lite_fc_fuse_pass", // the newly registered pass + ... + "argument_type_display_pass"}}); + } else { + RunPasses(passes); + } + exec_scope_ = program.exec_scope(); + } +``` +(3) 以上修改完成后,在CreatePredictor(CxxConfig)创建CxxPredictor时,模型优化过程会调用`lite_fc_fuse_pass `,扫描`mul + element_wise add`结构并替换为等效的Fc_OP。 diff --git a/docs/advanced_user_guides/add_operation.md b/docs/advanced_user_guides/add_operation.md new file mode 100644 index 0000000000000000000000000000000000000000..525832f8a9d7341c3124498084e05b160358b2ad --- /dev/null +++ b/docs/advanced_user_guides/add_operation.md @@ -0,0 +1,189 @@ +# 新增OP的方法 + +以下以添加argmax为例,详细说明新增op的方法。 + +## 1. 添加OpParam 结构体以传导 Op 的输入和输出 + +- 这里命名为 `ArgmaxParam` + +- 在 `paddlelite/lite/operators/op_params.h` 中添加 `ArgmaxParam` 结构体,代码如下: + ```c++ + struct ArgmaxParam { + lite::Tensor* X{}; + lite::Tensor* Out{}; + int Axis{0}; + }; + ``` +## 2. 添加 Argmax Op 并注册 + +- 在paddlelite/lite/operators/目录下新建argmax_op.h文件,主要代码如下: + ```c++ + class ArgmaxOpLite : public OpLite { + public: + ArgmaxOpLite() {} + explicit ArgmaxOpLite(const std::string &op_type) : OpLite(op_type) {} + bool CheckShape() const override; + bool InferShape() const override; + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "argmax"; } + private: + mutable ArgmaxParam param_; + }; + ``` + `ArgmaxOpLite` 继承 `OpLite` ,成员变量包括 `ArgmaxParam` 结构体,需要实现的接口包括 `CheckShape()` 、`InferShape()` 、`AttachImp()` 、`AttachKernel()` 和 `DebugString()` 函数。`AttachKernel()` 和 `DebugString() `函数较为简单,此处直接实现; + +- 在 `paddlelite/lite/operators/` 目录下新建argmax_op.cc文件,需要具体实现`CheckShape()`、`InferShape()`和`AttachImp()`函数。`CheckShape()`函数检查输入是否符合要求,`InferShape()`函数基于输入推断得到输出的维度,`AttachImp()`函数绑定Op的输入输出。然后在argmax_op.cc文件中注册argmax,核心代码如下: + ```c++ + bool ArgmaxOpLite::CheckShape() const { + CHECK_OR_FALSE(param_.X); + CHECK_OR_FALSE(param_.Out); + CHECK_OR_FALSE(param_.Axis < (param_.X)->dims().size()); + return true; + } + + bool ArgmaxOpLite::InferShape() const { + auto x_dims = param_.X->dims(); + int x_rank = x_dims.size(); + int axis = param_.Axis; + if (axis < 0) axis += x_rank; + + std::vector out_dims; + for (int64_t i = 0; i < axis; i++) { + out_dims.push_back(x_dims[i]); + } + for (int64_t i = axis + 1; i < x_rank; i++) { + out_dims.push_back(x_dims[i]); + } + + // Set output dims + param_.Out->Resize(lite::DDim(out_dims)); + return true; + } + + bool ArgmaxOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { + auto x = op_desc.Input("X").front(); + auto out = op_desc.Output("Out").front(); + + param_.X = scope->FindVar(x)->GetMutable(); + param_.Out = scope->FindVar(out)->GetMutable(); + param_.Axis = op_desc.GetAttr("Axis"); + + return true; + } + REGISTER_LITE_OP(argmax, paddle::lite::operators::ArgmaxOpLite); + ``` +- 在paddlelite/lite/operators/CMakeLists.txt中添加```add_operator(argmax_op basic SRCS argmax_op.cc DEPS ${op_DEPS})``` + +## 3. 添加Argmax Kernel并绑定 + +以下以arm端argmax实现为例说明 +- 在paddlelite/lite/kernels/arm/目录下新建argmax_compute.h文件,声明ArgmaxCompute类,并继承KernelLite,主要代码如下: + ```c++ + class ArgmaxCompute : public KernelLite { + public: + using param_t = operators::ArgmaxParam; + void Run() override; + virtual ~ArgmaxCompute() = default; + }; + ``` +- 在paddlelite/lite/kernels/arm/目录下新建argmax_compute.cc文件,主要实现Run函数。`Run()`函数调用paddlelite/lite/bachends/arm/math/argmax.h中的`argmax_func()`函数,根据输入计算输出。最后在argmax_compute.cc文件中,我们绑定argmax的输入输出(为tensor的输入参数都需要绑定),代码如下: + ```c++ + void ArgmaxCompute::Run() { + auto& param = Param(); + lite::Tensor* input = param.X; + lite::Tensor* output = param.Out; + int axis = param.Axis; + lite::arm::math::argmax_func(input, axis, output); + return; + } + + REGISTER_LITE_KERNEL( + argmax, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::ArgmaxCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); + ``` + +- 在paddlelite/lite/kernels/arm/CMakeLists.txt中添加 + ```cmake + add_kernel(argmax_compute_arm ARM basic SRCS argmax_compute.cc DEPS ${lite_kernel_deps} math_arm) + ``` + +## 4. 添加Argmax实现 + +- 在paddlelite/lite/backends/arm/math/目录下新建argmax.h文件,声明`argmax_func()`函数,代码如下: + ```c++ + void argmax_func(const lite::Tensor* input, const int axis, lite::Tensor* output); + ``` +- 在paddlelite/lite/backends/arm/math/目录下新建argmax.cc文件,具体实现`argmax_func()`函数,代码如下: + ```c++ + void argmax_func(const lite::Tensor *input, + const int axis, + lite::Tensor *output) { + auto input_ddim = input->dims(); + auto output_ddim = output->dims(); + + const int size = input_ddim[axis]; + const int in_channel = input_ddim.count(axis, input_ddim.size()); + const int out_channel = output_ddim.count(axis, output_ddim.size()); + const int in_stride = input_ddim.count(axis + 1, input_ddim.size()); + const int out_stride = input_ddim.count(0, axis); + + for (int n = 0; n < out_stride; n++) { + for (int k = 0; k < in_stride; k++) { + const float *in_ptr = input->data() + n * in_channel + k; + std::vector> vec; + vec.resize(size); + for (int i = 0; i < size; i++) { + vec[i] = std::make_pair(in_ptr[i * in_stride], i); + } + // sort + std::partial_sort(vec.begin(), + vec.begin() + 1, + vec.end(), + std::greater>()); + + // out + float *out_ptr = output->mutable_data() + n * out_channel + k; + *out_ptr = vec[0].second; + } + } + } + ``` +- 在paddlelite/lite/backends/arm/math/CMakeFile.txt中的```math_arm library```中添加argmax.cc,在paddlelite/lite/backends/arm/math/funcs.h中添加```#include "lite/arm/math/argmax.h"``` + +## 5. 添加Argmax单测 + +- 在paddlelite/lite/tests/kernels目录下新建argmax_compute_test.cc文件,声明并实现ArgmaxComputeTester类; +- ArgmaxComputeTester类中主要包括PrepareOpDesc、PrepareData和RunBaseline函数。PrepareOpDesc函数设定单测op的类型和输入输出参数,PrepareData函数对输入tensor进行初始化,RunBaseline是基于输入计算得到输出,用于和框架计算的输出进行对比; +- 使用gtest添加单测,代码如下: + ```c++ + TEST(Argmax, precision) { + #ifdef LITE_WITH_ARM + LOG(INFO) << "test argmax arm"; + Place place(TARGET(kARM)); + + for (int axis : {0, 1, 2, 3}) { + for (int n : {1, 3}) { + for (int c : {3, 6}) { + for (int h : {9, 18}) { + for (int w : {9, 18}) { + std::unique_ptr tester( + new ArgmaxComputeTester(place, "def", axis, n, c, h, w)); + arena::Arena arena(std::move(tester), place, 2e-5); + arena.TestPrecision(); + } + } + } + } + } + #endif + } + ``` +- 在paddlelite/lite/tests/kernels/CMakeLists.txt中添加 + ```cmake + lite_cc_test(test_kernel_argmax_compute SRCS argmax_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + ``` +## 6. 编译运行 +- 在paddlelite目录中,执行```./lite/tools/ci_build.sh build_test_arm```,该脚本会创建手机模拟器,并编译运行所有单测(花费时间较久)。如果运行无误,则表明添加argmax成功。 diff --git a/docs/advanced_user_guides/index.rst b/docs/advanced_user_guides/index.rst new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/docs/advanced_user_guides/model_quantization.md b/docs/advanced_user_guides/model_quantization.md new file mode 100644 index 0000000000000000000000000000000000000000..7d781ba9904400c26b64aed5f5dc764ecc5b24fa --- /dev/null +++ b/docs/advanced_user_guides/model_quantization.md @@ -0,0 +1,327 @@ +# 模型量化 + +本文主要介绍使用Paddle-Lite加载PaddlePaddle产出的量化模型,并进行推理执行。我们以MobileNetV1模型为示例,首先介绍准备量化模型,然后介绍部署执行。 + +## 准备量化模型 + +PaddlePaddle使用量化训练和训练后量化两种方法将FP32模型量化成Int8模型,下面分别介绍两种方法如何产出量化模型。 + +### 量化训练 + +目前,PaddlePaddle框架的量化训练主要针对卷积层(包括二维卷积和Depthwise卷积)、和全连接层,对应算子是conv2d、depthwise_conv2d和mul,更多量化训练的原理请参考[文档](https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/docs/tutorial.md#1-quantization-aware-training%E9%87%8F%E5%8C%96%E4%BB%8B%E7%BB%8D)。Paddle-Lite支持运行PaddlePaddle框架量化训练产出的模型,可以进一步加快模型在移动端的执行速度。 + +温馨提示:如果您是初次接触PaddlePaddle框架,建议首先学习[新人入门](https://www.paddlepaddle.org.cn/documentation/docs/zh/1.5/beginners_guide/index_cn.html)和[使用指南](https://www.paddlepaddle.org.cn/documentation/docs/zh/1.5/user_guides/index_cn.html)。 + + +您可以选择下载训练好的量化模型,或者使用PaddleSlim模型压缩工具训练得到量化模型。 + +#### 下载量化模型 + +官方发布了[MobileNetV1量化模型](https://paddle-inference-dist.bj.bcebos.com/int8%2Fpretrain%2Fmobilenet_v1_quant%2Ffloat.zip),直接下载到本地。 + +```bash +wget https://paddle-inference-dist.bj.bcebos.com/int8%2Fpretrain%2Fmobilenet_v1_quant%2Ffloat.zip +``` + +#### 使用PaddleSlim模型压缩工具训练量化模型 + +##### 安装PaddlePaddle + +根据操作系统、安装方式、Python版本和CUDA版本,按照[官方说明](https://paddlepaddle.org.cn/start)安装PaddlePaddle。例如: + +Ubuntu 16.04.4 LTS操作系统,CUDA9,cuDNN7,GPU版本安装: +```bash +pip install paddlepaddle-gpu==1.6.0.post97 -i https://mirrors.aliyun.com/pypi/simple/ +``` + +Ubuntu 16.04.4 LTS操作系统,CPU版本安装: +```bash +pip install paddlepaddle==1.6.0 -i https://mirrors.aliyun.com/pypi/simple/ +``` + +##### 克隆量化训练所需的代码库 + +克隆[PaddlePaddle/models](https://github.com/PaddlePaddle/models)到本地,并进入models/PaddleSlim路径。 + +```bash +git clone https://github.com/PaddlePaddle/models.git +cd models/PaddleSlim +``` + +##### 数据准备 +###### 训练数据准备 + +参考[models/PaddleCV/image_classification](https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/image_classification#data-preparation)中的数据准备教程,下载训练数据,并且保存到PaddleSlim/data路径下。 + +###### 预训练模型准备 + +参考/models/PaddleSlim/run.sh脚本, 从[models/PaddleCV/image_classification](https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleCV/image_classification#supported-models-and-performances)下载MobileNetV1的预训练模型,并保存到PaddleSlim/pretrain路径下。 + +经过以上三步,PaddleSlim目录下的文件结构如下所示: + +```bash +. +├── compress.py # 模型压缩任务主脚本,定义了压缩任务需要的模型相关信息 +├── configs # 压缩任务的配置文件,包括:蒸馏、int8量化量化、filter剪切和组合策略的配置文件 +├── data # 存放训练数据(需要用户自己创建) +│   └── ILSVRC2012 +├── pretrain # 存放预训练模型参数,执行run.sh自动生成 +│   ├── MobileNetV1_pretrained +│   ├── MobileNetV1_pretrained.tar +│   ├── ResNet50_pretrained +│   └── ResNet50_pretrained.tar +├── docs # 文档目录 +├── light_nas +├── models # 模型网络结构的定义,如MobileNetV1 +├── quant_low_level_api # 量化训练的底层API, 用于灵活定制量化训练的过程,适用于高阶用户 +├── reader.py # 定义数据处理逻辑 +├── README.md +├── run.sh # 模型压缩任务启动脚本 +└── utility.py # 定义了常用的工具方法 +``` + +##### 压缩脚本介绍 + +在`compress.py`中定义了执行压缩任务需要的所有模型相关的信息,这里对几个关键的步骤进行简要介绍: + +###### 目标网络的定义 + +compress.py的以下代码片段定义了train program, 这里train program只有前向计算操作。 +```python +out = model.net(input=image, class_dim=args.class_dim) +cost = fluid.layers.cross_entropy(input=out, label=label) +avg_cost = fluid.layers.mean(x=cost) +acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1) +acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5) +``` + +然后,通过clone方法得到eval_program, 用来在压缩过程中评估模型精度,如下: + +```python +val_program = fluid.default_main_program().clone() +``` + +定义完目标网络结构,需要对其初始化,并根据需要加载预训练模型。 + +###### 定义feed_list和fetch_list +对于train program, 定义train_feed_list用于指定从train data reader中取的数据feed给哪些variable。定义train_fetch_list用于指定在训练时,需要在log中展示的结果。如果需要在训练过程中在log中打印accuracy信心,则将('acc_top1', acc_top1.name)添加到train_fetch_list中即可。 +```python +train_feed_list = [('image', image.name), ('label', label.name)] +train_fetch_list = [('loss', avg_cost.name)] +``` + +> 注意: 在train_fetch_list里必须有loss这一项。 + +对于eval program. 同上定义eval_feed_list和train_fetch_list: + +```python +val_feed_list = [('image', image.name), ('label', label.name)] +val_fetch_list = [('acc_top1', acc_top1.name), ('acc_top5', acc_top5.name)] +``` + +###### Compressor和量化配置文件 +`compress.py`主要使用Compressor和yaml文件完成对模型的量化训练工作。Compressor类的定义如下: +```python +class Compressor(object): + def __init__(self, + place, + scope, + train_program, + train_reader=None, + train_feed_list=None, + train_fetch_list=None, + eval_program=None, + eval_reader=None, + eval_feed_list=None, + eval_fetch_list=None, + teacher_programs=[], + checkpoint_path='./checkpoints', + train_optimizer=None, + distiller_optimizer=None): +``` + +在定义Compressor对象时,需要注意以下问题: +* train program如果带反向operators和优化更新相关的operators, 参数train_optimizer需要设置为None. +* eval_program中parameter的名称需要与train_program中的parameter的名称完全一致。 +* 最终保存的量化模型是在eval_program网络基础上进行剪枝保存的。所以,如果用户希望最终保存的模型可以用于inference, 则eval program需要包含推理阶段需要的各种operators. +* checkpoint保存的是float数据类型的模型。 + +`configs/quantization.yaml`量化配置文件示例如下: + +```python +version: 1.0 +strategies: + quantization_strategy: + class: 'QuantizationStrategy' + start_epoch: 0 + end_epoch: 9 + float_model_save_path: './output/float' + mobile_model_save_path: './output/mobile' + int8_model_save_path: './output/int8' + weight_bits: 8 + activation_bits: 8 + weight_quantize_type: 'abs_max' + activation_quantize_type: 'moving_average_abs_max' + save_in_nodes: ['image'] + save_out_nodes: ['fc_0.tmp_2'] +compressor: + epoch: 10 + checkpoint_path: './checkpoints_quan/' + strategies: + - quantization_strategy +``` +其中,可配置参数包括: +- **class:** 量化策略的类名称,目前仅支持`QuantizationStrategy`。 +- **start_epoch:** 在start_epoch开始之前,量化训练策略会往train_program和eval_program插入量化operators和反量化operators。 从start_epoch开始,进入量化训练阶段。 +- **end_epoch:** 在end_epoch结束之后,会保存用户指定格式的模型。注意:end_epoch之后并不会停止量化训练,而是继续训练直到epoch数等于compressor.epoch值为止。举例来说,当start_epoch=0,end_epoch=0,compressor.epoch=2时,量化训练开始于epoch0,结束于epoch1,但保存的模型是epoch0结束时的参数状态。 +- **float_model_save_path:** 保存float数据格式的模型路径,即该路径下的模型参数范围为int8范围但参数数据类型为float32。如果设置为None, 则不存储float格式的模型,默认为None。**注意:Paddle-Lite即使用该目录下的模型进行量化模型推理优化,详见本文[使用Paddle-Lite运行量化模型推理](#二使用Paddle-Lite运行量化模型推理)部分。** +- **int8_model_save_path:** 保存int8数据格式的模型路径,即该路径下的模型参数范围为int8范围且参数数据类型为int8。如果设置为None, 则不存储int8格式的模型,默认为None. +- **mobile_model_save_path:** 保存兼容paddle-mobile框架的模型路径。如果设置为None, 则不存储paddle-mobile格式的模型,默认为None。目前paddle-mobile已升级为Paddle-Lite。 +- **weight_bits:** 量化weight的bit数,注意偏置(bias)参数不会被量化。 +- **activation_bits:** 量化activation的bit数。 +- **weight_quantize_type:** weight量化方式,目前量化训练支持`abs_max`、 `channel_wise_abs_max`。 +- **activation_quantize_type:** activation量化方式,目前量化训练支持`range_abs_max`、`moving_average_abs_max`。PaddlePaddle中还支持 `abs_max` 方法对激活进行量化,但是该方法动态计算输入的量化scale,这会增加计算量、减慢模型推理速度,所以lite不支持 `abs_max`激活量化方式。 +- **save_in_nodes:** variable名称列表。在保存量化后模型的时候,需要根据save_in_nodes对eval programg 网络进行前向遍历剪枝。默认为eval_feed_list内指定的variable的名称列表。 +- **save_out_nodes:** varibale名称列表。在保存量化后模型的时候,需要根据save_out_nodes对eval programg 网络进行回溯剪枝。默认为eval_fetch_list内指定的variable的名称列表。 + +> **备注:** +> +> 1)`abs_max`意为在训练的每个step及inference阶段均动态计算量化scale值。`channel_wise_abs_max`与`abs_max`类似,不同点在于它会对卷积权重进行分channel求取量化scale。换言之,`abs_max`属于tensor-wise量化,而`channel_wise_abs_max`属于channel-wise量化,详细说明请猛戳[此处](https://github.com/PaddlePaddle/FluidDoc/blob/develop/doc/fluid/design/quantization/training_quantization_model_format.md)。 +> +> 2)`moving_average_abs_max`和`range_abs_max`意为在训练阶段计算出一个静态的量化scale值,并将其用于inference阶段。`moving_average_abs_max`使用窗口滑动平均的方法计算量化scale,而`range_abs_max`则使用窗口绝对值最大值的方式。 +> +> 3)**目前,Paddle-Lite仅支持运行weight量化方式使用`abs_max`且activation量化方式使用`moving_average_abs_max`或`range_abs_max`产出的量化模型**。 + +##### 执行int8量化训练 + +修改run.sh,即注释掉`# enable GC strategy`与`# for sensitivity filter pruning`之间的内容并打开`#for quantization`相关的脚本命令(所需打开注释的命令如下所示)。 + +```bash +# for quantization +#--------------------------- +export CUDA_VISIBLE_DEVICES=0 +python compress.py \ +--batch_size 64 \ +--model "MobileNet" \ +--pretrained_model ./pretrain/MobileNetV1_pretrained \ +--compress_config ./configs/quantization.yaml \ +--quant_only True +``` +最后,运行`sh run.sh`命令开始int8量化训练。 + +上述量化训练过程完成后,若按照本文中所述`configs/quantization.yaml`文件内容配置的模型输出路径,则可在models/PaddleSlim/output目录下看到`float`、`int8`和`mobile`三个目录,其中: +* float目录: 参数范围为int8范围但参数数据类型为float32的量化模型。Paddle-Lite即使用该目录下的模型文件及参数进行量化模型的部署。 +* int8目录: 参数范围为int8范围且参数数据类型为int8的量化模型。 +* mobile目录:参数特点与int8目录相同且兼容paddle-mobile的量化模型(目前paddle-mobile已升级为Paddle-Lite)。 + +### 训练后量化 + +下面以MobileNetV1为例,介绍使用训练后量化方法产出量化模型。关于训练后量化的原理和详细使用方法,请参考[文档](https://github.com/PaddlePaddle/models/tree/develop/PaddleSlim/quant_low_level_api)。 + +> 该示例的代码放在[models/PaddleSlim/quant_low_level_api/](https://github.com/PaddlePaddle/models/tree/develop/PaddleSlim/quant_low_level_api)目录下。如果需要执行该示例,首先clone下来[models](https://github.com/PaddlePaddle/models.git),安装具有训练后量化功能的PaddlePaddle。因为目前Lite支持支持对conv2d、depthwise_conv2d和mul量化,所以修改[run_post_training_quanzation.sh](https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/quant_low_level_api/run_post_training_quanzation.sh) 脚本,设置is_full_quantize=False,然后执行该脚本;执行结束后,量化模型保存在`mobilenetv1_int8_model`目录下。下面介绍详细步骤。 + +1)**准备模型和校准数据** + +安装PaddlePaddle的develop分支编译的whl包,准备已经训练好的FP32预测模型。 + +准备校准数据,文件结构如下。val文件夹中有100张图片,val_list.txt文件中包含图片的label。 +```bash +samples_100 +└──val +└──val_list.txt +``` + +2)**配置校准数据生成器** + +MobileNetV1的输入是图片和标签,所以配置读取校准数据的sample_generator,每次返回一张图片和一个标签。详细代码在[models/PaddleSlim/reader.py](https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/reader.py)。 + +3)**调用训练后量化** + +调用训练后量化的核心代码如下,详细代码在[post_training_quantization.py](https://github.com/PaddlePaddle/models/blob/develop/PaddleSlim/quant_low_level_api/post_training_quantization.py)。 +``` python +place = fluid.CUDAPlace(0) if args.use_gpu == "True" else fluid.CPUPlace() +exe = fluid.Executor(place) +sample_generator = reader.val(data_dir=args.data_path) + +ptq = PostTrainingQuantization( + executor=exe, + sample_generator=sample_generator, + model_dir=args.model_dir, + model_filename=args.model_filename, + params_filename=args.params_filename, + batch_size=args.batch_size, + batch_nums=args.batch_nums, + algo=args.algo, + is_full_quantize=args.is_full_quantize == "True") +quantized_program = ptq.quantize() +ptq.save_quantized_model(args.save_model_path) +``` + +## 使用Paddle-Lite运行量化模型推理 + +#### 使用模型优化工具对量化模型进行优化 + +接下来,使用原始的量化模型生成适合在移动端直接部署的模型。 + +参考[源码编译](../source_compile)配置编译环境,确保可以编译成功。参考[模型转化方法](../model_optimize_tool),首先编译model_optimize_tool工具,然后执行下面命令对量化训练的模型进行优化(注意,需要自行修改model_file、param_file和optimize_out)。 +```bash +./model_optimize_tool \ +--model_file=mobilenet_v1_quant/float/model \ +--param_file=mobilenet_v1_quant/float/weights \ +--optimize_out_type=naive_buffer \ +--optimize_out=mobilenet_v1_quant_opt \ +--valid_targets=arm \ +--prefer_int8_kernel=true +``` + +如前所述,量化训练后,float目录下的模型参数范围为int8,但参数数据类型仍为float32类型,这样确实没有起到模型参数压缩的效果。但是,经过model\_optimize\_tool工具优化后对应的量化参数均会以int8类型重新存储达到参数压缩的效果,且模型结构也被优化(如进行了各种operator fuse操作)。 + +#### 在手机端准备量化模型文件 + +使用如下命令将mobilenet_v1_quant_opt目录下的量化模型文件导入到手机端: + +```bash +adb push mobilenet_v1_quant_opt /data/local/tmp +``` + +#### 使用mobilenetv1\_light\_api运行优化后的量化模型 + +参考[源码编译](../source_compile)配置编译环境后,在Paddle-Lite执行如下命令获取轻量级API的demo: + +```bash +cd /Paddle-Lite/build.lite.android.armv8.gcc/inference_lite_lib.android.armv8/demo/cxx/mobile_light +make clean && make -j +``` +执行完上述命令后,可在`Paddle-Lite/build.lite.android.armv8.gcc/inference_lite_lib.android.armv8/demo/cxx/mobile_light/`路径下看到`mobilenetv1_light_api`可执行文件。将`mobilenetv1_light_api`导入到手机端并运行量化模型推理。执行命令如下: + +```bash +adb push Paddle-Lite/build.lite.android.armv8.gcc/inference_lite_lib.android.armv8/demo/cxx/mobile_light/mobilenetv1_light_api /data/local/tmp +adb shell chmod +x /data/local/tmp/mobilenetv1_light_api +adb shell /data/local/tmp/mobilenetv1_light_api \ + --model_dir=/data/local/tmp/mobilenet_v1_quant_opt +``` +**程序运行结果如下:** +```bash +Output dim: 1000 +Output[0]: 0.000228 +Output[100]: 0.000260 +Output[200]: 0.000250 +Output[300]: 0.000560 +Output[400]: 0.000950 +Output[500]: 0.000275 +Output[600]: 0.005143 +Output[700]: 0.002509 +Output[800]: 0.000538 +Output[900]: 0.000969 +``` +在C++中使用Paddle-Lite API的方法请猛戳[此处](../cpp_demo),用户也可参考[mobilenetv1_light_api.cc](https://github.com/PaddlePaddle/Paddle-Lite/blob/develop/lite/demo/cxx/mobile_light/mobilenetv1_light_api.cc)的代码示例。 + +### FAQ + +**问题**:Compiled with WITH_GPU, but no GPU found in runtime + +**解答**:检查本机是否支持GPU训练,如果不支持请使用CPU训练。如果在docker进行GPU训练,请使用nvidia_docker启动容器。 + +**问题**:Inufficient GPU memory to allocation. at [/paddle/paddle/fluid/platform/gpu_info.cc:262] + +**解答**:正确设置run.sh脚本中`CUDA_VISIBLE_DEVICES`,确保显卡剩余内存大于需要内存。 diff --git a/docs/advanced_user_guides/support_operation_list.md b/docs/advanced_user_guides/support_operation_list.md new file mode 100644 index 0000000000000000000000000000000000000000..7c2ceb0ff819f7f1676308a33ec88f5eab820e57 --- /dev/null +++ b/docs/advanced_user_guides/support_operation_list.md @@ -0,0 +1,392 @@ +# 支持OP列表 + +## Ops + +- affine_channel +- anchor_generator +- arg_max +- assign +- assign_value +- attention_padding_mask +- axpy +- batch_norm +- beam_search +- beam_search_decode +- bilinear_interp +- box_clip +- box_coder +- calib +- calib_once +- cast +- collect_fpn_proposals +- concat +- conditional_block +- conv2d +- conv2d_transpose +- crop +- decode_bboxes +- density_prior_box +- depthwise_conv2d +- distribute_fpn_proposals +- dropout +- elementwise_add +- elementwise_div +- elementwise_max +- elementwise_mul +- elementwise_sub +- equal +- exp +- expand +- fake_channel_wise_dequantize_max_abs +- fake_dequantize_max_abs +- fake_quantize_dequantize_moving_average_abs_max +- fake_quantize_moving_average_abs_max +- fake_quantize_range_abs_max +- fc +- feed +- fetch +- fill_constant +- fill_constant_batch_size_like +- flatten +- flatten2 +- floor +- fusion_elementwise_add_activation +- fusion_elementwise_div_activation +- fusion_elementwise_max_activation +- fusion_elementwise_mul_activation +- fusion_elementwise_sub_activation +- gather +- generate_proposals +- graph_op +- greater_equal +- greater_than +- gru +- gru_unit +- hard_sigmoid +- im2sequence +- increment +- instance_norm +- io_copy +- io_copy_once +- is_empty +- layer_norm +- layout +- layout_once +- leaky_relu +- less_equal +- less_than +- lod_reset +- log +- logical_and +- logical_not +- logical_or +- logical_xor +- lookup_table +- lookup_table_v2 +- lrn +- match_matrix_tensor +- matmul +- mean +- merge_lod_tensor +- mul +- multiclass_nms +- nearest_interp +- negative +- norm +- notequal +- pad2d +- pool2d +- power +- prelu +- prior_box +- range +- read_from_array +- reduce_max +- reduce_mean +- reduce_prod +- reduce_sum +- relu +- relu6 +- relu_clipped +- reshape +- reshape2 +- roi_align +- rsqrt +- scale +- search_aligned_mat_mul +- search_attention_padding_mask +- search_fc +- search_grnn +- search_group_padding +- search_seq_arithmetic +- search_seq_depadding +- search_seq_fc +- search_seq_softmax +- sequence_arithmetic +- sequence_concat +- sequence_expand +- sequence_expand_as +- sequence_pool +- sequence_reshape +- sequence_reverse +- sequence_softmax +- sequence_topk_avg_pooling +- shape +- shuffle_channel +- sigmoid +- slice +- softmax +- softsign +- split +- split_lod_tensor +- sqrt +- square +- squeeze +- squeeze2 +- stack +- swish +- tanh +- top_k +- transpose +- transpose2 +- uniform_random +- unsqueeze +- unsqueeze2 +- var_conv_2d +- while +- write_to_array +- yolo_box + +## Kernels + +### Host kernels + +- feed +- fetch +- flatten +- flatten2 +- multiclass_nms +- reshape +- reshape2 + +### ARM kernels + +- affine_channel +- anchor_generator +- arg_max +- assign +- assign_value +- axpy +- batch_norm +- beam_search +- beam_search_decode +- bilinear_interp +- box_clip +- box_coder +- cast +- collect_fpn_proposals +- concat +- conditional_block +- conv2d +- conv2d_transpose +- crop +- decode_bboxes +- density_prior_box +- depthwise_conv2d +- distribute_fpn_proposals +- dropout +- elementwise_add +- elementwise_div +- elementwise_max +- elementwise_mul +- elementwise_sub +- equal +- exp +- expand +- fc +- fill_constant +- fill_constant_batch_size_like +- floor +- fusion_elementwise_add_activation +- fusion_elementwise_div_activation +- fusion_elementwise_max_activation +- fusion_elementwise_mul_activation +- fusion_elementwise_sub_activation +- gather +- generate_proposals +- greater_equal +- greater_than +- gru +- gru_unit +- hard_sigmoid +- im2sequence +- increment +- instance_norm +- is_empty +- layer_norm +- layout +- layout_once +- leaky_relu +- less_equal +- less_than +- lod_reset +- log +- logical_and +- logical_not +- logical_or +- logical_xor +- lookup_table +- lookup_table_v2 +- lrn +- matmul +- merge_lod_tensor +- mul +- nearest_interp +- negative +- norm +- not_equal +- pad2d +- pool2d +- power +- prelu +- prior_box +- range +- read_from_array +- reduce_max +- reduce_mean +- reduce_prod +- relu +- relu6 +- relu_clipped +- roi_align +- rsqrt +- scale +- sequence_expand +- sequence_pool +- sequence_softmax +- shape +- shuffle_channel +- sigmoid +- slice +- softmax +- split +- split_lod_tensor +- squeeze +- squeeze2 +- stack +- swish +- tanh +- top_k +- transpose +- transpose2 +- unsqueeze +- unsqueeze2 +- while +- write_to_array +- yolo_box + + +### X86 kernels +- batch_norm +- cast +- concat +- conv2d +- depthwise_conv2d +- dropout +- elementwise_add +- elementwise_sub +- fc +- fill_constant_batch_size_like +- gather +- gelu +- gru +- layer_norm +- match_matrix_tensor +- matmul +- mul +- pool2d +- reduce_sum +- relu +- reshape +- reshape2 +- scale +- search_aligned_mat_mul +- search_attention_padding_mask +- search_fc +- search_grnn +- search_group_padding +- search_seq_arithmetic +- search_seq_depadding +- search_seq_fc +- search_seq_softmax +- sequence_arithmetic +- sequence_concat +- sequence_expand_as +- sequence_pool +- sequence_reverse +- sequence_topk_avg_pooling +- shape +- slice +- softmax +- softsign +- square +- squeeze +- squeeze2 +- stack +- tanh +- transpose +- transpose2 +- var_conv_2d + +### CUDA kernels +- attention_padding_mask +- bilinear_interp +- calib +- concat +- conv +- dropout +- elementwise_add +- fusion_elementwise_add_activation +- fusion_elementwise_mul_activation +- elementwise_mul +- feed +- io_copy +- layout +- layout_once +- leaky_relu +- lookup_table +- match_matrix_tensor +- mul +- nearest_interp +- pool2d +- relu +- scale +- search_aligned_mat_mul +- search_fc +- search_grnn +- search_group_padding +- search_seq_depadding +- search_seq_fc +- sequence_arithmetic +- sequence_concat +- sequence_pool +- sequence_reverse +- sequence_topk_avg_pooling +- softmax +- transpose +- var_conv_2d +- yolo_box + +### OpenCL kernels +- conv2d +- depthwise_conv2d +- elementwise_add +- fc +- fusion_elementwise_add_activation +- layout +- layout_once +- io_copy +- io_copy_once +- mul +- pool2d +- relu diff --git a/docs/advanced_user_guides/x86.md b/docs/advanced_user_guides/x86.md new file mode 100644 index 0000000000000000000000000000000000000000..7cb08683440312b0349662699b05e99df0cb6df1 --- /dev/null +++ b/docs/advanced_user_guides/x86.md @@ -0,0 +1,104 @@ +# 使用X86预测库 + +Paddle-Lite 支持在Docker或Linux环境编译x86预测库。环境搭建参考[环境准备](../installation/source_compile)。 + +(注意:非docker Linux环境需要是Ubuntu16.04) + +## 编译 + +1、 下载代码 +```bash +git clone https://github.com/PaddlePaddle/Paddle-Lite.git +#需要切换到 release/v2.0.0之后版本 +git checkout +``` + +2、 源码编译 + +```bash +cd Paddle-Lite +./lite/tools/build.sh x86 +``` + +## 编译结果说明 + +x86编译结果位于 `build.lite.x86/inference_lite_lib` +**具体内容**说明: + +1、 `bin`文件夹:可执行工具文件 `test_model_bin` + +2、 `cxx`文件夹:包含c++的库文件与相应的头文件 + +- `include` : 头文件 +- `lib` : 库文件 + - 打包的静态库文件: + - `libpaddle_api_full_bundled.a` :包含 full_api 和 light_api 功能的静态库 + - `libpaddle_api_light_bundled.a` :只包含 light_api 功能的静态库 + - 打包的动态态库文件: + - `libpaddle_full_api_shared.so` :包含 full_api 和 light_api 功能的动态库 + - `libpaddle_light_api_shared.so`:只包含 light_api 功能的动态库 + +3、 `third_party` 文件夹:第三方库文件 + +## x86预测API使用示例 + +```c++ +#include +#include +#include +#include "paddle_api.h" // NOLINT +#include "paddle_use_kernels.h" // NOLINT +#include "paddle_use_ops.h" // NOLINT +#include "paddle_use_passes.h" // NOLINT + +using namespace paddle::lite_api; // NOLINT + +DEFINE_string(model_dir, "", "Model dir path."); +DEFINE_string(optimized_model_dir, "", "Optimized model dir."); +DEFINE_bool(prefer_int8_kernel, false, "Prefer to run model with int8 kernels"); + +int64_t ShapeProduction(const shape_t& shape) { + int64_t res = 1; + for (auto i : shape) res *= i; + return res; +} +void RunModel() { + // 1. Set CxxConfig + CxxConfig config; + config.set_model_file(FLAGS_model_dir + "model"); + config.set_param_file(FLAGS_model_dir + "params"); + + config.set_valid_places({ + lite_api::Place{TARGET(kX86), PRECISION(kFloat)} + }); + + // 2. Create PaddlePredictor by CxxConfig + std::shared_ptr predictor = + CreatePaddlePredictor(config); + + // 3. Prepare input data + std::unique_ptr input_tensor(std::move(predictor->GetInput(0))); + input_tensor->Resize(shape_t({1, 3, 224, 224})); + auto* data = input_tensor->mutable_data(); + for (int i = 0; i < ShapeProduction(input_tensor->shape()); ++i) { + data[i] = 1; + } + + // 4. Run predictor + predictor->Run(); + + // 5. Get output + std::unique_ptr output_tensor( + std::move(predictor->GetOutput(0))); + std::cout << "Output dim: " << output_tensor->shape()[1] << std::endl; + for (int i = 0; i < ShapeProduction(output_tensor->shape()); i += 100) { + std::cout << "Output[" << i << "]:" << output_tensor->data()[i] << std::endl; + } +} + +int main(int argc, char** argv) { + google::ParseCommandLineFlags(&argc, &argv, true); + RunModel(); + return 0; +} +``` diff --git a/docs/api_reference/cxx_api_doc.md b/docs/api_reference/cxx_api_doc.md new file mode 100644 index 0000000000000000000000000000000000000000..38385a4267d5727d9c5c7d985d3457dd011e203c --- /dev/null +++ b/docs/api_reference/cxx_api_doc.md @@ -0,0 +1,874 @@ + +# C++ API文档 + +## CreatePaddlePredictor + +```c++ +template +std::shared_ptr CreatePaddlePredictor(const ConfigT&); +``` + +`CreatePaddlePredictor`用来根据`MobileConfig`构建预测器。 + +示例: + +```c++ +// 设置MobileConfig +MobileConfig config; +config.set_model_dir(FLAGS_model_dir); + +// 根据MobileConfig创建PaddlePredictor +std::shared_ptr predictor = CreatePaddlePredictor(config); +``` + +参数: + +- `config(MobileConfig)` - 用于构建Predictor的配置信息。 + +返回:`PaddlePredictor`指针 + +返回类型:`std::shared_ptr` + +## CxxConfig + +```c++ +class CxxConfig; +``` + +`CxxConfig`用来配置构建CxxPredictor的配置信息,如protobuf格式的模型地址、能耗模式、工作线程数、place信息等等。 + +示例: + +```c++ +config = CxxConfig() +# 设置模型目录,加载非combined模型时使用 +config.set_model_dir() +# 设置工作线程数 +config.set_threads(4); +# 设置能耗模式 +config.set_power_mode(PowerMode.LITE_POWER_NO_BIND) +# 设置valid places +places = [Place(TargetType.ARM, PrecisionType.FP32)] +config.set_valid_places(places) + +# 根据CxxConfig创建CxxPredictor +predictor = create_paddle_predictor(config) +``` + +### `set_model_dir(model_dir)` + +设置模型文件夹路径,当需要从磁盘加载非combined模型时使用。 + +参数: + +- `model_dir(str)` - 模型文件夹路径 + +返回:`None` + +返回类型:`None` + + + +### `model_dir()` + +返回设置的模型文件夹路径。 + +参数: + +- `None` + +返回:模型文件夹路径 + +返回类型:`str` + + + +### `set_model_file(model_file)` + +设置模型文件路径,加载combined形式模型时使用。 + +参数: + +- `model_file(str)` - 模型文件路径 + +返回类型:`None` + + + +### `model_file()` + +获取设置模型文件路径,加载combined形式模型时使用。 + +参数: + +- `None` + +返回:模型文件路径 + +返回类型:`str` + + + +### `set_param_file(param_file)` + +设置模型参数文件路径,加载combined形式模型时使用。 + +参数: + +- `param_file(str)` - 模型文件路径 + +返回类型:`None` + + + +### `param_file()` + +获取设置模型参数文件路径,加载combined形式模型时使用。 + +参数: + +- `None` + +返回:模型参数文件路径 + +返回类型:`str` + + + +### `set_valid_places(valid_places)` + +设置可用的places列表。 + +参数: + +- `valid_places(list)` - 可用place列表。 + +返回类型:`None` + +示例: + +```c++ +config = CxxConfig() +# 设置模型目录,加载非combined模型时使用 +config.set_model_dir() +# 设置valid places +# 注意,valid_places列表中Place的排序表明了用户对Place的偏好程度,如用户想优先使用ARM上Int8精度的 +# kernel,则应把Place(TargetType.ARM, PrecisionType.INT8)置于valid_places列表的首位。 +places = [Place(TargetType.ARM, PrecisionType.INT8), + Place(TargetType.ARM, PrecisionType.FP32)] +config.set_valid_places(places) + +# 根据CxxConfig创建CxxPredictor +predictor = create_paddle_predictor(config) +``` + + + +### `set_power_mode(mode)` + +设置CPU能耗模式。若不设置,则默认使用`PowerMode.LITE_POWER_HIGH`。 + +*注意:只在开启`OpenMP`时生效,否则系统自动调度。此函数只在使用`LITE_WITH_ARM`编译选项下生效。* + +参数: + +- `mode(PowerMode)` - CPU能耗模式 + +返回:`None` + +返回类型:`None` + + + +### `power_mode()` + +获取设置的CPU能耗模式。 + +*注意:此函数只在使用`LITE_WITH_ARM`编译选项下生效。* + +参数: + +- `None` + +返回:设置的CPU能耗模式 + +返回类型:`PowerMode` + + + +### `set_threads(threads)` + +设置工作线程数。若不设置,则默认使用单线程。 + +*注意:只在开启`OpenMP`的模式下生效,否则只使用单线程。此函数只在使用`LITE_WITH_ARM`编译选项下生效。* + +参数: + +- `threads(int)` - 工作线程数 + +返回:`None` + +返回类型:`None` + + + +### `threads()` + +获取设置的工作线程数。 + +*注意:此函数只在使用`LITE_WITH_ARM`编译选项下生效。* + +参数: + +- `None` + +返回:工作线程数 + +返回类型:`int` + + +### `set_x86_math_library_num_threads(threads)` + +设置CPU Math库线程数,CPU核心数支持情况下可加速预测。默认为1,并且仅在x86下有效。 + +参数: + +- `threads(int)` - CPU Math库线程数。 + +返回:`None` + +返回类型:`None` + + +### `x86_math_library_num_threads()` + +返回CPU Math库线程数,CPU核心数支持情况下可加速预测。仅在x86下有效。 + +参数: + +- `None` + +返回:CPU Math库线程数。 + +返回类型:`int` + +## MobileConfig + +```c++ +class MobileConfig; +``` + +`MobileConfig`用来配置构建轻量级PaddlePredictor的配置信息,如NaiveBuffer格式的模型地址、模型的内存地址(从内存加载模型时使用)、能耗模式、工作线程数等等。 + +*注意:输入的模型需要使用[Model Optimize Tool](../model_optimize_tool)转化为NaiveBuffer格式的优化模型。* + +示例: + +```c++ +MobileConfig config; +// 设置NaiveBuffer格式模型目录,从文件加载模型时使用 +config.set_model_dir(FLAGS_model_dir); +// 设置工作线程数 +config.set_threads(4); +// 设置能耗模式 +config.set_power_mode(LITE_POWER_HIGH); + +// 根据MobileConfig创建PaddlePredictor +std::shared_ptr predictor = CreatePaddlePredictor(config); +``` + +### `set_model_from_file(model_dir)` + +设置模型文件,当需要从磁盘加载模型时使用。 + +参数: + +- `model_dir(std::string)` - 模型文件路径 + +返回:`None` + +返回类型:`void` + +### `set_model_dir(model_dir)` + +**注意**:Lite模型格式在release/v2.3.0之后修改,本接口为加载老格式模型的接口,将在release/v3.0.0废弃。建议替换为`set_model_from_file`接口。 + +设置模型文件夹路径,当需要从磁盘加载模型时使用。 + +参数: + +- `model_dir(std::string)` - 模型文件夹路径 + +返回:`None` + +返回类型:`void` + + + +### `model_dir()` + +返回设置的模型文件夹路径。 + +参数: + +- `None` + +返回:模型文件夹路径 + +返回类型:`std::string` + +### `set_model_from_buffer(model_buffer)` + +设置模型的内存数据,当需要从内存加载模型时使用。 + +参数: + +- `model_buffer(std::string)` - 内存中的模型数据 + +返回:`None` + +返回类型:`void` + +### `set_model_buffer(model_buffer, model_buffer_size, param_buffer, param_buffer_size)` + +**注意**:Lite模型格式在release/v2.3.0之后修改,本接口为加载老格式模型的接口,将在release/v3.0.0废弃。建议替换为`set_model_from_buffer`接口。 + +设置模型、参数的内存地址,当需要从内存加载模型时使用。 + +示例: + +```c++ +// 读取模型文件到内存 +std::string model_buffer = ReadFile(FLAGS_model_path); +std::string params_buffer = lite::ReadFile(FLAGS_params_path); + +// 设置MobileConfig +lite_api::MobileConfig config; +config.set_model_buffer(model_buffer.c_str(), model_buffer.size(), + params_buffer.c_str(), params_buffer.size()); + +// 根据MobileConfig创建PaddlePredictor +std::shared_ptr predictor = CreatePaddlePredictor(config); +``` + +参数: + +- `model_buffer(const char*)` - 内存中模型结构数据。 +- `model_buffer_size(size_t)` - 内存中模型结构数据的大小。 +- `param_buffer(const char*)` - 内存中模型参数数据。 +- `param_buffer_size(size_t)` - 内存中模型参数数据的大小。 + +返回:`None` + +返回类型:`Void` + + + +### `model_from_memory()` + +是否从内存中加载模型,当使用`set_model_buffer`接口时返回`true` + +参数: + +- `None` + +返回:是否从内存加载模型 + +返回类型:`bool` + + + +### `model_buffer()` + +获取内存中模型结构数据。 + +参数: + +- `None` + +返回:内存中模型结构数据 + +返回类型:`const std::string&` + + + +### `param_buffer()` + +获取内存中模型参数数据。 + +参数: + +- `None` + +返回:内存中模型结构数据 + +返回类型:`const std::string&` + + + +### `set_power_mode(mode)` + +设置CPU能耗模式。若不设置,则默认使用`LITE_POWER_HIGH`。 + +*注意:只在开启`OpenMP`时生效,否则系统自动调度。* + +参数: + +- `mode(PowerMode)` - CPU能耗模式 + +返回:`None` + +返回类型:`void` + + + +### `power_mode()` + +获取设置的CPU能耗模式。 + +参数: + +- `None` + +返回:设置的CPU能耗模式 + +返回类型:`PowerMode` + + + +### `set_threads(threads)` + +设置工作线程数。若不设置,则默认使用单线程。 + +*注意:只在开启`OpenMP`的模式下生效,否则只使用单线程。* + +参数: + +- `threads(int)` - 工作线程数 + +返回:`None` + +返回类型:`void` + + + +### `threads()` + +获取设置的工作线程数。 + +参数: + +- `None` + +返回:工作线程数 + +返回类型:`int` + +## PaddlePredictor + +```c++ +class PaddlePredictor +``` + +`PaddlePredictor`是Paddle-Lite的预测器,由`CreatePaddlePredictor`根据`MobileConfig`进行创建。用户可以根据PaddlePredictor提供的接口设置输入数据、执行模型预测、获取输出以及获得当前使用lib的版本信息等。 + +示例: + +```c++ +int64_t ShapeProduction(const shape_t& shape) { + int64_t res = 1; + for (auto i : shape) res *= i; + return res; +} + +// 设置MobileConfig +MobileConfig config; +config.set_model_dir(FLAGS_model_dir); + +// 根据MobileConfig创建PaddlePredictor +std::shared_ptr predictor = CreatePaddlePredictor(config); + +// 获得模型的输入和输出名称 +std::vector input_names = predictor->GetInputNames(); +for (int i = 0; i < input_names.size(); i ++) { + printf("Input name[%d]: %s\n", i, input_names[i].c_str()); +} +std::vector output_names = predictor->GetOutputNames(); +for (int i = 0; i < output_names.size(); i ++) { + printf("Output name[%d]: %s\n", i, output_names[i].c_str()); +} + +// 准备输入数据 +// (1)根据index获取输入Tensor +std::unique_ptr input_tensor(std::move(predictor->GetInput(0))); +// (2)根据名称获取输入Tensor +// std::unique_ptr input_tensor(std::move(predictor->GetInputByName(input_names[0]))); +input_tensor->Resize({1, 3, 224, 224}); +auto* data = input_tensor->mutable_data(); +for (int i = 0; i < ShapeProduction(input_tensor->shape()); ++i) { + data[i] = 1; +} + +// 执行预测 +predictor->Run(); + +// 获取输出 +// (1)根据index获取输出Tensor +std::unique_ptr output_tensor(std::move(predictor->GetOutput(0))); +// (2)根据名称获取输出Tensor +// std::unique_ptr output_tensor(std::move(predictor->GetOutput(output_names[0]))); +printf("Output dim: %d\n", output_tensor->shape()[1]); +for (int i = 0; i < ShapeProduction(output_tensor->shape()); i += 100) { + printf("Output[%d]: %f\n", i, output_tensor->data()[i]); +} +``` + +### `GetInput(index)` + +获取输入Tensor指针,用来设置模型的输入数据。 + +参数: + +- `index(int)` - 输入Tensor的索引 + +返回:第`index`个输入`Tensor`的指针 + +返回类型:`std::unique_ptr` + + + +### `GetOutput(index)` + +获取输出Tensor的指针,用来获取模型的输出结果。 + +参数: + +- `index(int)` - 输出Tensor的索引 + +返回:第`index`个输出Tensor`的指针 + +返回类型:`std::unique_ptr` + +### `GetInputNames()` + +获取所有输入Tensor的名称。 + +参数: + +- `None` + +返回:所有输入Tensor的名称 + +返回类型:`std::vector` + +### `GetOutputNames()` + +获取所有输出Tensor的名称。 + +参数: + +- `None` + +返回:所有输出Tensor的名称 + +返回类型:`std::vector` + +### `GetInputByName(name)` + +根据名称获取输出Tensor的指针,用来获取模型的输出结果。 + +参数: + +- `name(const std::string)` - 输入Tensor的名称 + +返回:输入Tensor`的指针 + +返回类型:`std::unique_ptr` + +### `GetTensor(name)` + +根据名称获取输出Tensor的指针。 + +**注意**:`GetTensor`接口是为开发者设计的调试接口,可以输出[转化](../model_optimize_tool)后模型中的任一节点。如果出现`GetTensor(InputName)`返回值为空`Tensor`,可能原因是以该`InputName`命名的Tensor在模型转化的**子图融合**过程被融合替换了。 + +参数: + +- `name(const std::string)` - Tensor的名称 + +返回:指向`const Tensor`的指针 + +返回类型:`std::unique_ptr` + +### `Run()` + +执行模型预测,需要在***设置输入数据后***调用。 + +参数: + +- `None` + +返回:`None` + +返回类型:`void` + + + +### `GetVersion()` + +用于获取当前lib使用的代码版本。若代码有相应tag则返回tag信息,如`v2.0-beta`;否则返回代码的`branch(commitid)`,如`develop(7e44619)`。 + +参数: + +- `None` + +返回:当前lib使用的代码版本信息 + +返回类型:`std::string` + +## TargetType + +```c++ +class TargetType; +``` +`TargetType`为目标设备硬件类型,用户可以根据应用场景选择硬件平台类型。 + +枚举型变量`TargetType`的所有可能取值包括: + +`{X86, CUDA, ARM, OpenCL, FPGA, NPU}` + + +## PrecisionType +```c++ +class PrecisionType {FP32}; +``` +`PrecisionType`为模型中Tensor的数据精度,默认值为FP32(float32)。 + +枚举型变量`PrecisionType`的所有可能取值包括: + +`{FP32, INT8, INT32, INT64}` + + + + +## DataLayoutType + +```c++ +class DataLayoutType {NCHW}; +``` +`DataLayoutType`为Tensor的数据格式,默认值为NCHW(number, channel, height, weigth)。 + +枚举型变量`DataLayoutType`的所有可能取值包括: + +` {NCHW, NHWC}` + + + +## Place +```c++ +class Place{ + TargetType target; + PrecisionType precision{FP32}; + DataLayoutType layout{NCHW} +} +``` +`Place`是`TargetType`、`PrecisionType`和`DataLayoutType`的集合,说明运行时的设备类型、数据精度和数据格式。 + +示例: +```C++ +Place{TargetType(ARM), PrecisionType(FP32), DataLayoutType(NCHW)} +``` + +## PowerMode + +```c++ +enum PowerMode; +``` + +`PowerMode`为ARM CPU能耗模式,用户可以根据应用场景设置能耗模式获得最优的能效比。 + +示例: + +```c++ +MobileConfig config; +// 设置NaiveBuffer格式模型目录 +config.set_model_dir(FLAGS_model_dir); +// 设置能耗模式 +config.set_power_mode(LITE_POWER_HIGH); + +// 根据MobileConfig创建PaddlePredictor +std::shared_ptr predictor = CreatePaddlePredictor(config); +``` + +PowerMode详细说明如下: + +| 选项 | 说明 | +| :------------------: | ------------------------------------------------------------ | +| LITE_POWER_HIGH | 绑定大核运行模式。如果ARM CPU支持big.LITTLE,则优先使用并绑定Big cluster。如果设置的线程数大于大核数量,则会将线程数自动缩放到大核数量。如果系统不存在大核或者在一些手机的低电量情况下会出现绑核失败,如果失败则进入不绑核模式。 | +| LITE_POWER_LOW | 绑定小核运行模式。如果ARM CPU支持big.LITTLE,则优先使用并绑定Little cluster。如果设置的线程数大于小核数量,则会将线程数自动缩放到小核数量。如果找不到小核,则自动进入不绑核模式。 | +| LITE_POWER_FULL | 大小核混用模式。线程数可以大于大核数量。当线程数大于核心数量时,则会自动将线程数缩放到核心数量。 | +| LITE_POWER_NO_BIND | 不绑核运行模式(推荐)。系统根据负载自动调度任务到空闲的CPU核心上。 | +| LITE_POWER_RAND_HIGH | 轮流绑定大核模式。如果Big cluster有多个核心,则每预测10次后切换绑定到下一个核心。 | +| LITE_POWER_RAND_LOW | 轮流绑定小核模式。如果Little cluster有多个核心,则每预测10次后切换绑定到下一个核心。 | + + + +## Tensor + +```c++ +class Tensor +``` + +Tensor是Paddle-Lite的数据组织形式,用于对底层数据进行封装并提供接口对数据进行操作,包括设置Shape、数据、LoD信息等。 + +*注意:用户应使用`PaddlePredictor`的`GetInput`和`GetOuput`接口获取输入/输出的`Tensor`。* + +示例: + +```c++ +int64_t ShapeProduction(const shape_t& shape) { + int64_t res = 1; + for (auto i : shape) res *= i; + return res; +} + +// 设置MobileConfig +MobileConfig config; +config.set_model_dir(FLAGS_model_dir); + +// 根据MobileConfig创建PaddlePredictor +std::shared_ptr predictor = CreatePaddlePredictor(config); + +// 准备输入数据, 获取输入Tensor +std::unique_ptr input_tensor(std::move(predictor->GetInput(0))); +// 设置输入Tensor维度信息 +input_tensor->Resize({1, 3, 224, 224}); +// 设置输入数据 +auto* data = input_tensor->mutable_data(); +for (int i = 0; i < ShapeProduction(input_tensor->shape()); ++i) { + data[i] = 1; +} + +// 执行预测 +predictor->Run(); + +// 获取输出Tensor +std::unique_ptr output_tensor(std::move(predictor->GetOutput(0))); +// 获取输出Tensor维度 +printf("Output dim: %d\n", output_tensor->shape()[1]); +// 获取输出Tensor数据 +for (int i = 0; i < ShapeProduction(output_tensor->shape()); i += 100) { + printf("Output[%d]: %f\n", i, output_tensor->data()[i]); +} +``` + +### `Resize(shape)` + +设置Tensor的维度信息。 + +参数: + +- `shape(std::vector)` - 维度信息 + +返回:`None` + +返回类型:`void` + + + +### `shape()` + +获取Tensor的维度信息。 + +参数: + +- `None` + +返回:Tensor的维度信息 + +返回类型:`std::vector` + + + +### `data()` + +```c++ +template +const T* data() const; +``` + +获取Tensor的底层数据的常量指针,根据传入的不同模型类型获取相应数据。用于读取Tensor数据。 + +示例: + +```c++ +std::unique_ptr output_tensor(std::move(predictor->GetOutput(0))); +// 如果模型中输出为float类型 +output_tensor->data() +``` + +参数: + +- `None` + +返回:`Tensor`底层数据常量指针 + +返回类型:`const T*` + + + +### `mutable_data()` + +```c++ +template +T* mutable_data() const; +``` + +获取Tensor的底层数据的指针,根据传入的不同模型类型获取相应数据。用于设置Tensor数据。 + +示例: + +```c++ +std::unique_ptr input_tensor(std::move(predictor->GetInput(0))); +// 如果模型中输出为float类型 +auto* data = input_tensor->mutable_data(); +// 设置Tensor数据 +for (int i = 0; i < ShapeProduction(input_tensor->shape()); ++i) { + data[i] = 1; +} +``` + +参数: + +- `None` + +返回:`Tensor`底层数据指针 + +返回类型:`T*` + + + +### `SetLoD(lod)` + +设置Tensor的LoD信息。 + +参数: + +- `lod(std::vector>)` - Tensor的LoD信息 + +返回:`None` + +返回类型:`void` + + + +### `lod()` + +获取Tensor的LoD信息 + +参数: + +- `None` + +返回:`Tensor`的LoD信息 + +返回类型:`std::vector>` diff --git a/docs/api_reference/index.rst b/docs/api_reference/index.rst new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/docs/benchmark/benchmark.md b/docs/benchmark/benchmark.md new file mode 100644 index 0000000000000000000000000000000000000000..efb0805fddc0bd62a2b21a130018edaa9213e0cf --- /dev/null +++ b/docs/benchmark/benchmark.md @@ -0,0 +1,150 @@ +# Benchmark 数据 + +可以参考[benchmark_tools](benchmark_tools),推荐**一键benchmark**。 + +## 测试环境 + +* 测试模型 + * fp32模型 + * mobilenet_v1 + * mobilenet_v2 + * squeezenet_v1.1 + * mnasnet + * shufflenet_v2 + + * int8模型 + * mobilenet_v1 + * mobilenet_v2 + * resnet50 + +* 测试机器(android ndk ndk-r17c) + * 骁龙855 + * xiaomi mi9, snapdragon 855 + * 4xA76(1@2.84GHz + 3@2.4GHz) + 4xA55@1.78GHz + + + * 骁龙845 + * xiaomi mi8, 845 + * 2.8GHz(大四核),1.7GHz(小四核) + + * 骁龙835 + * xiaomi mix2, snapdragon 835 + * 2.45GHz(大四核),1.9GHz(小四核) + + * 骁龙625 + * oppo R9s, snapdragon625 + * A53 x 8, big core@2.0GHz + + * 骁龙653 + * 360 N5, snapdragon 653 + * 4 x A73@2.0GHz + 4 x A53@1.4GHz + + * 麒麟970 + * HUAWEI Mate10 + +* 测试说明 + * branch: release/2.0.0 + * warmup=10, repeats=30,统计平均时间,单位是ms + * 当线程数为1时,```DeviceInfo::Global().SetRunMode```设置LITE_POWER_HIGH,否者设置LITE_POWER_NO_BIND + * 模型的输入图像的维度是{1, 3, 224, 224},输入图像的每一位数值是1 + +## 测试数据 + + +### fp32模型测试数据 + +#### paddlepaddle model + + +骁龙855|armv7 | armv7 | armv7 |armv8 | armv8 |armv8 +----| ---- | ---- | ---- | ---- |---- |---- +threads num|1 |2 |4 |1 |2 |4 +mobilenet_v1 |32.19 |18.81 |10.90 |30.92 |18.31 |10.15 +mobilenet_v2 |22.91 |13.75 |8.64 |21.15 |12.79 |7.84 +shufflenet_v2 |4.67 |3.37 |2.65 |4.43 |3.15 |2.66 +squeezenet_v1.1 |25.10 |15.93 |9.68 |23.28 |14.61 |8.71 +mnasnet |21.84 |13.14 |7.96 |19.61 |11.88 |7.55 + + + +骁龙835|armv7 | armv7 | armv7 |armv8 | armv8 |armv8 +----| ---- | ---- | ---- | ---- |---- |---- +threads num|1 |2 |4 |1 |2 |4 +mobilenet_v1 |94.13 |52.17 |30.68 |88.28 |47.58 |26.64 +mobilenet_v2 |61.24 |34.64 |22.36 |56.66 |32.19 |19.63 +shufflenet_v2 |10.87 |6.92 |5.12 |10.41 |6.76 |4.97 +squeezenet_v1.1 |73.61 |42.25 |24.44 |64.87 |38.43 |23.06 +mnasnet |58.22 |33.43 |20.44 |53.43 |30.20 |18.09 + + +麒麟980|armv7 | armv7 | armv7 |armv8 | armv8 |armv8 +----| ---- | ---- | ---- | ---- |---- |---- +threads num|1 |2 |4 |1 |2 |4 +mobilenet_v1 |55.11 |28.24 |13.27 |34.24 |17.74 |12.41 +mobilenet_v2 |37.03 |19.80 |51.94 |23.64 |12.98 |9.38 +shufflenet_v2 |7.26 |4.94 |15.06 |5.32 |3.33 |2.82 +squeezenet_v1.1 |42.73 |23.66 |57.39 |26.03 |14.53 |13.66 +mnasnet |36.87 |20.15 |46.04 |21.85 |12.06 |8.68 + +麒麟970|armv7 | armv7 | armv7 |armv8 | armv8 |armv8 +----| ---- | ---- | ---- | ---- |---- |---- +threads num|1 |2 |4 |1 |2 |4 +mobilenet_v1 |97.80 |52.64 |34.46 |94.51 |49.36 |28.43 +mobilenet_v2 |66.55 |38.52 |23.19 |62.89 |34.93 |21.53 +shufflenet_v2 |13.78 |8.11 |5.93 |11.95 |7.90 |5.91 +squeezenet_v1.1 |77.64 |43.67 |25.72 |69.91 |40.66 |24.62 +mnasnet |61.86 |34.62 |22.68 |59.61 |32.79 |19.56 + +#### caffe model + +骁龙855|armv7 | armv7 | armv7 |armv8 | armv8 |armv8 +----| ---- | ---- | ---- | ---- |---- |---- +threads num|1 |2 |4 |1 |2 |4 | +mobilenet_v1 |32.42 |18.68 |10.86 |30.92 |18.35 |10.07 | +mobilenet_v2 |29.53 |17.76 |10.89 |27.19 |16.53 |9.75 | +shufflenet_v2 |4.61 |3.29 |2.61 |4.36 |3.11 |2.51 | + + +骁龙835|armv7 | armv7 | armv7 |armv8 | armv8 |armv8 +----| ---- | ---- | ---- | ---- |---- |---- +threads num|1 |2 |4 |1 |2 |4 | +mobilenet_v1 |92.52 |52.34 |30.37 |88.31 |49.75 |27.29 | +mobilenet_v2 |79.50 |45.67 |28.79 |76.13 |44.01 |26.13 | +shufflenet_v2 |10.94 |7.08 |5.16 |10.64 |6.83 |5.01 | + + +麒麟980|armv7 | armv7 | armv7 |armv8 | armv8 |armv8 +----| ---- | ---- | ---- | ---- |---- |---- +threads num|1 |2 |4 |1 |2 |4 | +mobilenet_v1 |55.36 |28.18 |13.31 |34.42 |17.93 |12.52 | +mobilenet_v2 |49.17 |26.10 |65.49 |30.50 |16.66 |11.72 | +shufflenet_v2 |8.45 |5.00 |15.65 |4.58 |3.14 |2.83 | + + +麒麟970|armv7 | armv7 | armv7 |armv8 | armv8 |armv8 +----| ---- | ---- | ---- | ---- |---- |---- +threads num|1 |2 |4 |1 |2 |4 | +mobilenet_v1 |97.85 |53.38 |33.85 |94.29 |49.42 |28.29 | +mobilenet_v2 |87.40 |50.25 |31.85 |85.55 |48.11 |28.24 | +shufflenet_v2 |12.16 |8.39 |6.21 |12.21 |8.33 |6.32 | + +#### int8量化模型测试数据 + +骁龙855|armv7 | armv7 | armv7 |armv8 | armv8 |armv8 +----| ---- | ---- | ---- | ---- |---- |---- +threads num|1 |2 |4 |1 |2 |4 | +mobilenet_v1 |36.80 |21.58 |11.12 | 14.01 |8.13 |4.32 | +mobilenet_v2 |28.72 |19.08 |12.49 | 17.24 |11.55 |7.82 | + +骁龙835|armv7 | armv7 | armv7 |armv8 | armv8 |armv8 +----| ---- | ---- | ---- | ---- |---- |---- +threads num|1 |2 |4 |1 |2 |4 | +mobilenet_v1 |60.76 |32.25 |16.66 |56.57 |29.84 |15.24 | +mobilenet_v2 |49.38 |31.10 |22.07 |47.52 |28.18 |19.24 | + + +麒麟970|armv7 | armv7 | armv7 |armv8 | armv8 |armv8 +----| ---- | ---- | ---- | ---- |---- |---- +threads num|1 |2 |4 |1 |2 |4 | +mobilenet_v1 |65.95 |34.39 |18.68 |60.86 |30.98 |16.31 | +mobilenet_v2 |68.87 |39.39 |24.43 |65.57 |37.31 |20.87 | diff --git a/docs/benchmark/benchmark_tools.md b/docs/benchmark/benchmark_tools.md new file mode 100644 index 0000000000000000000000000000000000000000..60341762b70772bc46196b836050714b9d43228b --- /dev/null +++ b/docs/benchmark/benchmark_tools.md @@ -0,0 +1,187 @@ +# Benchmark 测试方法 + +本文将会介绍,在**Ubuntu:16.04交叉编译环境**下,用安卓手机在终端测试Paddle-Lite的性能,并介绍两种Benchmark方法: + +1. **一键Benchmark**:适用于想快速获得常见模型性能的用户,下载预编译好的benchmark可执行文件; +2. **逐步Benchmark**:将**一键Benchmark**流程拆解讲解。 + +## 环境准备 + +1. 准备[adb](https://developer.android.com/studio/command-line/adb)等必备软件: +```shell +sudo apt update +sudo apt install -y wget adb +``` +2. 检查手机与电脑连接。安卓手机USB连上电脑,打开设置 -> 开启开发者模式 -> 开启USB调试 -> 允许(授权)当前电脑调试手机; +3. 在电脑终端输入`adb devices`命令,查看当前连接到的设备: +```shell +adb devices +``` +命令成功执行,显示结果类似下面(序列码略有不同): +```shell +List of devices attached +712QSDSEMMS7C device +``` + +## 一. 一键Benchmark + +执行以下命令,完成Benchmark: + +```shell +wget -c https://paddle-inference-dist.bj.bcebos.com/PaddleLite/benchmark_0/run_benchmark.sh +sh run_benchmark.sh +``` + +该`run_benchmark.sh`脚本会: + +1. 下载模型,并上传手机:包含mobilenetv1/v2、shufflenetv2、squeezenetv1.1、mnasnet; +2. 下载pre-built android-armv7和android-armv8的可执行文件,并上传手机:`benchmark_bin_v7`和`benchmark_bin_v8`; +3. 自动执行另一个脚本`benchmark.sh`(多台手机连接USB,请在`benchmark.sh`脚本中对`adb`命令后加上测试手机的`serial number`); +4. 从手机下载benchmark结果`result_armv7.txt`和`result_armv8.txt`,到当前目录,并显示Benchmark结果。 + +## 二. 逐步Benchmark + +### 1. 获取benchmark可执行文件 + +benchmark_bin文件可以测试PaddleLite的性能,有下面两种方式获得。 + +#### 方式一:下载benchmark_bin可执行文件 + +```shell +# Download benchmark_bin for android-armv7 +wget -c https://paddle-inference-dist.bj.bcebos.com/PaddleLite/benchmark_0/benchmark_bin_v7 + +# Download benchmark_bin for android-armv8 +wget -c https://paddle-inference-dist.bj.bcebos.com/PaddleLite/benchmark_0/benchmark_bin_v8 +``` + +#### 方式二:由源码编译benchmark_bin文件 + +根据[源码编译](../source_compile)准备编译环境,拉取PaddleLite最新release发布版代码,并在仓库根目录下,执行: + +```shell +########################################### +# Build benchmark_bin for android-armv7 # +########################################### +./lite/tools/ci_build.sh \ + --arm_os="android" \ + --arm_abi="armv7" \ + --arm_lang="gcc " \ + build_arm + +# `benchmark_bin` 在: /build.lite.android.armv7.gcc/lite/api/benchmark_bin + +########################################### +# Build benchmark_bin for android-armv8 # +########################################### +./lite/tools/ci_build.sh \ + --arm_os="android" \ + --arm_abi="armv8" \ + --arm_lang="gcc " \ + build_arm + +# `benchmark_bin` 在: /build.lite.android.armv8.gcc/lite/api/benchmark_bin +``` + +> **注意**:为了避免在docker内部访问不到手机的问题,建议编译得到benchmark_bin后退出到docker外面,并且将benchmark_bin文件拷贝到一个临时目录。然后在该临时目录下,按照下面步骤下载模型、拷贝脚本、测试。 + +### 2. 准备模型 + +PaddleLite为Benchmark准备好了[常见Benchmark模型](https://paddle-inference-dist.bj.bcebos.com/PaddleLite/benchmark_0/benchmark_models.tgz)。 + +执行以下命令,下载常见Benchmark模型并解压: + +```shell +wget -c https://paddle-inference-dist.bj.bcebos.com/PaddleLite/benchmark_0/benchmark_models.tgz +tar zxvf benchmark_models.tgz +``` + +如果测试其他模型,请将模型文件放到 `benchmark_models` 文件夹中。 + +### 3. benchmark.sh脚本 + +benchmark测试的执行脚本`benchmark.sh` 位于源码中的`/PaddleLite/lite/tools/benchmark.sh`位置,测试时需要将`benchmark.sh`、 `benchmark_bin` 、 `benchmark_models` 文件复制到同一目录下。 + +### 4. 测试 + +从终端进入benchmark.sh、可执行文件(benchmark_bin_v7、benchmark_bin_v8)和模型文件(benchmark_models)所在文件夹。 + +如果 `benchmark_models` 中所有模型文件都已经使用 `model_optimize_tool` 进行转换,则使用 benchmark.sh 脚本执行如下命令进行测试: + +```shell +# Benchmark for android-armv7 +sh benchmark.sh ./benchmark_bin_v7 ./benchmark_models result_armv7.txt + +# Benchmark for android-armv8 +sh benchmark.sh ./benchmark_bin_v8 ./benchmark_models result_armv8.txt +``` + +如果 `benchmark_models` 中所有模型文件都没有使用 `model_optimize_tool` 进行转换,则执行下面的命令。`benchmark_bin` 会首先转换模型,然后加载模型进行测试。 + +```shell +# Benchmark for android-armv7 +sh benchmark.sh ./benchmark_bin_v7 ./benchmark_models result_armv7.txt true + +# Benchmark for android-armv8 +sh benchmark.sh ./benchmark_bin_v8 ./benchmark_models result_armv8.txt true +``` + +测试结束后,armv7和armv8的结果,分别保存在当前目录下的`result_armv7.txt`和`result_armv8.txt`文件中。 + +**查看测试结果** + +在当前目录的`result_armv7.txt`和`result_armv8.txt`文件,查看测试结果。 + +> 不同手机,不同版本,测试模型的性能数据不同。 + +```shell +run benchmark armv7 +-------------------------------------- +PaddleLite Benchmark +Threads=1 Warmup=10 Repeats=30 +-- mnasnet avg = 159.8427 ms +-- mobilenet_v1 avg = 235.0072 ms +-- mobilenet_v2 avg = 173.0387 ms +-- shufflenet_v2 avg = 76.0040 ms +-- squeezenet_v11 avg = 164.2957 ms + +Threads=2 Warmup=10 Repeats=30 +-- mnasnet avg = 83.1287 ms +-- mobilenet_v1 avg = 121.6029 ms +-- mobilenet_v2 avg = 86.6175 ms +-- shufflenet_v2 avg = 41.5761 ms +-- squeezenet_v11 avg = 87.8678 ms + +Threads=4 Warmup=10 Repeats=30 +-- mnasnet avg = 73.3880 ms +-- mobilenet_v1 avg = 119.0739 ms +-- mobilenet_v2 avg = 85.3050 ms +-- shufflenet_v2 avg = 38.0762 ms +-- squeezenet_v11 avg = 64.2201 ms +-------------------------------------- + +run benchmark armv8 +-------------------------------------- +PaddleLite Benchmark +Threads=1 Warmup=10 Repeats=30 +-- mnasnet avg = 165.3073 ms +-- mobilenet_v1 avg = 306.0188 ms +-- mobilenet_v2 avg = 195.1884 ms +-- shufflenet_v2 avg = 99.3692 ms +-- squeezenet_v11 avg = 156.6971 ms + +Threads=2 Warmup=10 Repeats=30 +-- mnasnet avg = 90.2290 ms +-- mobilenet_v1 avg = 157.0007 ms +-- mobilenet_v2 avg = 118.1607 ms +-- shufflenet_v2 avg = 68.6804 ms +-- squeezenet_v11 avg = 91.3090 ms + +Threads=4 Warmup=10 Repeats=30 +-- mnasnet avg = 179.9730 ms +-- mobilenet_v1 avg = 204.0684 ms +-- mobilenet_v2 avg = 181.6486 ms +-- shufflenet_v2 avg = 123.2728 ms +-- squeezenet_v11 avg = 412.9046 ms +-------------------------------------- +``` diff --git a/docs/benchmark/index.rst b/docs/benchmark/index.rst new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 0000000000000000000000000000000000000000..ae8548e32056a8a824c11f6a622e91c4a6c7da2c --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,174 @@ +# -*- coding: utf-8 -*- +# +# Configuration file for the Sphinx documentation builder. +# +# This file does only contain a selection of the most common options. For a +# full list see the documentation: +# http://www.sphinx-doc.org/en/master/config + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +import os +import sys +#sys.path.insert(0, os.path.abspath('.')) + +import sphinx_rtd_theme +from recommonmark.parser import CommonMarkParser +from recommonmark.transform import AutoStructify + +# -- Project information ----------------------------------------------------- + +project = u'Paddle-Lite' +copyright = u'2020, Paddle-Lite Developer' +author = u'Paddle-Lite Developer' + +# The short X.Y version +version = u'latest' +# The full version, including alpha/beta/rc tags +release = u'' + + +# -- General configuration --------------------------------------------------- + +# If your documentation needs a minimal Sphinx version, state it here. +# +# needs_sphinx = '1.0' + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = ['recommonmark', 'sphinx_markdown_tables'] + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# The suffix(es) of source filenames. +# You can specify multiple suffix as a list of string: +# +source_suffix = ['.rst', '.md'] + +# The master toctree document. +master_doc = 'index' + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +# +# This is also used if you do content translation via gettext catalogs. +# Usually you set "language" from the command line for these cases. +language = None + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path. +exclude_patterns = [u'_build', 'Thumbs.db', '.DS_Store'] + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = None + + +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = 'sphinx_rtd_theme' + +# Theme options are theme-specific and customize the look and feel of a theme +# further. For a list of options available for each theme, see the +# documentation. +# +# html_theme_options = {} + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['_static'] + +# Custom sidebar templates, must be a dictionary that maps document names +# to template names. +# +# The default sidebars (for documents that don't match any pattern) are +# defined by theme itself. Builtin themes are using these templates by +# default: ``['localtoc.html', 'relations.html', 'sourcelink.html', +# 'searchbox.html']``. +# +# html_sidebars = {} + + +# -- Options for HTMLHelp output --------------------------------------------- + +# Output file base name for HTML help builder. +htmlhelp_basename = 'Paddle-Litedoc' + + +# -- Options for LaTeX output ------------------------------------------------ + +latex_elements = { + # The paper size ('letterpaper' or 'a4paper'). + # + # 'papersize': 'letterpaper', + + # The font size ('10pt', '11pt' or '12pt'). + # + # 'pointsize': '10pt', + + # Additional stuff for the LaTeX preamble. + # + # 'preamble': '', + + # Latex figure (float) alignment + # + # 'figure_align': 'htbp', +} + +# Grouping the document tree into LaTeX files. List of tuples +# (source start file, target name, title, +# author, documentclass [howto, manual, or own class]). +latex_documents = [ + (master_doc, 'Paddle-Lite.tex', u'Paddle-Lite Documentation', + u'Paddle-Lite Developer', 'manual'), +] + + +# -- Options for manual page output ------------------------------------------ + +# One entry per manual page. List of tuples +# (source start file, name, description, authors, manual section). +man_pages = [ + (master_doc, 'paddle-lite', u'Paddle-Lite Documentation', + [author], 1) +] + + +# -- Options for Texinfo output ---------------------------------------------- + +# Grouping the document tree into Texinfo files. List of tuples +# (source start file, target name, title, author, +# dir menu entry, description, category) +texinfo_documents = [ + (master_doc, 'Paddle-Lite', u'Paddle-Lite Documentation', + author, 'Paddle-Lite', 'One line description of project.', + 'Miscellaneous'), +] + + +# -- Options for Epub output ------------------------------------------------- + +# Bibliographic Dublin Core info. +epub_title = project + +# The unique identifier of the text. This can be a ISBN number +# or the project homepage. +# +# epub_identifier = '' + +# A unique identification for the text. +# +# epub_uid = '' + +# A list of files that should not be packed into the epub file. +epub_exclude_files = ['search.html'] diff --git a/docs/develop_guides/index.rst b/docs/develop_guides/index.rst new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/docs/images/architecture.png b/docs/images/architecture.png new file mode 100644 index 0000000000000000000000000000000000000000..35cb336a0640c868d6fc1df738f039a0e7b5884d Binary files /dev/null and b/docs/images/architecture.png differ diff --git a/docs/index.rst b/docs/index.rst new file mode 100644 index 0000000000000000000000000000000000000000..d7359f1d0508f8e85824f450ca07f095d047f90c --- /dev/null +++ b/docs/index.rst @@ -0,0 +1,71 @@ +.. Paddle-Lite documentation master file, created by + sphinx-quickstart on Thu Feb 6 14:11:30 2020. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +Welcome to Paddle-Lite's documentation! +======================================= + +.. toctree:: + :maxdepth: 1 + :caption: 简介 + :name: sec-introduction + + introduction/tech_highlights + introduction/architecture + +.. toctree:: + :maxdepth: 1 + :caption: Benchmark数据和方法 + :name: sec-benchmark + + benchmark/benchmark + benchmark/benchmark_tools + +.. toctree:: + :maxdepth: 1 + :caption: 安装 + :name: sec-install + + installation/source_compile + +.. toctree:: + :maxdepth: 1 + :caption: 使用指南 + :name: sec-user-guides + + user_guides/model_optimize_tool + user_guides/library_tailoring + user_guides/cuda + user_guides/opencl + +.. toctree:: + :maxdepth: 1 + :caption: 进阶使用指南 + + advanced_user_guides/support_operation_list + advanced_user_guides/add_operation + advanced_user_guides/add_layout + advanced_user_guides/model_quantization + advanced_user_guides/add_new_pass + advanced_user_guides/x86 + +.. toctree:: + :maxdepth: 1 + :caption: 开发者文档 + +.. toctree:: + :maxdepth: 1 + :caption: API文档 + + api_reference/cxx_api_doc + +.. toctree:: + :maxdepth: 1 + :caption: FAQ + +.. toctree:: + :maxdepth: 1 + :caption: paddle-mobile + + diff --git a/docs/installation/library.md b/docs/installation/library.md new file mode 100644 index 0000000000000000000000000000000000000000..ef2f8fdb18ade439d620b348738cbb752d5bd8b6 --- /dev/null +++ b/docs/installation/library.md @@ -0,0 +1,61 @@ + +# 预测库说明 + +Paddle-Lite的编译结果为预测库文件(包括静态库和动态库),具体编译过程参考[源码编译](./source_compile)。 + +Lite预测库分为**基础预测库**和**全量预测库**:基础预测库只打包了基础模型需要的基础算子,预测库体积较小;全量预测库打包了所有的Lite算子,可以支持更多的模型,但是预测库的体积也更大。 编译时由编译选项 `build_extra`(默认为OFF)控制,`--build_extra=OFF`时编译基础预测库,`--build_extra=ON`时编译全量的预测库。 + +## 基础预测库 + +### 编译方法 +编译时设置`--build_extra=OFF` (默认值) 或不指定即可编译出基础预测库。例如: + +``` +./lite/tools/build.sh --arm_os=android --arm_abi=armv8 --arm_lang=gcc --android_stl=c++_static tiny_publish +``` + +### 基础预测库支持的功能 + +(1)支持基础CV模型 + +(2)支持基础的in8量化模型 + +(3)支持[benchmark测试](../benchmark/benchmark) + + +### 基础预测库支持的基础模型: + +1. fluid基础模型(paddle model 提供的基础模型9个) + +``` +mobileNetV1 mnasnet yolov3 ssd_mobilenetv1 shufflenet_v2 +mobileNetV2 resnet50 unet squeezenet_v11 +``` + +2. int8量化模型模型 + +``` +mobilenet_v1 mobilenet_v2 resnet50 +``` + +### 特点 + 轻量级预测库,体积更小,支持常用的基础模型。 + + + +## 全量预测库 + +### 编译方法 +编译时设置`--build_extra=ON` 即可编译出全量预测库。例如: + +``` +./lite/tools/build.sh --arm_os=android --arm_abi=armv8 --arm_lang=gcc --android_stl=c++_static --build_extra=ON tiny_publish +``` +### 全量预测库功能 + +(1) 基础预测库所有功能 + +(2)支持所有Paddle-Lite中注册的所有算子 + +### 特点 + 支持更多的硬件平台和算子,可以支持更多模型但体量更大。 diff --git a/docs/installation/source_compile.md b/docs/installation/source_compile.md new file mode 100644 index 0000000000000000000000000000000000000000..f2016b83188b755eca8daab8a4aa38b25e08c0f1 --- /dev/null +++ b/docs/installation/source_compile.md @@ -0,0 +1,415 @@ + +# 源码编译 + +Paddle-Lite 提供了移动端的一键源码编译脚本 `lite/tools/build.sh`,编译流程如下: + +1. 环境准备(选择其一):Docker交叉编译环境、Linux交叉编译环境 +2. 编译:调用`build.sh`脚本一键编译 + +## 一、环境准备 + +目前支持三种编译的环境: + +1. Docker 容器环境, +2. Linux(推荐 Ubuntu 16.04)环境, +3. Mac OS 环境。 + +### 1、 Docker开发环境 + +[Docker](https://www.docker.com/) 是一个开源的应用容器引擎, 使用沙箱机制创建独立容器,方便运行不同程序。Docker初学者可以参考[Docker使用方法](https://thenewstack.io/docker-station-part-one-essential-docker-concepts-tools-terminology/)正确安装Docker。 + +#### 准备Docker镜像 + +有两种方式准备Docker镜像,推荐从Dockerhub直接拉取Docker镜像 + +```shell +# 方式一:从Dockerhub直接拉取Docker镜像 +docker pull paddlepaddle/paddle-lite:2.0.0_beta + +# 方式二:本地源码编译Docker镜像 +git clone https://github.com/PaddlePaddle/Paddle-Lite.git +cd Paddle-Lite/lite/tools +mkdir mobile_image +cp Dockerfile.mobile mobile_image/Dockerfile +cd mobile_image +docker build -t paddlepaddle/paddle-lite . + +# 镜像编译成功后,可用`docker images`命令,看到`paddlepaddle/paddle-lite`镜像。 +``` + +#### 进入Docker容器 + +在拉取Paddle-Lite仓库代码的上层目录,执行如下代码,进入Docker容器: + +```shell +docker run -it \ + --name paddlelite_docker \ + -v $PWD/Paddle-Lite:/Paddle-Lite \ + --net=host \ + paddlepaddle/paddle-lite /bin/bash +``` + +该命令的含义:将容器命名为`paddlelite_docker`即``,将当前目录下的`Paddle-Lite`文件夹挂载到容器中的`/Paddle-Lite`这个根目录下,并进入容器中。至此,完成Docker环境的准备。 + +#### Docker常用命令 + +```shell +# 退出容器但不停止/关闭容器:键盘同时按住三个键:CTRL + q + p + +# 启动停止的容器 +docker start + +# 从shell进入已启动的容器 +docker attach + +# 停止正在运行的Docker容器 +docker stop + +# 重新启动正在运行的Docker容器 +docker restart + +# 删除Docker容器 +docker rm +``` + +### 2、Linux 开发环境 + +#### Android + +##### 交叉编译环境要求 + +- gcc、g++、git、make、wget、python、adb +- Java environment +- cmake(建议使用3.10或以上版本) +- Android NDK (建议ndk-r17c) + +##### 具体步骤 + +安装软件部分以 Ubuntu 为例,其他 Linux 发行版类似。 + +```shell +# 1. Install basic software +apt update +apt-get install -y --no-install-recommends \ + gcc g++ git make wget python unzip adb curl + +# 2. Prepare Java env. +apt-get install -y default-jdk + +# 3. Install cmake 3.10 or above +wget -c https://mms-res.cdn.bcebos.com/cmake-3.10.3-Linux-x86_64.tar.gz && \ + tar xzf cmake-3.10.3-Linux-x86_64.tar.gz && \ + mv cmake-3.10.3-Linux-x86_64 /opt/cmake-3.10 && \ + ln -s /opt/cmake-3.10/bin/cmake /usr/bin/cmake && \ + ln -s /opt/cmake-3.10/bin/ccmake /usr/bin/ccmake + +# 4. Download Android NDK for linux-x86_64 +# Note: Skip this step if NDK installed +# recommand android-ndk-r17c-darwin-x86_64 +# ref: https://developer.android.com/ndk/downloads +cd /tmp && curl -O https://dl.google.com/android/repository/android-ndk-r17c-linux-x86_64.zip +cd /opt && unzip /tmp/android-ndk-r17c-linux-x86_64.zip + +# 5. Add environment ${NDK_ROOT} to `~/.bashrc` +echo "export NDK_ROOT=/opt/android-ndk-r17c" >> ~/.bashrc +source ~/.bashrc +``` + +#### ARM Linux + +适用于基于 ARMv8 和 ARMv7 架构 CPU 的各种开发板,例如 RK3399,树莓派等,目前支持交叉编译和本地编译两种方式,对于交叉编译方式,在完成目标程序编译后,可通过 scp 方式将程序拷贝到开发板运行。 + +##### 交叉编译 + +###### 编译环境要求 + +- gcc、g++、git、make、wget、python、scp +- cmake(建议使用3.10或以上版本) + +###### 具体步骤 + +安装软件部分以 Ubuntu 为例,其他 Linux 发行版类似。 + +```shell +# 1. Install basic software +apt update +apt-get install -y --no-install-recommends \ + gcc g++ git make wget python unzip + +# 2. Install arm gcc toolchains +apt-get install -y --no-install-recommends \ + g++-arm-linux-gnueabi gcc-arm-linux-gnueabi \ + g++-arm-linux-gnueabihf gcc-arm-linux-gnueabihf \ + gcc-aarch64-linux-gnu g++-aarch64-linux-gnu + +# 3. Install cmake 3.10 or above +wget -c https://mms-res.cdn.bcebos.com/cmake-3.10.3-Linux-x86_64.tar.gz && \ + tar xzf cmake-3.10.3-Linux-x86_64.tar.gz && \ + mv cmake-3.10.3-Linux-x86_64 /opt/cmake-3.10 && \ + ln -s /opt/cmake-3.10/bin/cmake /usr/bin/cmake && \ + ln -s /opt/cmake-3.10/bin/ccmake /usr/bin/ccmake +``` + +##### 本地编译(直接在RK3399或树莓派上编译) + +###### 编译环境要求 + +- gcc、g++、git、make、wget、python +- cmake(建议使用3.10或以上版本) + +###### 具体步骤 + +安装软件部分以 Ubuntu 为例,其他 Linux 发行版本类似。 + +```shell +# 1. Install basic software +apt update +apt-get install -y --no-install-recomends \ + gcc g++ make wget python unzip + +# 2. install cmake 3.10 or above +wget https://www.cmake.org/files/v3.10/cmake-3.10.3.tar.gz +tar -zxvf cmake-3.10.3.tar.gz +cd cmake-3.10.3 +./configure +make +sudo make install +``` + +之后可通过cmake --version查看cmake是否安装成功。 + +至此,完成 Linux 交叉编译环境的准备。 + +### 3、Mac OS 开发环境 + +#### 交叉编译环境要求 + +- gcc、git、make、curl、unzip、java +- cmake(Android编译请使用3.10版本,IOS编译请使用3.15版本) +- 编译Android: Android NDK (建议ndk-r17c) +- 编译IOS: XCode(Version 10.1) + +#### 具体步骤 + +```bash +# 1. Install basic software +brew install curl gcc git make unzip wget + +# 2. Install cmake: mac上实现IOS编译和Android编译要求的cmake版本不一致,可以根据需求选择安装。 +# (1)在mac环境编译 Paddle-Lite 的Android版本,需要安装cmake 3.10 +# mkdir /usr/local/Cellar/cmake/ && cd /usr/local/Cellar/cmake/ +# wget https://cmake.org/files/v3.10/cmake-3.10.2-Darwin-x86_64.tar.gz +# tar zxf ./cmake-3.10.2-Darwin-x86_64.tar.gz +# mv cmake-3.10.2-Darwin-x86_64/CMake.app/Contents/ ./3.10.2 +# ln -s /usr/local/Cellar/cmake/3.10.2/bin/cmake /usr/local/bin/cmake +# (2)在mac环境编译 Paddle-Lite 的IOS版本,需要安装cmake 3.15 +# mkdir /usr/local/Cellar/cmake/ && cd /usr/local/Cellar/cmake/ +# cd /usr/local/Cellar/cmake/ +# wget https://cmake.org/files/v3.15/cmake-3.15.2-Darwin-x86_64.tar.gz +# tar zxf ./cmake-3.15.2-Darwin-x86_64.tar.gz +# mv cmake-3.15.2-Darwin-x86_64/CMake.app/Contents/ ./3.15.2 +# ln -s /usr/local/Cellar/cmake/3.15.2/bin/cmake /usr/local/bin/cmake + +# 3. Download Android NDK for Mac +# recommand android-ndk-r17c-darwin-x86_64 +# ref: https://developer.android.com/ndk/downloads +# Note: Skip this step if NDK installed +cd ~/Documents && curl -O https://dl.google.com/android/repository/android-ndk-r17c-darwin-x86_64.zip +cd ~/Library && unzip ~/Documents/android-ndk-r17c-darwin-x86_64.zip + +# 4. Add environment ${NDK_ROOT} to `~/.bash_profile` +echo "export NDK_ROOT=~/Library/android-ndk-r17c" >> ~/.bash_profile +source ~/.bash_profile + +# 5. Install Java Environment +brew cask install java + +# 6. 编译IOS需要安装XCode(Version 10.1),可以在App Store里安装。安装后需要启动一次并执行下面语句。 +# sudo xcode-select -s /Applications/Xcode.app/Contents/Developer +``` + +至此,完成 Mac 交叉编译环境的准备。 + +**注意**: Mac上编译Paddle-Lite的full_publish版本时,Paddle-Lite所在路径中不可以含有中文字符 + +## 二、编译PaddleLite + +### 下载代码 + +```shell +git clone https://github.com/PaddlePaddle/Paddle-Lite.git +cd Paddle-Lite +git checkout +``` + +### 编译模式与参数 + +编译脚本`./lite/tools/build.sh`,支持三种编译模式: + +| 编译模式 | 介绍 | 适用对象 | +|:-------:|-----|:-------:| +| tiny_publish | 编译移动端部署库,无第三方库依赖 | 用户 | +| full_publish | 编译移动端部署库,有第三方依赖如protobuf、glags等,含有可将模型转换为无需protobuf依赖的naive buffer格式的工具,供tiny_publish库使用 | 用户 | +| test | 编译指定`arm_os`、`arm_abi`下的移动端单元测试 | 框架开发者 | + +编译脚本`./lite/tools/build.sh`,追加参数说明: + +| 参数 | 介绍 | 值 | +|-----------|-------------|-------------| +| --arm_os |必选,选择安装平台 | `android`、`ios`、`ios64`、`armlinux` | +| --arm_abi |必选,选择编译的arm版本,其中`armv7hf`为ARMLinux编译时选用| `armv8`、`armv7`、`armv7hf`(仅`armlinux`支持) | +| --arm_lang |arm_os=android时必选,选择编译器 | `gcc`、`clang`(`clang`当前暂不支持) | +| --android_stl |arm_os=android时必选,选择静态链接STL或动态链接STL | `c++_static`、`c++_shared`| +| --build_java | 可选,是否编译java预测库(默认为OFF) | `ON`、`OFF` | +| --build_extra | 可选,是否编译全量预测库(默认为OFF)。详情可参考[预测库说明](./library.html)。 | `ON`、`OFF` | +| target |必选,选择编译模式,`tiny_publish`为编译移动端部署库、`full_publish`为带依赖的移动端部署库、`test`为移动端单元测试、`ios`为编译ios端`tiny_publish` | `tiny_publish`、`full_publish`、`test`、 `ios` | + +### 编译代码 + +**注意**:非开发者建议在编译前使用[**“加速第三方依赖库的下载”**](#id22)的方法,加速工程中第三方依赖库的下载与编译。 + +#### 编译`tiny publish`动态库 + +##### Android +```shell +./lite/tools/build.sh \ + --arm_os=android \ + --arm_abi=armv8 \ + --build_extra=OFF \ + --arm_lang=gcc \ + --android_stl=c++_static \ + --build_extra=OFF \ + tiny_publish +``` +##### IOS +```shell +./lite/tools/build.sh \ + --arm_os=ios64 \ + --arm_abi=armv8 \ + --build_extra=OFF \ + ios +``` +**注意:mac环境编译IOS 时,cmake版本需要高于cmake 3.15;mac环境上编译Android时,cmake版本需要设置为cmake 3.10。** + +ios tiny publish支持的编译选项: + +* `--arm_os`: 可选ios或者ios64 +* `--arm_abi`: 可选armv7和armv8(**注意**:当`arm_os=ios`时只能选择`arm_abi=armv7`,当`arm_os=ios64`时只能选择`arm_abi=armv8`) +* 如果mac编译过程中报错:"Invalid CMAKE_DEVELOPER_ROOT: does not exist", 运行: +```shell +sudo xcode-select -s /Applications/Xcode.app/Contents/Developer +``` +##### ARMLinux +```shell +./lite/tools/build.sh \ + --build_extra=OFF \ + --arm_os=armlinux \ + --arm_abi=armv7hf \ + --arm_lang=gcc \ + --build_extra=OFF \ + tiny_publish +``` +- `--arm_abi`: 树莓派3b使用armv7hf,RK3399使用armv8 + +#### 编译`full publish`动态库 + +##### Android +```shell +./lite/tools/build.sh \ + --arm_os=android \ + --arm_abi=armv8 \ + --build_extra=OFF \ + --arm_lang=gcc \ + --android_stl=c++_static \ + --build_extra=OFF \ + full_publish +``` +##### ARMLinux +```shell +./lite/tools/build.sh \ + --arm_os=armlinux \ + --arm_abi=armv7hf \ + --arm_lang=gcc \ + --build_extra=OFF \ + full_publish +``` +- `--arm_abi`: 树莓派3b使用armv7hf,RK3399使用armv8 + +### 编译结果说明 + +**编译最终产物位置**在 `build.lite.xxx.xxx.xxx` 下的 `inference_lite_lib.xxx.xxx` ,如 Android 下 ARMv8 的产物位于`inference_lite_lib.android.armv8`: + +![](https://user-images.githubusercontent.com/45189361/65375706-204e8780-dccb-11e9-9816-ab4563ce0963.png) + +**目录内容**(可能)如下: + +**Full_publish编译结果:** + +![](https://user-images.githubusercontent.com/45189361/65375704-19c01000-dccb-11e9-9650-6856c7a5bf82.png) + +**Tiny_publish结果:** + +![](https://user-images.githubusercontent.com/45189361/65375726-3bb99280-dccb-11e9-9903-8ce255371905.png) + +**IOS编译结果:** + +![](https://user-images.githubusercontent.com/45189361/65375726-3bb99280-dccb-11e9-9903-8ce255371905.png) + + + +**具体内容**说明: + +1、 `bin`文件夹:可执行工具文件 `paddle_code_generator`、`test_model_bin` + +2、 `cxx`文件夹:包含c++的库文件与相应的头文件 + +- `include` : 头文件 +- `lib` : 库文件 + - 打包的静态库文件: + - `libpaddle_api_full_bundled.a` :包含 full_api 和 light_api 功能的静态库 + - `libpaddle_api_light_bundled.a` :只包含 light_api 功能的静态库 + - 打包的动态态库文件: + - `libpaddle_full_api_shared.so` :包含 full_api 和 light_api 功能的动态库 + - `libpaddle_light_api_shared.so`:只包含 light_api 功能的动态库 + +3、 `demo`文件夹:示例 demo ,包含 C++ demo 和 Java demo。 + +- `cxx` : C++示例 demo + - `mobile_full` : full_api 的使用示例 + - `mobile_light` : light_api的使用示例 +- `java` :Java 示例 demo + - `android` : Java的 Android 示例 + +4、 `java` 文件夹:包含 Jni 的动态库文件与相应的 Jar 包 + +- `jar` : `PaddlePredictor.jar` +- `so` : Jni动态链接库 `libpaddle_lite_jni.so` + +5、 `third_party` 文件夹:第三方库文件`gflags` + +**注意:** + +1、 只有当`--arm_os=android` 时才会编译出: + +- Java库文件与示例:`Java`和`demo/java` + +- 动态库文件:`libpaddle_full_api_shared.so`,`libpaddle_light_api_shared.so` + +2、 `tiny_publish`编译结果不包括 C++ demo和 C++ 静态库,但提供 C++ 的 light_api 动态库、 Jni 动态库和Java demo + +### 加速第三方依赖库的下载 + +移动端相关编译所需的第三方库均位于 `/third-party` 目录下,默认编译过程中,会利用`git submodule update --init --recursive`链上相关的第三方依赖的仓库。 + +为加速`full_publish`、`test`编译模式中对`protobuf`等第三方依赖的下载,`build.sh` 和 `ci_build.sh`支持了从国内 CDN 下载第三方依赖的压缩包。 + +使用方法:`git clone`完`Paddle-Lite`仓库代码后,手动删除本地仓库根目录下的`third-party`目录: + +```shell +git clone https://github.com/PaddlePaddle/Paddle-Lite.git +git checkout +cd Paddle-Lite +rm -rf third-party +``` + +之后再根据本文档,进行后续编译时,便会忽略第三方依赖对应的`submodule`,改为下载第三方压缩包。 diff --git a/docs/introduction/architecture.md b/docs/introduction/architecture.md new file mode 100644 index 0000000000000000000000000000000000000000..1a94494af0b44a03988266d341be5788c46f96c2 --- /dev/null +++ b/docs/introduction/architecture.md @@ -0,0 +1,94 @@ +# 架构设计 + +Mobile 在这次升级为 Lite 架构, 侧重多硬件、高性能的支持,其主要设计思想如下 + +- 引入 Type system,强化多硬件、量化方法、data layout 的混合调度能力 +- 硬件细节隔离,通过不同编译开关,对支持的任何硬件可以自由插拔 +- 引入 MIR(Machine IR) 的概念,强化带执行环境下的优化支持 +- 优化期和执行期严格隔离,保证预测时轻量和高效率 + +架构图如下 + +![Paddle Inference Refactor1.0](https://user-images.githubusercontent.com/52520497/64949619-26e49580-d8ac-11e9-855a-514feb9b75af.png) + +## 编译期和执行期严格隔离设计 + +- compile time 优化完毕可以将优化信息存储到模型中;execution time 载入并执行 +- 两套 API 及对应的预测lib,满足不同场景 + - `CxxPredictor` 打包了 `Compile Time` 和 `Execution Time`,可以 runtime 在具体硬件上做分析和优化,得到最优效果 + - `MobilePredictor` 只打包 `Execution Time`,保持部署和执行的轻量 + +## `Execution Time` 轻量级设计和实现 + +- 每个 batch 实际执行只包含两个步骤执行 + - `Op.InferShape` + - `Kernel.Run`,Kernel 相关参数均使用指针提前确定,后续无查找或传参消耗 + - 设计目标,执行时,只有 kernel 计算本身消耗 +- 轻量级 `Op` 及 `Kernel` 设计,避免框架额外消耗 + - `Op` 只有 `CreateKernels` 和 `InferShape` 两个重要职能 + - `Kernel` 只有 `Run` 职能 + +## 多硬件后端支持 + +- 硬件通用行为,使用 `TargetWrapper` 模块做适配器适配,对上层框架提供一致界面 +- 框架上层策略保持硬件无关,如存储优化 (Memory optimize),计算剪枝 (Computation prune) 等,任何硬件接入均可直接复用 +- 框架支持了硬件通用行为,特定硬件细节不做过多约束,各硬件可以自行实现并接入框架 +- 计算模式上目前支持两种主流模型,一种是类似 X86, ARM CPU 等非异构设备;一种是 GPU,或 FPGA 等异构设备(支持 stream, event异步执行模式以及跨设备拷贝) + +--- +## 多硬件及算法混合调度支持 +`TensorTy` 用来表示 Tensor 类型 + +```c++ +struct TensorTy { + TargetType target; + PrecisionType precision; + DataLayout layout; + int deviceid; +}; +``` + +```c++ +enum class TargetType { kARM, kX86, kCUDA, kOpenCL }; +enum class PrecisionType { kFP32, kFP16, kInt8, kInt16 }; +enum class DataLayout { kNCHW, kNHWC }; +``` +--- + +注册 Kernel,确定特定 Kernel 的输入输出特征 + +```c++ +REGISTER_LITE_KERNEL( + mul, kARM, kFloat, kNCHW, arm::MulCompute, def) + .BindInput("X", {LiteType::GetTensorTy(kARM, kFloat, kNCHW)}) + .BindInput("Y", {LiteType::GetTensorTy(kARM, kFloat, kNCHW))}) + .BindOutput("Out", {LiteType::GetTensorTy(kARM, kFloat, kNCHW)}) + .Finalize(); +``` + +--- + +同一个 Op 的不同 Kernel 类似函数重载 + +用于支持任意的混合调度: + +1. 标记模型中所有 tensor 的 Type +2. 标记 Kernel 的 硬件、执行精度、data layout 等信息 + +全局做类型推断,当发现 tensor 传递中有类型冲突,采用 type cast 操作,通过插入特定功能 Op 来实现正确的传导 + +![lite-7](https://user-images.githubusercontent.com/52520497/64949642-395ecf00-d8ac-11e9-8b69-ced1996abc3b.png) + + + +--- + +## MIR 用于图分析优化 + +基于 Type System 的 SSA,通过 IR Pass 对计算图进行分析和优化: + +- 支持对整个 graph 进行类型推断,发现类型冲突并加入 type cast op,来支持通用混合调度 +- 计算剪枝 (Compute prune),比如去掉 scale(1), assign op 等 +- 存储优化 (Memory optimize) +- 操作熔合 (Operator fuse)(已经支持 fc, conv_bn, ele_add+act 等6种 fuse 策略) +- 支持量化处理(已支持 Int8预测) diff --git a/docs/introduction/tech_highlights.md b/docs/introduction/tech_highlights.md new file mode 100644 index 0000000000000000000000000000000000000000..83618aaa4bcbd9b7383782d193580e1d3dec7143 --- /dev/null +++ b/docs/introduction/tech_highlights.md @@ -0,0 +1,44 @@ +# 技术特点 + +不同于普通的移动端预测基于类 Caffe 的架构,Lite 架构最早的设计目标来源于 Paddle Server 和 Mobile 两种场景的要求,其中 Server 端需要有完善的图分析和优化能力,而 Mobile 端要求有轻量级部署的能力,两种场景共同的要求是高性能,多硬件支持等。 + +基于上述要求,Lite 架构完整实现了相应的能力,重点描述如下。 + +## 多硬件支持 + +Lite 架构已经验证和完整支持从 Mobile 到 Server 多种硬件的支持需求,包括 ARM CPU, ARM GPU, Huawei NPU, Intel X86 CPU, NV GPU 等。 得益于对不同硬件适度的抽象,在Lite 框架本身清晰的同时支持不同硬件的特殊调度需求,使得Lite架构在框架的清晰程度和硬件的特定调度优化上达到很好的平衡,比如 Nvidia GPU 上复杂的 stream, event 分配,在 Lite 中可以清晰表示。 + +多种硬件的 Kernel 在代码层和执行层均互不干扰,用户可以自由插拔任何硬件的支持。 + +## 高性能 + +高性能来源于两方面,一是 Kernel 优化;二是框架执行。 + +Kernel 方面,我们对相应硬件上的 Kernel 通过指令集、操作熔合、算法改写等方式进行了深入优化。 + +框架执行方面,通过简化 Op 和 Kernel 的功能,使得执行期的框架开销极低;此外,框架极大的灵活性可以支持各种硬件的特定调度优化以提升整体效率。 + +## 量化支持 + +Lite 支持Paddle Slim 强大的量化训练完毕的模型,因此完整保留了量化计算的高性能以及量化训练的高精度。 + +## 强大的图分析和优化能力 + +在图分析优化上,不同于常规的移动端预测引擎基于 Python 脚本工具转化模型, Lite 架构上有完整基于 C++ 开发的 IR 及相应 Pass 集合,以支持操作熔合 (Operator fusion),计算剪枝 (Computation pruning),存储优化 (Memory optimization),量化计算 (Quantitative computation) 等多类计算图优化。 + +更多的优化策略可以简单通过添加 Pass 的方式模块化支持。 + +## 轻量级部署 + +尽管图优化上有复杂的策略,但并不影响移动端的轻量级部署,图分析模块和最终的执行引擎可以拆开使用,最终部署只有一层薄薄的 Kernel 。 + +## 可支持任意硬件的混合调度 + +Lite 支持系统可见任意硬件的混合调度,目前已经支持 ARM CPU 和 ARM GPU 的 Kernel 自动混合调度,并验证了 X86 CPU 和 Nvidia GPU 间的混合调度。 + +支持混合调度的考量有两点: + +1. 当系统内同时存在多种硬件可用时,混合调度可以充分利用各类硬件资源 +2. 随着支持模型的增多,各硬件对kernel的支持丰富度不一,难免需要混合调度才能跑通 + +Lite架构通过从底层支持 `Type system` 的方式通用建模各类混合执行的行为,从而能够相对完备地支持混调。 diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 0000000000000000000000000000000000000000..7893348a1b7dbb588983a48e6991282eae7e1b55 --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=. +set BUILDDIR=_build + +if "%1" == "" goto help + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.http://sphinx-doc.org/ + exit /b 1 +) + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% + +:end +popd diff --git a/docs/paddle_mobile/index.rst b/docs/paddle_mobile/index.rst new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..f11fa32f6f465f7b002d7fd37cbd78203206d8d7 --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,4 @@ +sphinx +recommonmark +sphinx_markdown_tables +sphinx_rtd_theme diff --git a/docs/user_guides/cuda.md b/docs/user_guides/cuda.md new file mode 100644 index 0000000000000000000000000000000000000000..45597057bb18c44b60234459f9a49a59b54135f6 --- /dev/null +++ b/docs/user_guides/cuda.md @@ -0,0 +1,110 @@ +# Lite基于CUDA的模型预测 + +Lite支持在x86_64,arm64架构上(如:TX2)进行CUDA的编译运行。 + +## 编译 + +**NOTE:** 如果是在TX2等NVIDIA嵌入式硬件上编译,请使用最新的[Jetpack](https://developer.nvidia.com/embedded/jetpack) 安装依赖库。 + + +一: 下载代码 + +``` +git clone https://github.com/PaddlePaddle/Paddle-Lite.git +``` + +二:编译 + +``` +# 进入代码目录 +cd Paddle-Lite + +# 运行编译脚本 +# 编译结束会在本目录下生成 build_cuda 目录 +# 编译过程中如果提示找不到CUDA,CUDNN,请在环境变量设置CUDA_TOOLKIT_ROOT_DIR, CUDNN_ROOT +# CUDA_TOOLKIT_ROOT_DIR,CUDNN_ROOT分别表示CUDA,CUDNN的根目录 +./lite/tools/build.sh cuda +# 如果使用python接口,需要打开build_python选项 +./lite/tools/build.sh --build_python=ON cuda +``` + +编译结束会在 `build_cuda/inference_lite_lib/python/lib/` 目录下生成 `lite_core.so`。 + +## 运行 + +以下以Yolov3模型为例,介绍如何在Nvidia GPU硬件上运行模型。 + +一: 下载darknet_yolov3模型,模型信息请参考[这里](https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/yolov3) + + +``` +# 下载模型 +wget https://paddle-inference-dist.cdn.bcebos.com/PaddleLite/yolov3_infer.tar.gz +tar -zxf yolov3_infer.tar.gz +# 下载图片样例 +wget https://paddle-inference-dist.cdn.bcebos.com/PaddleLite/kite.jpg +``` + +二: 运行 + +**NOTE:**此处示例使用的是python接口,后续会开放C++接口以及示例。 + +``` python +#-*- coding: utf-8 -*- +from __future__ import print_function +import sys +import numpy as np +import cv2 +sys.path.append('build_cuda/inference_lite_lib/python/lib') +from lite_core import * + +def read_img(im_path, resize_h, resize_w): + im = cv2.imread(im_path).astype('float32') + im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + h, w, _ = im.shape + im_scale_x = resize_h / float(w) + im_scale_y = resize_w / float(h) + out_img = cv2.resize(im, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=cv2.INTER_CUBIC) + mean = np.array([0.485, 0.456, 0.406]).reshape((1, 1, -1)) + std = np.array([0.229, 0.224, 0.225]).reshape((1, 1, -1)) + out_img = (out_img / 255.0 - mean) / std + out_img = out_img.transpose((2, 0, 1)) + return out_img + +# 配置config +a = CxxConfig() +a.set_model_file('./yolov3_infer/__model__') # 指定模型文件路径 +a.set_param_file('./yolov3_infer/__params__') # 指定参数文件路径 +place_cuda = Place(TargetType.CUDA) +a.set_valid_places([place_cuda]) + +# 创建predictor +predictor = create_paddle_predictor(a) + +# 设置输入 +input_tensor = predictor.get_input(0); +height, width = 608, 608 +input_tensor.resize([1, 3, height, width]) +data = read_img('./kite.jpg', height, width).flatten() +input_tensor.set_float_data(data, TargetType.CUDA) + +in2 = predictor.get_input(1); +in2.resize([1, 2]) +in2.set_int32_data([height, width], TargetType.CUDA) + +# 运行 +predictor.run() + +# 获取输出 +output_tensor = predictor.get_output(0); + +print (output_tensor.shape()) +# [100L, 6L] +print (output_tensor.target()) +# TargetType.Host +print (output_tensor.float_data()[:6]) +# [0.0, 0.9862784743309021, 98.51927185058594, 471.2381286621094, 120.73092651367188, 578.33251953125] + +``` + +**NOTE:** 对CUDA的支持还在持续开发中。 diff --git a/docs/user_guides/index.rst b/docs/user_guides/index.rst new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/docs/user_guides/library_tailoring.md b/docs/user_guides/library_tailoring.md new file mode 100644 index 0000000000000000000000000000000000000000..5ba12cf819945ab2f182f672a2c96123bc12e070 --- /dev/null +++ b/docs/user_guides/library_tailoring.md @@ -0,0 +1,185 @@ + +# 裁剪预测库方法 + +Paddle-Lite支持**根据模型裁剪预测库**功能。Paddle-Lite的一般编译会将所有已注册的operator打包到预测库中,造成库文件体积膨胀;**裁剪预测库**能针对具体的模型,只打包优化后该模型需要的operator,有效降低预测库文件大小。 + +## 效果展示(Tiny_publish Android动态预测库体积) + +| 测试模型 | 裁剪开关 | **libpaddle_lite_jni.so** |转化后模型中的OP| +| ------------------ | ---------------------------- | -------- |------------------| +| mobilenetv1(armv8) | 裁剪前--build_tailor=OFF | 1.5M | feed,etch,conv2d,depthwise_conv2d,fc,fpool2d,softmax | +| mobilenetv1(armv8) | 裁剪后--build_tailor=ON | 788K |feed,etch,conv2d,depthwise_conv2d,fc,fpool2d,softmax| +| mobilenetv2(armv8) | 裁剪前--build_tailor=OFF | 1.5M | feed,fetch,conv2d,depthwise_conv2d,elementwise_add,fc,pool2d,relu6,softmax | +| mobilenetv2(armv8) | 裁剪后--build_tailor=ON | 912K |feed,fetch,conv2d,depthwise_conv2d,elementwise_add,fc,pool2d,relu6,softmax| +| mobilenetv1(armv7) | 裁剪前--build_tailor=OFF | 938K |feed,fetch,concat,conv2d,dropout,fc,pool2d,softmax| +| mobilenetv1(armv7) | 裁剪后--build_tailor=ON | 607K |feed,fetch,concat,conv2d,dropout,fc,pool2d,softmax| +| mobilenetv2(armv7) | 裁剪前--build_tailor=OFF | 938K | feed,fetch,conv2d,depthwise_conv2d,elementwise_add,fc,pool2d,relu6,softmax | +| mobilenetv2(armv7) | 裁剪后--build_tailor=ON |687K |feed,fetch,conv2d,depthwise_conv2d,elementwise_add,fc,pool2d,relu6,softmax| + + + + +## 实现过程: + + +### 1、转化模型时记录优化后模型信息 + +说明:使用model_optimize_tool转化模型时,选择 `--record_tailoring_info =true` 会将优化后模型的OP和kernel信息保存到输出文件夹,这些信息将用于编译裁剪后的动态库。 +注意:需要使用Paddle-Lite 最新版本(release/v2.0.0之后)代码编译出的model_optimize_tool +例如: + +```bash +./model_optimize_tool --model_dir=./mobilenet_v1 --optimize_out_type=naive_buffer --optimize_out=mobilenet_v1NB --record_tailoring_info =true --valid_targets=arm +``` +效果:优化后模型使用的OP和kernel信息被保存在 `mobilenet_v1NB`文件夹中的隐藏文件里了 + +### 2、根据模型信息编译裁剪后的预测库 + +说明:编译Paddle-Lite时选择`--build_tailor=ON` ,并且用 `–-opt_model_dir=` 指定优化后的模型的地址 +例如: + +```bash +./lite/tools/build.sh --arm_os=android --arm_abi=armv7 --arm_lang=gcc --android_stl=c++_static --build_extra=ON --build_tailor=ON --opt_model_dir=../mobilenet_v1NB full_publish +``` +**注意**:上面命令中的`../mobilenet_v1NB`是第1步得到的转化模型的输出路径 + +**效果**:编译出来的动态库文件变小,且可以运行优化后的模型。 + +编译出的C++预测库文件位于 : + +`build.lite.android.armv7.gcc/inference_lite_lib.android.armv7/cxx/lib/` + +编译出的Java预测库文件位于: + +`build.lite.android.armv7.gcc/inference_lite_lib.android.armv7/java/so/` + +### 3、运行裁剪后的预测库文件 + +注意:基于某一模型裁剪出的预测库只能支持优化工具转化后的该模型,例如根据mobilenetV1裁剪出的 full_api预测库只能运行以protobuf格式转化出的模型mobilenetV1_opt_nb, 裁剪出的light_api预测库只能运行以naive_buffer格式转化出的模型mobilenetV1_opt_nb, 运行其他模型可能会出现`segementation fault:undifined op or kernel`。 模型转化方法参考:[使用opt转化模型](./model_optimize_tool))。 + + + +**示例1**:使用裁剪后的light_api预测库运行mobilenetv1 + +1、执行第二步编译后,light_api的C++ 示例位于 + +`/Paddle-Lite/build.lite.android.armv7.gcc/inference_lite_lib.android.armv7/demo/cxx/mobile_light` + +输入`make`命令执行编译可编译出可执行文件mobilenetv1_light_api + +2、使用adb将mobilenetV1_NB模型和mobilenetv1_light_api传到手机后执行demo: + +`./mobilenetv1_light_api --model_dir=./mobilenetV1_NB` + +注意:`mobilenetV1_NB`是用`mobilenetV1`模型转化的naive_buffer格式模型(不需要设置` --record_tailoring_info =true`,转化流程参考:[使用opt转化模型](./model_optimize_tool))。 + + + +**示例2**:使用裁剪后的full_api预测库运行mobilenetv1 + +1、执行第二步编译后,full_api的C++ 示例位于 + +`/Paddle-Lite/build.lite.android.armv7.gcc/inference_lite_lib.android.armv7/demo/cxx/mobile_light` + +替换mobilenetv1_full_api.cc代码内容: + +```C++ +#include +#include +#include +#include "paddle_api.h" // NOLINT +#include "paddle_use_kernels.h" // NOLINT +#include "paddle_use_ops.h" // NOLINT +#include "paddle_use_passes.h" // NOLINT + +using namespace paddle::lite_api; // NOLINT + +DEFINE_string(model_dir, "", "Model dir path."); + +int64_t ShapeProduction(const shape_t& shape) { + int64_t res = 1; + for (auto i : shape) res *= i; + return res; +} + +void RunModel() { + // 1. Set CxxConfig + CxxConfig config; + config.set_model_file(FLAGS_model_dir + "model"); + config.set_param_file(FLAGS_model_dir + "params"); + + std::vector valid_places{Place{TARGET(kARM), PRECISION(kFloat)}}; + config.set_valid_places(valid_places); + + // 2. Create PaddlePredictor by CxxConfig + std::shared_ptr predictor = + CreatePaddlePredictor(config); + + // 3. Prepare input data + std::unique_ptr input_tensor(std::move(predictor->GetInput(0))); + input_tensor->Resize(shape_t({1, 3, 224, 224})); + auto* data = input_tensor->mutable_data(); + for (int i = 0; i < ShapeProduction(input_tensor->shape()); ++i) { + data[i] = 1; + } + + // 4. Run predictor + predictor->Run(); + + // 5. Get output + std::unique_ptr output_tensor( + std::move(predictor->GetOutput(0))); + printf("Output dim: %d\n", output_tensor->shape()[1]); + for (int i = 0; i < ShapeProduction(output_tensor->shape()); i += 100) { + printf("Output[%d]: %f\n", i, output_tensor->data()[i]); + } +} + +int main(int argc, char** argv) { + google::ParseCommandLineFlags(&argc, &argv, true); + RunModel(); + return 0; +} + +``` + +2、使用adb将mobilenetV1_PB模型和mobilenetv1_full_api传到手机后执行demo: + +`./mobilenetv1_full_api --model_dir=./mobilenetV1_PB` + +注意:`mobilenetV1_PB`是用`mobilenetV1`模型转化的protobuf格式模型(不需要设置` --record_tailoring_info =true`,转化流程参考:[使用opt转化模型](./model_optimize_tool))。 + +## 按模型集合裁剪预测库 + +为了方便用户使用,我们同时提供了按模型集合进行预测库裁剪的功能。用户可以提供一个模型集合,Model Optimize Tool会根据用户所指定的模型集合分析其**优化后的**模型所需要的算子信息对预测库进行裁剪。使用此功能用户根据自己的需要使用模型集合来对预测库中的算子进行任意裁剪。 + +使用方法如下所示: + +```shell +# 非combined模型集合 +./model_optimize_tool \ + --model_set_dir= \ + --optimize_out_type=naive_buffer \ + --optimize_out= \ + --record_tailoring_info=true \ + --valid_targets=arm + +# combined模型集合 +./model_optimize_tool \ + --model_set_dir= \ + --optimize_out_type=naive_buffer \ + --model_filename= \ + --param_filename= \ + --optimize_out= \ + --record_tailoring_info=true \ + --valid_targets=arm +``` + +经过以上步骤后会在``中生成模型集合中各模型对应的NaiveBuffer格式的优化模型。此步会对模型集合中所需算子信息进行搜集并存储到``中。下一步编译预测库的流程与使用单模型进行预测库裁剪步骤相同。 + +**注意:** + +1. 模型集合**必须**均为combined参数模型或均为非combined参数模型。 +2. 使用非combined参数模型时,模型拓扑文件名应为`__model__`,使用非combined参数模型时,集合中各模型的拓扑与参数名应相同,分别由`--model_filename`和`--param_filename`指定。 +3. 模型集合**必须**均为INT8量化模型或均为非INT8量化模型。 +4. 需要使用Paddle-Lite 最新版本(release/v2.1.0之后)代码编译出的model_optimize_tool。 diff --git a/docs/user_guides/model_optimize_tool.md b/docs/user_guides/model_optimize_tool.md new file mode 100644 index 0000000000000000000000000000000000000000..fccc6d8b23c78474257d11399d121816f57fc422 --- /dev/null +++ b/docs/user_guides/model_optimize_tool.md @@ -0,0 +1,161 @@ + +# 模型转化方法 + +Lite架构在预测过程中表现出来的高性能得益于其丰富的优化组件,其中包括量化、子图融合、混合调度、Kernel优选等等策略。为了使优化过程更加方便易用,我们提供了**opt**来自动完成优化步骤,输出一个轻量的、最优的可执行模型。具体使用方法介绍如下: + +**注意**:release/v2.2.0之前的模型转化工具名称为`model_optimize_tool`,从release/v2.3开始模型转化工具名称修改为`opt` + +## 准备opt +当前获得opt方法有三种: + +1. 我们提供当前develop分支编译结果下载:[opt](https://paddlelite-data.bj.bcebos.com/model_optimize_tool/opt)、[opt_mac](https://paddlelite-data.bj.bcebos.com/model_optimize_tool/opt_mac) +release/v2.2.0之前版本的model_optimize_tool: [model_optimize_tool](https://paddlelite-data.bj.bcebos.com/model_optimize_tool/model_optimize_tool)、[model_optimize_tool_mac](https://paddlelite-data.bj.bcebos.com/model_optimize_tool/model_optimize_tool_mac) + +2. 可以进入Paddle-Lite Github仓库的[release界面](https://github.com/PaddlePaddle/Paddle-Lite/releases),选择release版本下载对应的转化工具`opt` + (release/v2.2.0之前的转化工具为model_optimize_tool、release/v2.3.0之后为opt) + +3. 可以下载Paddle-Lite源码,从源码编译出opt工具 +```bash +git clone https://github.com/PaddlePaddle/Paddle-Lite.git +cd Paddle-Lite +git checkout +./lite/tools/build.sh build_optimize_tool +``` +编译结果位于`Paddle-Lite/build.opt/lite/api/opt` +**注意**:从源码编译opt前需要先[安装Paddle-Lite的开发环境](../installation/source_compile)。 + +## 使用opt + +opt是x86平台上的可执行文件,需要在PC端运行:包括Linux终端和Mac终端。 + +### 帮助信息 + 执行opt时不加入任何输入选项,会输出帮助信息,提示当前支持的选项: +```bash + ./opt +``` +![](https://paddlelite-data.bj.bcebos.com/doc_images/1.png) + +### 功能一:转化模型为Paddle-Lite格式 +opt可以将PaddlePaddle支持的模型转化为Paddle-Lite支持的模型格式,期间执行的操作包括:将protobuf格式的模型文件转化为naive_buffer格式的模型文件,有效降低模型体积;执行“量化、子图融合、混合调度、Kernel优选”等图优化操作,提升其在Paddle-Lite上的运行速度、内存占用等性能指标。 + +模型优化过程: + +(1)准备待优化的PaddlePaddle模型 + +PaddlePaddle模型有两种保存格式: + Combined Param:所有参数信息保存在单个文件`params`中,模型的拓扑信息保存在`__model__`文件中。 + +![opt_combined_model](https://paddlelite-data.bj.bcebos.com/doc_images%2Fcombined_model.png) + + Seperated Param:参数信息分开保存在多个参数文件中,模型的拓扑信息保存在`__model__`文件中。 +![opt_seperated_model](https://paddlelite-data.bj.bcebos.com/doc_images%2Fseperated_model.png) + +(2) 终端中执行`opt`优化模型 +**使用示例**:转化`mobilenet_v1`模型 + +``` +./opt --model_dir=./mobilenet_v1 --valid_targets=arm --optimize_out_type=naive_buffer --optimize_out=mobilenet_v1_opt +``` +以上命令可以将`mobilenet_v1`模型转化为arm硬件平台、naive_buffer格式的Paddle_Lite支持模型,优化后的模型文件为`mobilenet_v1_opt.nb`,转化结果如下图所示: + +![opt_resulted_model](https://paddlelite-data.bj.bcebos.com/doc_images/2.png) + + +(3) **更详尽的转化命令**总结: + +```shell +./opt \ + --model_dir= \ + --model_file= \ + --param_file= \ + --optimize_out_type=(protobuf|naive_buffer) \ + --optimize_out= \ + --valid_targets=(arm|opencl|x86|npu|xpu) \ + --prefer_int8_kernel=(true|false) \ + --record_tailoring_info =(true|false) +``` + +| 选项 | 说明 | +| ------------------- | ------------------------------------------------------------ | +| --model_dir | 待优化的PaddlePaddle模型(非combined形式)的路径 | +| --model_file | 待优化的PaddlePaddle模型(combined形式)的网络结构文件路径。 | +| --param_file | 待优化的PaddlePaddle模型(combined形式)的权重文件路径。 | +| --optimize_out_type | 输出模型类型,目前支持两种类型:protobuf和naive_buffer,其中naive_buffer是一种更轻量级的序列化/反序列化实现。若您需要在mobile端执行模型预测,请将此选项设置为naive_buffer。默认为protobuf。 | +| --optimize_out | 优化模型的输出路径。 | +| --valid_targets | 指定模型可执行的backend,默认为arm。目前可支持x86、arm、opencl、npu、xpu,可以同时指定多个backend(以空格分隔),Model Optimize Tool将会自动选择最佳方式。如果需要支持华为NPU(Kirin 810/990 Soc搭载的达芬奇架构NPU),应当设置为npu, arm。 | +| --prefer_int8_kernel | 若待优化模型为int8量化模型(如量化训练得到的量化模型),则设置该选项为true以使用int8内核函数进行推理加速,默认为false。 | +| --record_tailoring_info | 当使用 [根据模型裁剪库文件](./library_tailoring.html) 功能时,则设置该选项为true,以记录优化后模型含有的kernel和OP信息,默认为false。 | + +* 如果待优化的fluid模型是非combined形式,请设置`--model_dir`,忽略`--model_file`和`--param_file`。 +* 如果待优化的fluid模型是combined形式,请设置`--model_file`和`--param_file`,忽略`--model_dir`。 +* 优化后的模型包括__model__.nb和param.nb文件。 + +### 功能二:统计模型算子信息、判断是否支持 + +opt可以统计并打印出model中的算子信息、判断Paddle-Lite是否支持该模型。并可以打印出当前Paddle-Lite的算子支持情况。 + +(1)使用opt统计模型中算子信息 + +下面命令可以打印出mobilenet_v1模型中包含的所有算子,并判断在硬件平台`valid_targets`下Paddle-Lite是否支持该模型 + +`./opt --print_model_ops=true --model_dir=mobilenet_v1 --valid_targets=arm` + +![opt_print_modelops](https://paddlelite-data.bj.bcebos.com/doc_images/3.png) + +(2)使用opt打印当前Paddle-Lite支持的算子信息 + +`./opt --print_all_ops=true` + +以上命令可以打印出当前Paddle-Lite支持的所有算子信息,包括OP的数量和每个OP支持哪些硬件平台: + +![opt_print_allops](https://paddlelite-data.bj.bcebos.com/doc_images/4.png) + +`./opt ----print_supported_ops=true --valid_targets=x86` + +以上命令可以打印出当`valid_targets=x86`时Paddle-Lite支持的所有OP: + +![opt_print_supportedops](https://paddlelite-data.bj.bcebos.com/doc_images/5.png) + +## 其他功能:合并x2paddle和opt的一键脚本 + +**背景**:如果想用Paddle-Lite运行第三方来源(tensorflow、caffe、onnx)模型,一般需要经过两次转化。即使用x2paddle工具将第三方模型转化为PaddlePaddle格式,再使用opt将PaddlePaddle模型转化为Padde-Lite可支持格式。 +为了简化这一过程,我们提供一键脚本,将x2paddle转化和opt转化合并: + +**一键转化脚本**:[auto_transform.sh](https://paddlelite-data.bj.bcebos.com/model_optimize_tool/auto_transform.sh) + + +**环境要求**:使用`auto_transform.sh`脚本转化第三方模型时,需要先安装x2paddle环境,请参考[x2paddle环境安装方法](https://github.com/PaddlePaddle/X2Paddle#环境依赖) 安装x2paddle和其环境依赖项。 + +**使用方法**: + +(1)打印帮助帮助信息:` ./auto_transform.sh` + +(2)转化模型方法 + +```bash +USAGE: + auto_transform.sh combines the function of x2paddle and opt, it can + tranform model from tensorflow/caffe/onnx form into paddle-lite naive-buffer form. +---------------------------------------- +example: + ./auto_transform.sh --framework=tensorflow --model=tf_model.pb --optimize_out=opt_model_result +---------------------------------------- +Arguments about x2paddle: + --framework=(tensorflow|caffe|onnx); + --model='model file for tensorflow or onnx'; + --prototxt='proto file for caffe' --weight='weight file for caffe' + For TensorFlow: + --framework=tensorflow --model=tf_model.pb + + For Caffe: + --framework=caffe --prototxt=deploy.prototxt --weight=deploy.caffemodel + + For ONNX + --framework=onnx --model=onnx_model.onnx + +Arguments about opt: + --valid_targets=(arm|opencl|x86|npu|xpu); valid targets on Paddle-Lite. + --fluid_save_dir='path to outputed model after x2paddle' + --optimize_out='path to outputed Paddle-Lite model' +---------------------------------------- +``` diff --git a/docs/user_guides/opencl.md b/docs/user_guides/opencl.md new file mode 100644 index 0000000000000000000000000000000000000000..e9533af1ff6e2447a8e4d389df90cdb457f58fb2 --- /dev/null +++ b/docs/user_guides/opencl.md @@ -0,0 +1,242 @@ +# Lite基于OpenCL的ARM GPU预测 + +Lite支持在Android系统上运行基于OpenCL的程序,目前支持Ubuntu环境下armv8、armv7的交叉编译。 + +## 编译 + +### 编译环境 + +1. Docker 容器环境; +2. Linux(推荐 Ubuntu 16.04)环境。 + +详见 **源码编译指南-环境准备** 章节。 + +### 编译选项 + +|参数|介绍|值| +|--------|--------|--------| +|--arm_os|代表目标操作系统|目前仅支持且默认为`android`| +|--arm_abi|代表体系结构类型,支持armv8和armv7|默认为`armv8`即arm64-v8a;`armv7`即armeabi-v7a| +|--arm_lang|代表编译目标文件所使用的编译器|默认为gcc,支持 gcc和clang两种| + +### 编译Paddle-Lite OpenCL库范例 + +注:以android-armv8-opencl的目标、Docker容器的编译开发环境为例,CMake3.10,android-ndk-r17c位于`/opt/`目录下。 + +```bash +# 假设当前位于处于Lite源码根目录下 + +# 导入NDK_ROOT变量,注意检查您的安装目录若与本示例不同 +export NDK_ROOT=/opt/android-ndk-r17c + +# 删除上一次CMake自动生成的.h文件 +rm ./lite/api/paddle_use_kernels.h +rm ./lite/api/paddle_use_ops.h + +# 根据指定编译参数编译 +./lite/tools/ci_build.sh \ + --arm_os=android \ + --arm_abi=armv8 \ + --arm_lang=gcc \ + build_test_arm_opencl +``` + +编译产物位于`build.lite.android.armv8.gcc.opencl`下的`inference_lite_lib.android.armv8.opencl`文件夹内,这里仅罗列关键产物: + +- `cxx`:该目录是编译目标的C++的头文件和库文件; +- `demo`:该目录包含了两个demo,用来调用使用`libpaddle_api_full_bundled.a`和`libpaddle_api_light_bundled.a`,分别对应`mobile_full`和`mobile_light`文件夹。编译对应的demo仅需在`mobile_full`或`mobile_light`文 + - `mobile_full`:使用cxx config,可直接加载fluid模型,若使用OpenCL需要在`mobilenetv1_full_api.cc`代码里开启`DEMO_USE_OPENCL`的宏,详细见代码注释; + - `mobile_light`:使用mobile config,只能加载`model_optimize_tool`优化过的模型; +- `opencl`:该目录存放opencl实现的相关kernel。 + +```bash +. +|-- cxx +| |-- include +| | |-- paddle_api.h +| | |-- paddle_image_preprocess.h +| | |-- paddle_lite_factory_helper.h +| | |-- paddle_place.h +| | |-- paddle_use_kernels.h +| | |-- paddle_use_ops.h +| | `-- paddle_use_passes.h +| `-- lib +| |-- libpaddle_api_full_bundled.a +| |-- libpaddle_api_light_bundled.a +| |-- libpaddle_full_api_shared.so +| `-- libpaddle_light_api_shared.so +|-- demo +| `-- cxx +| |-- Makefile.def +| |-- README.md +| |-- include +| | |-- paddle_api.h +| | |-- paddle_lite_factory_helper.h +| | |-- paddle_place.h +| | |-- paddle_use_kernels.h +| | |-- paddle_use_ops.h +| | `-- paddle_use_passes.h +| |-- mobile_full +| | |-- Makefile +| | `-- mobilenetv1_full_api.cc +| `-- mobile_light +| |-- Makefile +| `-- mobilenetv1_light_api.cc +`-- opencl + `-- cl_kernel + |-- buffer + | |-- depthwise_conv2d_kernel.cl + | |-- elementwise_add_kernel.cl + | |-- fc_kernel.cl + | |-- im2col_kernel.cl + | |-- layout_kernel.cl + | |-- mat_mul_kernel.cl + | |-- pool_kernel.cl + | `-- relu_kernel.cl + |-- cl_common.h + `-- image + |-- channel_add_kernel.cl + |-- elementwise_add_kernel.cl + |-- pool_kernel.cl + `-- relu_kernel.cl +``` + +调用`libpaddle_api_full_bundled.a`和`libpaddle_api_light_bundled.a`见下一部分运行示例。 + + + +## 运行示例 + +下面以android、ARMv8、gcc的环境为例,介绍3个示例,分别如何在手机上执行基于OpenCL的ARM GPU推理过程。 + + +**注意:** 以下命令均在Lite源码根目录下运行。在3个示例前,下面这段命令都先要执行用来准备环境: + +```bash +# 在/data/local/tmp目录下创建OpenCL文件目录 +adb shell mkdir -p /data/local/tmp/opencl +adb shell mkdir -p /data/local/tmp/opencl/cl_kernel/buffer +adb shell mkdir -p /data/local/tmp/opencl/cl_kernel/image + +# 将OpenCL的kernels文件推送到/data/local/tmp/opencl目录下 +adb push lite/backends/opencl/cl_kernel/cl_common.h /data/local/tmp/opencl/cl_kernel/ +adb push lite/backends/opencl/cl_kernel/buffer/* /data/local/tmp/opencl/cl_kernel/buffer/ +adb push lite/backends/opencl/cl_kernel/image/* /data/local/tmp/opencl/cl_kernel/image/ +``` + +### 运行示例1: 编译产物demo示例 + +```bash +###################################################################### +# 编译mobile_full的demo # +###################################################################### +# 步骤: # +# 0.确保编译Paddle-Lite时编译了OpenCL; # +# 1.编辑`mobilenetv1_full_api.cc`代码, 开启`DEMO_USE_OPENCL`的宏; # +# 2.在产物目录`demo/cxx/mobile_full`下编译`mobile_full`的demo; # +# 3.上传demo, 模型, opencl kernel文件到手机; # +# 4.运行demo得到预期结果. # +###################################################################### +adb shell mkdir /data/local/tmp/opencl/mobilenet_v1 +chmod +x ./build.lite.android.armv8.gcc.opencl/inference_lite_lib.android.armv8.opencl/demo/cxx/mobile_full/mobilenetv1_full_api +adb push ./build.lite.android.armv8.gcc.opencl/inference_lite_lib.android.armv8.opencl/demo/cxx/mobile_full/mobilenetv1_full_api /data/local/tmp/opencl/ +adb push ./build.lite.android.armv8.gcc.opencl/install/mobilenet_v1/* /data/local/tmp/opencl/mobilenet_v1 + +# use mobile_full run mobilenet_v1 +# `GLOG_v` is log level +adb shell "export GLOG_v=0; \ + /data/local/tmp/opencl/mobilenetv1_full_api \ + --model_dir=/data/local/tmp/opencl/mobilenet_v1 \ + --optimized_model_dir=/data/local/tmp/opencl/full_api_opt_model" + + + +###################################################################### +# 编译mobile_light的demo # +###################################################################### +# 步骤: # +# 0.确保编译Paddle-Lite时编译了OpenCL; # +# 1.编译model_optimize_tool并对模型优化, `targets`参数为`opencl`; # +# 2.在产物目录`demo/cxx/mobile_light`下编译`mobile_light`的demo; # +# 3.上传demo, 模型, opencl kernel文件到手机; # +# 4.运行demo得到预期结果. # +###################################################################### + +# use model_optimize_tool to optimize model +./build.model_optimize_tool/lite/api/model_optimize_tool \ + --model_dir=./build.lite.android.armv8.gcc.opencl/install/mobilenet_v1/ \ + --optimize_out_type=naive_buffer \ + --optimize_out=./build.lite.android.armv8.gcc.opencl/install/mobilenet_v1/ \ + --valid_targets=opencl + +adb shell mkdir /data/local/tmp/opencl/mobilenet_v1 +chmod +x ./build.lite.android.armv8.gcc.opencl/inference_lite_lib.android.armv8.opencl/demo/cxx/mobile_light/mobilenetv1_light_api +adb push ./build.lite.android.armv8.gcc.opencl/inference_lite_lib.android.armv8.opencl/demo/cxx/mobile_light/mobilenetv1_light_api /data/local/tmp/opencl/ +adb push ./build.lite.android.armv8.gcc.opencl/install/mobilenet_v1/* /data/local/tmp/opencl/mobilenet_v1 + +# use mobile_light run mobilenet_v1 +adb shell "export GLOG_v=5; \ + /data/local/tmp/opencl/mobilenetv1_light_api \ + --model_dir=/data/local/tmp/opencl/" +``` + +### 运行示例2: test_mobilenetv1单元测试 + +- **运行文件准备** + +```bash +# 将mobilenet_v1的模型文件推送到/data/local/tmp/opencl目录下 +adb shell mkdir -p /data/local/tmp/opencl/mobilenet_v1 +adb push build.lite.android.armv8.gcc.opencl/third_party/install/mobilenet_v1/* /data/local/tmp/opencl/mobilenet_v1/ + +# 将OpenCL单元测试程序test_mobilenetv1,推送到/data/local/tmp/opencl目录下 +adb push build.lite.android.armv8.gcc.opencl/lite/api/test_mobilenetv1 /data/local/tmp/opencl +``` + +- **执行OpenCL推理过程** + +使用如下命令运行OpenCL程序。其中: + +- `--cl_path`指定了OpenCL的kernels文件即cl\_kernel所在目录; +- `--modle_dir`指定了模型文件所在目录。 + +```bash +adb shell chmod +x /data/local/tmp/opencl/test_mobilenetv1 + +adb shell /data/local/tmp/opencl/test_mobilenetv1 \ + --cl_path=/data/local/tmp/opencl \ + --model_dir=/data/local/tmp/opencl/mobilenet_v1 \ + --warmup=1 \ + --repeats=1 +``` + +**注意:** 因为权重参数均会在Op Kernel第一次运行时进行加载,所以第一次的执行时间会略长。一般将warmup的值设为1,repeats值设为多次。 + +### 运行示例3: test_layout_opencl单元测试 + +- **运行文件准备** + +```bash +# 将OpenCL单元测试程序test_layout_opencl,推送到/data/local/tmp/opencl目录下 +adb push build.lite.android.armv8.gcc.opencl/lite/kernels/opencl/test_layout_opencl /data/local/tmp/opencl/ +``` + + +OpenCL推理过程** + +```bash +adb shell chmod +x /data/local/tmp/opencl/test_layout_opencl +adb shell /data/local/tmp/opencl/test_layout_opencl +``` + + +# 如何在Code中使用 + +见运行示例1的demo代码: + +1. [./lite/demo/cxx/mobile_light/mobilenetv1_light_api.cc](https://github.com/PaddlePaddle/Paddle-Lite/blob/develop/lite/demo/cxx/mobile_light/mobilenetv1_light_api.cc); +2. [./lite/demo/cxx/mobile_full/mobilenetv1_full_api.cc](https://github.com/PaddlePaddle/Paddle-Lite/blob/develop/lite/demo/cxx/mobile_full/mobilenetv1_full_api.cc). + +注:这里给出的链接会跳转到线上最新develop分支的代码,很可能与您本地的代码存在差异,建议参考自己本地位于`lite/demo/cxx/`目录的代码,查看如何使用。 + +**NOTE:** 对OpenCL的支持还在持续开发中。 diff --git a/lite/CMakeLists.txt b/lite/CMakeLists.txt index 61f07583b2ed920ce7ac0f2d56b2b2e89bb99b42..bac6f80c4721e0c5de201eebfe7e6a39a0bdc73a 100644 --- a/lite/CMakeLists.txt +++ b/lite/CMakeLists.txt @@ -5,9 +5,11 @@ message(STATUS "LIGHT_FRAMEWORK:\t${LITE_WITH_LIGHT_WEIGHT_FRAMEWORK}") message(STATUS "LITE_WITH_CUDA:\t${LITE_WITH_CUDA}") message(STATUS "LITE_WITH_X86:\t${LITE_WITH_X86}") message(STATUS "LITE_WITH_ARM:\t${LITE_WITH_ARM}") +message(STATUS "LITE_WITH_OPENCL:\t${LITE_WITH_OPENCL}") message(STATUS "LITE_WITH_NPU:\t${LITE_WITH_NPU}") message(STATUS "LITE_WITH_XPU:\t${LITE_WITH_XPU}") message(STATUS "LITE_WITH_FPGA:\t${LITE_WITH_FPGA}") +message(STATUS "LITE_WITH_BM:\t${LITE_WITH_BM}") message(STATUS "LITE_WITH_PROFILE:\t${LITE_WITH_PROFILE}") message(STATUS "LITE_WITH_CV:\t${LITE_WITH_CV}") @@ -65,6 +67,9 @@ if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM) if (LITE_WITH_FPGA) set(INFER_LITE_PUBLISH_ROOT "${INFER_LITE_PUBLISH_ROOT}.fpga") endif(LITE_WITH_FPGA) + if (LITE_WITH_BM) + set(INFER_LITE_PUBLISH_ROOT "${INFER_LITE_PUBLISH_ROOT}.bm") + endif(LITE_WITH_BM) else() set(INFER_LITE_PUBLISH_ROOT "${CMAKE_BINARY_DIR}/inference_lite_lib") endif() @@ -160,7 +165,7 @@ if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM) COMMAND mkdir -p "${INFER_LITE_PUBLISH_ROOT}/include" COMMAND cp "${CMAKE_SOURCE_DIR}/lite/api/paddle_*.h" "${INFER_LITE_PUBLISH_ROOT}/include" COMMAND cp "${CMAKE_BINARY_DIR}/libpaddle_api_light_bundled.a" "${INFER_LITE_PUBLISH_ROOT}/lib" - COMMAND cp "${CMAKE_SOURCE_DIR}/lite/utils/cv/paddle_*.h" "${INFER_LITE_PUBLISH_ROOT}/cxx/include" + COMMAND cp "${CMAKE_SOURCE_DIR}/lite/utils/cv/paddle_*.h" "${INFER_LITE_PUBLISH_ROOT}/include" ) add_dependencies(tiny_publish_lib bundle_light_api) add_dependencies(publish_inference tiny_publish_lib) @@ -171,11 +176,17 @@ if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM) COMMAND mkdir -p "${INFER_LITE_PUBLISH_ROOT}/cxx/include" COMMAND mkdir -p "${INFER_LITE_PUBLISH_ROOT}/cxx/lib" COMMAND cp "${CMAKE_SOURCE_DIR}/lite/api/paddle_*.h" "${INFER_LITE_PUBLISH_ROOT}/cxx/include" + COMMAND cp "${CMAKE_BINARY_DIR}/libpaddle_api_light_bundled.a" "${INFER_LITE_PUBLISH_ROOT}/cxx/lib" COMMAND cp "${CMAKE_BINARY_DIR}/lite/api/libpaddle_light_api_shared.so" "${INFER_LITE_PUBLISH_ROOT}/cxx/lib" COMMAND cp "${CMAKE_SOURCE_DIR}/lite/utils/cv/paddle_*.h" "${INFER_LITE_PUBLISH_ROOT}/cxx/include" ) add_dependencies(tiny_publish_cxx_lib paddle_light_api_shared) + add_dependencies(tiny_publish_cxx_lib bundle_light_api) add_dependencies(publish_inference tiny_publish_cxx_lib) + if(NOT "${CMAKE_BUILD_TYPE}" STREQUAL "Debug") + add_custom_command(TARGET tiny_publish_cxx_lib POST_BUILD + COMMAND ${CMAKE_STRIP} "-s" ${INFER_LITE_PUBLISH_ROOT}/cxx/lib/libpaddle_light_api_shared.so) + endif() endif() endif() endif() @@ -213,7 +224,16 @@ if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM) COMMAND cp "${CMAKE_SOURCE_DIR}/lite/demo/cxx/makefiles/mobile_full/Makefile.${ARM_TARGET_OS}.${ARM_TARGET_ARCH_ABI}" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx/mobile_full/Makefile" COMMAND cp -r "${CMAKE_SOURCE_DIR}/lite/demo/cxx/mobile_light" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx" COMMAND cp "${CMAKE_SOURCE_DIR}/lite/demo/cxx/makefiles/mobile_light/Makefile.${ARM_TARGET_OS}.${ARM_TARGET_ARCH_ABI}" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx/mobile_light/Makefile" - COMMAND cp "${CMAKE_SOURCE_DIR}/lite/api/paddle_*.h" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx/include" + COMMAND cp -r "${CMAKE_SOURCE_DIR}/lite/demo/cxx/ssd_detection" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx" + COMMAND cp "${CMAKE_SOURCE_DIR}/lite/demo/cxx/makefiles/ssd_detection/Makefile.${ARM_TARGET_OS}.${ARM_TARGET_ARCH_ABI}" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx/ssd_detection/Makefile" + COMMAND cp -r "${CMAKE_SOURCE_DIR}/lite/demo/cxx/yolov3_detection" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx" + COMMAND cp "${CMAKE_SOURCE_DIR}/lite/demo/cxx/makefiles/yolov3_detection/Makefile.${ARM_TARGET_OS}.${ARM_TARGET_ARCH_ABI}" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx/yolov3_detection/Makefile" + COMMAND cp -r "${CMAKE_SOURCE_DIR}/lite/demo/cxx/mobile_classify" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx" + COMMAND cp "${CMAKE_SOURCE_DIR}/lite/demo/cxx/makefiles/mobile_classify/Makefile.${ARM_TARGET_OS}.${ARM_TARGET_ARCH_ABI}" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx/mobile_classify/Makefile" + COMMAND cp -r "${CMAKE_SOURCE_DIR}/lite/demo/cxx/test_cv" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx" + COMMAND cp "${CMAKE_SOURCE_DIR}/lite/demo/cxx/makefiles/test_cv/Makefile.${ARM_TARGET_OS}.${ARM_TARGET_ARCH_ABI}" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx/test_cv/Makefile" + COMMAND cp -r "${CMAKE_SOURCE_DIR}/lite/demo/cxx/mask_detection" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx" + COMMAND cp "${CMAKE_SOURCE_DIR}/lite/demo/cxx/makefiles/mask_detection/Makefile.${ARM_TARGET_OS}.${ARM_TARGET_ARCH_ABI}" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx/mask_detection/Makefile" ) add_dependencies(publish_inference_android_cxx_demos logging gflags) add_dependencies(publish_inference_cxx_lib publish_inference_android_cxx_demos) @@ -225,6 +245,16 @@ if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM) COMMAND cp "${CMAKE_SOURCE_DIR}/lite/demo/cxx/README.md" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx" COMMAND cp -r "${CMAKE_SOURCE_DIR}/lite/demo/cxx/mobile_light" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx" COMMAND cp "${CMAKE_SOURCE_DIR}/lite/demo/cxx/makefiles/mobile_light/Makefile.${ARM_TARGET_OS}.${ARM_TARGET_ARCH_ABI}" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx/mobile_light/Makefile" + COMMAND cp -r "${CMAKE_SOURCE_DIR}/lite/demo/cxx/ssd_detection" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx" + COMMAND cp "${CMAKE_SOURCE_DIR}/lite/demo/cxx/makefiles/ssd_detection/Makefile.${ARM_TARGET_OS}.${ARM_TARGET_ARCH_ABI}" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx/ssd_detection/Makefile" + COMMAND cp -r "${CMAKE_SOURCE_DIR}/lite/demo/cxx/yolov3_detection" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx" + COMMAND cp "${CMAKE_SOURCE_DIR}/lite/demo/cxx/makefiles/yolov3_detection/Makefile.${ARM_TARGET_OS}.${ARM_TARGET_ARCH_ABI}" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx/yolov3_detection/Makefile" + COMMAND cp -r "${CMAKE_SOURCE_DIR}/lite/demo/cxx/mobile_classify" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx" + COMMAND cp "${CMAKE_SOURCE_DIR}/lite/demo/cxx/makefiles/mobile_classify/Makefile.${ARM_TARGET_OS}.${ARM_TARGET_ARCH_ABI}" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx/mobile_classify/Makefile" + COMMAND cp -r "${CMAKE_SOURCE_DIR}/lite/demo/cxx/test_cv" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx" + COMMAND cp "${CMAKE_SOURCE_DIR}/lite/demo/cxx/makefiles/test_cv/Makefile.${ARM_TARGET_OS}.${ARM_TARGET_ARCH_ABI}" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx/test_cv/Makefile" + COMMAND cp -r "${CMAKE_SOURCE_DIR}/lite/demo/cxx/mask_detection" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx" + COMMAND cp "${CMAKE_SOURCE_DIR}/lite/demo/cxx/makefiles/mask_detection/Makefile.${ARM_TARGET_OS}.${ARM_TARGET_ARCH_ABI}" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx/mask_detection/Makefile" ) add_dependencies(tiny_publish_cxx_lib publish_inference_android_cxx_demos) endif() diff --git a/lite/api/CMakeLists.txt b/lite/api/CMakeLists.txt index 63d53869ea530212ea03b24ef746d980fd13a19b..f7f74ab5822a1305e3e8d24cf36a0a458a6494ff 100644 --- a/lite/api/CMakeLists.txt +++ b/lite/api/CMakeLists.txt @@ -16,32 +16,40 @@ if ((NOT LITE_ON_TINY_PUBLISH) AND (LITE_WITH_CUDA OR LITE_WITH_X86 OR ARM_TARGE add_dependencies(paddle_full_api_shared op_list_h kernel_list_h framework_proto) target_link_libraries(paddle_full_api_shared framework_proto) if(LITE_WITH_X86) - add_dependencies(paddle_full_api_shared xxhash) - target_link_libraries(paddle_full_api_shared xxhash) + add_dependencies(paddle_full_api_shared xxhash) + target_link_libraries(paddle_full_api_shared xxhash) + if (NOT LITE_ON_MODEL_OPTIMIZE_TOOL) + add_dependencies(paddle_full_api_shared dynload_mklml) + endif() endif() if(LITE_WITH_CUDA) target_link_libraries(paddle_full_api_shared ${math_cuda} "-Wl,--whole-archive" ${cuda_kernels} "-Wl,--no-whole-archive") - endif(LITE_WITH_CUDA) + endif(LITE_WITH_CUDA) + #light api dynamic library lite_cc_library(paddle_light_api_shared MODULE - SRCS light_api_shared.cc - DEPS ${light_lib_DEPS} - ARM_DEPS ${arm_kernels} NPU_DEPS ${npu_kernels}) + SRCS light_api_shared.cc + DEPS ${light_lib_DEPS} + ARM_DEPS ${arm_kernels} + CV_DEPS paddle_cv_arm + NPU_DEPS ${npu_kernels}) + target_link_libraries(paddle_light_api_shared ${light_lib_DEPS} ${arm_kernels} ${npu_kernels}) - if (LITE_WITH_NPU) - # Strips the symbols of our protobuf functions to fix the conflicts during - # loading HIAI builder libs (libhiai_ir.so and libhiai_ir_build.so) - set(LINK_FLAGS "-Wl,--version-script ${PADDLE_SOURCE_DIR}/lite/core/lite.map") - set_target_properties(paddle_light_api_shared PROPERTIES LINK_FLAGS "${LINK_FLAGS}") - endif() + set(LINK_MAP_FILE "${PADDLE_SOURCE_DIR}/lite/core/lite.map") + set(LINK_FLAGS "-Wl,--version-script ${LINK_MAP_FILE}") + add_custom_command(OUTPUT ${LINK_MAP_FILE} COMMAND ...) + add_custom_target(custom_linker_map DEPENDS ${LINK_MAP_FILE}) + set_target_properties(paddle_full_api_shared PROPERTIES LINK_FLAGS ${LINK_FLAGS}) + add_dependencies(paddle_full_api_shared custom_linker_map) else() if ((ARM_TARGET_OS STREQUAL "android") OR (ARM_TARGET_OS STREQUAL "armlinux")) add_library(paddle_light_api_shared SHARED "") target_sources(paddle_light_api_shared PUBLIC ${__lite_cc_files} paddle_api.cc light_api.cc light_api_impl.cc) + set_target_properties(paddle_light_api_shared PROPERTIES COMPILE_FLAGS "-flto -fdata-sections") add_dependencies(paddle_light_api_shared op_list_h kernel_list_h) if (LITE_WITH_NPU) # Need to add HIAI runtime libs (libhiai.so) dependency - target_link_libraries(paddle_light_api_shared ${npu_runtime_libs}) + target_link_libraries(paddle_light_api_shared ${npu_builder_libs} ${npu_runtime_libs}) endif() endif() endif() @@ -52,13 +60,19 @@ if (WITH_TESTING) ${ops} ${host_kernels} CUDA_DEPS ${cuda_kernels} X86_DEPS ${x86_kernels} - XPU_DEPS ${xpu_kernels}) + XPU_DEPS ${xpu_kernels} + BM_DEPS ${bm_kernels}) endif() if(LITE_WITH_FPGA) set(light_api_deps ${light_api_deps} ${fpga_deps}) set(cxx_api_deps ${cxx_api_deps} ${fpga_deps}) endif() +if(LITE_WITH_BM) + set(light_api_deps ${light_api_deps} ${bm_deps}) + set(cxx_api_deps ${cxx_api_deps} ${bm_deps}) +endif() + message(STATUS "get ops ${ops}") message(STATUS "get X86 kernels ${x86_kernels}") message(STATUS "get CUDA kernels ${cuda_kernels}") @@ -67,29 +81,32 @@ message(STATUS "get ARM kernels ${arm_kernels}") message(STATUS "get NPU kernels ${npu_kernels}") message(STATUS "get XPU kernels ${xpu_kernels}") message(STATUS "get FPGA kernels ${fpga_kernels}") +message(STATUS "get BM kernels ${bm_kernels}") # for full api if (NOT LITE_ON_TINY_PUBLISH) set(cxx_api_deps - scope optimizer target_wrapper_host model_parser program) + scope optimizer target_wrapper_host model_parser program) lite_cc_library(cxx_api - SRCS cxx_api.cc - DEPS ${cxx_api_deps} ${ops} ${host_kernels} program - X86_DEPS ${x86_kernels} - ARM_DEPS ${arm_kernels} - NPU_DEPS ${npu_kernels} ${npu_bridges} npu_pass - XPU_DEPS ${xpu_kernels} ${xpu_bridges} xpu_pass - CL_DEPS ${opencl_kernels} - FPGA_DEPS ${fpga_kernels}) + SRCS cxx_api.cc + DEPS ${cxx_api_deps} ${ops} ${host_kernels} program + X86_DEPS ${x86_kernels} + CUDA_DEPS ${cuda_kernels} + ARM_DEPS ${arm_kernels} + CV_DEPS paddle_cv_arm + NPU_DEPS ${npu_kernels} + XPU_DEPS ${xpu_kernels} + BM_DEPS ${bm_kernels} + CL_DEPS ${opencl_kernels} + FPGA_DEPS ${fpga_kernels}) endif() # for light api set(light_api_deps scope target_wrapper_host model_parser program) if(LITE_WITH_CUDA) + get_property(cuda_deps GLOBAL PROPERTY CUDA_MODULES) set(light_api_deps ${light_api_deps} target_wrapper_cuda) - set(cuda_static_deps cudart_static cublas_static curand_static - cudnn_static culibos_static) endif() lite_cc_library(light_api SRCS light_api.cc DEPS scope target_wrapper_host model_parser @@ -97,10 +114,12 @@ lite_cc_library(light_api SRCS light_api.cc CUDA_DEPS ${cuda_kernels} X86_DEPS ${x86_kernels} ARM_DEPS ${arm_kernels} + CV_DEPS paddle_cv_arm NPU_DEPS ${npu_kernels} XPU_DEPS ${xpu_kernels} CL_DEPS ${opencl_kernels} - FPGA_DEPS ${fpga_kernels}) + FPGA_DEPS ${fpga_kernels} + BM_DEPS ${bm_kernels}) include(ExternalProject) set(LITE_DEMO_INSTALL_DIR "${THIRD_PARTY_PATH}/inference_demo" CACHE STRING @@ -111,11 +130,14 @@ if(WITH_TESTING) DEPS cxx_api mir_passes lite_api_test_helper ${ops} ${host_kernels} X86_DEPS ${x86_kernels} + CUDA_DEPS ${cuda_kernels} ARM_DEPS ${arm_kernels} + CV_DEPS paddle_cv_arm NPU_DEPS ${npu_kernels} XPU_DEPS ${xpu_kernels} CL_DEPS ${opencl_kernels} FPGA_DEPS ${fpga_kernels} + BM_DEPS ${bm_kernels} EXCLUDE_COMPILE_DEPS "ON" ARGS --model_dir=${LITE_MODEL_DIR}/lite_naive_model --optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL) @@ -151,6 +173,12 @@ if(WITH_TESTING) ${ops} ${host_kernels} ${x86_kernels} ARGS --model_dir=${LITE_MODEL_DIR}/step_rnn) add_dependencies(test_step_rnn_lite_x86 extern_lite_download_step_rnn_tar_gz) + if(LITE_WITH_BM) + lite_cc_test(test_resnet50_lite_bm SRCS test_resnet50_lite_bm.cc + DEPS mir_passes lite_api_test_helper paddle_api_full paddle_api_light gflags utils + ${ops} ${host_kernels} ${bm_kernels} ${bm_bridges} + ARGS --model_dir=${LITE_MODEL_DIR}/resnet50) + endif() endif() endif() @@ -221,16 +249,18 @@ else() endif() if (NOT LITE_ON_TINY_PUBLISH) lite_cc_library(paddle_api_full SRCS cxx_api_impl.cc DEPS cxx_api paddle_api_light - ${ops} - ARM_DEPS ${arm_kernels} - NPU_DEPS ${npu_kernels} - CL_DEPS ${opencl_kernels} - FPGA_DEPS ${fpga_kernels}) + ${ops} + ARM_DEPS ${arm_kernels} + CV_DEPS paddle_cv_arm + NPU_DEPS ${npu_kernels} + CL_DEPS ${opencl_kernels} + FPGA_DEPS ${fpga_kernels}) # The final inference library for just MobileConfig. bundle_static_library(paddle_api_full paddle_api_full_bundled bundle_full_api) + target_link_libraries(paddle_api_full ${cuda_deps}) get_property(fluid_modules GLOBAL PROPERTY FLUID_MODULES) - cc_library(api_full_static SRCS DEPS paddle_api_full cxx_api paddle_api light_api ${cxx_api_deps} ${ops} ${host_kernels} ${cuda_kernels} program tensor memory naive_buffer types ${fluid_modules} protobuf ${cuda_static_deps}) endif() + bundle_static_library(paddle_api_light paddle_api_light_bundled bundle_light_api) #----------------------------------------------------------------------------------------------------- @@ -240,6 +270,7 @@ lite_cc_test(test_light_api SRCS light_api_test.cc DEPS light_api program mir_passes paddle_api_light CL_DEPS ${opencl_kernels} FPGA_DEPS ${fpga_kernels} + BM_DEPS ${bm_kernels} ARGS --optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL) lite_cc_test(test_apis SRCS apis_test.cc @@ -248,6 +279,7 @@ lite_cc_test(test_apis SRCS apis_test.cc X86_DEPS ${x86_kernels} XPU_DEPS ${xpu_kernels} FPGA_DEPS ${fpga_kernels} + BM_DEPS ${bm_kernels} ARGS --model_dir=${LITE_MODEL_DIR}/lite_naive_model --optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL) @@ -255,7 +287,7 @@ if (LITE_WITH_JAVA AND LITE_WITH_ARM) add_subdirectory(android) endif() -if (LITE_WITH_PYTHON) +if (LITE_WITH_PYTHON) add_subdirectory(python) endif() @@ -264,20 +296,22 @@ if (LITE_ON_TINY_PUBLISH) endif() if (LITE_ON_MODEL_OPTIMIZE_TOOL) - message(STATUS "Compiling model_optimize_tool") - lite_cc_binary(model_optimize_tool SRCS model_optimize_tool.cc cxx_api_impl.cc paddle_api.cc cxx_api.cc + message(STATUS "Compiling opt") + lite_cc_binary(opt SRCS opt.cc cxx_api_impl.cc paddle_api.cc cxx_api.cc DEPS gflags kernel op optimizer mir_passes utils) - add_dependencies(model_optimize_tool op_list_h kernel_list_h all_kernel_faked_cc) + add_dependencies(opt op_list_h kernel_list_h all_kernel_faked_cc supported_kernel_op_info_h) endif(LITE_ON_MODEL_OPTIMIZE_TOOL) lite_cc_test(test_paddle_api SRCS paddle_api_test.cc DEPS paddle_api_full paddle_api_light ${ops} ARM_DEPS ${arm_kernels} + CV_DEPS paddle_cv_arm NPU_DEPS ${npu_kernels} XPU_DEPS ${xpu_kernels} CL_DEPS ${opencl_kernels} X86_DEPS ${x86_kernels} FPGA_DEPS ${fpga_kernels} + BM_DEPS ${bm_kernels} ARGS --model_dir=${LITE_MODEL_DIR}/lite_naive_model SERIAL) if (WITH_TESTING) add_dependencies(test_paddle_api extern_lite_download_lite_naive_model_tar_gz) @@ -285,25 +319,39 @@ endif() # Some bins if(NOT IOS) - lite_cc_binary(test_model_bin SRCS model_test.cc DEPS paddle_api_full paddle_api_light gflags utils - ${ops} ${host_kernels} - ARM_DEPS ${arm_kernels} - NPU_DEPS ${npu_kernels} - XPU_DEPS ${xpu_kernels} - CL_DEPS ${opencl_kernels} - FPGA_DEPS ${fpga_kernels} - X86_DEPS ${x86_kernels} - CUDA_DEPS ${cuda_kernels}) - lite_cc_binary(benchmark_bin SRCS benchmark.cc DEPS paddle_api_full paddle_api_light gflags utils - ${ops} ${host_kernels} - ARM_DEPS ${arm_kernels} - NPU_DEPS ${npu_kernels} - XPU_DEPS ${xpu_kernels} - CL_DEPS ${opencl_kernels} - FPGA_DEPS ${fpga_kernels} - X86_DEPS ${x86_kernels} - CUDA_DEPS ${cuda_kernels}) + lite_cc_binary(test_model_bin SRCS model_test.cc DEPS paddle_api_full paddle_api_light gflags utils + ${ops} ${host_kernels} + ARM_DEPS ${arm_kernels} + CV_DEPS paddle_cv_arm + NPU_DEPS ${npu_kernels} + XPU_DEPS ${xpu_kernels} + CL_DEPS ${opencl_kernels} + BM_DEPS ${bm_kernels} + FPGA_DEPS ${fpga_kernels} + X86_DEPS ${x86_kernels} + CUDA_DEPS ${cuda_kernels}) + lite_cc_binary(benchmark_bin SRCS benchmark.cc DEPS paddle_api_full paddle_api_light gflags utils + ${ops} ${host_kernels} + ARM_DEPS ${arm_kernels} + CV_DEPS paddle_cv_arm + NPU_DEPS ${npu_kernels} + XPU_DEPS ${xpu_kernels} + CL_DEPS ${opencl_kernels} + FPGA_DEPS ${fpga_kernels} + X86_DEPS ${x86_kernels} + CUDA_DEPS ${cuda_kernels}) + lite_cc_binary(multithread_test SRCS lite_multithread_test.cc DEPS paddle_api_full paddle_api_light gflags utils + ${ops} ${host_kernels} + ARM_DEPS ${arm_kernels} + CV_DEPS paddle_cv_arm + NPU_DEPS ${npu_kernels} + XPU_DEPS ${xpu_kernels} + CL_DEPS ${opencl_kernels} + BM_DEPS ${bm_kernels} + FPGA_DEPS ${fpga_kernels} + X86_DEPS ${x86_kernels} + CUDA_DEPS ${cuda_kernels}) endif() #lite_cc_binary(cxx_api_bin SRCS cxx_api_bin.cc diff --git a/lite/api/_paddle_use_ops.h b/lite/api/_paddle_use_ops.h index bdccfab5df67e485b9fef110dc6cc1e9d74b21c3..6da47e53789d651f4a36d0b8d6a7ca1ea5a0a3d3 100644 --- a/lite/api/_paddle_use_ops.h +++ b/lite/api/_paddle_use_ops.h @@ -108,7 +108,7 @@ USE_LITE_OP(while) USE_LITE_OP(lod_reset) USE_LITE_OP(lookup_table) USE_LITE_OP(multiclass_nms) -USE_LITE_OP(graph_op) +USE_LITE_OP(subgraph) USE_LITE_OP(sequence_expand) USE_LITE_OP(sequence_pool) USE_LITE_OP(reduce_max) diff --git a/lite/api/android/jni/native/CMakeLists.txt b/lite/api/android/jni/native/CMakeLists.txt index 3efa980332f25d786d5c880fab9b3ba5af0a1013..c1766772f8aaa417c3da1d72f2692c10c10194b4 100644 --- a/lite/api/android/jni/native/CMakeLists.txt +++ b/lite/api/android/jni/native/CMakeLists.txt @@ -25,11 +25,12 @@ if (NOT LITE_ON_TINY_PUBLISH) endif() else() add_library(paddle_lite_jni SHARED "") + set_target_properties(paddle_lite_jni PROPERTIES COMPILE_FLAGS "-flto -fdata-sections") target_sources(paddle_lite_jni PUBLIC ${__lite_cc_files} paddle_lite_jni.cc tensor_jni.cc) add_dependencies(paddle_lite_jni op_list_h kernel_list_h) if (LITE_WITH_NPU) # Need to add HIAI runtime libs (libhiai.so) dependency - target_link_libraries(paddle_lite_jni ${npu_runtime_libs}) + target_link_libraries(paddle_lite_jni ${npu_builder_libs} ${npu_runtime_libs}) endif() endif() diff --git a/lite/api/android/jni/native/convert_util_jni.h b/lite/api/android/jni/native/convert_util_jni.h index 5e5d3723e43eb311f64b85f7507a12497d724109..e4adafdc572fdc937f568508aa9d43eb78470d0d 100644 --- a/lite/api/android/jni/native/convert_util_jni.h +++ b/lite/api/android/jni/native/convert_util_jni.h @@ -181,6 +181,7 @@ inline MobileConfig jmobileconfig_to_cpp_mobileconfig(JNIEnv *env, MobileConfig config; // set model dir + // NOTE: This is a deprecated API and will be removed in latter release. jmethodID model_dir_method = env->GetMethodID( mobileconfig_jclazz, "getModelDir", "()Ljava/lang/String;"); jstring java_model_dir = @@ -190,6 +191,27 @@ inline MobileConfig jmobileconfig_to_cpp_mobileconfig(JNIEnv *env, config.set_model_dir(cpp_model_dir); } + // set model from file + jmethodID model_file_method = env->GetMethodID( + mobileconfig_jclazz, "getModelFromFile", "()Ljava/lang/String;"); + jstring java_model_file = + (jstring)env->CallObjectMethod(jmobileconfig, model_file_method); + if (java_model_file != nullptr) { + std::string cpp_model_file = jstring_to_cpp_string(env, java_model_file); + config.set_model_from_file(cpp_model_file); + } + + // set model from buffer + jmethodID model_buffer_method = env->GetMethodID( + mobileconfig_jclazz, "getModelFromBuffer", "()Ljava/lang/String;"); + jstring java_model_buffer = + (jstring)env->CallObjectMethod(jmobileconfig, model_buffer_method); + if (java_model_buffer != nullptr) { + std::string cpp_model_buffer = + jstring_to_cpp_string(env, java_model_buffer); + config.set_model_from_buffer(cpp_model_buffer); + } + // set threads jmethodID threads_method = env->GetMethodID(mobileconfig_jclazz, "getThreads", "()I"); diff --git a/lite/api/android/jni/native/tensor_jni.cc b/lite/api/android/jni/native/tensor_jni.cc index 59cafa19399c4d265915e2dac8653e9ed7d10851..5212fe9a6eba2b034883da93c9ea5d845a63c773 100644 --- a/lite/api/android/jni/native/tensor_jni.cc +++ b/lite/api/android/jni/native/tensor_jni.cc @@ -120,6 +120,22 @@ JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_nativeSetData___3B( return JNI_TRUE; } +JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_nativeSetData___3I( + JNIEnv *env, jobject jtensor, jintArray buf) { + std::unique_ptr *tensor = get_writable_tensor_pointer(env, jtensor); + if (tensor == nullptr || (*tensor == nullptr)) { + return JNI_FALSE; + } + int64_t buf_size = (int64_t)env->GetArrayLength(buf); + if (buf_size != product((*tensor)->shape())) { + return JNI_FALSE; + } + + int32_t *input = (*tensor)->mutable_data(); + env->GetIntArrayRegion(buf, 0, buf_size, input); + return JNI_TRUE; +} + JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_lite_Tensor_getFloatData(JNIEnv *env, jobject jtensor) { if (is_const_tensor(env, jtensor)) { @@ -148,6 +164,20 @@ Java_com_baidu_paddle_lite_Tensor_getByteData(JNIEnv *env, jobject jtensor) { } } +JNIEXPORT jintArray JNICALL +Java_com_baidu_paddle_lite_Tensor_getIntData(JNIEnv *env, jobject jtensor) { + if (is_const_tensor(env, jtensor)) { + std::unique_ptr *tensor = + get_read_only_tensor_pointer(env, jtensor); + return cpp_array_to_jintarray( + env, (*tensor)->data(), product((*tensor)->shape())); + } else { + std::unique_ptr *tensor = get_writable_tensor_pointer(env, jtensor); + return cpp_array_to_jintarray( + env, (*tensor)->data(), product((*tensor)->shape())); + } +} + JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_deleteCppTensor( JNIEnv *env, jobject jtensor, jlong java_pointer) { if (java_pointer == 0) { diff --git a/lite/api/android/jni/native/tensor_jni.h b/lite/api/android/jni/native/tensor_jni.h index 34c35b6a76f777895dbe88dc5eadf48c659ee544..9b029dfb4c7431354d5de20c6132236764c6cc66 100644 --- a/lite/api/android/jni/native/tensor_jni.h +++ b/lite/api/android/jni/native/tensor_jni.h @@ -16,8 +16,8 @@ #include /* Header for class com_baidu_paddle_lite_Tensor */ -#ifndef PADDLE_FLUID_LITE_API_ANDROID_JNI_NATIVE_TENSOR_JNI_H_ -#define PADDLE_FLUID_LITE_API_ANDROID_JNI_NATIVE_TENSOR_JNI_H_ +#ifndef LITE_API_ANDROID_JNI_NATIVE_TENSOR_JNI_H_ +#define LITE_API_ANDROID_JNI_NATIVE_TENSOR_JNI_H_ #ifdef __cplusplus extern "C" { #endif @@ -49,6 +49,14 @@ Java_com_baidu_paddle_lite_Tensor_getFloatData(JNIEnv *, jobject); JNIEXPORT jbyteArray JNICALL Java_com_baidu_paddle_lite_Tensor_getByteData(JNIEnv *, jobject); +/* + * Class: com_baidu_paddle_lite_Tensor + * Method: getIntData + * Signature: ()[I + */ +JNIEXPORT jintArray JNICALL +Java_com_baidu_paddle_lite_Tensor_getIntData(JNIEnv *, jobject); + /* * Class: com_baidu_paddle_lite_Tensor * Method: nativeResize @@ -73,6 +81,14 @@ JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_nativeSetData___3F( JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_nativeSetData___3B( JNIEnv *, jobject, jbyteArray); +/* + * Class: com_baidu_paddle_lite_Tensor + * Method: nativeSetData + * Signature: ([I)Z + */ +JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_nativeSetData___3I( + JNIEnv *, jobject, jintArray); + /* * Class: com_baidu_paddle_lite_Tensor * Method: deleteCppTensor @@ -87,4 +103,4 @@ Java_com_baidu_paddle_lite_Tensor_deleteCppTensor(JNIEnv *, jobject, jlong); #ifdef __cplusplus } #endif -#endif // PADDLE_FLUID_LITE_API_ANDROID_JNI_NATIVE_TENSOR_JNI_H_ +#endif // LITE_API_ANDROID_JNI_NATIVE_TENSOR_JNI_H_ diff --git a/lite/api/android/jni/src/com/baidu/paddle/lite/MobileConfig.java b/lite/api/android/jni/src/com/baidu/paddle/lite/MobileConfig.java index 5c71db0c92b344e44ea2927305580de1be293f75..e150f98f22113ef6bcedd5e9882e0bd2a6378c97 100644 --- a/lite/api/android/jni/src/com/baidu/paddle/lite/MobileConfig.java +++ b/lite/api/android/jni/src/com/baidu/paddle/lite/MobileConfig.java @@ -64,6 +64,44 @@ public class MobileConfig extends ConfigBase { return powerMode.value(); } + /** + * Set model from file. + * + * @return + */ + public void setModelFromFile(String modelFile) { + this.liteModelFile = modelFile; + } + + /** + * Returns name of model_file. + * + * @return liteModelFile + */ + public String getModelFile() { + return liteModelFile; + } + + /** + * Set model from buffer. + * + * @return + */ + public void setModelFromBuffer(String modelBuffer) { + this.liteModelBuffer = modelBuffer; + } + + /** + * Returns model buffer + * + * @return liteModelBuffer + */ + public String getModelBuffer() { + return liteModelBuffer; + } + private PowerMode powerMode = PowerMode.LITE_POWER_HIGH; private int threads = 1; + private String liteModelFile; + private String liteModelBuffer; } diff --git a/lite/api/android/jni/src/com/baidu/paddle/lite/Tensor.java b/lite/api/android/jni/src/com/baidu/paddle/lite/Tensor.java index ac78800bd2e4903b44332a0a0aefe9c69b75abab..f76841dd413ddda86678eecf8241068dd98b74a4 100644 --- a/lite/api/android/jni/src/com/baidu/paddle/lite/Tensor.java +++ b/lite/api/android/jni/src/com/baidu/paddle/lite/Tensor.java @@ -108,6 +108,19 @@ public class Tensor { return nativeSetData(buf); } + /** + * Set the tensor int data. + * + * @param buf the int array buffer which will be copied into tensor. + * @return true if set data successfully. + */ + public boolean setData(int[] buf) { + if (readOnly) { + return false; + } + return nativeSetData(buf); + } + /** * @return shape of the tensor as long array. */ @@ -123,12 +136,19 @@ public class Tensor { */ public native byte[] getByteData(); + /** + * @return the tensor data as int array. + */ + public native int[] getIntData(); + private native boolean nativeResize(long[] dims); private native boolean nativeSetData(float[] buf); private native boolean nativeSetData(byte[] buf); + private native boolean nativeSetData(int[] buf); + /** * Delete C++ Tenor object pointed by the input pointer, which is presented by a * long value. diff --git a/lite/api/apis_test.cc b/lite/api/apis_test.cc index ac2c385d53ea0a1785393cd488d115d20c4264f1..bb852297d11a8862460ed6f12e007d727aca9428 100644 --- a/lite/api/apis_test.cc +++ b/lite/api/apis_test.cc @@ -62,7 +62,7 @@ TEST(CXXApi_LightApi, optim_model) { TEST(CXXApi_LightApi, save_and_load_model) { lite::Predictor cxx_api; - lite::LightPredictor light_api(FLAGS_optimized_model); + lite::LightPredictor light_api(FLAGS_optimized_model + ".nb", false); // CXXAPi { diff --git a/lite/api/benchmark.cc b/lite/api/benchmark.cc index 462a5e2381acf3cc86ca81002a282933f01ee049..718dbe44296f2d197efc5b567cf0cc211835d176 100644 --- a/lite/api/benchmark.cc +++ b/lite/api/benchmark.cc @@ -13,40 +13,82 @@ // limitations under the License. #include +#include +#include +#include #include #include +#include +#include #include #include #include "lite/api/paddle_api.h" #include "lite/api/paddle_use_kernels.h" #include "lite/api/paddle_use_ops.h" #include "lite/api/paddle_use_passes.h" -#include "lite/api/test_helper.h" #include "lite/core/device_info.h" #include "lite/utils/cp_logging.h" #include "lite/utils/string.h" +DEFINE_string(model_dir, + "", + "the path of the model, set model_dir when the model is no " + "combined formate. This option will be ignored if model_file " + "and param_file are exist."); +DEFINE_string(model_file, + "", + "the path of model file, set model_file when the model is " + "combined formate."); +DEFINE_string(param_file, + "", + "the path of param file, set param_file when the model is " + "combined formate."); DEFINE_string(input_shape, "1,3,224,224", - "input shapes, separated by colon and comma"); -DEFINE_string(result_filename, "", "save test result"); + "set input shapes according to the model, " + "separated by colon and comma, " + "such as 1,3,244,244:1,3,300,300."); +DEFINE_int32(warmup, 0, "warmup times"); +DEFINE_int32(repeats, 1, "repeats times"); +DEFINE_int32(power_mode, + 3, + "arm power mode: " + "0 for big cluster, " + "1 for little cluster, " + "2 for all cores, " + "3 for no bind"); +DEFINE_int32(threads, 1, "threads num"); +DEFINE_string(result_filename, + "result.txt", + "save benchmark " + "result to the file"); DEFINE_bool(run_model_optimize, false, - "if set true, apply model_optimize_tool to model, use optimized " - "model to test"); -DEFINE_bool(is_quantized_model, false, "if set true, test the quantized model"); + "if set true, apply model_optimize_tool to " + "model and use optimized model to test. "); +DEFINE_bool(is_quantized_model, + false, + "if set true, " + "test the performance of the quantized model. "); namespace paddle { namespace lite_api { -void OutputOptModel(const std::string& load_model_dir, - const std::string& save_optimized_model_dir, +inline double GetCurrentUS() { + struct timeval time; + gettimeofday(&time, NULL); + return 1e+6 * time.tv_sec + time.tv_usec; +} + +void OutputOptModel(const std::string& save_optimized_model_dir, const std::vector>& input_shapes) { lite_api::CxxConfig config; - config.set_model_dir(load_model_dir); - std::vector vaild_places = {Place{TARGET(kARM), PRECISION(kFloat)}, - Place{TARGET(kX86), PRECISION(kFloat)}, - Place{TARGET(kOpenCL), PRECISION(kFloat)}}; + config.set_model_dir(FLAGS_model_dir); + config.set_model_file(FLAGS_model_file); + config.set_param_file(FLAGS_param_file); + std::vector vaild_places = { + Place{TARGET(kARM), PRECISION(kFloat)}, + }; if (FLAGS_is_quantized_model) { vaild_places.insert(vaild_places.begin(), Place{TARGET(kARM), PRECISION(kInt8)}); @@ -58,34 +100,33 @@ void OutputOptModel(const std::string& load_model_dir, paddle::lite::string_format("rm -rf %s", save_optimized_model_dir.c_str()) .c_str()); if (ret == 0) { - LOG(INFO) << "delete old optimized model " << save_optimized_model_dir; + LOG(INFO) << "Delete old optimized model " << save_optimized_model_dir; } predictor->SaveOptimizedModel(save_optimized_model_dir, LiteModelType::kNaiveBuffer); - LOG(INFO) << "Load model from " << load_model_dir; + LOG(INFO) << "Load model from " << FLAGS_model_dir; LOG(INFO) << "Save optimized model to " << save_optimized_model_dir; } #ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK void Run(const std::vector>& input_shapes, const std::string& model_dir, - const int repeat, - const int thread_num, - const int warmup_times, const std::string model_name) { + // set config and create predictor lite_api::MobileConfig config; - config.set_threads(thread_num); - config.set_power_mode(LITE_POWER_NO_BIND); - config.set_model_dir(model_dir); + config.set_threads(FLAGS_threads); + config.set_power_mode(static_cast(FLAGS_power_mode)); + config.set_model_from_file(model_dir + ".nb"); auto predictor = lite_api::CreatePaddlePredictor(config); + // set input for (int j = 0; j < input_shapes.size(); ++j) { auto input_tensor = predictor->GetInput(j); input_tensor->Resize(input_shapes[j]); auto input_data = input_tensor->mutable_data(); int input_num = 1; - for (int i = 0; i < input_shapes[j].size(); ++i) { + for (size_t i = 0; i < input_shapes[j].size(); ++i) { input_num *= input_shapes[j][i]; } for (int i = 0; i < input_num; ++i) { @@ -93,26 +134,37 @@ void Run(const std::vector>& input_shapes, } } - for (int i = 0; i < warmup_times; ++i) { + // warmup + for (int i = 0; i < FLAGS_warmup; ++i) { predictor->Run(); } - auto start = lite::GetCurrentUS(); - for (int i = 0; i < repeat; ++i) { + // run + std::vector perf_vct; + for (int i = 0; i < FLAGS_repeats; ++i) { + auto start = GetCurrentUS(); predictor->Run(); + auto end = GetCurrentUS(); + perf_vct.push_back((end - start) / 1000.0); } - auto end = lite::GetCurrentUS(); - - std::FILE* pf = std::fopen(FLAGS_result_filename.c_str(), "a"); - if (nullptr == pf) { - LOG(INFO) << "create result file error"; - exit(0); + std::sort(perf_vct.begin(), perf_vct.end()); + float min_res = perf_vct.back(); + float max_res = perf_vct.front(); + float total_res = accumulate(perf_vct.begin(), perf_vct.end(), 0.0); + float avg_res = total_res / FLAGS_repeats; + + // save result + std::ofstream ofs(FLAGS_result_filename, std::ios::app); + if (!ofs.is_open()) { + LOG(FATAL) << "open result file failed"; } - fprintf(pf, - "-- %-18s avg = %5.4f ms\n", - model_name.c_str(), - (end - start) / repeat / 1000.0); - std::fclose(pf); + ofs.precision(5); + ofs << std::setw(30) << std::fixed << std::left << model_name; + ofs << "min = " << std::setw(12) << min_res; + ofs << "max = " << std::setw(12) << max_res; + ofs << "average = " << std::setw(12) << avg_res; + ofs << std::endl; + ofs.close(); } #endif @@ -122,9 +174,7 @@ void Run(const std::vector>& input_shapes, int main(int argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); if (FLAGS_model_dir == "" || FLAGS_result_filename == "") { - LOG(INFO) << "usage: " - << "--model_dir /path/to/your/model --result_filename " - "/path/to/resultfile"; + LOG(INFO) << "please run ./benchmark_bin --help to obtain usage."; exit(0); } @@ -166,26 +216,20 @@ int main(int argc, char** argv) { std::vector str_input_shapes = split_string(FLAGS_input_shape); std::vector> input_shapes; - for (int i = 0; i < str_input_shapes.size(); ++i) { + for (size_t i = 0; i < str_input_shapes.size(); ++i) { input_shapes.push_back(get_shape(str_input_shapes[i])); } - // Output optimized model + // Output optimized model if needed if (FLAGS_run_model_optimize) { - paddle::lite_api::OutputOptModel( - FLAGS_model_dir, save_optimized_model_dir, input_shapes); + paddle::lite_api::OutputOptModel(save_optimized_model_dir, input_shapes); } #ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK // Run inference using optimized model std::string run_model_dir = FLAGS_run_model_optimize ? save_optimized_model_dir : FLAGS_model_dir; - paddle::lite_api::Run(input_shapes, - run_model_dir, - FLAGS_repeats, - FLAGS_threads, - FLAGS_warmup, - model_name); + paddle::lite_api::Run(input_shapes, run_model_dir, model_name); #endif return 0; } diff --git a/lite/api/cxx_api.cc b/lite/api/cxx_api.cc index cbe938cea6e5f84dfb3718585da0880e16cd5bfc..f6f7ec75e65ff54e3f3642822e51057d3522ae3a 100644 --- a/lite/api/cxx_api.cc +++ b/lite/api/cxx_api.cc @@ -24,13 +24,6 @@ namespace paddle { namespace lite { -static const char TAILORD_OPS_SOURCE_LIST_FILENAME[] = - ".tailored_ops_source_list"; -static const char TAILORD_OPS_LIST_NAME[] = ".tailored_ops_list"; -static const char TAILORD_KERNELS_SOURCE_LIST_FILENAME[] = - ".tailored_kernels_source_list"; -static const char TAILORD_KERNELS_LIST_NAME[] = ".tailored_kernels_list"; - void Predictor::SaveModel(const std::string &dir, lite_api::LiteModelType model_type, bool record_info) { @@ -50,6 +43,7 @@ void Predictor::SaveModel(const std::string &dir, LOG(FATAL) << "Unknown model type"; } if (record_info) { + MkDirRecur(dir); SaveOpKernelInfo(dir); } } @@ -128,6 +122,7 @@ void Predictor::SaveOpKernelInfo(const std::string &model_dir) { << kpf_path; } +#ifndef LITE_WITH_FPGA lite::Tensor *Predictor::GetInput(size_t offset) { CHECK(input_names_.size() > offset) << "The network has " << input_names_.size() << " inputs" @@ -137,6 +132,17 @@ lite::Tensor *Predictor::GetInput(size_t offset) { << " in exec_scope"; return in_var->GetMutable(); } +#else +lite::Tensor *Predictor::GetInput(size_t offset) { + auto *_feed_list = exec_scope_->FindVar("feed"); + CHECK(_feed_list) << "no feed variable in exec_scope"; + auto *feed_list = _feed_list->GetMutable>(); + if (offset >= feed_list->size()) { + feed_list->resize(offset + 1); + } + return &feed_list->at(offset); +} +#endif // get inputs names std::vector Predictor::GetInputNames() { return input_names_; } @@ -149,10 +155,10 @@ void Predictor::PrepareFeedFetch() { if (!program_) { GenRuntimeProgram(); } + std::vector feeds; std::vector fetchs; const auto &insts = program_->instructions(); - for (size_t i = 0; i < program_->num_instructions(); i++) { const auto &op = insts[i].op()->op_info(); if (op->Type() == "feed") { @@ -174,6 +180,8 @@ void Predictor::PrepareFeedFetch() { } } +#ifndef LITE_WITH_FPGA + const lite::Tensor *Predictor::GetOutput(size_t offset) const { CHECK(output_names_.size() > offset) << "The network has " << output_names_.size() << " outputs" @@ -193,6 +201,29 @@ std::vector Predictor::GetOutputs() const { } return outputs; } +#else + +const lite::Tensor *Predictor::GetOutput(size_t offset) const { + auto *_fetch_list = exec_scope_->FindVar("fetch"); + CHECK(_fetch_list) << "no fatch variable in exec_scope"; + auto &fetch_list = *_fetch_list->GetMutable>(); + CHECK_LT(offset, fetch_list.size()) << "offset " << offset << " overflow"; + return &fetch_list.at(offset); +} + +std::vector Predictor::GetOutputs() const { + auto *_fetch_list = exec_scope_->FindVar("fetch"); + CHECK(_fetch_list) << "no fatch variable in exec_scope"; + auto &fetch_list = *_fetch_list->GetMutable>(); + + std::vector outputs; + for (auto out : fetch_list) { + outputs.push_back(&out); + } + return outputs; +} + +#endif const cpp::ProgramDesc &Predictor::program_desc() const { return program_desc_; @@ -208,7 +239,11 @@ void Predictor::Build(const lite_api::CxxConfig &config, const std::string &model_file = config.model_file(); const std::string ¶m_file = config.param_file(); const bool model_from_memory = config.model_from_memory(); - LOG(INFO) << "load from memory " << model_from_memory; + if (model_from_memory) { + LOG(INFO) << "Load model from memory."; + } else { + LOG(INFO) << "Load model from file."; + } Build(model_path, model_file, @@ -242,7 +277,7 @@ void Predictor::Build(const std::string &model_path, case lite_api::LiteModelType::kNaiveBuffer: CHECK(!model_path.empty()) << "NaiveBuffer backend only supported combined param"; - LoadModelNaive(model_path, scope_.get(), &program_desc_); + LoadModelNaiveFromFile(model_path, scope_.get(), &program_desc_); break; default: LOG(FATAL) << "Unknown model type"; diff --git a/lite/api/cxx_api.h b/lite/api/cxx_api.h index 502ce812e1f4a7f520e89e6eaff020c5853f5308..504710d9fa29420b8762f31e0c675b59c6c626bd 100644 --- a/lite/api/cxx_api.h +++ b/lite/api/cxx_api.h @@ -29,6 +29,13 @@ namespace paddle { namespace lite { +static const char TAILORD_OPS_SOURCE_LIST_FILENAME[] = + ".tailored_ops_source_list"; +static const char TAILORD_OPS_LIST_NAME[] = ".tailored_ops_list"; +static const char TAILORD_KERNELS_SOURCE_LIST_FILENAME[] = + ".tailored_kernels_source_list"; +static const char TAILORD_KERNELS_LIST_NAME[] = ".tailored_kernels_list"; + /* * Predictor for inference, input a model, it will optimize and execute it. */ diff --git a/lite/api/cxx_api_impl.cc b/lite/api/cxx_api_impl.cc index 6fa400db6da9f029c38b496cd70d593a876628c9..81ea60eac66849f8ce42fb8cb210226d18bbfa9b 100644 --- a/lite/api/cxx_api_impl.cc +++ b/lite/api/cxx_api_impl.cc @@ -20,6 +20,12 @@ #include "lite/core/device_info.h" #include "lite/core/version.h" +#if (defined LITE_WITH_X86) && (defined PADDLE_WITH_MKLML) && \ + !(defined LITE_ON_MODEL_OPTIMIZE_TOOL) +#include +#include "lite/backends/x86/mklml.h" +#endif + namespace paddle { namespace lite { @@ -33,6 +39,17 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) { mode_ = config.power_mode(); threads_ = config.threads(); + +#if (defined LITE_WITH_X86) && (defined PADDLE_WITH_MKLML) && \ + !(defined LITE_ON_MODEL_OPTIMIZE_TOOL) + int num_threads = config.x86_math_library_num_threads(); + int real_num_threads = num_threads > 1 ? num_threads : 1; + paddle::lite::x86::MKL_Set_Num_Threads(real_num_threads); + omp_set_num_threads(real_num_threads); + VLOG(3) << "set_x86_math_library_math_threads() is set successfully and the " + "number of threads is:" + << num_threads; +#endif } std::unique_ptr CxxPaddleApiImpl::GetInput(int i) { diff --git a/lite/api/cxx_api_test.cc b/lite/api/cxx_api_test.cc index 4d711302cb5880247f4a7b7082185c500b9ad6e9..cdf1e838366f4bcafc1c1c991d8805f115de7345 100644 --- a/lite/api/cxx_api_test.cc +++ b/lite/api/cxx_api_test.cc @@ -101,7 +101,7 @@ TEST(CXXApi, save_model) { TEST(CXXApi, load_model_naive) { lite::Predictor predictor; std::vector valid_places({Place{TARGET(kARM), PRECISION(kFloat)}}); - predictor.Build(FLAGS_optimized_model + ".naive", + predictor.Build(FLAGS_optimized_model + ".naive.nb", "", "", valid_places, diff --git a/lite/api/light_api.cc b/lite/api/light_api.cc index a0c4b7e5e375d9d004de63345ba5013ee6c252b9..29d8f4f29ab822f8c9601bbd63a3626abbbf1818 100644 --- a/lite/api/light_api.cc +++ b/lite/api/light_api.cc @@ -18,6 +18,17 @@ namespace paddle { namespace lite { +void LightPredictor::Build(const std::string& lite_model_file, + bool model_from_memory) { + if (model_from_memory) { + LoadModelNaiveFromMemory(lite_model_file, scope_.get(), &cpp_program_desc_); + } else { + LoadModelNaiveFromFile(lite_model_file, scope_.get(), &cpp_program_desc_); + } + BuildRuntimeProgram(cpp_program_desc_); + PrepareFeedFetch(); +} + void LightPredictor::Build(const std::string& model_dir, const std::string& model_buffer, const std::string& param_buffer, @@ -41,6 +52,8 @@ void LightPredictor::Build(const std::string& model_dir, default: LOG(FATAL) << "Unknown model type"; } + + DequantizeWeight(); BuildRuntimeProgram(cpp_program_desc_); PrepareFeedFetch(); } @@ -144,5 +157,69 @@ void LightPredictor::BuildRuntimeProgram(const cpp::ProgramDesc& prog) { program_->set_exec_scope(program.exec_scope()); } +void LightPredictor::DequantizeWeight() { +#define PROCESS_CONV2D_DATA() \ + for (int64_t i = 0; i < h; ++i) { \ + for (int64_t j = 0; j < w; ++j) { \ + fp_data[i * w + j] = scale_list[i] * int_data[i * w + j]; \ + } \ + } + +#define PROCESS_FC_DATA() \ + for (int i = 0; i < input_tensor->numel(); i++) { \ + *fp_data = scale_list[0] * (*int_data); \ + ++fp_data; \ + ++int_data; \ + } + + Tensor tmp_tensor; + CHECK(cpp_program_desc_.BlocksSize()); + auto* main_block = cpp_program_desc_.GetBlock(0); + for (size_t k = 0; k < main_block->OpsSize(); ++k) { + auto* op_desc = main_block->GetOp(k); + if (op_desc->HasAttr("quantize_weight_bits")) { // weight quantized op + auto input_names = op_desc->input_vars(); + for (auto& input_name : input_names) { + std::string input_scale_name = input_name + "_quant_scale"; + if (op_desc->HasAttr(input_scale_name)) { // the input is quantized + auto input_tensor = + scope_->FindVar(input_name)->GetMutable(); + tmp_tensor.CopyDataFrom(*input_tensor); + auto scale_list = + op_desc->GetAttr>(input_scale_name); + int quantize_weight_bits = + op_desc->GetAttr("quantize_weight_bits"); + float* fp_data = input_tensor->mutable_data(); + + std::string op_type = op_desc->Type(); + if (op_type == "conv2d" || op_type == "depthwise_conv2d") { + int64_t h = input_tensor->dims()[0]; + int64_t w = input_tensor->numel() / h; + CHECK_EQ(scale_list.size(), h); + if (quantize_weight_bits == 8) { + const int8_t* int_data = tmp_tensor.data(); + PROCESS_CONV2D_DATA() + } else { + const int16_t* int_data = tmp_tensor.data(); + PROCESS_CONV2D_DATA() + } + } else if (op_type == "fc" || op_type == "mul") { + if (quantize_weight_bits == 8) { + const int8_t* int_data = tmp_tensor.data(); + PROCESS_FC_DATA() + } else { + const int16_t* int_data = tmp_tensor.data(); + PROCESS_FC_DATA() + } + } + } + } + } + } + +#undef PROCESS_CONV2D_DATA +#undef PROCESS_FC_DATA +} + } // namespace lite } // namespace paddle diff --git a/lite/api/light_api.h b/lite/api/light_api.h index 3781bc4d674db5d2e8794edaf33f00627b9977bb..aa25ea81c7b62238211f96265a4edc49f2d065a1 100644 --- a/lite/api/light_api.h +++ b/lite/api/light_api.h @@ -18,6 +18,7 @@ */ #pragma once +#include #include #include #include @@ -39,12 +40,22 @@ namespace lite { */ class LITE_API LightPredictor { public: - LightPredictor( - const std::string& model_dir, - const std::string& model_buffer = "", - const std::string& param_buffer = "", - bool model_from_memory = false, - lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf) { + // constructor function of LightPredictor, `lite_model_file` refers to data in + // model file or buffer,`model_from_memory` refers to whther to load model + // from memory. + LightPredictor(const std::string& lite_model_file, + bool model_from_memory = false) { + scope_ = std::make_shared(); + Build(lite_model_file, model_from_memory); + } + + // NOTE: This is a deprecated API and will be removed in latter release. + LightPredictor(const std::string& model_dir, + const std::string& model_buffer = "", + const std::string& param_buffer = "", + bool model_from_memory = false, + lite_api::LiteModelType model_type = + lite_api::LiteModelType::kNaiveBuffer) { scope_ = std::make_shared(); Build(model_dir, model_buffer, param_buffer, model_type, model_from_memory); } @@ -69,6 +80,10 @@ class LITE_API LightPredictor { void PrepareFeedFetch(); private: + void Build(const std::string& lite_model_file, + bool model_from_memory = false); + + // NOTE: This is a deprecated API and will be removed in latter release. void Build( const std::string& model_dir, const std::string& model_buffer, @@ -78,6 +93,8 @@ class LITE_API LightPredictor { void BuildRuntimeProgram(const cpp::ProgramDesc& prog); + void DequantizeWeight(); + private: std::shared_ptr scope_; std::unique_ptr program_; diff --git a/lite/api/light_api_impl.cc b/lite/api/light_api_impl.cc index a0ae28df0958403237114a3d4b94031829019339..3965843250abe45c43490bdbb4aaed58915e0908 100644 --- a/lite/api/light_api_impl.cc +++ b/lite/api/light_api_impl.cc @@ -23,13 +23,17 @@ namespace lite { void LightPredictorImpl::Init(const lite_api::MobileConfig& config) { // LightPredictor Only support NaiveBuffer backend in publish lib - raw_predictor_.reset( - new LightPredictor(config.model_dir(), - config.model_buffer(), - config.param_buffer(), - config.model_from_memory(), - lite_api::LiteModelType::kNaiveBuffer)); - + if (config.lite_model_file().empty()) { + raw_predictor_.reset( + new LightPredictor(config.model_dir(), + config.model_buffer(), + config.param_buffer(), + config.model_from_memory(), + lite_api::LiteModelType::kNaiveBuffer)); + } else { + raw_predictor_.reset(new LightPredictor(config.lite_model_file(), + config.model_from_memory())); + } mode_ = config.power_mode(); threads_ = config.threads(); } diff --git a/lite/api/lite_multithread_test.cc b/lite/api/lite_multithread_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..addd512eb0039c43edeca562b8f568528aab76f9 --- /dev/null +++ b/lite/api/lite_multithread_test.cc @@ -0,0 +1,360 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "lite/api/paddle_api.h" +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/api/paddle_use_passes.h" +#include "lite/api/test_helper.h" +#include "lite/core/device_info.h" +#include "lite/core/profile/timer.h" +#include "lite/utils/cp_logging.h" +#include "lite/utils/string.h" +#ifdef LITE_WITH_PROFILE +#include "lite/core/profile/basic_profiler.h" +#endif // LITE_WITH_PROFILE +#include // NOLINT + +using paddle::lite::profile::Timer; + +DEFINE_string(input_shape, + "1,3,224,224", + "input shapes, separated by colon and comma"); + +DEFINE_string(model_dir_0, "", "model_dir_0"); +DEFINE_string(input_shape_0, + "1,3,224,224", + "input shapes another, separated by colon and comma"); + +DEFINE_bool(use_optimize_nb, + false, + "optimized & naive buffer model for mobile devices"); + +DEFINE_int32(test_type, 0, "multithread test type"); + +namespace paddle { +namespace lite_api { + +void OutputOptModel(const std::string& load_model_dir, + const std::string& save_optimized_model_dir, + const std::vector>& input_shapes) { + lite_api::CxxConfig config; + config.set_model_dir(load_model_dir); + config.set_valid_places({ + Place{TARGET(kARM), PRECISION(kFloat)}, + }); + auto predictor = lite_api::CreatePaddlePredictor(config); + + // delete old optimized model + int ret = system( + paddle::lite::string_format("rm -rf %s", save_optimized_model_dir.c_str()) + .c_str()); + if (ret == 0) { + LOG(INFO) << "delete old optimized model " << save_optimized_model_dir; + } + predictor->SaveOptimizedModel(save_optimized_model_dir, + LiteModelType::kNaiveBuffer); + LOG(INFO) << "Load model from " << load_model_dir; + LOG(INFO) << "Save optimized model to " << save_optimized_model_dir; +} + +#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK +void Run(const std::vector>& input_shapes, + const std::string& model_dir, + const PowerMode power_mode, + const int thread_num, + const int repeat, + int tid, + const int warmup_times = 5) { + lite_api::MobileConfig config; + config.set_model_dir(model_dir); + config.set_power_mode(power_mode); + config.set_threads(thread_num); + + auto predictor = lite_api::CreatePaddlePredictor(config); + + for (int j = 0; j < input_shapes.size(); ++j) { + auto input_tensor = predictor->GetInput(j); + input_tensor->Resize(input_shapes[j]); + auto input_data = input_tensor->mutable_data(); + int input_num = 1; + for (int i = 0; i < input_shapes[j].size(); ++i) { + input_num *= input_shapes[j][i]; + } + for (int i = 0; i < input_num; ++i) { + input_data[i] = 1.f; + } + } + + for (int i = 0; i < warmup_times; ++i) { + predictor->Run(); + } + + Timer ti; + for (int j = 0; j < repeat; ++j) { + ti.Start(); + predictor->Run(); + float t = ti.Stop(); + auto output = predictor->GetOutput(0); + auto out = output->data(); + LOG(INFO) << "[thread " << tid << "] Model: " << model_dir + << " output[0]:" << out[0] << "; output[1]:" << out[1]; + } + LOG(INFO) << "[thread " << tid << "] Model: " << model_dir + << ", power_mode: " << static_cast(power_mode) + << ", threads num " << thread_num + << ", avg time: " << ti.LapTimes().Avg() << "ms" + << ", min time: " << ti.LapTimes().Min() << " ms" + << ", max time: " << ti.LapTimes().Max() << " ms."; +} + +void RunTestType_00(const std::vector>& input_shapes, + const std::string& model_dir, + const PowerMode power_mode, + const int thread_num, + const int repeat, + const int warmup_times = 5) { + std::thread run_th0(Run, + input_shapes, + model_dir, + power_mode, + thread_num, + repeat, + 0, + warmup_times); + Run(input_shapes, model_dir, power_mode, thread_num, repeat, 1, warmup_times); + run_th0.join(); +} +void RunTestType_01(const std::vector>& input_shapes, + const std::string& model_dir, + const std::vector>& input_shapes_0, + const std::string& model_dir_0, + const PowerMode power_mode, + const int thread_num, + const int repeat, + const int warmup_times = 5) { + std::thread run_th0(Run, + input_shapes, + model_dir, + power_mode, + thread_num, + repeat, + 0, + warmup_times); + Run(input_shapes_0, + model_dir_0, + power_mode, + thread_num, + repeat, + 1, + warmup_times); + run_th0.join(); +} + +void run_with_predictor(std::shared_ptr predictor, + const std::vector>& input_shapes, + int index, + const std::string& name) { + for (int j = 0; j < input_shapes.size(); ++j) { + auto input_tensor = predictor->GetInput(j); + input_tensor->Resize(input_shapes[j]); + auto input_data = input_tensor->mutable_data(); + int input_num = 1; + for (int i = 0; i < input_shapes[j].size(); ++i) { + input_num *= input_shapes[j][i]; + } + for (int i = 0; i < input_num; ++i) { + input_data[i] = 1.f; + } + } + + Timer ti; + ti.Start(); + predictor->Run(); + float t = ti.Stop(); + + auto output = predictor->GetOutput(0); + auto out = output->data(); + LOG(INFO) << "[thread " << index << "] name: " << name + << ",run time: " << ti.LapTimes().Avg() << "ms" + << " output[0]:" << out[0] << "; output[1]:" << out[1]; +} +void RunTestType_10(const std::vector>& input_shapes, + const std::string& model_dir, + const PowerMode power_mode, + const int thread_num, + const int repeat, + int warmup = 5) { + lite_api::MobileConfig config; + config.set_model_dir(model_dir); + config.set_power_mode(power_mode); + config.set_threads(thread_num); + + auto predictor = lite_api::CreatePaddlePredictor(config); + + for (int i = 0; i < repeat; ++i) { + std::thread pre_th0( + run_with_predictor, predictor, input_shapes, i, model_dir); + pre_th0.join(); + } +} +void RunTestType_11(const std::vector>& input_shapes, + const std::string& model_dir, + const std::vector>& input_shapes_0, + const std::string& model_dir_0, + const PowerMode power_mode, + const int thread_num, + const int repeat, + int warmup = 5) { + lite_api::MobileConfig config; + config.set_model_dir(model_dir); + config.set_power_mode(power_mode); + config.set_threads(thread_num); + + auto predictor = lite_api::CreatePaddlePredictor(config); + + config.set_model_dir(model_dir_0); + auto predictor_0 = lite_api::CreatePaddlePredictor(config); + + for (int i = 0; i < 2 * repeat; i += 2) { + std::thread pre_th0( + run_with_predictor, predictor, input_shapes, i, model_dir); + std::thread pre_th1( + run_with_predictor, predictor_0, input_shapes_0, i + 1, model_dir_0); + pre_th0.join(); + pre_th1.join(); + } +} + +#endif + +} // namespace lite_api +} // namespace paddle + +int main(int argc, char** argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + if (FLAGS_model_dir == "") { + LOG(INFO) << "usage: " + << "--model_dir /path/to/your/model"; + exit(0); + } + std::string save_optimized_model_dir = ""; + std::string save_optimized_model_dir_0 = ""; + if (FLAGS_use_optimize_nb) { + save_optimized_model_dir = FLAGS_model_dir; + save_optimized_model_dir_0 = FLAGS_model_dir_0; + } else { + save_optimized_model_dir = FLAGS_model_dir + "opt2"; + save_optimized_model_dir_0 = FLAGS_model_dir_0 + "opt2"; + } + + auto split_string = + [](const std::string& str_in) -> std::vector { + std::vector str_out; + std::string tmp_str = str_in; + while (!tmp_str.empty()) { + size_t next_offset = tmp_str.find(":"); + str_out.push_back(tmp_str.substr(0, next_offset)); + if (next_offset == std::string::npos) { + break; + } else { + tmp_str = tmp_str.substr(next_offset + 1); + } + } + return str_out; + }; + + auto get_shape = [](const std::string& str_shape) -> std::vector { + std::vector shape; + std::string tmp_str = str_shape; + while (!tmp_str.empty()) { + int dim = atoi(tmp_str.data()); + shape.push_back(dim); + size_t next_offset = tmp_str.find(","); + if (next_offset == std::string::npos) { + break; + } else { + tmp_str = tmp_str.substr(next_offset + 1); + } + } + return shape; + }; + + std::vector str_input_shapes = split_string(FLAGS_input_shape); + std::vector> input_shapes; + for (int i = 0; i < str_input_shapes.size(); ++i) { + input_shapes.push_back(get_shape(str_input_shapes[i])); + } + std::vector str_input_shapes_0 = + split_string(FLAGS_input_shape_0); + std::vector> input_shapes_0; + for (int i = 0; i < str_input_shapes_0.size(); ++i) { + input_shapes_0.push_back(get_shape(str_input_shapes_0[i])); + } + + if (!FLAGS_use_optimize_nb) { + // Output optimized model + paddle::lite_api::OutputOptModel( + FLAGS_model_dir, save_optimized_model_dir, input_shapes); + paddle::lite_api::OutputOptModel( + FLAGS_model_dir_0, save_optimized_model_dir_0, input_shapes_0); + } + +#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK + // Run inference using optimized model + if (FLAGS_test_type == 0) { + paddle::lite_api::RunTestType_00( + input_shapes, + save_optimized_model_dir, + static_cast(0), + FLAGS_threads, + FLAGS_repeats, + 5); + LOG(INFO) << "=========above is case 0, below is case " + "1============================"; + paddle::lite_api::RunTestType_10( + input_shapes, + save_optimized_model_dir, + static_cast(0), + FLAGS_threads, + FLAGS_repeats); + } + if (FLAGS_test_type == 1) { + paddle::lite_api::RunTestType_01( + input_shapes, + save_optimized_model_dir, + input_shapes_0, + save_optimized_model_dir_0, + static_cast(0), + FLAGS_threads, + FLAGS_repeats, + 5); + LOG(INFO) << "=========above is case 0, below is case " + "1============================"; + paddle::lite_api::RunTestType_11( + input_shapes, + save_optimized_model_dir, + input_shapes_0, + save_optimized_model_dir_0, + static_cast(0), + FLAGS_threads, + FLAGS_repeats); + } + +#endif + return 0; +} diff --git a/lite/api/mobilenetv1_test.cc b/lite/api/mobilenetv1_test.cc index 79f9bea762e099b249f597dddb7df790361edc2a..bcc9644f81542ab6fb8a0badf8ecaea89fc8dedb 100644 --- a/lite/api/mobilenetv1_test.cc +++ b/lite/api/mobilenetv1_test.cc @@ -23,6 +23,10 @@ #include "lite/core/op_registry.h" DEFINE_string(optimized_model, "", "optimized_model"); +DEFINE_int32(N, 1, "input_batch"); +DEFINE_int32(C, 3, "input_channel"); +DEFINE_int32(H, 224, "input_height"); +DEFINE_int32(W, 224, "input_width"); namespace paddle { namespace lite { @@ -37,7 +41,8 @@ void TestModel(const std::vector& valid_places, predictor.Build(model_dir, "", "", valid_places); auto* input_tensor = predictor.GetInput(0); - input_tensor->Resize(DDim(std::vector({1, 3, 224, 224}))); + input_tensor->Resize(DDim( + std::vector({FLAGS_N, FLAGS_C, FLAGS_H, FLAGS_W}))); auto* data = input_tensor->mutable_data(); auto item_size = input_tensor->dims().production(); for (int i = 0; i < item_size; i++) { @@ -58,6 +63,8 @@ void TestModel(const std::vector& valid_places, predictor.SaveModel(FLAGS_optimized_model); } + LOG(INFO) << "input shape(NCHW):" << FLAGS_N << " " << FLAGS_C << " " + << FLAGS_H << " " << FLAGS_W; LOG(INFO) << "================== Speed Report ==================="; LOG(INFO) << "Model: " << model_dir << ", threads num " << FLAGS_threads << ", warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats @@ -123,10 +130,10 @@ TEST(MobileNetV1, test_arm) { #ifdef LITE_WITH_OPENCL TEST(MobileNetV1, test_opencl) { std::vector valid_places({ - Place{TARGET(kOpenCL), PRECISION(kFP16), DATALAYOUT(kNCHW)}, - Place{TARGET(kOpenCL), PRECISION(kFP16), DATALAYOUT(kNHWC)}, + Place{TARGET(kOpenCL), PRECISION(kFloat), DATALAYOUT(kImageDefault)}, Place{TARGET(kOpenCL), PRECISION(kFloat), DATALAYOUT(kNCHW)}, - Place{TARGET(kOpenCL), PRECISION(kFloat), DATALAYOUT(kNHWC)}, + Place{TARGET(kOpenCL), PRECISION(kAny), DATALAYOUT(kImageDefault)}, + Place{TARGET(kOpenCL), PRECISION(kAny), DATALAYOUT(kNCHW)}, TARGET(kARM), // enable kARM CPU kernel when no opencl kernel }); diff --git a/lite/api/mobilenetv2_test.cc b/lite/api/mobilenetv2_test.cc index 84bd27e352f549d619cfa51f9127f973023e6d45..012d6d48d9e6d3747f83a7f1089944bbaf359f71 100644 --- a/lite/api/mobilenetv2_test.cc +++ b/lite/api/mobilenetv2_test.cc @@ -23,6 +23,10 @@ #include "lite/core/op_registry.h" DEFINE_string(optimized_model, "", "optimized_model"); +DEFINE_int32(N, 1, "input_batch"); +DEFINE_int32(C, 3, "input_channel"); +DEFINE_int32(H, 224, "input_height"); +DEFINE_int32(W, 224, "input_width"); namespace paddle { namespace lite { @@ -38,7 +42,8 @@ void TestModel(const std::vector& valid_places, predictor.Build(model_dir, "", "", valid_places); auto* input_tensor = predictor.GetInput(0); - input_tensor->Resize(DDim(std::vector({1, 3, 224, 224}))); + input_tensor->Resize(DDim( + std::vector({FLAGS_N, FLAGS_C, FLAGS_H, FLAGS_W}))); auto* data = input_tensor->mutable_data(); auto item_size = input_tensor->dims().production(); for (int i = 0; i < item_size; i++) { @@ -59,6 +64,8 @@ void TestModel(const std::vector& valid_places, predictor.SaveModel(FLAGS_optimized_model); } + LOG(INFO) << "input shape(NCHW):" << FLAGS_N << " " << FLAGS_C << " " + << FLAGS_H << " " << FLAGS_W; LOG(INFO) << "================== Speed Report ==================="; LOG(INFO) << "Model: " << model_dir << ", threads num " << FLAGS_threads << ", warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats @@ -123,8 +130,11 @@ TEST(MobileNetV2, test_arm) { #ifdef LITE_WITH_OPENCL TEST(MobileNetV2, test_opencl) { std::vector valid_places({ - Place{TARGET(kOpenCL), PRECISION(kFloat)}, - Place{TARGET(kARM), PRECISION(kFloat)}, + Place{TARGET(kOpenCL), PRECISION(kFloat), DATALAYOUT(kImageDefault)}, + Place{TARGET(kOpenCL), PRECISION(kFloat), DATALAYOUT(kNCHW)}, + Place{TARGET(kOpenCL), PRECISION(kAny), DATALAYOUT(kImageDefault)}, + Place{TARGET(kOpenCL), PRECISION(kAny), DATALAYOUT(kNCHW)}, + TARGET(kARM), // enable kARM CPU kernel when no opencl kernel }); TestModel(valid_places); diff --git a/lite/api/model_optimize_tool.cc b/lite/api/model_optimize_tool.cc deleted file mode 100644 index daa57cd45632764172426cc41914abc7f82bea33..0000000000000000000000000000000000000000 --- a/lite/api/model_optimize_tool.cc +++ /dev/null @@ -1,141 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#ifdef PADDLE_WITH_TESTING -#include -#endif -// "all_kernel_faked.cc" and "kernel_src_map.h" are created automatically during -// model_optimize_tool's compiling period -#include "all_kernel_faked.cc" // NOLINT -#include "kernel_src_map.h" // NOLINT -#include "lite/api/paddle_api.h" -#include "lite/api/paddle_use_ops.h" -#include "lite/api/paddle_use_passes.h" -#include "lite/core/op_registry.h" -#include "lite/utils/cp_logging.h" -#include "lite/utils/string.h" - -DEFINE_string(model_dir, - "", - "path of the model. This option will be ignored if model_file " - "and param_file are exist"); -DEFINE_string(model_file, "", "model file path of the combined-param model"); -DEFINE_string(param_file, "", "param file path of the combined-param model"); -DEFINE_string( - optimize_out_type, - "protobuf", - "store type of the output optimized model. protobuf/naive_buffer"); -DEFINE_bool(display_kernels, false, "Display kernel information"); -DEFINE_bool(record_tailoring_info, - false, - "Record kernels and operators information of the optimized model " - "for tailoring compiling, information are stored into optimized " - "model path as hidden files"); -DEFINE_string(optimize_out, "", "path of the output optimized model"); -DEFINE_string(valid_targets, - "arm", - "The targets this model optimized for, should be one of (arm, " - "opencl, x86), splitted by space"); -DEFINE_bool(prefer_int8_kernel, false, "Prefer to run model with int8 kernels"); - -namespace paddle { -namespace lite_api { - -//! Display the kernel information. -void DisplayKernels() { - LOG(INFO) << ::paddle::lite::KernelRegistry::Global().DebugString(); -} - -void Main() { - if (!FLAGS_model_file.empty() && !FLAGS_param_file.empty()) { - LOG(WARNING) - << "Load combined-param model. Option model_dir will be ignored"; - } - - if (FLAGS_display_kernels) { - DisplayKernels(); - exit(0); - } - - lite_api::CxxConfig config; - config.set_model_dir(FLAGS_model_dir); - config.set_model_file(FLAGS_model_file); - config.set_param_file(FLAGS_param_file); - - std::vector valid_places; - auto target_reprs = lite::Split(FLAGS_valid_targets, " "); - for (auto& target_repr : target_reprs) { - if (target_repr == "arm") { - valid_places.emplace_back(TARGET(kARM)); - } else if (target_repr == "opencl") { - valid_places.emplace_back( - Place{TARGET(kOpenCL), PRECISION(kFP16), DATALAYOUT(kNCHW)}); - valid_places.emplace_back( - Place{TARGET(kOpenCL), PRECISION(kFP16), DATALAYOUT(kNHWC)}); - valid_places.emplace_back( - Place{TARGET(kOpenCL), PRECISION(kFloat), DATALAYOUT(kNCHW)}); - valid_places.emplace_back( - Place{TARGET(kOpenCL), PRECISION(kFloat), DATALAYOUT(kNHWC)}); - valid_places.emplace_back( - TARGET(kARM)); // enable kARM CPU kernel when no opencl kernel - } else if (target_repr == "x86") { - valid_places.emplace_back(TARGET(kX86)); - } else { - LOG(FATAL) << lite::string_format( - "Wrong target '%s' found, please check the command flag " - "'valid_targets'", - target_repr.c_str()); - } - } - - CHECK(!valid_places.empty()) - << "At least one target should be set, should set the " - "command argument 'valid_targets'"; - - if (FLAGS_prefer_int8_kernel) { - LOG(WARNING) << "Int8 mode is only support by ARM target"; - valid_places.insert(valid_places.begin(), - Place{TARGET(kARM), PRECISION(kInt8)}); - } - config.set_valid_places(valid_places); - - auto predictor = lite_api::CreatePaddlePredictor(config); - - LiteModelType model_type; - if (FLAGS_optimize_out_type == "protobuf") { - model_type = LiteModelType::kProtobuf; - } else if (FLAGS_optimize_out_type == "naive_buffer") { - model_type = LiteModelType::kNaiveBuffer; - } else { - LOG(FATAL) << "Unsupported Model type :" << FLAGS_optimize_out_type; - } - OpKernelInfoCollector::Global().SetKernel2path(kernel2path_map); - - predictor->SaveOptimizedModel( - FLAGS_optimize_out, model_type, FLAGS_record_tailoring_info); - if (FLAGS_record_tailoring_info) { - LOG(INFO) << "Record the information of tailored model into :" - << FLAGS_optimize_out; - } -} - -} // namespace lite_api -} // namespace paddle - -int main(int argc, char** argv) { - google::ParseCommandLineFlags(&argc, &argv, false); - paddle::lite_api::Main(); - return 0; -} diff --git a/lite/api/model_test.cc b/lite/api/model_test.cc index 1358267000991c81b80453669cf46638449b8a7b..190890da4c109f39cc52ca5209cd952f8937f780 100644 --- a/lite/api/model_test.cc +++ b/lite/api/model_test.cc @@ -13,6 +13,7 @@ // limitations under the License. #include +#include #include #include #include "lite/api/paddle_api.h" @@ -21,22 +22,22 @@ #include "lite/api/paddle_use_passes.h" #include "lite/api/test_helper.h" #include "lite/core/device_info.h" -#include "lite/tests/utils/timer.h" +#include "lite/core/profile/timer.h" #include "lite/utils/cp_logging.h" #include "lite/utils/string.h" #ifdef LITE_WITH_PROFILE #include "lite/core/profile/basic_profiler.h" #endif // LITE_WITH_PROFILE -using paddle::lite::Timer; +using paddle::lite::profile::Timer; DEFINE_string(input_shape, "1,3,224,224", "input shapes, separated by colon and comma"); - DEFINE_bool(use_optimize_nb, false, "optimized & naive buffer model for mobile devices"); +DEFINE_string(arg_name, "", "the arg name"); namespace paddle { namespace lite_api { @@ -47,7 +48,6 @@ void OutputOptModel(const std::string& load_model_dir, lite_api::CxxConfig config; config.set_model_dir(load_model_dir); config.set_valid_places({ - Place{TARGET(kX86), PRECISION(kFloat)}, Place{TARGET(kARM), PRECISION(kFloat)}, }); auto predictor = lite_api::CreatePaddlePredictor(config); @@ -72,12 +72,8 @@ void Run(const std::vector>& input_shapes, const int thread_num, const int repeat, const int warmup_times = 0) { -#ifdef LITE_WITH_PROFILE - lite::profile::BasicProfiler::Global().SetWarmup( - warmup_times); -#endif lite_api::MobileConfig config; - config.set_model_dir(model_dir); + config.set_model_from_file(model_dir + ".nb"); config.set_power_mode(power_mode); config.set_threads(thread_num); @@ -91,6 +87,7 @@ void Run(const std::vector>& input_shapes, for (int i = 0; i < input_shapes[j].size(); ++i) { input_num *= input_shapes[j][i]; } + for (int i = 0; i < input_num; ++i) { input_data[i] = 1.f; } @@ -102,20 +99,20 @@ void Run(const std::vector>& input_shapes, Timer ti; for (int j = 0; j < repeat; ++j) { - ti.start(); + ti.Start(); predictor->Run(); - ti.end(); - LOG(INFO) << "iter: " << j << ", time: " << ti.latest_time() << " ms"; + float t = ti.Stop(); + LOG(INFO) << "iter: " << j << ", time: " << t << " ms"; } LOG(INFO) << "================== Speed Report ==================="; LOG(INFO) << "Model: " << model_dir << ", power_mode: " << static_cast(power_mode) << ", threads num " << thread_num << ", warmup: " << warmup_times - << ", repeats: " << repeat << ", avg time: " << ti.get_average_ms() + << ", repeats: " << repeat << ", avg time: " << ti.LapTimes().Avg() << " ms" - << ", min time: " << ti.get_min_time() << " ms" - << ", max time: " << ti.get_max_time() << " ms."; + << ", min time: " << ti.LapTimes().Min() << " ms" + << ", max time: " << ti.LapTimes().Max() << " ms."; auto output = predictor->GetOutput(0); auto out = output->data(); @@ -127,6 +124,28 @@ void Run(const std::vector>& input_shapes, output_num *= output_shape[i]; } LOG(INFO) << "output_num: " << output_num; + + // please turn off memory_optimize_pass to use this feature. + if (FLAGS_arg_name != "") { + auto arg_tensor = predictor->GetTensor(FLAGS_arg_name); + auto arg_shape = arg_tensor->shape(); + int arg_num = 1; + std::ostringstream os; + os << "{"; + for (int i = 0; i < arg_shape.size(); ++i) { + arg_num *= arg_shape[i]; + os << arg_shape[i] << ","; + } + os << "}"; + float sum = 0.; + std::ofstream out(FLAGS_arg_name + ".txt"); + for (size_t i = 0; i < arg_num; ++i) { + sum += arg_tensor->data()[i]; + out << std::to_string(arg_tensor->data()[i]) << "\n"; + } + LOG(INFO) << FLAGS_arg_name << " shape is " << os.str() + << ", mean value is " << sum * 1. / arg_num; + } } #endif diff --git a/lite/api/opt.cc b/lite/api/opt.cc new file mode 100644 index 0000000000000000000000000000000000000000..a00646f4e11b68f0233a8b6009fbf847e9d50d63 --- /dev/null +++ b/lite/api/opt.cc @@ -0,0 +1,460 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#ifdef PADDLE_WITH_TESTING +#include +#endif +// "supported_kernel_op_info.h", "all_kernel_faked.cc" and "kernel_src_map.h" +// are created automatically during opt's compiling period +#include +#include "all_kernel_faked.cc" // NOLINT +#include "kernel_src_map.h" // NOLINT +#include "lite/api/cxx_api.h" +#include "lite/api/paddle_api.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/api/paddle_use_passes.h" +#include "lite/core/op_registry.h" +#include "lite/core/version.h" +#include "lite/model_parser/compatible_pb.h" +#include "lite/model_parser/pb/program_desc.h" +#include "lite/utils/cp_logging.h" +#include "lite/utils/string.h" +#include "supported_kernel_op_info.h" // NOLINT + +DEFINE_string(model_dir, + "", + "path of the model. This option will be ignored if model_file " + "and param_file are exist"); +DEFINE_string(model_filename, + "", + "model topo filename of the model in models set. This option" + " will be used to specific tailoring"); +DEFINE_string(param_filename, + "", + "model param filename of the model in models set. This option" + " will be used to specific tailoring"); +DEFINE_string(model_set_dir, + "", + "path of the models set. This option will be used to specific" + " tailoring"); +DEFINE_string(model_file, "", "model file path of the combined-param model"); +DEFINE_string(param_file, "", "param file path of the combined-param model"); +DEFINE_string( + optimize_out_type, + "protobuf", + "store type of the output optimized model. protobuf/naive_buffer"); +DEFINE_bool(display_kernels, false, "Display kernel information"); +DEFINE_bool(record_tailoring_info, + false, + "Record kernels and operators information of the optimized model " + "for tailoring compiling, information are stored into optimized " + "model path as hidden files"); +DEFINE_string(optimize_out, "", "path of the output optimized model"); +DEFINE_string(valid_targets, + "arm", + "The targets this model optimized for, should be one of (arm, " + "opencl, x86), splitted by space"); +DEFINE_bool(prefer_int8_kernel, false, "Prefer to run model with int8 kernels"); +DEFINE_bool(print_supported_ops, + false, + "Print supported operators on the inputed target"); +DEFINE_bool(print_all_ops, + false, + "Print all the valid operators of Paddle-Lite"); +DEFINE_bool(print_model_ops, false, "Print operators in the input model"); + +namespace paddle { +namespace lite_api { +//! Display the kernel information. +void DisplayKernels() { + LOG(INFO) << ::paddle::lite::KernelRegistry::Global().DebugString(); +} + +std::vector ParserValidPlaces() { + std::vector valid_places; + auto target_reprs = lite::Split(FLAGS_valid_targets, ","); + for (auto& target_repr : target_reprs) { + if (target_repr == "arm") { + valid_places.emplace_back(TARGET(kARM)); + } else if (target_repr == "opencl") { + valid_places.emplace_back( + Place{TARGET(kOpenCL), PRECISION(kFloat), DATALAYOUT(kImageDefault)}); + valid_places.emplace_back( + Place{TARGET(kOpenCL), PRECISION(kFloat), DATALAYOUT(kNCHW)}); + valid_places.emplace_back( + Place{TARGET(kOpenCL), PRECISION(kAny), DATALAYOUT(kImageDefault)}); + valid_places.emplace_back( + Place{TARGET(kOpenCL), PRECISION(kAny), DATALAYOUT(kNCHW)}); + valid_places.emplace_back( + TARGET(kARM)); // enable kARM CPU kernel when no opencl kernel + } else if (target_repr == "x86") { + valid_places.emplace_back(TARGET(kX86)); + } else if (target_repr == "npu") { + valid_places.emplace_back(TARGET(kNPU)); + } else if (target_repr == "xpu") { + valid_places.emplace_back(TARGET(kXPU)); + } else { + LOG(FATAL) << lite::string_format( + "Wrong target '%s' found, please check the command flag " + "'valid_targets'", + target_repr.c_str()); + } + } + + CHECK(!valid_places.empty()) + << "At least one target should be set, should set the " + "command argument 'valid_targets'"; + + if (FLAGS_prefer_int8_kernel) { + LOG(WARNING) << "Int8 mode is only support by ARM target"; + valid_places.insert(valid_places.begin(), + Place{TARGET(kARM), PRECISION(kInt8)}); + } + return valid_places; +} + +void RunOptimize(const std::string& model_dir, + const std::string& model_file, + const std::string& param_file, + const std::string& optimize_out, + const std::string& optimize_out_type, + const std::vector& valid_places, + bool record_tailoring_info) { + if (!model_file.empty() && !param_file.empty()) { + LOG(WARNING) + << "Load combined-param model. Option model_dir will be ignored"; + } + + lite_api::CxxConfig config; + config.set_model_dir(model_dir); + config.set_model_file(model_file); + config.set_param_file(param_file); + config.set_valid_places(valid_places); + auto predictor = lite_api::CreatePaddlePredictor(config); + + LiteModelType model_type; + if (optimize_out_type == "protobuf") { + model_type = LiteModelType::kProtobuf; + } else if (optimize_out_type == "naive_buffer") { + model_type = LiteModelType::kNaiveBuffer; + } else { + LOG(FATAL) << "Unsupported Model type :" << optimize_out_type; + } + + OpKernelInfoCollector::Global().SetKernel2path(kernel2path_map); + predictor->SaveOptimizedModel( + optimize_out, model_type, record_tailoring_info); + if (record_tailoring_info) { + LOG(INFO) << "Record the information of tailored model into :" + << optimize_out; + } +} + +void CollectModelMetaInfo(const std::string& output_dir, + const std::vector& models, + const std::string& filename) { + std::set total; + for (const auto& name : models) { + std::string model_path = + lite::Join({output_dir, name, filename}, "/"); + auto lines = lite::ReadLines(model_path); + total.insert(lines.begin(), lines.end()); + } + std::string output_path = + lite::Join({output_dir, filename}, "/"); + lite::WriteLines(std::vector(total.begin(), total.end()), + output_path); +} +void PrintOpsInfo(std::set valid_ops = {}) { + std::vector targets = {"kHost", + "kX86", + "kCUDA", + "kARM", + "kOpenCL", + "kFPGA", + "kNPU", + "kXPU", + "kAny", + "kUnk"}; + int maximum_optype_length = 0; + for (auto it = supported_ops.begin(); it != supported_ops.end(); it++) { + maximum_optype_length = it->first.size() > maximum_optype_length + ? it->first.size() + : maximum_optype_length; + } + std::cout << std::setiosflags(std::ios::internal); + std::cout << std::setw(maximum_optype_length) << "OP_name"; + for (int i = 0; i < targets.size(); i++) { + std::cout << std::setw(10) << targets[i].substr(1); + } + std::cout << std::endl; + if (valid_ops.empty()) { + for (auto it = supported_ops.begin(); it != supported_ops.end(); it++) { + std::cout << std::setw(maximum_optype_length) << it->first; + auto ops_valid_places = it->second; + for (int i = 0; i < targets.size(); i++) { + if (std::find(ops_valid_places.begin(), + ops_valid_places.end(), + targets[i]) != ops_valid_places.end()) { + std::cout << std::setw(10) << "Y"; + } else { + std::cout << std::setw(10) << " "; + } + } + std::cout << std::endl; + } + } else { + for (auto op = valid_ops.begin(); op != valid_ops.end(); op++) { + std::cout << std::setw(maximum_optype_length) << *op; + // Check: If this kernel doesn't match any operator, we will skip it. + if (supported_ops.find(*op) == supported_ops.end()) { + continue; + } + // Print OP info. + auto ops_valid_places = supported_ops.at(*op); + for (int i = 0; i < targets.size(); i++) { + if (std::find(ops_valid_places.begin(), + ops_valid_places.end(), + targets[i]) != ops_valid_places.end()) { + std::cout << std::setw(10) << "Y"; + } else { + std::cout << std::setw(10) << " "; + } + } + std::cout << std::endl; + } + } +} +/// Print help information +void PrintHelpInfo() { + // at least one argument should be inputed + const std::string opt_version = lite::version(); + const char help_info[] = + "At least one argument should be inputed. Valid arguments are listed " + "below:\n" + " Arguments of model optimization:\n" + " `--model_dir=`\n" + " `--model_file=`\n" + " `--param_file=`\n" + " `--optimize_out_type=(protobuf|naive_buffer)`\n" + " `--optimize_out=`\n" + " `--valid_targets=(arm|opencl|x86|npu|xpu)`\n" + " `--prefer_int8_kernel=(true|false)`\n" + " `--record_tailoring_info=(true|false)`\n" + " Arguments of model checking and ops information:\n" + " `--print_all_ops=true` Display all the valid operators of " + "Paddle-Lite\n" + " `--print_supported_ops=true " + "--valid_targets=(arm|opencl|x86|npu|xpu)`" + " Display valid operators of input targets\n" + " `--print_model_ops=true --model_dir= " + "--valid_targets=(arm|opencl|x86|npu|xpu)`" + " Display operators in the input model\n"; + std::cout << "opt version:" << opt_version << std::endl + << help_info << std::endl; + exit(1); +} + +// Parse Input command +void ParseInputCommand() { + if (FLAGS_print_all_ops) { + std::cout << "All OPs supported by Paddle-Lite: " << supported_ops.size() + << " ops in total." << std::endl; + PrintOpsInfo(); + exit(1); + } else if (FLAGS_print_supported_ops) { + auto valid_places = paddle::lite_api::ParserValidPlaces(); + // get valid_targets string + std::vector target_types = {}; + for (int i = 0; i < valid_places.size(); i++) { + target_types.push_back(valid_places[i].target); + } + std::string targets_str = TargetToStr(target_types[0]); + for (int i = 1; i < target_types.size(); i++) { + targets_str = targets_str + TargetToStr(target_types[i]); + } + + std::cout << "Supported OPs on '" << targets_str << "': " << std::endl; + target_types.push_back(TARGET(kHost)); + target_types.push_back(TARGET(kUnk)); + + std::set valid_ops; + for (int i = 0; i < target_types.size(); i++) { + auto ops = supported_ops_target[static_cast(target_types[i])]; + valid_ops.insert(ops.begin(), ops.end()); + } + PrintOpsInfo(valid_ops); + exit(1); + } +} +// test whether this model is supported +void CheckIfModelSupported() { + // 1. parse valid places and valid targets + auto valid_places = paddle::lite_api::ParserValidPlaces(); + // set valid_ops + auto valid_ops = supported_ops_target[static_cast(TARGET(kHost))]; + auto valid_unktype_ops = supported_ops_target[static_cast(TARGET(kUnk))]; + valid_ops.insert( + valid_ops.end(), valid_unktype_ops.begin(), valid_unktype_ops.end()); + for (int i = 0; i < valid_places.size(); i++) { + auto target = valid_places[i].target; + auto ops = supported_ops_target[static_cast(target)]; + valid_ops.insert(valid_ops.end(), ops.begin(), ops.end()); + } + // get valid ops + std::set valid_ops_set(valid_ops.begin(), valid_ops.end()); + + // 2.Load model into program to get ops in model + std::string prog_path = FLAGS_model_dir + "/__model__"; + if (!FLAGS_model_file.empty() && !FLAGS_param_file.empty()) { + prog_path = FLAGS_model_file; + } + lite::cpp::ProgramDesc cpp_prog; + framework::proto::ProgramDesc pb_proto_prog = + *lite::LoadProgram(prog_path, false); + lite::pb::ProgramDesc pb_prog(&pb_proto_prog); + // Transform to cpp::ProgramDesc + lite::TransformProgramDescAnyToCpp(pb_prog, &cpp_prog); + + std::set unsupported_ops; + std::set input_model_ops; + for (int index = 0; index < cpp_prog.BlocksSize(); index++) { + auto current_block = cpp_prog.GetBlock(index); + for (size_t i = 0; i < current_block->OpsSize(); ++i) { + auto& op_desc = *current_block->GetOp(i); + auto op_type = op_desc.Type(); + input_model_ops.insert(op_type); + if (valid_ops_set.count(op_type) == 0) { + unsupported_ops.insert(op_type); + } + } + } + // 3. Print ops_info of input model and check if this model is supported + if (FLAGS_print_model_ops) { + std::cout << "OPs in the input model include:\n"; + PrintOpsInfo(input_model_ops); + } + if (!unsupported_ops.empty()) { + std::string unsupported_ops_str = *unsupported_ops.begin(); + for (auto op_str = ++unsupported_ops.begin(); + op_str != unsupported_ops.end(); + op_str++) { + unsupported_ops_str = unsupported_ops_str + ", " + *op_str; + } + std::vector targets = {}; + for (int i = 0; i < valid_places.size(); i++) { + targets.push_back(valid_places[i].target); + } + std::sort(targets.begin(), targets.end()); + targets.erase(unique(targets.begin(), targets.end()), targets.end()); + std::string targets_str = TargetToStr(targets[0]); + for (int i = 1; i < targets.size(); i++) { + targets_str = targets_str + "," + TargetToStr(targets[i]); + } + + LOG(ERROR) << "Error: This model is not supported, because " + << unsupported_ops.size() << " ops are not supported on '" + << targets_str << "'. These unsupported ops are: '" + << unsupported_ops_str << "'."; + exit(1); + } + if (FLAGS_print_model_ops) { + std::cout << "Paddle-Lite supports this model!" << std::endl; + exit(1); + } +} + +void Main() { + if (FLAGS_display_kernels) { + DisplayKernels(); + exit(0); + } + + auto valid_places = ParserValidPlaces(); + if (FLAGS_model_set_dir == "") { + RunOptimize(FLAGS_model_dir, + FLAGS_model_file, + FLAGS_param_file, + FLAGS_optimize_out, + FLAGS_optimize_out_type, + valid_places, + FLAGS_record_tailoring_info); + return; + } + + if (!FLAGS_record_tailoring_info) { + LOG(WARNING) << "--model_set_dir option only be used with " + "--record_tailoring_info=true together"; + return; + } + + auto model_dirs = lite::ListDir(FLAGS_model_set_dir, true); + if (model_dirs.size() == 0) { + LOG(FATAL) << "[" << FLAGS_model_set_dir << "] does not contain any model"; + } + // Optimize models in FLAGS_model_set_dir + for (const auto& name : model_dirs) { + std::string input_model_dir = + lite::Join({FLAGS_model_set_dir, name}, "/"); + std::string output_model_dir = + lite::Join({FLAGS_optimize_out, name}, "/"); + + std::string model_file = ""; + std::string param_file = ""; + + if (FLAGS_model_filename != "" && FLAGS_param_filename != "") { + model_file = + lite::Join({input_model_dir, FLAGS_model_filename}, "/"); + param_file = + lite::Join({input_model_dir, FLAGS_param_filename}, "/"); + } + + LOG(INFO) << "Start optimize model: " << input_model_dir; + RunOptimize(input_model_dir, + model_file, + param_file, + output_model_dir, + FLAGS_optimize_out_type, + valid_places, + FLAGS_record_tailoring_info); + LOG(INFO) << "Optimize done. "; + } + + // Collect all models information + CollectModelMetaInfo( + FLAGS_optimize_out, model_dirs, lite::TAILORD_OPS_SOURCE_LIST_FILENAME); + CollectModelMetaInfo( + FLAGS_optimize_out, model_dirs, lite::TAILORD_OPS_LIST_NAME); + CollectModelMetaInfo(FLAGS_optimize_out, + model_dirs, + lite::TAILORD_KERNELS_SOURCE_LIST_FILENAME); + CollectModelMetaInfo( + FLAGS_optimize_out, model_dirs, lite::TAILORD_KERNELS_LIST_NAME); +} + +} // namespace lite_api +} // namespace paddle + +int main(int argc, char** argv) { + // If there is none input argument, print help info. + if (argc < 2) { + paddle::lite_api::PrintHelpInfo(); + } + google::ParseCommandLineFlags(&argc, &argv, false); + paddle::lite_api::ParseInputCommand(); + paddle::lite_api::CheckIfModelSupported(); + paddle::lite_api::Main(); + return 0; +} diff --git a/lite/api/paddle_api.cc b/lite/api/paddle_api.cc index f148096bb69a3a249521bcb847d5beae3f8297f9..9f071cf7780e27defdd1fcd6be02844618165fb6 100644 --- a/lite/api/paddle_api.cc +++ b/lite/api/paddle_api.cc @@ -93,7 +93,7 @@ void Tensor::CopyFromCpu(const T *src_data) { } } template -void Tensor::CopyToCpu(T *data) { +void Tensor::CopyToCpu(T *data) const { const T *src_data = tensor(raw_tensor_)->data(); int64_t num = tensor(raw_tensor_)->numel(); CHECK(num > 0) << "You should call Resize interface first"; @@ -121,12 +121,13 @@ template void Tensor::CopyFromCpu(const int *); template void Tensor::CopyFromCpu(const float *); template void Tensor::CopyFromCpu(const int8_t *); template void Tensor::CopyFromCpu(const int *); +template void Tensor::CopyFromCpu(const int64_t *); template void Tensor::CopyFromCpu(const float *); template void Tensor::CopyFromCpu(const int8_t *); -template void Tensor::CopyToCpu(int8_t *); -template void Tensor::CopyToCpu(float *); -template void Tensor::CopyToCpu(int *); +template void Tensor::CopyToCpu(int8_t *) const; +template void Tensor::CopyToCpu(float *) const; +template void Tensor::CopyToCpu(int *) const; shape_t Tensor::shape() const { return ctensor(raw_tensor_)->dims().Vectorize(); @@ -189,5 +190,27 @@ void ConfigBase::set_threads(int threads) { #endif } +// set model data in combined format, `set_model_from_file` refers to loading +// model from file, set_model_from_buffer refers to loading model from memory +// buffer +void MobileConfig::set_model_from_file(const std::string &x) { + lite_model_file_ = x; +} +void MobileConfig::set_model_from_buffer(const std::string &x) { + lite_model_file_ = x; + model_from_memory_ = true; +} +void MobileConfig::set_model_buffer(const char *model_buffer, + size_t model_buffer_size, + const char *param_buffer, + size_t param_buffer_size) { + LOG(WARNING) << "warning: `set_model_buffer` will be abandened in " + "release/v3.0.0, new method `set_model_from_buffer(const " + "std::string &x)` is recommended."; + model_buffer_ = std::string(model_buffer, model_buffer + model_buffer_size); + param_buffer_ = std::string(param_buffer, param_buffer + param_buffer_size); + model_from_memory_ = true; +} + } // namespace lite_api } // namespace paddle diff --git a/lite/api/paddle_api.h b/lite/api/paddle_api.h index 42b455da811fe1a21277d38f2e1237000276b1ff..307eeb74e8b4cdc3b2d6188eb18490e4dcf89b8f 100644 --- a/lite/api/paddle_api.h +++ b/lite/api/paddle_api.h @@ -49,7 +49,7 @@ struct LITE_API Tensor { void CopyFromCpu(const T* data); template - void CopyToCpu(T* data); + void CopyToCpu(T* data) const; /// Shape of the tensor. shape_t shape() const; TargetType target() const; @@ -133,6 +133,9 @@ class LITE_API CxxConfig : public ConfigBase { std::string model_file_; std::string param_file_; bool model_from_memory_{false}; +#ifdef LITE_WITH_X86 + int x86_math_library_math_threads_ = 1; +#endif public: void set_valid_places(const std::vector& x) { valid_places_ = x; } @@ -151,27 +154,54 @@ class LITE_API CxxConfig : public ConfigBase { std::string model_file() const { return model_file_; } std::string param_file() const { return param_file_; } bool model_from_memory() const { return model_from_memory_; } + +#ifdef LITE_WITH_X86 + void set_x86_math_library_num_threads(int threads) { + x86_math_library_math_threads_ = threads; + } + int x86_math_library_num_threads() const { + return x86_math_library_math_threads_; + } +#endif }; /// MobileConfig is the config for the light weight predictor, it will skip /// IR optimization or other unnecessary stages. class LITE_API MobileConfig : public ConfigBase { + // whether to load data from memory. Model data will be loaded from memory + // buffer if model_from_memory_ is true. + bool model_from_memory_{false}; + + // model data readed from file or memory buffer in combined format. + std::string lite_model_file_; + + // NOTE: This is a deprecated variable and will be removed in latter release. std::string model_buffer_; std::string param_buffer_; - bool model_from_memory_{false}; public: + // set model data in combined format, `set_model_from_file` refers to loading + // model from file, set_model_from_buffer refers to loading model from memory + // buffer + void set_model_from_file(const std::string& x); + void set_model_from_buffer(const std::string& x); + // return model data in lite_model_file_, which is in combined format. + const std::string& lite_model_file() const { return lite_model_file_; } + + // return model_from_memory_, which indicates whether to load model from + // memory buffer. + bool model_from_memory() const { return model_from_memory_; } + + // NOTE: This is a deprecated API and will be removed in latter release. void set_model_buffer(const char* model_buffer, size_t model_buffer_size, const char* param_buffer, - size_t param_buffer_size) { - model_buffer_ = std::string(model_buffer, model_buffer + model_buffer_size); - param_buffer_ = std::string(param_buffer, param_buffer + param_buffer_size); - model_from_memory_ = true; - } + size_t param_buffer_size); - bool model_from_memory() const { return model_from_memory_; } + // NOTE: This is a deprecated API and will be removed in latter release. const std::string& model_buffer() const { return model_buffer_; } + + // NOTE: This is a deprecated API and will be removed in latter release. const std::string& param_buffer() const { return param_buffer_; } }; diff --git a/lite/api/paddle_api_test.cc b/lite/api/paddle_api_test.cc index 69d544c3decac9f312bc9eb03cdc6c3702c5032b..9213a24e5c0614550a098c4de8d97b6cf6695177 100644 --- a/lite/api/paddle_api_test.cc +++ b/lite/api/paddle_api_test.cc @@ -72,7 +72,7 @@ TEST(CxxApi, run) { #ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK TEST(LightApi, run) { lite_api::MobileConfig config; - config.set_model_dir(FLAGS_model_dir + ".opt2.naive"); + config.set_model_from_file(FLAGS_model_dir + ".opt2.naive.nb"); auto predictor = lite_api::CreatePaddlePredictor(config); @@ -109,16 +109,11 @@ TEST(LightApi, run) { // Demo2 for Loading model from memory TEST(MobileConfig, LoadfromMemory) { // Get naive buffer - auto model_path = std::string(FLAGS_model_dir) + ".opt2.naive/__model__.nb"; - auto params_path = std::string(FLAGS_model_dir) + ".opt2.naive/param.nb"; - std::string model_buffer = lite::ReadFile(model_path); - size_t size_model = model_buffer.length(); - std::string params_buffer = lite::ReadFile(params_path); - size_t size_params = params_buffer.length(); + auto model_file = std::string(FLAGS_model_dir) + ".opt2.naive.nb"; + std::string model_buffer = lite::ReadFile(model_file); // set model buffer and run model lite_api::MobileConfig config; - config.set_model_buffer( - model_buffer.c_str(), size_model, params_buffer.c_str(), size_params); + config.set_model_from_buffer(model_buffer); auto predictor = lite_api::CreatePaddlePredictor(config); auto input_tensor = predictor->GetInput(0); diff --git a/lite/api/paddle_place.cc b/lite/api/paddle_place.cc index 894d839185ea9e1b6b47b87c398f249f044c2b51..2cced919e601f8ecb79ce262a2b083d5b6862da9 100644 --- a/lite/api/paddle_place.cc +++ b/lite/api/paddle_place.cc @@ -55,7 +55,8 @@ const std::string& TargetToStr(TargetType target) { "any", "fpga", "npu", - "xpu"}; + "xpu", + "bm"}; auto x = static_cast(target); CHECK_LT(x, static_cast(TARGET(NUM))); return target2string[x]; @@ -77,7 +78,8 @@ const std::string& PrecisionToStr(PrecisionType precision) { } const std::string& DataLayoutToStr(DataLayoutType layout) { - static const std::string datalayout2string[] = {"unk", "NCHW", "any", "NHWC"}; + static const std::string datalayout2string[] = { + "unk", "NCHW", "any", "NHWC", "ImageDefault", "ImageFolder", "ImageNW"}; auto x = static_cast(layout); CHECK_LT(x, static_cast(DATALAYOUT(NUM))); return datalayout2string[x]; @@ -93,7 +95,8 @@ const std::string& TargetRepr(TargetType target) { "kAny", "kFPGA", "kNPU", - "kXPU"}; + "kXPU", + "kBM"}; auto x = static_cast(target); CHECK_LT(x, static_cast(TARGET(NUM))); return target2string[x]; @@ -115,8 +118,13 @@ const std::string& PrecisionRepr(PrecisionType precision) { } const std::string& DataLayoutRepr(DataLayoutType layout) { - static const std::string datalayout2string[] = { - "kUnk", "kNCHW", "kAny", "kNHWC"}; + static const std::string datalayout2string[] = {"kUnk", + "kNCHW", + "kAny", + "kNHWC", + "kImageDefault", + "kImageFolder", + "kImageNW"}; auto x = static_cast(layout); CHECK_LT(x, static_cast(DATALAYOUT(NUM))); return datalayout2string[x]; @@ -129,6 +137,7 @@ std::set ExpandValidTargets(TargetType target) { TARGET(kOpenCL), TARGET(kNPU), TARGET(kXPU), + TARGET(kBM), TARGET(kFPGA)}); if (target == TARGET(kAny)) { return valid_set; @@ -146,8 +155,12 @@ std::set ExpandValidPrecisions(PrecisionType precision) { } std::set ExpandValidLayouts(DataLayoutType layout) { - static const std::set valid_set( - {DATALAYOUT(kNCHW), DATALAYOUT(kAny), DATALAYOUT(kNHWC)}); + static const std::set valid_set({DATALAYOUT(kNCHW), + DATALAYOUT(kAny), + DATALAYOUT(kNHWC), + DATALAYOUT(kImageDefault), + DATALAYOUT(kImageFolder), + DATALAYOUT(kImageNW)}); if (layout == DATALAYOUT(kAny)) { return valid_set; } diff --git a/lite/api/paddle_place.h b/lite/api/paddle_place.h index 07284be095c05e5dfa069b0973d5982cf1f07c8a..7da52adc7fb6fdd70de3b098508e4622496bed7d 100644 --- a/lite/api/paddle_place.h +++ b/lite/api/paddle_place.h @@ -52,8 +52,9 @@ enum class TargetType : int { kFPGA = 7, kNPU = 8, kXPU = 9, + kBM = 10, kAny = 6, // any target - NUM = 10, // number of fields. + NUM = 11, // number of fields. }; enum class PrecisionType : int { kUnk = 0, @@ -71,8 +72,11 @@ enum class DataLayoutType : int { kUnk = 0, kNCHW = 1, kNHWC = 3, - kAny = 2, // any data layout - NUM = 4, // number of fields. + kImageDefault = 4, // for opencl image2d + kImageFolder = 5, // for opencl image2d + kImageNW = 6, // for opencl image2d + kAny = 2, // any data layout + NUM = 7, // number of fields. }; typedef enum { @@ -112,6 +116,34 @@ static size_t PrecisionTypeLength(PrecisionType type) { } } +template +struct PrecisionTypeTrait { + constexpr static PrecisionType Type() { return PrecisionType::kUnk; } +}; + +#define _ForEachPrecisionTypeHelper(callback, cpp_type, precision_type) \ + callback(cpp_type, ::paddle::lite_api::PrecisionType::precision_type); + +#define _ForEachPrecisionType(callback) \ + _ForEachPrecisionTypeHelper(callback, bool, kBool); \ + _ForEachPrecisionTypeHelper(callback, float, kFloat); \ + _ForEachPrecisionTypeHelper(callback, int8_t, kInt8); \ + _ForEachPrecisionTypeHelper(callback, int16_t, kInt16); \ + _ForEachPrecisionTypeHelper(callback, int, kInt32); \ + _ForEachPrecisionTypeHelper(callback, int64_t, kInt64); + +#define DefinePrecisionTypeTrait(cpp_type, precision_type) \ + template <> \ + struct PrecisionTypeTrait { \ + constexpr static PrecisionType Type() { return precision_type; } \ + } + +_ForEachPrecisionType(DefinePrecisionTypeTrait); + +#undef _ForEachPrecisionTypeHelper +#undef _ForEachPrecisionType +#undef DefinePrecisionTypeTrait + #define TARGET(item__) paddle::lite_api::TargetType::item__ #define PRECISION(item__) paddle::lite_api::PrecisionType::item__ #define DATALAYOUT(item__) paddle::lite_api::DataLayoutType::item__ diff --git a/lite/api/paddle_use_passes.h b/lite/api/paddle_use_passes.h index 70355fdf890eb63cd5bedd5bab42a2dd69af0927..943760d30742b74a0fe9150e4c2d8c8bb5dbc52a 100644 --- a/lite/api/paddle_use_passes.h +++ b/lite/api/paddle_use_passes.h @@ -20,7 +20,6 @@ USE_MIR_PASS(static_kernel_pick_pass); USE_MIR_PASS(variable_place_inference_pass); USE_MIR_PASS(type_target_cast_pass); USE_MIR_PASS(generate_program_pass); -USE_MIR_PASS(subgraph_program_pass); USE_MIR_PASS(io_copy_kernel_pick_pass); USE_MIR_PASS(argument_type_display_pass); @@ -32,11 +31,17 @@ USE_MIR_PASS(lite_fc_fuse_pass); USE_MIR_PASS(lite_shuffle_channel_fuse_pass); USE_MIR_PASS(lite_transpose_softmax_transpose_fuse_pass); USE_MIR_PASS(lite_interpolate_fuse_pass); +USE_MIR_PASS(lite_sequence_pool_concat_fuse_pass); USE_MIR_PASS(identity_scale_eliminate_pass); USE_MIR_PASS(lite_conv_elementwise_fuse_pass); USE_MIR_PASS(lite_conv_activation_fuse_pass); +USE_MIR_PASS(lite_var_conv_2d_activation_fuse_pass); USE_MIR_PASS(lite_elementwise_add_activation_fuse_pass); USE_MIR_PASS(lite_quant_dequant_fuse_pass); USE_MIR_PASS(type_precision_cast_pass); USE_MIR_PASS(type_layout_cast_pass); USE_MIR_PASS(memory_optimize_pass); +USE_MIR_PASS(elementwise_mul_constant_eliminate_pass) +USE_MIR_PASS(npu_subgraph_pass); +USE_MIR_PASS(xpu_subgraph_pass); +USE_MIR_PASS(weight_quantization_preprocess_pass); diff --git a/lite/api/python/pybind/CMakeLists.txt b/lite/api/python/pybind/CMakeLists.txt index 178f167e6a1627d01df13b2e105e0af36b20601a..eabb6b150b93a722282118c3932676cd1aee5da8 100644 --- a/lite/api/python/pybind/CMakeLists.txt +++ b/lite/api/python/pybind/CMakeLists.txt @@ -4,3 +4,6 @@ if (NOT LITE_ON_TINY_PUBLISH) endif() lite_cc_library(lite_pybind SHARED SRCS pybind.cc DEPS ${PYBIND_DEPS}) +if (LITE_ON_TINY_PUBLISH) + set_target_properties(lite_pybind PROPERTIES COMPILE_FLAGS "-flto -fdata-sections") +endif() diff --git a/lite/api/python/pybind/pybind.cc b/lite/api/python/pybind/pybind.cc index 2df2e8f8f8aa56bb71b0e1cb293df2ecbbafd0bb..2dfe0c49490ecd13e8a3ce480807bdf3875348b7 100644 --- a/lite/api/python/pybind/pybind.cc +++ b/lite/api/python/pybind/pybind.cc @@ -116,6 +116,8 @@ void BindLiteMobileConfig(py::module *m) { py::class_ mobile_config(*m, "MobileConfig"); mobile_config.def(py::init<>()) + .def("set_model_from_file", &MobileConfig::set_model_from_file) + .def("set_model_from_buffer", &MobileConfig::set_model_from_buffer) .def("set_model_dir", &MobileConfig::set_model_dir) .def("model_dir", &MobileConfig::model_dir) .def("set_model_buffer", &MobileConfig::set_model_buffer) @@ -165,6 +167,9 @@ void BindLitePlace(py::module *m) { py::enum_(*m, "DataLayoutType") .value("NCHW", DataLayoutType::kNCHW) .value("NHWC", DataLayoutType::kNHWC) + .value("ImageDefault", DataLayoutType::kImageDefault) + .value("ImageFolder", DataLayoutType::kImageFolder) + .value("ImageNW", DataLayoutType::kImageNW) .value("Any", DataLayoutType::kAny); // Place diff --git a/lite/api/resnet50_test_fpga.cc b/lite/api/resnet50_test_fpga.cc index ab647f96998f1c0e73476369611218d0a7930c57..75e6f0cbbc43c3cd7eb9bfa89bc004554ea6f85b 100644 --- a/lite/api/resnet50_test_fpga.cc +++ b/lite/api/resnet50_test_fpga.cc @@ -31,11 +31,7 @@ TEST(ResNet50, test) { std::vector valid_places( {Place{TARGET(kFPGA), PRECISION(kFP16), DATALAYOUT(kNHWC)}}); - predictor.Build(FLAGS_model_dir, - "", - "", - Place{TARGET(kFPGA), PRECISION(kFP16), DATALAYOUT(kNHWC)}, - valid_places); + predictor.Build(FLAGS_model_dir, "", "", valid_places); auto* input_tensor = predictor.GetInput(0); input_tensor->Resize(DDim(std::vector({1, 3, 224, 224}))); diff --git a/lite/api/test_resnet50_lite_bm.cc b/lite/api/test_resnet50_lite_bm.cc new file mode 100644 index 0000000000000000000000000000000000000000..62a58704f4245b8618540ea7109447dd99d0bfea --- /dev/null +++ b/lite/api/test_resnet50_lite_bm.cc @@ -0,0 +1,92 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include "lite/api/cxx_api.h" +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/api/paddle_use_passes.h" +#include "lite/api/test_helper.h" +#include "lite/core/op_registry.h" + +DEFINE_string(input_img_txt_path, + "", + "if set input_img_txt_path, read the img filename as input."); + +namespace paddle { +namespace lite { + +void TestModel(const std::vector& valid_places) { + lite::Predictor predictor; + std::vector passes; + passes.push_back("bm_subgraph_pass"); + predictor.Build(FLAGS_model_dir, "", "", valid_places, passes); + + auto* input_tensor = predictor.GetInput(0); + input_tensor->Resize(DDim(std::vector({1, 3, 224, 224}))); + auto* data = input_tensor->mutable_data(); + auto item_size = input_tensor->dims().production(); + if (FLAGS_input_img_txt_path.empty()) { + for (int i = 0; i < item_size; i++) { + data[i] = 1; + } + } else { + std::fstream fs(FLAGS_input_img_txt_path, std::ios::in); + if (!fs.is_open()) { + LOG(FATAL) << "open input_img_txt error."; + } + for (int i = 0; i < item_size; i++) { + fs >> data[i]; + } + } + for (int i = 0; i < FLAGS_warmup; ++i) { + predictor.Run(); + } + + auto start = GetCurrentUS(); + for (int i = 0; i < FLAGS_repeats; ++i) { + predictor.Run(); + } + + LOG(INFO) << "================== Speed Report ==================="; + LOG(INFO) << "Model: " << FLAGS_model_dir << ", threads num " << FLAGS_threads + << ", warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats + << ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0 + << " ms in average."; + + auto* out = predictor.GetOutput(0); + ASSERT_EQ(out->dims().size(), 2); + ASSERT_EQ(out->dims()[0], 1); + ASSERT_EQ(out->dims()[1], 1000); + + auto* out_data = out->data(); + FILE* fp = fopen("result.txt", "wb"); + for (int i = 0; i < out->numel(); i++) { + fprintf(fp, "%f\n", out_data[i]); + } + fclose(fp); +} + +TEST(ResNet50, test_bm) { + std::vector valid_places({Place{TARGET(kBM), PRECISION(kFloat)}, + Place{TARGET(kX86), PRECISION(kFloat)}}); + + TestModel(valid_places); +} + +} // namespace lite +} // namespace paddle diff --git a/lite/api/test_step_rnn_lite_x86.cc b/lite/api/test_step_rnn_lite_x86.cc index c483373dc745f6520d51ece3936448ada71990d3..013fd82b19bc22ace22184389249a7b2d9bf237e 100644 --- a/lite/api/test_step_rnn_lite_x86.cc +++ b/lite/api/test_step_rnn_lite_x86.cc @@ -12,20 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - #include #include #include @@ -44,6 +30,9 @@ TEST(Step_rnn, test_step_rnn_lite_x86) { std::string model_dir = FLAGS_model_dir; lite_api::CxxConfig config; config.set_model_dir(model_dir); +#ifdef LITE_WITH_X86 + config.set_x86_math_library_num_threads(1); +#endif config.set_valid_places({lite_api::Place{TARGET(kX86), PRECISION(kInt64)}, lite_api::Place{TARGET(kX86), PRECISION(kFloat)}, lite_api::Place{TARGET(kHost), PRECISION(kFloat)}}); @@ -62,7 +51,7 @@ TEST(Step_rnn, test_step_rnn_lite_x86) { "micro_video_id", "vertical_type_id"}; - for (int i = 0; i < target_names.size(); ++i) { + for (size_t i = 0; i < target_names.size(); ++i) { auto input_tensor = predictor->GetInput(i); int size = 0; if (i == 6 || i == 8) { @@ -87,8 +76,7 @@ TEST(Step_rnn, test_step_rnn_lite_x86) { predictor->Run(); } - // LOG(INFO) << "================== Speed Report ==================="; - LOG(INFO) << ", warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats + LOG(INFO) << "warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats << ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0 << " ms in average."; @@ -99,8 +87,8 @@ TEST(Step_rnn, test_step_rnn_lite_x86) { std::vector out_shape = out->shape(); - for (int i = 0; i < results.size(); ++i) { - for (int j = 0; j < results[i].size(); ++j) { + for (size_t i = 0; i < results.size(); ++i) { + for (size_t j = 0; j < results[i].size(); ++j) { EXPECT_NEAR( out->data()[j + (out_shape[1] * i)], results[i][j], 1e-6); } diff --git a/lite/backends/CMakeLists.txt b/lite/backends/CMakeLists.txt index dec63e6efa0e4c4548646ebdd6f6de24f046d6d0..e3517464812a24c9911e824c53841efc05dd2bc5 100644 --- a/lite/backends/CMakeLists.txt +++ b/lite/backends/CMakeLists.txt @@ -6,3 +6,4 @@ add_subdirectory(fpga) add_subdirectory(host) add_subdirectory(npu) add_subdirectory(xpu) +add_subdirectory(bm) diff --git a/lite/backends/arm/math/CMakeLists.txt b/lite/backends/arm/math/CMakeLists.txt index cbbcf49a5fd55dabd6b072bc6b3b2e3f9bb91a13..6f6f7e7aa71ba5067d831a2bcc2b7b933205fbe0 100644 --- a/lite/backends/arm/math/CMakeLists.txt +++ b/lite/backends/arm/math/CMakeLists.txt @@ -57,9 +57,10 @@ endif() if (NOT HAS_ARM_MATH_LIB_DIR) # TODO(xxx): seperate them and do not deps proto, eigen3 - cc_library(math_arm SRCS - funcs.cc + cc_library(math_arm SRCS + funcs.cc packed_sgemm.cc + packed_sgemm_c4.cc sgemm.cc gemm_prepacked_int8.cc gemm_s8.cc @@ -67,25 +68,26 @@ if (NOT HAS_ARM_MATH_LIB_DIR) gemv_arm_int8.cc conv3x3s1_direct_fp32.cc conv3x3s2_direct_fp32.cc - conv3x3s1_depthwise_fp32.cc - conv3x3s2_depthwise_fp32.cc + conv3x3s1p01_depthwise_fp32.cc + conv3x3s2p01_depthwise_fp32.cc + conv3x3s1px_depthwise_fp32.cc + conv3x3s2px_depthwise_fp32.cc conv3x3s1_direct_int8.cc conv3x3s2_direct_int8.cc conv3x3s1_depthwise_int8.cc conv3x3s2_depthwise_int8.cc conv5x5s1_depthwise_int8.cc conv5x5s1_depthwise_fp32.cc + conv5x5s2_depthwise_int8.cc conv5x5s2_depthwise_fp32.cc - conv_depthwise_3x3p0.cc - conv_depthwise_3x3p1.cc - conv_depthwise_3x3s1.cc - conv_depthwise_3x3s2.cc + conv3x3_winograd_fp32_c4.cc conv_winograd_3x3.cc conv_impl.cc - softmax.cc + softmax.cc scale.cc pooling.cc elementwise.cc + layout.cc lrn.cc decode_bboxes.cc concat.cc @@ -119,6 +121,7 @@ if (NOT HAS_ARM_MATH_LIB_DIR) stack.cc affine_channel.cc anchor_generator.cc + split_merge_lod_tenosr.cc + reduce_prod.cc DEPS ${lite_kernel_deps} context tensor) endif() - diff --git a/lite/backends/arm/math/col_im_transform.cc b/lite/backends/arm/math/col_im_transform.cc index b5d2c6af13cc1dd864eaac6cb6589cc879f029fe..38be1d689dd47ab59baf417e40989a91bb6366e0 100644 --- a/lite/backends/arm/math/col_im_transform.cc +++ b/lite/backends/arm/math/col_im_transform.cc @@ -32,8 +32,10 @@ void col2im(const float* data_col, const int width, const int kernel_h, const int kernel_w, - const int pad_h, - const int pad_w, + const int pad_h0, + const int pad_h1, + const int pad_w0, + const int pad_w1, const int stride_h, const int stride_w, const int dilation_h, @@ -41,19 +43,22 @@ void col2im(const float* data_col, float* data_im) { memset(data_im, 0, height * width * channels * sizeof(float)); const int output_h = - (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + (height + pad_h0 + pad_h1 - (dilation_h * (kernel_h - 1) + 1)) / + stride_h + + 1; const int output_w = - (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + (width + pad_w0 + pad_w1 - (dilation_w * (kernel_w - 1) + 1)) / stride_w + + 1; const int channel_size = height * width; for (int channel = channels; channel--; data_im += channel_size) { for (int kernel_row = 0; kernel_row < kernel_h; kernel_row++) { for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) { - int input_row = -pad_h + kernel_row * dilation_h; + int input_row = -pad_h0 + kernel_row * dilation_h; for (int output_rows = output_h; output_rows; output_rows--) { if (!is_a_ge_zero_and_a_lt_b(input_row, height)) { data_col += output_w; } else { - int input_col = -pad_w + kernel_col * dilation_w; + int input_col = -pad_w0 + kernel_col * dilation_w; for (int output_col = output_w; output_col; output_col--) { if (is_a_ge_zero_and_a_lt_b(input_col, width)) { data_im[input_row * width + input_col] += *data_col; diff --git a/lite/backends/arm/math/col_im_transform.h b/lite/backends/arm/math/col_im_transform.h index 8560679d7f4091c4cb424b54e54a42cf6e7e8905..e3e32c4715ade10972f77e0c4d5a2cd4d16b4725 100644 --- a/lite/backends/arm/math/col_im_transform.h +++ b/lite/backends/arm/math/col_im_transform.h @@ -26,8 +26,10 @@ void col2im(const Dtype* data_col, const int width, const int kernel_h, const int kernel_w, - const int pad_h, - const int pad_w, + const int pad_h0, + const int pad_h1, + const int pad_w0, + const int pad_w1, const int stride_h, const int stride_w, const int dilation_h, diff --git a/lite/backends/arm/math/concat.cc b/lite/backends/arm/math/concat.cc index 9b94cefa16bca0dd487ad0e4f6b88e604b694416..65f93453388d7f41d73669f583d189bec9035bb5 100644 --- a/lite/backends/arm/math/concat.cc +++ b/lite/backends/arm/math/concat.cc @@ -26,31 +26,32 @@ namespace math { void concat_func(const std::vector &input, const int axis, lite::Tensor *output) { - size_t num = input.size(); - int rows = 1; + int64_t concat_input_size = 1; + int64_t num_cancats = 1; auto dim_0 = input[0]->dims(); - for (int i = 0; i < axis; ++i) { - rows *= dim_0[i]; + size_t num = input.size(); + for (int i = axis + 1; i < dim_0.size(); i++) { + concat_input_size *= dim_0[i]; } - int out_rows = rows, out_cols = 0; - - std::vector input_cols(input.size()); - for (int i = 0; i < num; ++i) { - int t_cols = input[i]->numel() / rows; - out_cols += t_cols; - input_cols[i] = t_cols; + for (int i = 0; i < axis; i++) { + num_cancats *= dim_0[i]; } - - // computation - for (int k = 0; k < out_rows; ++k) { - float *dst_ptr = output->mutable_data() + k * out_cols; - int col_idx = 0; - for (int j = 0; j < num; ++j) { - int col_len = input_cols[j]; - const float *src_prt = input[j]->data() + k * col_len; - std::memcpy(dst_ptr + col_idx, src_prt, sizeof(float) * col_len); - col_idx += col_len; + float *dst_ptr = output->mutable_data(); + const int out_concat_axis = output->dims()[axis]; + int64_t offset_concat_axis = 0; + int64_t out_sum = out_concat_axis * concat_input_size; + for (int n = 0; n < num; n++) { + auto dims = input[n]->dims(); + const float *src_ptr = input[n]->data(); + int64_t in_concat_axis = dims[axis]; + float *dout_ptr = dst_ptr + offset_concat_axis * concat_input_size; + int64_t in_sum = in_concat_axis * concat_input_size; + for (int i = 0; i < num_cancats; i++) { + std::memcpy(dout_ptr, src_ptr, sizeof(float) * in_sum); + dout_ptr += out_sum; + src_ptr += in_sum; } + offset_concat_axis += in_concat_axis; } } diff --git a/lite/backends/arm/math/conv3x3_winograd_fp32_c4.cc b/lite/backends/arm/math/conv3x3_winograd_fp32_c4.cc new file mode 100644 index 0000000000000000000000000000000000000000..d1992f62bbfa9e15ab4d39565f7fe3555e17b215 --- /dev/null +++ b/lite/backends/arm/math/conv3x3_winograd_fp32_c4.cc @@ -0,0 +1,1310 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/backends/arm/math/conv_block_utils.h" +#include "lite/backends/arm/math/conv_impl.h" +#include "lite/backends/arm/math/packed_sgemm_c4.h" +#ifdef ARM_WITH_OMP +#include +#endif +#include + +namespace paddle { +namespace lite { +namespace arm { +namespace math { +void input_trans_c4_8x8(const float* src, + int src_stride, + float* dest, + int dest_stride); +void output_trans_c4_6x8(const float* src, + int src_stride, + float* dest, + int dest_stride); +void output_trans_c4_post_6x8(const float* src, + int src_stride, + float* dest, + int dest_stride, + float* bias_value, + bool has_relu); +void input_trans_c4_4x4(const float* src, + int src_stride, + int src_h_stride, + float* dest, + int dest_stride, + int dest_h_stride); +void output_trans_c4_post_2x4(const float* src, + int src_stride, + int src_h_stride, + float* dest, + int dest_stride, + int dest_h_stride, + float* bias_value, + bool has_relu); +void weight_trans_c4_8x8( + float* dest, const float* src, int ic, int oc, void* workspace); +void weight_trans_c4_4x4( + float* dest, const float* src, int ic, int oc, void* workspace); + +/* +*The following function conv_compute_6x6_3x3 and conv_compute_2x2_3x3[_small] is +*base on +*MNN[https://github.com/alibaba/MNN] +* +*Copyright © 2018, Alibaba Group Holding Limited +*/ + +// F(6,3) +void conv_compute_6x6_3x3(const float* input, + float* output, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const float* weight, + const float* bias, + const operators::ConvParam& param, + ARMContext* ctx) { + auto act_param = param.activation_param; + const int pad_h = (*param.paddings)[0]; + const int pad_w = (*param.paddings)[2]; + float* tmp_work_space = + ctx->workspace_data() + ctx->llc_size() / sizeof(float); + + int in_n_stride = chin * hin * win; + int out_n_stride = chout * hout * wout; + int ic_stride = win * hin; + int oc_stride = wout * hout; + int ic_4 = (chin + 3) / 4; + int oc_4 = (chout + 3) / 4; + + int tile_w = (wout + 5) / 6; + int tile_h = (hout + 5) / 6; + int size_tile = tile_h * tile_w; + + int w_pad = win + pad_w * 2; + int h_pad = hin + pad_h * 2; + + const int zero_len = w_pad; + float zero_ptr[zero_len]; // NOLINT + memset(zero_ptr, 0, zero_len * sizeof(float)); + + float* input_c4 = tmp_work_space; + int new_h_stride = w_pad * 4; + int new_c_stride = new_h_stride * h_pad; + + int ic_4_stride = w_pad * h_pad * 4; + int oc_4_stride = wout * hout * 4; + + int tile_block = 8; + int block_count = (size_tile + tile_block - 1) / tile_block; + + int threads = ctx->threads(); + float* g_tmp_data = tmp_work_space + ic_4 * new_c_stride; + int tmp_data_thread_stride = tile_block * (oc_4 + ic_4) * 256; + memset(g_tmp_data, 0, threads * tmp_data_thread_stride * sizeof(float)); + float* g_trans_tmp_data = g_tmp_data + threads * tmp_data_thread_stride; + float* g_trans_remain_tmp_data = g_trans_tmp_data + threads * 256; + + // begin compute + for (int ni = 0; ni < num; ++ni) { + // trans input to c4 + for (int i = 0; i < ic_4; ++i) { + prepack_input_nxwc4_dw(input + ni * in_n_stride, + input_c4 + i * new_c_stride, + i * 4, + -pad_h, + hin + pad_h, + -pad_w, + win + pad_w, + chin, + win, + hin, + zero_ptr); + } + float* output_ptr = output + ni * out_n_stride; + + const float* weight_ptr = weight; + const float* bias_ptr = bias; +#pragma omp parallel for num_threads(threads) + for (int tbi = 0; tbi < block_count; ++tbi) { +#ifdef ARM_WITH_OMP + float* tmp_data = + g_tmp_data + omp_get_thread_num() * tmp_data_thread_stride; + float* trans_tmp_data = g_trans_tmp_data + omp_get_thread_num() * 256; + float* trans_remain_tmp_data = + g_trans_remain_tmp_data + omp_get_thread_num() * 256; +#else + float* tmp_data = g_tmp_data; + float* trans_tmp_data = g_trans_tmp_data; + float* trans_remain_tmp_data = g_trans_remain_tmp_data; +#endif + int tile_index = tbi * tile_block; + int tile_remain = size_tile - tile_index; + int tile_count = tile_remain > tile_block ? tile_block : tile_remain; + + // input trans + int c_gi_stride = tile_count * oc_4 * 4; + int b_gi_stride = tile_count * ic_4 * 4; + //* + for (int ti = 0; ti < tile_count; ++ti) { + int index = tile_index + ti; + + int tw_index = index % tile_w; + int th_index = index / tile_w; + + int src_x = tw_index * 6; + int src_y = th_index * 6; + int ex = src_x + 8 > w_pad ? w_pad - src_x : 8; + int ey = src_y + 8 > h_pad ? h_pad - src_y : 8; + + float* dst_ptr = tmp_data + ti * 4; + const float* src_ptr = input_c4 + (src_y * w_pad + src_x) * 4; + + if (ex == 8 && ey == 8) { + // trans input + for (int ci = 0; ci < ic_4; ++ci) { + const float* src_ci = src_ptr + ci * ic_4_stride; + for (int i = 0; i < 8; ++i) { + const float* ci_ptr = src_ci + i * w_pad * 4; + input_trans_c4_8x8(ci_ptr, 4, trans_tmp_data + i * 4, 32); + } + float* dst_ci = dst_ptr + ci * tile_count * 4; + for (int i = 0; i < 8; ++i) { + input_trans_c4_8x8(trans_tmp_data + i * 32, + 4, + dst_ci + i * b_gi_stride * 8, + b_gi_stride); + } + } + } else { + // trans remain input + int x_size = ex; + for (int ci = 0; ci < ic_4; ++ci) { + const float* src_ci = src_ptr + ci * ic_4_stride; + // pad + memset(trans_remain_tmp_data, 0, 256 * sizeof(float)); + if (x_size > 0) { + for (int yi = 0; yi < ey; ++yi) { + float* dst_yi = trans_remain_tmp_data + yi * 32; + const float* src_yi = src_ci + w_pad * yi * 4; + memcpy(dst_yi, src_yi, x_size * sizeof(float) * 4); + } + } + + // trans + for (int i = 0; i < 8; ++i) { + float* ci_ptr = trans_remain_tmp_data + i * 32; + input_trans_c4_8x8(ci_ptr, 4, trans_tmp_data + i * 4, 32); + } + float* dst_ci = dst_ptr + ci * tile_count * 4; + for (int i = 0; i < 8; ++i) { + input_trans_c4_8x8(trans_tmp_data + i * 32, + 4, + dst_ci + i * b_gi_stride * 8, + b_gi_stride); + } + } // for ci_4 + } + } + //*/ + // input trans end + // *begin compute dot + // * + //* + float* dst_temp_data = tmp_data + tile_block * ic_4 * 256; + float* b_ptr = tmp_data; + int w_gi_stride = ic_4 * oc_4 * 16; + for (int gi = 0; gi < 64; ++gi) { + float* origin_C = dst_temp_data + gi * c_gi_stride; + float* origin_B = b_ptr + gi * b_gi_stride; + const float* origin_A = weight + gi * w_gi_stride; + sgemm_prepack_c4_small( + oc_4 * 4, tile_count, ic_4 * 4, origin_A, origin_B, origin_C, ctx); + } + //*/ + //* + // output trans + float bias_value[4]; + memset(bias_value, 0, 4 * sizeof(float)); + + for (int ti = 0; ti < tile_count; ++ti) { + int index = tile_index + ti; + + int tw_index = index % tile_w; + int th_index = index / tile_w; + + int dst_x = tw_index * 6; + int dst_y = th_index * 6; + + int ex = dst_x + 6 > wout ? wout - dst_x : 6; + int ey = dst_y + 6 > hout ? hout - dst_y : 6; + + float* dst_ptr = output + (dst_y * wout + dst_x) * 4; + float* src_ptr = dst_temp_data + ti * 4; + + if (ex == 6) { + // trans output + for (int ci = 0; ci < oc_4; ++ci) { + if (param.bias) { + bias_value[0] = bias[ci * 4]; + bias_value[1] = bias[ci * 4 + 1]; + bias_value[2] = bias[ci * 4 + 2]; + bias_value[3] = bias[ci * 4 + 3]; + } + + float* dst_ci = dst_ptr + ci * oc_4_stride; + float* src_ci = src_ptr + ci * tile_count * 4; + for (int i = 0; i < 8; ++i) { + output_trans_c4_6x8(src_ci + i * c_gi_stride * 8, + c_gi_stride, + trans_tmp_data + i * 4, + 32); + } + for (int i = 0; i < ey; ++i) { + output_trans_c4_post_6x8(trans_tmp_data + i * 32, + 4, + trans_remain_tmp_data + i * 24, + 4, + bias_value, + param.fuse_relu); + } + write_to_output_c4_fp32(trans_remain_tmp_data, + output_ptr, + ci * 4, + ci * 4 + 4, + dst_y, + dst_y + ey, + dst_x, + dst_x + ex, + chout, + hout, + wout, + false, + zero_ptr, + &act_param); + } + } else { + for (int ci = 0; ci < oc_4; ++ci) { + if (param.bias) { + bias_value[0] = bias[ci * 4]; + bias_value[1] = bias[ci * 4 + 1]; + bias_value[2] = bias[ci * 4 + 2]; + bias_value[3] = bias[ci * 4 + 3]; + } + // trans output + float* dst_ci = dst_ptr + ci * oc_4_stride; + float* src_ci = src_ptr + ci * tile_count * 4; + for (int i = 0; i < 8; ++i) { + output_trans_c4_6x8(src_ci + i * c_gi_stride * 8, + c_gi_stride, + trans_tmp_data + i * 4, + 32); + } + for (int i = 0; i < ey; ++i) { + output_trans_c4_post_6x8(trans_tmp_data + i * 32, + 4, + trans_remain_tmp_data + i * 24, + 4, + bias_value, + param.fuse_relu); + } + // copy to dest + memset(trans_tmp_data, 0, 144 * sizeof(float)); + for (int i = 0; i < ey; ++i) { + memcpy(trans_tmp_data + i * ex * 4, + trans_remain_tmp_data + i * 24, + ex * sizeof(float) * 4); + } + write_to_output_c4_fp32(trans_tmp_data, + output_ptr, + ci * 4, + ci * 4 + 4, + dst_y, + dst_y + ey, + dst_x, + dst_x + ex, + chout, + hout, + wout, + false, + zero_ptr, + &act_param); + } + } + } + //*/ + } // for block_count + } // for num +} // conv_compute + +// F(2,3) +void conv_compute_2x2_3x3(const float* input, + float* output, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const float* weight, + const float* bias, + const operators::ConvParam& param, + ARMContext* ctx) { + auto act_param = param.activation_param; + const int pad_h = (*param.paddings)[0]; + const int pad_w = (*param.paddings)[2]; + float* tmp_work_space = + ctx->workspace_data() + ctx->llc_size() / sizeof(float); + + int in_n_stride = chin * hin * win; + int out_n_stride = chout * hout * wout; + int ic_stride = win * hin; + int oc_stride = wout * hout; + int ic_4 = (chin + 3) / 4; + int oc_4 = (chout + 3) / 4; + + int tile_w = (wout + 1) / 2; + int tile_h = (hout + 1) / 2; + int size_tile = tile_h * tile_w; + + int w_pad = win + pad_w * 2; + int h_pad = hin + pad_h * 2; + + const int zero_len = w_pad; + float zero_ptr[zero_len]; // NOLINT + memset(zero_ptr, 0, zero_len * sizeof(float)); + + float* input_c4 = tmp_work_space; + int new_h_stride = w_pad * 4; + int new_c_stride = new_h_stride * h_pad; + + int ic_4_stride = w_pad * h_pad * 4; + int oc_4_stride = wout * hout * 4; + + int tile_block = 8; + int block_count = (size_tile + tile_block - 1) / tile_block; + + int threads = ctx->threads(); + float* g_tmp_data = tmp_work_space + ic_4 * new_c_stride; + int tmp_data_thread_stride = tile_block * (oc_4 + ic_4) * 64; + memset(g_tmp_data, 0, threads * tmp_data_thread_stride * sizeof(float)); + float* g_trans_tmp_data = g_tmp_data + threads * tmp_data_thread_stride; + float* g_trans_remain_tmp_data = g_trans_tmp_data + threads * 64; + + // begin compute + for (int ni = 0; ni < num; ++ni) { + // trans input to c4 + for (int i = 0; i < ic_4; ++i) { + prepack_input_nxwc4_dw(input + ni * in_n_stride, + input_c4 + i * new_c_stride, + i * 4, + -pad_h, + hin + pad_h, + -pad_w, + win + pad_w, + chin, + win, + hin, + zero_ptr); + } + float* output_ptr = output + ni * out_n_stride; + + const float* weight_ptr = weight; + const float* bias_ptr = bias; +#pragma omp parallel for num_threads(threads) + for (int tbi = 0; tbi < block_count; ++tbi) { +#ifdef ARM_WITH_OMP + float* tmp_data = + g_tmp_data + omp_get_thread_num() * tmp_data_thread_stride; + float* trans_tmp_data = g_trans_tmp_data + omp_get_thread_num() * 64; + float* trans_remain_tmp_data = + g_trans_remain_tmp_data + omp_get_thread_num() * 64; +#else + float* tmp_data = g_tmp_data; + float* trans_tmp_data = g_trans_tmp_data; + float* trans_remain_tmp_data = g_trans_remain_tmp_data; +#endif + int tile_index = tbi * tile_block; + int tile_remain = size_tile - tile_index; + int tile_count = tile_remain > tile_block ? tile_block : tile_remain; + + // input trans + int c_gi_stride = tile_count * oc_4 * 4; + int b_gi_stride = tile_count * ic_4 * 4; + //* + for (int ti = 0; ti < tile_count; ++ti) { + int index = tile_index + ti; + + int tw_index = index % tile_w; + int th_index = index / tile_w; + + int src_x = tw_index + tw_index; + int src_y = th_index + th_index; + int ex = src_x + 4 > w_pad ? w_pad - src_x : 4; + int ey = src_y + 4 > h_pad ? h_pad - src_y : 4; + + float* dst_ptr = tmp_data + ti * 4; + const float* src_ptr = input_c4 + (src_y * w_pad + src_x) * 4; + + if (ex == 4 && ey == 4) { + // trans input + for (int ci = 0; ci < ic_4; ++ci) { + const float* src_ci = src_ptr + ci * ic_4_stride; + float* dst_ci = dst_ptr + ci * tile_count * 4; + input_trans_c4_4x4( + src_ci, 4, w_pad * 4, dst_ci, b_gi_stride, b_gi_stride * 4); + } + } else { + // trans remain input + int x_size = ex; + for (int ci = 0; ci < ic_4; ++ci) { + const float* src_ci = src_ptr + ci * ic_4_stride; + // pad + memset(trans_remain_tmp_data, 0, 64 * sizeof(float)); + if (x_size > 0) { + for (int yi = 0; yi < ey; ++yi) { + float* dst_yi = trans_remain_tmp_data + yi * 16; + const float* src_yi = src_ci + w_pad * yi * 4; + memcpy(dst_yi, src_yi, x_size * sizeof(float) * 4); + } + } + + // trans + float* dst_ci = dst_ptr + ci * tile_count * 4; + input_trans_c4_4x4(trans_remain_tmp_data, + 4, + 16, + dst_ci, + b_gi_stride, + b_gi_stride * 4); + } // for ci_4 + } + } + //*/ + // input trans end + // *begin compute dot + // * + //* + float* dst_temp_data = tmp_data + tile_block * ic_4 * 64; + float* b_ptr = tmp_data; + int w_gi_stride = ic_4 * oc_4 * 16; + for (int gi = 0; gi < 16; ++gi) { + float* origin_C = dst_temp_data + gi * c_gi_stride; + float* origin_B = b_ptr + gi * b_gi_stride; + const float* origin_A = weight + gi * w_gi_stride; + sgemm_prepack_c4_small( + oc_4 * 4, tile_count, ic_4 * 4, origin_A, origin_B, origin_C, ctx); + } + //*/ + //* + // output trans + float bias_value[4]; + memset(bias_value, 0, 4 * sizeof(float)); + + for (int ti = 0; ti < tile_count; ++ti) { + int index = tile_index + ti; + + int tw_index = index % tile_w; + int th_index = index / tile_w; + + int dst_x = tw_index * 2; + int dst_y = th_index * 2; + + int ex = dst_x + 2 > wout ? wout - dst_x : 2; + int ey = dst_y + 2 > hout ? hout - dst_y : 2; + + float* dst_ptr = output + (dst_y * wout + dst_x) * 4; + float* src_ptr = dst_temp_data + ti * 4; + + if (ex == 2) { + // trans output + for (int ci = 0; ci < oc_4; ++ci) { + if (param.bias) { + bias_value[0] = bias[ci * 4]; + bias_value[1] = bias[ci * 4 + 1]; + bias_value[2] = bias[ci * 4 + 2]; + bias_value[3] = bias[ci * 4 + 3]; + } + + float* dst_ci = dst_ptr + ci * oc_4_stride; + float* src_ci = src_ptr + ci * tile_count * 4; + output_trans_c4_post_2x4(src_ci, + c_gi_stride, + c_gi_stride * 4, + trans_remain_tmp_data, + 4, + 8, + bias_value, + param.fuse_relu); + write_to_output_c4_fp32(trans_remain_tmp_data, + output_ptr, + ci * 4, + ci * 4 + 4, + dst_y, + dst_y + ey, + dst_x, + dst_x + ex, + chout, + hout, + wout, + false, + zero_ptr, + &act_param); + } + } else { + for (int ci = 0; ci < oc_4; ++ci) { + if (param.bias) { + bias_value[0] = bias[ci * 4]; + bias_value[1] = bias[ci * 4 + 1]; + bias_value[2] = bias[ci * 4 + 2]; + bias_value[3] = bias[ci * 4 + 3]; + } + // trans output + float* dst_ci = dst_ptr + ci * oc_4_stride; + float* src_ci = src_ptr + ci * tile_count * 4; + output_trans_c4_post_2x4(src_ci, + c_gi_stride, + c_gi_stride * 4, + trans_remain_tmp_data, + 4, + 8, + bias_value, + param.fuse_relu); + // copy to dest + memset(trans_tmp_data, 0, 16 * sizeof(float)); + for (int i = 0; i < ey; ++i) { + memcpy(trans_tmp_data + i * ex * 4, + trans_remain_tmp_data + i * 8, + ex * sizeof(float) * 4); + } + write_to_output_c4_fp32(trans_tmp_data, + output_ptr, + ci * 4, + ci * 4 + 4, + dst_y, + dst_y + ey, + dst_x, + dst_x + ex, + chout, + hout, + wout, + false, + zero_ptr, + &act_param); + } + } + } + //*/ + } // for block_count + } // for num +} // conv_compute +void conv_compute_2x2_3x3_small(const float* input, + float* output, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const float* weight, + const float* bias, + const operators::ConvParam& param, + ARMContext* ctx) { + auto act_param = param.activation_param; + const int pad_h = (*param.paddings)[0]; + const int pad_w = (*param.paddings)[2]; + float* tmp_work_space = + ctx->workspace_data() + ctx->llc_size() / sizeof(float); + + int in_n_stride = chin * hin * win; + int out_n_stride = chout * hout * wout; + int ic_stride = win * hin; + int oc_stride = wout * hout; + int ic_4 = (chin + 3) / 4; + int oc_4 = (chout + 3) / 4; + + int tile_w = (wout + 1) / 2; + int tile_h = (hout + 1) / 2; + int size_tile = tile_h * tile_w; + + int w_pad = win + pad_w * 2; + int h_pad = hin + pad_h * 2; + + const int zero_len = w_pad; + float zero_ptr[zero_len]; // NOLINT + memset(zero_ptr, 0, zero_len * sizeof(float)); + + float* input_c4 = tmp_work_space; + int new_h_stride = w_pad * 4; + int new_c_stride = new_h_stride * h_pad; + + int ic_4_stride = w_pad * h_pad * 4; + int oc_4_stride = wout * hout * 4; + + int tile_block = 8; + int block_count = (size_tile + tile_block - 1) / tile_block; + + int threads = ctx->threads(); + float* g_tmp_data = tmp_work_space + ic_4 * new_c_stride; + int tmp_data_thread_stride = tile_block * (oc_4 + ic_4) * 64; + memset(g_tmp_data, 0, tmp_data_thread_stride * sizeof(float)); + float* g_trans_tmp_data = g_tmp_data + tmp_data_thread_stride; + float* g_trans_remain_tmp_data = g_trans_tmp_data + 64; + + // begin compute + for (int ni = 0; ni < num; ++ni) { + // trans input to c4 + + for (int i = 0; i < ic_4; ++i) { + prepack_input_nxwc4_dw(input + ni * in_n_stride, + input_c4 + i * new_c_stride, + i * 4, + -pad_h, + hin + pad_h, + -pad_w, + win + pad_w, + chin, + win, + hin, + zero_ptr); + } + float* output_ptr = output + ni * out_n_stride; + + const float* weight_ptr = weight; + const float* bias_ptr = bias; + for (int tbi = 0; tbi < block_count; ++tbi) { + float* tmp_data = g_tmp_data; + float* trans_tmp_data = g_trans_tmp_data; + float* trans_remain_tmp_data = g_trans_remain_tmp_data; + int tile_index = tbi * tile_block; + int tile_remain = size_tile - tile_index; + int tile_count = tile_remain > tile_block ? tile_block : tile_remain; + + // input trans + int c_gi_stride = tile_count * oc_4 * 4; + int b_gi_stride = tile_count * ic_4 * 4; + //* + for (int ti = 0; ti < tile_count; ++ti) { + int index = tile_index + ti; + + int tw_index = index % tile_w; + int th_index = index / tile_w; + + int src_x = tw_index + tw_index; + int src_y = th_index + th_index; + int ex = src_x + 4 > w_pad ? w_pad - src_x : 4; + int ey = src_y + 4 > h_pad ? h_pad - src_y : 4; + + float* dst_ptr = tmp_data + ti * 4; + const float* src_ptr = input_c4 + (src_y * w_pad + src_x) * 4; + + if (ex == 4 && ey == 4) { + // trans input + for (int ci = 0; ci < ic_4; ++ci) { + const float* src_ci = src_ptr + ci * ic_4_stride; + float* dst_ci = dst_ptr + ci * tile_count * 4; + input_trans_c4_4x4( + src_ci, 4, w_pad * 4, dst_ci, b_gi_stride, b_gi_stride * 4); + } + } else { + // trans remain input + int x_size = ex; + for (int ci = 0; ci < ic_4; ++ci) { + const float* src_ci = src_ptr + ci * ic_4_stride; + // pad + memset(trans_remain_tmp_data, 0, 64 * sizeof(float)); + if (x_size > 0) { + for (int yi = 0; yi < ey; ++yi) { + float* dst_yi = trans_remain_tmp_data + yi * 16; + const float* src_yi = src_ci + w_pad * yi * 4; + memcpy(dst_yi, src_yi, x_size * sizeof(float) * 4); + } + } + + float* dst_ci = dst_ptr + ci * tile_count * 4; + input_trans_c4_4x4(trans_remain_tmp_data, + 4, + 16, + dst_ci, + b_gi_stride, + b_gi_stride * 4); + } // for ci_4 + } + } + //*/ + // input trans end + // *begin compute dot + // * + //* + float* dst_temp_data = tmp_data + tile_block * ic_4 * 64; + float* b_ptr = tmp_data; + int w_gi_stride = ic_4 * oc_4 * 16; +#pragma omp parallel for num_threads(threads) + for (int gi = 0; gi < 16; ++gi) { + float* origin_C = dst_temp_data + gi * c_gi_stride; + float* origin_B = b_ptr + gi * b_gi_stride; + const float* origin_A = weight + gi * w_gi_stride; + sgemm_prepack_c4_small( + oc_4 * 4, tile_count, ic_4 * 4, origin_A, origin_B, origin_C, ctx); + } + //*/ + //* + // output trans + float bias_value[4]; + memset(bias_value, 0, 4 * sizeof(float)); + + for (int ti = 0; ti < tile_count; ++ti) { + int index = tile_index + ti; + + int tw_index = index % tile_w; + int th_index = index / tile_w; + + int dst_x = tw_index * 2; + int dst_y = th_index * 2; + + int ex = dst_x + 2 > wout ? wout - dst_x : 2; + int ey = dst_y + 2 > hout ? hout - dst_y : 2; + + float* dst_ptr = output + (dst_y * wout + dst_x) * 4; + float* src_ptr = dst_temp_data + ti * 4; + + if (ex == 2) { + // trans output + for (int ci = 0; ci < oc_4; ++ci) { + if (param.bias) { + bias_value[0] = bias[ci * 4]; + bias_value[1] = bias[ci * 4 + 1]; + bias_value[2] = bias[ci * 4 + 2]; + bias_value[3] = bias[ci * 4 + 3]; + } + + float* dst_ci = dst_ptr + ci * oc_4_stride; + float* src_ci = src_ptr + ci * tile_count * 4; + + output_trans_c4_post_2x4(src_ci, + c_gi_stride, + c_gi_stride * 4, + trans_remain_tmp_data, + 4, + 8, + bias_value, + param.fuse_relu); + write_to_output_c4_fp32(trans_remain_tmp_data, + output_ptr, + ci * 4, + ci * 4 + 4, + dst_y, + dst_y + ey, + dst_x, + dst_x + ex, + chout, + hout, + wout, + false, + zero_ptr, + &act_param); + } + } else { + for (int ci = 0; ci < oc_4; ++ci) { + if (param.bias) { + bias_value[0] = bias[ci * 4]; + bias_value[1] = bias[ci * 4 + 1]; + bias_value[2] = bias[ci * 4 + 2]; + bias_value[3] = bias[ci * 4 + 3]; + } + // trans output + float* dst_ci = dst_ptr + ci * oc_4_stride; + float* src_ci = src_ptr + ci * tile_count * 4; + output_trans_c4_post_2x4(src_ci, + c_gi_stride, + c_gi_stride * 4, + trans_remain_tmp_data, + 4, + 8, + bias_value, + param.fuse_relu); + // copy to dest + memset(trans_tmp_data, 0, 16 * sizeof(float)); + for (int i = 0; i < ey; ++i) { + memcpy(trans_tmp_data + i * ex * 4, + trans_remain_tmp_data + i * 8, + ex * sizeof(float) * 4); + } + write_to_output_c4_fp32(trans_tmp_data, + output_ptr, + ci * 4, + ci * 4 + 4, + dst_y, + dst_y + ey, + dst_x, + dst_x + ex, + chout, + hout, + wout, + false, + zero_ptr, + &act_param); + } + } + } + //*/ + } // for block_count + } // for num +} // conv_compute +void output_trans_c4_6x8(const float* src, + int src_stride, + float* dest, + int dest_stride) { + const float32x4_t src0 = vld1q_f32(src); + const float32x4_t src1 = vld1q_f32(src + src_stride); + const float32x4_t src2 = vld1q_f32(src + src_stride * 2); + const float32x4_t src3 = vld1q_f32(src + src_stride * 3); + const float32x4_t src4 = vld1q_f32(src + src_stride * 4); + const float32x4_t src5 = vld1q_f32(src + src_stride * 5); + const float32x4_t src6 = vld1q_f32(src + src_stride * 6); + const float32x4_t src7 = vld1q_f32(src + src_stride * 7); + + float32x4_t tmp024a = vaddq_f32(src1, src2); + float32x4_t tmp135a = vsubq_f32(src1, src2); + float32x4_t tmp024b = vaddq_f32(src3, src4); + float32x4_t tmp135b = vsubq_f32(src3, src4); + float32x4_t tmp024c = vaddq_f32(src5, src6); + float32x4_t tmp135c = vsubq_f32(src5, src6); + + float32x4_t dest0 = + vaddq_f32(vaddq_f32(vaddq_f32(src0, tmp024a), tmp024b), tmp024c); + float32x4_t dest2 = vaddq_f32(vaddq_f32(tmp024a, vmulq_n_f32(tmp024b, 4)), + vmulq_n_f32(tmp024c, 0.25f)); + float32x4_t dest4 = vaddq_f32(vaddq_f32(tmp024a, vmulq_n_f32(tmp024b, 16)), + vmulq_n_f32(tmp024c, 0.0625f)); + + float32x4_t dest1 = vaddq_f32(vaddq_f32(tmp135a, vmulq_n_f32(tmp135b, 2)), + vmulq_n_f32(tmp135c, 0.5f)); + float32x4_t dest3 = vaddq_f32(vaddq_f32(tmp135a, vmulq_n_f32(tmp135b, 8)), + vmulq_n_f32(tmp135c, 0.125f)); + float32x4_t dest5 = + vaddq_f32(src7, + vaddq_f32(vaddq_f32(tmp135a, vmulq_n_f32(tmp135b, 32)), + vmulq_n_f32(tmp135c, 0.03125f))); + + vst1q_f32(dest, dest0); + vst1q_f32(dest + dest_stride, dest1); + vst1q_f32(dest + dest_stride * 2, dest2); + vst1q_f32(dest + dest_stride * 3, dest3); + vst1q_f32(dest + dest_stride * 4, dest4); + vst1q_f32(dest + dest_stride * 5, dest5); +} + +void output_trans_c4_post_6x8(const float* src, + int src_stride, + float* dest, + int dest_stride, + float* bias_value, + bool has_relu = false) { + const float32x4_t src0 = vld1q_f32(src); + const float32x4_t src1 = vld1q_f32(src + src_stride); + const float32x4_t src2 = vld1q_f32(src + src_stride * 2); + const float32x4_t src3 = vld1q_f32(src + src_stride * 3); + const float32x4_t src4 = vld1q_f32(src + src_stride * 4); + const float32x4_t src5 = vld1q_f32(src + src_stride * 5); + const float32x4_t src6 = vld1q_f32(src + src_stride * 6); + const float32x4_t src7 = vld1q_f32(src + src_stride * 7); + + float32x4_t tmp024a = vaddq_f32(src1, src2); + float32x4_t tmp135a = vsubq_f32(src1, src2); + float32x4_t tmp024b = vaddq_f32(src3, src4); + float32x4_t tmp135b = vsubq_f32(src3, src4); + float32x4_t tmp024c = vaddq_f32(src5, src6); + float32x4_t tmp135c = vsubq_f32(src5, src6); + + float32x4_t dest0 = + vaddq_f32(vaddq_f32(vaddq_f32(src0, tmp024a), tmp024b), tmp024c); + float32x4_t dest2 = vaddq_f32(vaddq_f32(tmp024a, vmulq_n_f32(tmp024b, 4)), + vmulq_n_f32(tmp024c, 0.25f)); + float32x4_t dest4 = vaddq_f32(vaddq_f32(tmp024a, vmulq_n_f32(tmp024b, 16)), + vmulq_n_f32(tmp024c, 0.0625f)); + + float32x4_t dest1 = vaddq_f32(vaddq_f32(tmp135a, vmulq_n_f32(tmp135b, 2)), + vmulq_n_f32(tmp135c, 0.5f)); + float32x4_t dest3 = vaddq_f32(vaddq_f32(tmp135a, vmulq_n_f32(tmp135b, 8)), + vmulq_n_f32(tmp135c, 0.125f)); + float32x4_t dest5 = + vaddq_f32(src7, + vaddq_f32(vaddq_f32(tmp135a, vmulq_n_f32(tmp135b, 32)), + vmulq_n_f32(tmp135c, 0.03125f))); + + if (bias_value) { + float32x4_t bias = vld1q_f32(bias_value); + dest0 = vaddq_f32(dest0, bias); + dest1 = vaddq_f32(dest1, bias); + dest2 = vaddq_f32(dest2, bias); + dest3 = vaddq_f32(dest3, bias); + dest4 = vaddq_f32(dest4, bias); + dest5 = vaddq_f32(dest5, bias); + } + + if (has_relu) { + float32x4_t zeros = vdupq_n_f32(0); + dest0 = vmaxq_f32(dest0, zeros); + dest1 = vmaxq_f32(dest1, zeros); + dest2 = vmaxq_f32(dest2, zeros); + dest3 = vmaxq_f32(dest3, zeros); + dest4 = vmaxq_f32(dest4, zeros); + dest5 = vmaxq_f32(dest5, zeros); + } + + vst1q_f32(dest, dest0); + vst1q_f32(dest + dest_stride, dest1); + vst1q_f32(dest + dest_stride * 2, dest2); + vst1q_f32(dest + dest_stride * 3, dest3); + vst1q_f32(dest + dest_stride * 4, dest4); + vst1q_f32(dest + dest_stride * 5, dest5); +} + +void input_trans_c4_8x8(const float* src, + int src_stride, + float* dest, + int dest_stride) { + float32x4_t src0 = vld1q_f32(src); + float32x4_t src1 = vld1q_f32(src + src_stride); + float32x4_t src2 = vld1q_f32(src + src_stride * 2); + float32x4_t src3 = vld1q_f32(src + src_stride * 3); + float32x4_t src4 = vld1q_f32(src + src_stride * 4); + float32x4_t src5 = vld1q_f32(src + src_stride * 5); + float32x4_t src6 = vld1q_f32(src + src_stride * 6); + float32x4_t src7 = vld1q_f32(src + src_stride * 7); + + float32x4_t dst0 = vaddq_f32(vsubq_f32(src0, src6), + vmulq_n_f32(vsubq_f32(src4, src2), 5.25)); + float32x4_t dst7 = vaddq_f32(vsubq_f32(src7, src1), + vmulq_n_f32(vsubq_f32(src3, src5), 5.25)); + + float32x4_t tmp12a = + vsubq_f32(vaddq_f32(src2, src6), vmulq_n_f32(src4, 4.25)); + float32x4_t tmp12b = + vsubq_f32(vaddq_f32(src1, src5), vmulq_n_f32(src3, 4.25)); + float32x4_t dst1 = vaddq_f32(tmp12a, tmp12b); + float32x4_t dst2 = vsubq_f32(tmp12a, tmp12b); + + float32x4_t tmp34a = vsubq_f32(vaddq_f32(src6, vmulq_n_f32(src2, 0.25)), + vmulq_n_f32(src4, 1.25)); + float32x4_t tmp34b = + vaddq_f32(vsubq_f32(vmulq_n_f32(src1, 0.5), vmulq_n_f32(src3, 2.5)), + vmulq_n_f32(src5, 2)); + float32x4_t dst3 = vaddq_f32(tmp34a, tmp34b); + float32x4_t dst4 = vsubq_f32(tmp34a, tmp34b); + + float32x4_t tmp56a = + vaddq_f32(src6, vmulq_n_f32(vsubq_f32(src2, vmulq_n_f32(src4, 1.25)), 4)); + float32x4_t tmp56b = + vaddq_f32(vsubq_f32(vmulq_n_f32(src1, 2), vmulq_n_f32(src3, 2.5)), + vmulq_n_f32(src5, 0.5)); + float32x4_t dst5 = vaddq_f32(tmp56a, tmp56b); + float32x4_t dst6 = vsubq_f32(tmp56a, tmp56b); + + vst1q_f32(dest, dst0); + vst1q_f32(dest + dest_stride, dst1); + vst1q_f32(dest + dest_stride * 2, dst2); + vst1q_f32(dest + dest_stride * 3, dst3); + vst1q_f32(dest + dest_stride * 4, dst4); + vst1q_f32(dest + dest_stride * 5, dst5); + vst1q_f32(dest + dest_stride * 6, dst6); + vst1q_f32(dest + dest_stride * 7, dst7); +} + +// BT=[1, 0, -1, 0, +// 0, 1, 1, 0, +// 0, -1, 1, 0, +// 0, 1, 0, -1] +void input_trans_c4_4x4(const float* src, + int src_stride, + int src_h_stride, + float* dest, + int dest_stride, + int dest_h_stride) { + float32x4_t src00 = vld1q_f32(src); + float32x4_t src01 = vld1q_f32(src + src_stride); + float32x4_t src02 = vld1q_f32(src + src_stride + src_stride); + float32x4_t src03 = vld1q_f32(src + src_stride + src_stride + src_stride); + src += src_h_stride; + float32x4_t src10 = vld1q_f32(src); + float32x4_t src11 = vld1q_f32(src + src_stride); + float32x4_t src12 = vld1q_f32(src + src_stride + src_stride); + float32x4_t src13 = vld1q_f32(src + src_stride + src_stride + src_stride); + src += src_h_stride; + float32x4_t src20 = vld1q_f32(src); + float32x4_t src21 = vld1q_f32(src + src_stride); + float32x4_t src22 = vld1q_f32(src + src_stride + src_stride); + float32x4_t src23 = vld1q_f32(src + src_stride + src_stride + src_stride); + src += src_h_stride; + float32x4_t src30 = vld1q_f32(src); + float32x4_t src31 = vld1q_f32(src + src_stride); + float32x4_t src32 = vld1q_f32(src + src_stride + src_stride); + float32x4_t src33 = vld1q_f32(src + src_stride + src_stride + src_stride); + + float32x4_t dst00 = vsubq_f32(src00, src02); + float32x4_t dst10 = vaddq_f32(src01, src02); + float32x4_t dst20 = vsubq_f32(src02, src01); + float32x4_t dst30 = vsubq_f32(src01, src03); + + float32x4_t dst01 = vsubq_f32(src10, src12); + float32x4_t dst11 = vaddq_f32(src11, src12); + float32x4_t dst21 = vsubq_f32(src12, src11); + float32x4_t dst31 = vsubq_f32(src11, src13); + + float32x4_t dst02 = vsubq_f32(src20, src22); + float32x4_t dst12 = vaddq_f32(src21, src22); + float32x4_t dst22 = vsubq_f32(src22, src21); + float32x4_t dst32 = vsubq_f32(src21, src23); + + float32x4_t dst03 = vsubq_f32(src30, src32); + float32x4_t dst13 = vaddq_f32(src31, src32); + float32x4_t dst23 = vsubq_f32(src32, src31); + float32x4_t dst33 = vsubq_f32(src31, src33); + + float32x4_t dest00 = vsubq_f32(dst00, dst02); + float32x4_t dest10 = vaddq_f32(dst01, dst02); + float32x4_t dest20 = vsubq_f32(dst02, dst01); + float32x4_t dest30 = vsubq_f32(dst01, dst03); + + float32x4_t dest01 = vsubq_f32(dst10, dst12); + float32x4_t dest11 = vaddq_f32(dst11, dst12); + float32x4_t dest21 = vsubq_f32(dst12, dst11); + float32x4_t dest31 = vsubq_f32(dst11, dst13); + + float32x4_t dest02 = vsubq_f32(dst20, dst22); + float32x4_t dest12 = vaddq_f32(dst21, dst22); + float32x4_t dest22 = vsubq_f32(dst22, dst21); + float32x4_t dest32 = vsubq_f32(dst21, dst23); + + float32x4_t dest03 = vsubq_f32(dst30, dst32); + float32x4_t dest13 = vaddq_f32(dst31, dst32); + float32x4_t dest23 = vsubq_f32(dst32, dst31); + float32x4_t dest33 = vsubq_f32(dst31, dst33); + + vst1q_f32(dest, dest00); + vst1q_f32(dest + dest_stride, dest10); + vst1q_f32(dest + dest_stride + dest_stride, dest20); + vst1q_f32(dest + dest_stride + dest_stride + dest_stride, dest30); + dest += dest_h_stride; + vst1q_f32(dest, dest01); + vst1q_f32(dest + dest_stride, dest11); + vst1q_f32(dest + dest_stride + dest_stride, dest21); + vst1q_f32(dest + dest_stride + dest_stride + dest_stride, dest31); + dest += dest_h_stride; + vst1q_f32(dest, dest02); + vst1q_f32(dest + dest_stride, dest12); + vst1q_f32(dest + dest_stride + dest_stride, dest22); + vst1q_f32(dest + dest_stride + dest_stride + dest_stride, dest32); + dest += dest_h_stride; + vst1q_f32(dest, dest03); + vst1q_f32(dest + dest_stride, dest13); + vst1q_f32(dest + dest_stride + dest_stride, dest23); + vst1q_f32(dest + dest_stride + dest_stride + dest_stride, dest33); +} + +// AT=[1, 1, 1, 0, +// 0, 1, -1, -1] +void output_trans_c4_post_2x4(const float* src, + int src_stride, + int src_h_stride, + float* dest, + int dest_stride, + int dest_h_stride, + float* bias_value, + bool has_relu) { + float32x4_t src00 = vld1q_f32(src); + float32x4_t src01 = vld1q_f32(src + src_stride); + float32x4_t src02 = vld1q_f32(src + src_stride + src_stride); + float32x4_t src03 = vld1q_f32(src + src_stride + src_stride + src_stride); + src += src_h_stride; + float32x4_t src10 = vld1q_f32(src); + float32x4_t src11 = vld1q_f32(src + src_stride); + float32x4_t src12 = vld1q_f32(src + src_stride + src_stride); + float32x4_t src13 = vld1q_f32(src + src_stride + src_stride + src_stride); + src += src_h_stride; + float32x4_t src20 = vld1q_f32(src); + float32x4_t src21 = vld1q_f32(src + src_stride); + float32x4_t src22 = vld1q_f32(src + src_stride + src_stride); + float32x4_t src23 = vld1q_f32(src + src_stride + src_stride + src_stride); + src += src_h_stride; + float32x4_t src30 = vld1q_f32(src); + float32x4_t src31 = vld1q_f32(src + src_stride); + float32x4_t src32 = vld1q_f32(src + src_stride + src_stride); + float32x4_t src33 = vld1q_f32(src + src_stride + src_stride + src_stride); + + float32x4_t dst00 = vaddq_f32(vaddq_f32(src00, src01), src02); + float32x4_t dst10 = vsubq_f32(vsubq_f32(src01, src02), src03); + float32x4_t dst01 = vaddq_f32(vaddq_f32(src10, src11), src12); + float32x4_t dst11 = vsubq_f32(vsubq_f32(src11, src12), src13); + float32x4_t dst02 = vaddq_f32(vaddq_f32(src20, src21), src22); + float32x4_t dst12 = vsubq_f32(vsubq_f32(src21, src22), src23); + float32x4_t dst03 = vaddq_f32(vaddq_f32(src30, src31), src32); + float32x4_t dst13 = vsubq_f32(vsubq_f32(src31, src32), src33); + + float32x4_t dest00 = vaddq_f32(vaddq_f32(dst00, dst01), dst02); + float32x4_t dest10 = vsubq_f32(vsubq_f32(dst01, dst02), dst03); + float32x4_t dest01 = vaddq_f32(vaddq_f32(dst10, dst11), dst12); + float32x4_t dest11 = vsubq_f32(vsubq_f32(dst11, dst12), dst13); + + if (bias_value) { + float32x4_t bias = vld1q_f32(bias_value); + dest00 = vaddq_f32(dest00, bias); + dest10 = vaddq_f32(dest10, bias); + dest01 = vaddq_f32(dest01, bias); + dest11 = vaddq_f32(dest11, bias); + } + + if (has_relu) { + float32x4_t zeros = vdupq_n_f32(0); + dest00 = vmaxq_f32(dest00, zeros); + dest10 = vmaxq_f32(dest10, zeros); + dest01 = vmaxq_f32(dest01, zeros); + dest11 = vmaxq_f32(dest11, zeros); + } + + vst1q_f32(dest, dest00); + vst1q_f32(dest + dest_stride, dest10); + dest += dest_h_stride; + vst1q_f32(dest, dest01); + vst1q_f32(dest + dest_stride, dest11); +} +void weight_trans_c4_8x8( + float* dest, const float* din, int ch_in, int ch_out, void* workspace) { + const float coeff[8][3] = {{1.0f, 0.0f, 0.0f}, + {-2.0f / 9, -2.0f / 9, -2.0f / 9}, + {-2.0f / 9, 2.0f / 9, -2.0f / 9}, + {1.0f / 90, 1.0f / 45, 2.0f / 45}, + {1.0f / 90, -1.0f / 45, 2.0f / 45}, + {32.0f / 45, 16.0f / 45, 8.0f / 45}, + {32.0f / 45, -16.0f / 45, 8.0f / 45}, + {0.0f, 0.0f, 1.0f}}; + + float* ptr_out = static_cast(workspace); + + for (int i = 0; i < ch_out; i++) { + for (int j = 0; j < ch_in; j++) { + const float* kernel0 = + static_cast(din) + (i * ch_in + j) * 9; + float* ptr_channel = ptr_out + (i * ch_in + j) * 64; + + //! transform kernel, transposed + const float* k0 = kernel0; + const float* k1 = kernel0 + 3; + const float* k2 = kernel0 + 6; + + //! h + float tmp[8][3]; + for (int i = 0; i < 8; i++) { + tmp[i][0] = + k0[0] * coeff[i][0] + k0[1] * coeff[i][1] + k0[2] * coeff[i][2]; + tmp[i][1] = + k1[0] * coeff[i][0] + k1[1] * coeff[i][1] + k1[2] * coeff[i][2]; + tmp[i][2] = + k2[0] * coeff[i][0] + k2[1] * coeff[i][1] + k2[2] * coeff[i][2]; + } + + //! v + for (int j = 0; j < 8; j++) { + float* tmpp = &tmp[j][0]; + for (int i = 0; i < 8; i++) { + ptr_channel[j * 8 + i] = tmpp[0] * coeff[i][0] + + tmpp[1] * coeff[i][1] + + tmpp[2] * coeff[i][2]; + } + } + } + } + + int oc_pad = (ch_out + 3) / 4 * 4; + int ic_pad = (ch_in + 3) / 4 * 4; + int c_stride = ic_pad * oc_pad; + for (int i = 0; i < ch_out * ch_in * 64; ++i) { + int new_c = i % 64; + int new_oc = i / ch_in / 64 / 4; + int new_ic = i / 64 % (ch_in * 4) % ch_in; + int new_inner = i / ch_in / 64 % 4; + int dest_ind = + new_c * c_stride + new_oc * ic_pad * 4 + new_ic * 4 + new_inner; + dest[dest_ind] = ptr_out[i]; + } +} + +void weight_trans_c4_4x4( + float* dest, const float* din, int ch_in, int ch_out, void* workspace) { + const float coeff[4][3] = {{1.0f, 0.0f, 0.0f}, + {0.5f, 0.5f, 0.5f}, + {0.5f, -0.5f, 0.5f}, + {0.0f, 0.0f, 1.0f}}; + + float* ptr_out = static_cast(workspace); + + for (int i = 0; i < ch_out; i++) { + for (int j = 0; j < ch_in; j++) { + const float* kernel0 = + static_cast(din) + (i * ch_in + j) * 9; + float* ptr_channel = ptr_out + (i * ch_in + j) * 16; + + //! transform kernel, transposed + const float* k0 = kernel0; + const float* k1 = kernel0 + 3; + const float* k2 = kernel0 + 6; + + //! h + float tmp[4][3]; + for (int i = 0; i < 4; i++) { + tmp[i][0] = + k0[0] * coeff[i][0] + k0[1] * coeff[i][1] + k0[2] * coeff[i][2]; + tmp[i][1] = + k1[0] * coeff[i][0] + k1[1] * coeff[i][1] + k1[2] * coeff[i][2]; + tmp[i][2] = + k2[0] * coeff[i][0] + k2[1] * coeff[i][1] + k2[2] * coeff[i][2]; + } + + //! v + for (int j = 0; j < 4; j++) { + float* tmpp = &tmp[j][0]; + for (int i = 0; i < 4; i++) { + ptr_channel[j * 4 + i] = tmpp[0] * coeff[i][0] + + tmpp[1] * coeff[i][1] + + tmpp[2] * coeff[i][2]; + } + } + } + } + + int oc_pad = (ch_out + 3) / 4 * 4; + int ic_pad = (ch_in + 3) / 4 * 4; + int c_stride = ic_pad * oc_pad; + for (int i = 0; i < ch_out * ch_in * 16; ++i) { + int new_c = i % 16; + int new_oc = i / ch_in / 16 / 4; + int new_ic = i / 16 % (ch_in * 4) % ch_in; + int new_inner = i / ch_in / 16 % 4; + int dest_ind = + new_c * c_stride + new_oc * ic_pad * 4 + new_ic * 4 + new_inner; + dest[dest_ind] = ptr_out[i]; + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/backends/arm/math/conv3x3s1_depthwise_fp32.cc b/lite/backends/arm/math/conv3x3s1_depthwise_fp32.cc deleted file mode 100644 index 99aeea8bdea2a50795dcdca18464a196ee877291..0000000000000000000000000000000000000000 --- a/lite/backends/arm/math/conv3x3s1_depthwise_fp32.cc +++ /dev/null @@ -1,538 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include "lite/backends/arm/math/conv_block_utils.h" -#include "lite/backends/arm/math/conv_impl.h" -#include "lite/core/context.h" -#include "lite/operators/op_params.h" -#ifdef ARM_WITH_OMP -#include -#endif - -namespace paddle { -namespace lite { -namespace arm { -namespace math { -void conv_3x3s1_depthwise_fp32(const float* i_data, - float* o_data, - int bs, - int oc, - int oh, - int ow, - int ic, - int ih, - int win, - const float* weights, - const float* bias, - const operators::ConvParam& param, - ARMContext* ctx) { - int threads = ctx->threads(); - const int pad_h = param.paddings[0]; - const int pad_w = param.paddings[1]; - const int out_c_block = 4; - const int out_h_kernel = 2; - const int out_w_kernel = 4; - const int win_ext = ow + 2; - const int ow_round = ROUNDUP(ow, 4); - const int win_round = ROUNDUP(win_ext, 4); - const int hin_round = oh + 2; - const int prein_size = win_round * hin_round * out_c_block; - auto workspace_size = - threads * prein_size + win_round /*tmp zero*/ + ow_round /*tmp writer*/; - ctx->ExtendWorkspace(sizeof(float) * workspace_size); - - bool flag_relu = param.fuse_relu; - bool flag_bias = param.bias != nullptr; - - /// get workspace - float* ptr_zero = ctx->workspace_data(); - memset(ptr_zero, 0, sizeof(float) * win_round); - float* ptr_write = ptr_zero + win_round; - - int size_in_channel = win * ih; - int size_out_channel = ow * oh; - - int ws = -pad_w; - int we = ws + win_round; - int hs = -pad_h; - int he = hs + hin_round; - int w_loop = ow_round / 4; - auto remain = w_loop * 4 - ow; - bool flag_remain = remain > 0; - remain = 4 - remain; - remain = remain > 0 ? remain : 0; - int row_len = win_round * out_c_block; - - for (int n = 0; n < bs; ++n) { - const float* din_batch = i_data + n * ic * size_in_channel; - float* dout_batch = o_data + n * oc * size_out_channel; -#pragma omp parallel for num_threads(threads) - for (int c = 0; c < oc; c += out_c_block) { -#ifdef ARM_WITH_OMP - float* pre_din = ptr_write + ow_round + omp_get_thread_num() * prein_size; -#else - float* pre_din = ptr_write + ow_round; -#endif - /// const array size - float pre_out[out_c_block * out_w_kernel * out_h_kernel]; // NOLINT - prepack_input_nxwc4_dw( - din_batch, pre_din, c, hs, he, ws, we, ic, win, ih, ptr_zero); - const float* weight_c = weights + c * 9; // kernel_w * kernel_h - float* dout_c00 = dout_batch + c * size_out_channel; - float bias_local[4] = {0, 0, 0, 0}; - if (flag_bias) { - bias_local[0] = bias[c]; - bias_local[1] = bias[c + 1]; - bias_local[2] = bias[c + 2]; - bias_local[3] = bias[c + 3]; - } - float32x4_t vbias = vld1q_f32(bias_local); -#ifdef __aarch64__ - float32x4_t w0 = vld1q_f32(weight_c); // w0, v23 - float32x4_t w1 = vld1q_f32(weight_c + 4); // w1, v24 - float32x4_t w2 = vld1q_f32(weight_c + 8); // w2, v25 - float32x4_t w3 = vld1q_f32(weight_c + 12); // w3, v26 - float32x4_t w4 = vld1q_f32(weight_c + 16); // w4, v27 - float32x4_t w5 = vld1q_f32(weight_c + 20); // w5, v28 - float32x4_t w6 = vld1q_f32(weight_c + 24); // w6, v29 - float32x4_t w7 = vld1q_f32(weight_c + 28); // w7, v30 - float32x4_t w8 = vld1q_f32(weight_c + 32); // w8, v31 -#endif - for (int h = 0; h < oh; h += out_h_kernel) { - float* outc00 = dout_c00 + h * ow; - float* outc01 = outc00 + ow; - float* outc10 = outc00 + size_out_channel; - float* outc11 = outc10 + ow; - float* outc20 = outc10 + size_out_channel; - float* outc21 = outc20 + ow; - float* outc30 = outc20 + size_out_channel; - float* outc31 = outc30 + ow; - const float* inr0 = pre_din + h * row_len; - const float* inr1 = inr0 + row_len; - const float* inr2 = inr1 + row_len; - const float* inr3 = inr2 + row_len; - if (c + out_c_block > oc) { - switch (c + out_c_block - oc) { - case 3: - outc10 = ptr_write; - outc11 = ptr_write; - case 2: - outc20 = ptr_write; - outc21 = ptr_write; - case 1: - outc30 = ptr_write; - outc31 = ptr_write; - default: - break; - } - } - if (h + out_h_kernel > oh) { - outc01 = ptr_write; - outc11 = ptr_write; - outc21 = ptr_write; - outc31 = ptr_write; - } - float* outl[] = {outc00, - outc10, - outc20, - outc30, - outc01, - outc11, - outc21, - outc31, - reinterpret_cast(bias_local), - reinterpret_cast(flag_relu)}; - void* outl_ptr = reinterpret_cast(outl); - for (int w = 0; w < w_loop; ++w) { - bool flag_mask = (w == w_loop - 1) && flag_remain; - float* out0 = pre_out; -// clang-format off -#ifdef __aarch64__ - asm volatile( - "ldp q0, q1, [%[inr0]], #32\n" /* load input r0*/ - "ldp q6, q7, [%[inr1]], #32\n" /* load input r1*/ - "ldp q2, q3, [%[inr0]], #32\n" /* load input r0*/ - "ldp q8, q9, [%[inr1]], #32\n" /* load input r1*/ - "ldp q4, q5, [%[inr0]]\n" /* load input r0*/ - "ldp q10, q11, [%[inr1]]\n" /* load input r1*/ - /* r0, r1, mul w0, get out r0, r1 */ - "fmul v15.4s , %[w0].4s, v0.4s\n" /* outr00 = w0 * r0, 0*/ - "fmul v16.4s , %[w0].4s, v1.4s\n" /* outr01 = w0 * r0, 1*/ - "fmul v17.4s , %[w0].4s, v2.4s\n" /* outr02 = w0 * r0, 2*/ - "fmul v18.4s , %[w0].4s, v3.4s\n" /* outr03 = w0 * r0, 3*/ - "fmul v19.4s , %[w0].4s, v6.4s\n" /* outr10 = w0 * r1, 0*/ - "fmul v20.4s , %[w0].4s, v7.4s\n" /* outr11 = w0 * r1, 1*/ - "fmul v21.4s , %[w0].4s, v8.4s\n" /* outr12 = w0 * r1, 2*/ - "fmul v22.4s , %[w0].4s, v9.4s\n" /* outr13 = w0 * r1, 3*/ - /* r0, r1, mul w1, get out r0, r1 */ - "fmla v15.4s , %[w1].4s, v1.4s\n" /* outr00 = w1 * r0[1]*/ - "ldp q0, q1, [%[inr2]], #32\n" /* load input r2*/ - "fmla v16.4s , %[w1].4s, v2.4s\n" /* outr01 = w1 * r0[2]*/ - "fmla v17.4s , %[w1].4s, v3.4s\n" /* outr02 = w1 * r0[3]*/ - "fmla v18.4s , %[w1].4s, v4.4s\n" /* outr03 = w1 * r0[4]*/ - "fmla v19.4s , %[w1].4s, v7.4s\n" /* outr10 = w1 * r1[1]*/ - "fmla v20.4s , %[w1].4s, v8.4s\n" /* outr11 = w1 * r1[2]*/ - "fmla v21.4s , %[w1].4s, v9.4s\n" /* outr12 = w1 * r1[3]*/ - "fmla v22.4s , %[w1].4s, v10.4s\n"/* outr13 = w1 * r1[4]*/ - /* r0, r1, mul w2, get out r0, r1 */ - "fmla v15.4s , %[w2].4s, v2.4s\n" /* outr00 = w2 * r0[2]*/ - "fmla v16.4s , %[w2].4s, v3.4s\n" /* outr01 = w2 * r0[3]*/ - "ldp q2, q3, [%[inr2]], #32\n" /* load input r2*/ - "fmla v17.4s , %[w2].4s, v4.4s\n" /* outr02 = w2 * r0[4]*/ - "fmla v18.4s , %[w2].4s, v5.4s\n" /* outr03 = w2 * r0[5]*/ - "ldp q4, q5, [%[inr2]]\n" /* load input r2*/ - "fmla v19.4s , %[w2].4s, v8.4s\n" /* outr10 = w2 * r1[2]*/ - "fmla v20.4s , %[w2].4s, v9.4s\n" /* outr11 = w2 * r1[3]*/ - "fmla v21.4s , %[w2].4s, v10.4s\n"/* outr12 = w2 * r1[4]*/ - "fmla v22.4s , %[w2].4s, v11.4s\n"/* outr13 = w2 * r1[5]*/ - /* r1, r2, mul w3, get out r0, r1 */ - "fmla v15.4s , %[w3].4s, v6.4s\n" /* outr00 = w3 * r1[0]*/ - "fmla v16.4s , %[w3].4s, v7.4s\n" /* outr01 = w3 * r1[1]*/ - "fmla v17.4s , %[w3].4s, v8.4s\n" /* outr02 = w3 * r1[2]*/ - "fmla v18.4s , %[w3].4s, v9.4s\n" /* outr03 = w3 * r1[3]*/ - "fmla v19.4s , %[w3].4s, v0.4s\n" /* outr10 = w3 * r2[0]*/ - "fmla v20.4s , %[w3].4s, v1.4s\n" /* outr11 = w3 * r2[1]*/ - "fmla v21.4s , %[w3].4s, v2.4s\n" /* outr12 = w3 * r2[2]*/ - "fmla v22.4s , %[w3].4s, v3.4s\n" /* outr13 = w3 * r2[3]*/ - /* r1, r2, mul w4, get out r0, r1 */ - "fmla v15.4s , %[w4].4s, v7.4s\n" /* outr00 = w4 * r1[1]*/ - "ldp q6, q7, [%[inr3]], #32\n" /* load input r3*/ - "fmla v16.4s , %[w4].4s, v8.4s\n" /* outr01 = w4 * r1[2]*/ - "fmla v17.4s , %[w4].4s, v9.4s\n" /* outr02 = w4 * r1[3]*/ - "fmla v18.4s , %[w4].4s, v10.4s\n"/* outr03 = w4 * r1[4]*/ - "ldp x0, x1, [%[outl]] \n" - "fmla v19.4s , %[w4].4s, v1.4s\n" /* outr10 = w4 * r2[1]*/ - "fmla v20.4s , %[w4].4s, v2.4s\n" /* outr11 = w4 * r2[2]*/ - "fmla v21.4s , %[w4].4s, v3.4s\n" /* outr12 = w4 * r2[3]*/ - "fmla v22.4s , %[w4].4s, v4.4s\n" /* outr13 = w4 * r2[4]*/ - /* r1, r2, mul w5, get out r0, r1 */ - "fmla v15.4s , %[w5].4s, v8.4s\n" /* outr00 = w5 * r1[2]*/ - "fmla v16.4s , %[w5].4s, v9.4s\n" /* outr01 = w5 * r1[3]*/ - "ldp q8, q9, [%[inr3]], #32\n" /* load input r3*/ - "fmla v17.4s , %[w5].4s, v10.4s\n"/* outr02 = w5 * r1[4]*/ - "fmla v18.4s , %[w5].4s, v11.4s\n"/* outr03 = w5 * r1[5]*/ - "ldp q10, q11, [%[inr3]]\n" /* load input r3*/ - "fmla v19.4s , %[w5].4s, v2.4s\n" /* outr10 = w5 * r2[2]*/ - "fmla v20.4s , %[w5].4s, v3.4s\n" /* outr11 = w5 * r2[3]*/ - "fmla v21.4s , %[w5].4s, v4.4s\n" /* outr12 = w5 * r2[4]*/ - "fmla v22.4s , %[w5].4s, v5.4s\n" /* outr13 = w5 * r2[5]*/ - /* r2, r3, mul w6, get out r0, r1 */ - "fmla v15.4s , %[w6].4s, v0.4s\n" /* outr00 = w6 * r2[0]*/ - "fmla v16.4s , %[w6].4s, v1.4s\n" /* outr01 = w6 * r2[1]*/ - "fmla v17.4s , %[w6].4s, v2.4s\n" /* outr02 = w6 * r2[2]*/ - "fmla v18.4s , %[w6].4s, v3.4s\n" /* outr03 = w6 * r2[3]*/ - "ldp x2, x3, [%[outl], #16] \n" - "fmla v19.4s , %[w6].4s, v6.4s\n" /* outr10 = w6 * r3[0]*/ - "fmla v20.4s , %[w6].4s, v7.4s\n" /* outr11 = w6 * r3[1]*/ - "fmla v21.4s , %[w6].4s, v8.4s\n" /* outr12 = w6 * r3[2]*/ - "fmla v22.4s , %[w6].4s, v9.4s\n" /* outr13 = w6 * r3[3]*/ - /* r2, r3, mul w7, get out r0, r1 */ - "fmla v15.4s , %[w7].4s, v1.4s\n" /* outr00 = w7 * r2[1]*/ - "fmla v16.4s , %[w7].4s, v2.4s\n" /* outr01 = w7 * r2[2]*/ - "fmla v17.4s , %[w7].4s, v3.4s\n" /* outr02 = w7 * r2[3]*/ - "fmla v18.4s , %[w7].4s, v4.4s\n" /* outr03 = w7 * r2[4]*/ - "ldp x4, x5, [%[outl], #32] \n" - "fmla v19.4s , %[w7].4s, v7.4s\n" /* outr10 = w7 * r3[1]*/ - "fmla v20.4s , %[w7].4s, v8.4s\n" /* outr11 = w7 * r3[2]*/ - "fmla v21.4s , %[w7].4s, v9.4s\n" /* outr12 = w7 * r3[3]*/ - "fmla v22.4s , %[w7].4s, v10.4s\n"/* outr13 = w7 * r3[4]*/ - /* r2, r3, mul w8, get out r0, r1 */ - "fmla v15.4s , %[w8].4s, v2.4s\n" /* outr00 = w8 * r2[2]*/ - "fmla v16.4s , %[w8].4s, v3.4s\n" /* outr01 = w8 * r2[3]*/ - "fmla v17.4s , %[w8].4s, v4.4s\n" /* outr02 = w8 * r2[0]*/ - "fmla v18.4s , %[w8].4s, v5.4s\n" /* outr03 = w8 * r2[1]*/ - "ldp x6, x7, [%[outl], #48] \n" - "fmla v19.4s , %[w8].4s, v8.4s\n" /* outr10 = w8 * r3[2]*/ - "fmla v20.4s , %[w8].4s, v9.4s\n" /* outr11 = w8 * r3[3]*/ - "fmla v21.4s , %[w8].4s, v10.4s\n"/* outr12 = w8 * r3[0]*/ - "fmla v22.4s , %[w8].4s, v11.4s\n"/* outr13 = w8 * r3[1]*/ - - "fadd v15.4s, v15.4s, %[vbias].4s\n"/* add bias */ - "fadd v16.4s, v16.4s, %[vbias].4s\n"/* add bias */ - "fadd v17.4s, v17.4s, %[vbias].4s\n"/* add bias */ - "fadd v18.4s, v18.4s, %[vbias].4s\n"/* add bias */ - "fadd v19.4s, v19.4s, %[vbias].4s\n"/* add bias */ - "fadd v20.4s, v20.4s, %[vbias].4s\n"/* add bias */ - "fadd v21.4s, v21.4s, %[vbias].4s\n"/* add bias */ - "fadd v22.4s, v22.4s, %[vbias].4s\n"/* add bias */ - - /* transpose */ - "trn1 v0.4s, v15.4s, v16.4s\n" /* r0: a0a1c0c1*/ - "trn2 v1.4s, v15.4s, v16.4s\n" /* r0: b0b1d0d1*/ - "trn1 v2.4s, v17.4s, v18.4s\n" /* r0: a2a3c2c3*/ - "trn2 v3.4s, v17.4s, v18.4s\n" /* r0: b2b3d2d3*/ - "trn1 v4.4s, v19.4s, v20.4s\n" /* r1: a0a1c0c1*/ - "trn2 v5.4s, v19.4s, v20.4s\n" /* r1: b0b1d0d1*/ - "trn1 v6.4s, v21.4s, v22.4s\n" /* r1: a2a3c2c3*/ - "trn2 v7.4s, v21.4s, v22.4s\n" /* r1: b2b3d2d3*/ - "trn1 v15.2d, v0.2d, v2.2d\n" /* r0: a0a1a2a3*/ - "trn2 v19.2d, v0.2d, v2.2d\n" /* r0: c0c1c2c3*/ - "trn1 v17.2d, v1.2d, v3.2d\n" /* r0: b0b1b2b3*/ - "trn2 v21.2d, v1.2d, v3.2d\n" /* r0: d0d1d2d3*/ - "trn1 v16.2d, v4.2d, v6.2d\n" /* r1: a0a1a2a3*/ - "trn2 v20.2d, v4.2d, v6.2d\n" /* r1: c0c1c2c3*/ - "trn1 v18.2d, v5.2d, v7.2d\n" /* r1: b0b1b2b3*/ - "trn2 v22.2d, v5.2d, v7.2d\n" /* r1: d0d1d2d3*/ - - "cbz %w[flag_relu], 0f\n" /* skip relu*/ - "movi v0.4s, #0\n" /* for relu */ - "fmax v15.4s, v15.4s, v0.4s\n" - "fmax v16.4s, v16.4s, v0.4s\n" - "fmax v17.4s, v17.4s, v0.4s\n" - "fmax v18.4s, v18.4s, v0.4s\n" - "fmax v19.4s, v19.4s, v0.4s\n" - "fmax v20.4s, v20.4s, v0.4s\n" - "fmax v21.4s, v21.4s, v0.4s\n" - "fmax v22.4s, v22.4s, v0.4s\n" - "0:\n" - "cbnz %w[flag_mask], 1f\n" - "str q15, [x0]\n" /* save outc00 */ - "str q16, [x4]\n" /* save outc01 */ - "str q17, [x1]\n" /* save outc10 */ - "str q18, [x5]\n" /* save outc11 */ - "str q19, [x2]\n" /* save outc20 */ - "str q20, [x6]\n" /* save outc21 */ - "str q21, [x3]\n" /* save outc30 */ - "str q22, [x7]\n" /* save outc31 */ - "b 2f\n" - "1:\n" - "str q15, [%[out]], #16 \n" /* save remain to pre_out */ - "str q17, [%[out]], #16 \n" /* save remain to pre_out */ - "str q19, [%[out]], #16 \n" /* save remain to pre_out */ - "str q21, [%[out]], #16 \n" /* save remain to pre_out */ - "str q16, [%[out]], #16 \n" /* save remain to pre_out */ - "str q18, [%[out]], #16 \n" /* save remain to pre_out */ - "str q20, [%[out]], #16 \n" /* save remain to pre_out */ - "str q22, [%[out]], #16 \n" /* save remain to pre_out */ - "2:\n" - :[inr0] "+r"(inr0), [inr1] "+r"(inr1), - [inr2] "+r"(inr2), [inr3] "+r"(inr3), - [out]"+r"(out0) - :[w0] "w"(w0), [w1] "w"(w1), [w2] "w"(w2), - [w3] "w"(w3), [w4] "w"(w4), [w5] "w"(w5), - [w6] "w"(w6), [w7] "w"(w7), [w8] "w"(w8), - [vbias]"w" (vbias), [outl] "r" (outl_ptr), - [flag_mask] "r" (flag_mask), [flag_relu] "r" (flag_relu) - : "cc", "memory", - "v0","v1","v2","v3","v4","v5","v6","v7", - "v8", "v9", "v10", "v11", "v15", - "v16","v17","v18","v19","v20","v21","v22", - "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7" - ); -#else - asm volatile( - /* load weights */ - "vld1.32 {d10-d13}, [%[wc0]]! @ load w0, w1, to q5, q6\n" - "vld1.32 {d14-d15}, [%[wc0]]! @ load w2, to q7\n" - /* load r0, r1 */ - "vld1.32 {d0-d3}, [%[r0]]! @ load r0, q0, q1\n" - "vld1.32 {d4-d7}, [%[r0]]! @ load r0, q2, q3\n" - /* main loop */ - "0: @ main loop\n" - /* mul r0 with w0, w1, w2, get out r0 */ - "vmul.f32 q8, q5, q0 @ w0 * inr00\n" - "vmul.f32 q9, q5, q1 @ w0 * inr01\n" - "vmul.f32 q10, q5, q2 @ w0 * inr02\n" - "vmul.f32 q11, q5, q3 @ w0 * inr03\n" - "vmla.f32 q8, q6, q1 @ w1 * inr01\n" - "vld1.32 {d0-d3}, [%[r0]] @ load r0, q0, q1\n" - "vmla.f32 q9, q6, q2 @ w1 * inr02\n" - "vmla.f32 q10, q6, q3 @ w1 * inr03\n" - "vmla.f32 q11, q6, q0 @ w1 * inr04\n" - "vmla.f32 q8, q7, q2 @ w2 * inr02\n" - "vmla.f32 q9, q7, q3 @ w2 * inr03\n" - "vld1.32 {d4-d7}, [%[r1]]! @ load r0, q2, q3\n" - "vmla.f32 q10, q7, q0 @ w2 * inr04\n" - "vmla.f32 q11, q7, q1 @ w2 * inr05\n" - "vld1.32 {d0-d3}, [%[r1]]! @ load r0, q0, q1\n" - "vld1.32 {d8-d9}, [%[wc0]]! @ load w3 to q4\n" - /* mul r1 with w0-w5, get out r0, r1 */ - "vmul.f32 q12, q5, q2 @ w0 * inr10\n" - "vmul.f32 q13, q5, q3 @ w0 * inr11\n" - "vmul.f32 q14, q5, q0 @ w0 * inr12\n" - "vmul.f32 q15, q5, q1 @ w0 * inr13\n" - "vld1.32 {d10-d11}, [%[wc0]]! @ load w4 to q5\n" - "vmla.f32 q8, q4, q2 @ w3 * inr10\n" - "vmla.f32 q9, q4, q3 @ w3 * inr11\n" - "vmla.f32 q10, q4, q0 @ w3 * inr12\n" - "vmla.f32 q11, q4, q1 @ w3 * inr13\n" - /* mul r1 with w1, w4, get out r1, r0 */ - "vmla.f32 q8, q5, q3 @ w4 * inr11\n" - "vmla.f32 q12, q6, q3 @ w1 * inr11\n" - "vld1.32 {d4-d7}, [%[r1]] @ load r1, q2, q3\n" - "vmla.f32 q9, q5, q0 @ w4 * inr12\n" - "vmla.f32 q13, q6, q0 @ w1 * inr12\n" - "vmla.f32 q10, q5, q1 @ w4 * inr13\n" - "vmla.f32 q14, q6, q1 @ w1 * inr13\n" - "vmla.f32 q11, q5, q2 @ w4 * inr14\n" - "vmla.f32 q15, q6, q2 @ w1 * inr14\n" - "vld1.32 {d12-d13}, [%[wc0]]! @ load w5 to q6\n" - /* mul r1 with w2, w5, get out r1, r0 */ - "vmla.f32 q12, q7, q0 @ w2 * inr12\n" - "vmla.f32 q13, q7, q1 @ w2 * inr13\n" - "vmla.f32 q8, q6, q0 @ w5 * inr12\n" - "vmla.f32 q9, q6, q1 @ w5 * inr13\n" - "vld1.32 {d0-d3}, [%[r2]]! @ load r2, q0, q1\n" - "vmla.f32 q14, q7, q2 @ w2 * inr14\n" - "vmla.f32 q15, q7, q3 @ w2 * inr15\n" - "vmla.f32 q10, q6, q2 @ w5 * inr14\n" - "vmla.f32 q11, q6, q3 @ w5 * inr15\n" - "vld1.32 {d4-d7}, [%[r2]]! @ load r2, q0, q1\n" - "vld1.32 {d14-d15}, [%[wc0]]! @ load w6, to q7\n" - /* mul r2 with w3-w8, get out r0, r1 */ - "vmla.f32 q12, q4, q0 @ w3 * inr20\n" - "vmla.f32 q13, q4, q1 @ w3 * inr21\n" - "vmla.f32 q14, q4, q2 @ w3 * inr22\n" - "vmla.f32 q15, q4, q3 @ w3 * inr23\n" - "vld1.32 {d8-d9}, [%[wc0]]! @ load w7, to q4\n" - "vmla.f32 q8, q7, q0 @ w6 * inr20\n" - "vmla.f32 q9, q7, q1 @ w6 * inr21\n" - "vmla.f32 q10, q7, q2 @ w6 * inr22\n" - "vmla.f32 q11, q7, q3 @ w6 * inr23\n" - /* mul r2 with w4, w7, get out r1, r0 */ - "vmla.f32 q8, q4, q1 @ w7 * inr21\n" - "vmla.f32 q12, q5, q1 @ w4 * inr21\n" - "vld1.32 {d0-d3}, [%[r2]] @ load r2, q0, q1\n" - "vmla.f32 q9, q4, q2 @ w7 * inr22\n" - "vmla.f32 q13, q5, q2 @ w4 * inr22\n" - "vmla.f32 q10, q4, q3 @ w7 * inr23\n" - "vmla.f32 q14, q5, q3 @ w4 * inr23\n" - "vmla.f32 q11, q4, q0 @ w7 * inr24\n" - "vmla.f32 q15, q5, q0 @ w4 * inr24\n" - "vld1.32 {d10-d11}, [%[wc0]]! @ load w8 to q5\n" - /* mul r1 with w5, w8, get out r1, r0 */ - "vmla.f32 q12, q6, q2 @ w5 * inr22\n" - "vmla.f32 q13, q6, q3 @ w5 * inr23\n" - "vmla.f32 q8, q5, q2 @ w8 * inr22\n" - "vmla.f32 q9, q5, q3 @ w8 * inr23\n" - "vld1.32 {d4-d7}, [%[r3]]! @ load r3, q2, q3\n" - "ldr r4, [%[outl], #32] @ load bias addr to r4\n" - "vmla.f32 q14, q6, q0 @ w5 * inr24\n" - "vmla.f32 q15, q6, q1 @ w5 * inr25\n" - "vmla.f32 q10, q5, q0 @ w8 * inr24\n" - "vmla.f32 q11, q5, q1 @ w8 * inr25\n" - "vld1.32 {d0-d3}, [%[r3]]! @ load r3, q0, q1\n" - "sub %[wc0], %[wc0], #144 @ wc0 - 144 to start address\n" - /* mul r3 with w6, w7, w8, get out r1 */ - "vmla.f32 q12, q7, q2 @ w6 * inr30\n" - "vmla.f32 q13, q7, q3 @ w6 * inr31\n" - "vmla.f32 q14, q7, q0 @ w6 * inr32\n" - "vmla.f32 q15, q7, q1 @ w6 * inr33\n" - "vmla.f32 q12, q4, q3 @ w7 * inr31\n" - "vld1.32 {d4-d7}, [%[r3]] @ load r3, q2, q3\n" - "vld1.32 {d12-d13}, [r4] @ load bias\n" - "vmla.f32 q13, q4, q0 @ w7 * inr32\n" - "vmla.f32 q14, q4, q1 @ w7 * inr33\n" - "vmla.f32 q15, q4, q2 @ w7 * inr34\n" - "ldr r0, [%[outl]] @ load outc00 to r0\n" - "vmla.f32 q12, q5, q0 @ w8 * inr32\n" - "vmla.f32 q13, q5, q1 @ w8 * inr33\n" - "ldr r5, [%[outl], #36] @ load flag_relu to r5\n" - "vmla.f32 q14, q5, q2 @ w8 * inr34\n" - "vmla.f32 q15, q5, q3 @ w8 * inr35\n" - "ldr r1, [%[outl], #4] @ load outc10 to r1\n" - "vadd.f32 q8, q8, q6 @ r00 add bias\n" - "vadd.f32 q9, q9, q6 @ r01 add bias\n" - "vadd.f32 q10, q10, q6 @ r02 add bias\n" - "vadd.f32 q11, q11, q6 @ r03 add bias\n" - "ldr r2, [%[outl], #8] @ load outc20 to r2\n" - "vadd.f32 q12, q12, q6 @ r10 add bias\n" - "vadd.f32 q13, q13, q6 @ r11 add bias\n" - "vadd.f32 q14, q14, q6 @ r12 add bias\n" - "vadd.f32 q15, q15, q6 @ r13 add bias\n" - "ldr r3, [%[outl], #12] @ load outc30 to r3\n" - "vmov.u32 q7, #0 @ mov zero to q7\n" - "cmp r5, #0 @ cmp flag relu\n" - "beq 1f @ skip relu\n" - "vmax.f32 q8, q8, q7 @ r00 relu\n" - "vmax.f32 q9, q9, q7 @ r01 relu\n" - "vmax.f32 q10, q10, q7 @ r02 relu\n" - "vmax.f32 q11, q11, q7 @ r03 relu\n" - "vmax.f32 q12, q12, q7 @ r10 relu\n" - "vmax.f32 q13, q13, q7 @ r11 relu\n" - "vmax.f32 q14, q14, q7 @ r12 relu\n" - "vmax.f32 q15, q15, q7 @ r13 relu\n" - "1:\n" - "ldr r4, [%[outl], #16] @ load outc01 to r4\n" - "vtrn.32 q8, q9 @ r0: q8 : a0a1c0c1, q9 : b0b1d0d1\n" - "vtrn.32 q10, q11 @ r0: q10: a2a3c2c3, q11: b2b3d2d3\n" - "vtrn.32 q12, q13 @ r1: q12: a0a1c0c1, q13: b0b1d0d1\n" - "vtrn.32 q14, q15 @ r1: q14: a2a3c2c3, q15: b2b3d2d3\n" - "ldr r5, [%[outl], #20] @ load outc11 to r5\n" - "vswp d17, d20 @ r0: q8 : a0a1a2a3, q10: c0c1c2c3 \n" - "vswp d19, d22 @ r0: q9 : b0b1b2b3, q11: d0d1d2d3 \n" - "vswp d25, d28 @ r1: q12: a0a1a2a3, q14: c0c1c2c3 \n" - "vswp d27, d30 @ r1: q13: b0b1b2b3, q15: d0d1d2d3 \n" - "cmp %[flag_mask], #0 @ cmp flag mask\n" - "bne 2f\n" - "vst1.32 {d16-d17}, [r0] @ save outc00\n" - "vst1.32 {d18-d19}, [r1] @ save outc10\n" - "vst1.32 {d20-d21}, [r2] @ save outc20\n" - "vst1.32 {d22-d23}, [r3] @ save outc30\n" - "vst1.32 {d24-d25}, [r4] @ save outc01\n" - "vst1.32 {d26-d27}, [r5] @ save outc11\n" - "ldr r0, [%[outl], #24] @ load outc21 to r0\n" - "ldr r1, [%[outl], #28] @ load outc31 to r1\n" - "vst1.32 {d28-d29}, [r0] @ save outc21\n" - "vst1.32 {d30-d31}, [r1] @ save outc31\n" - "b 3f @ branch end\n" - "2: \n" - "vst1.32 {d16-d17}, [%[out0]]! @ save remain to pre_out\n" - "vst1.32 {d18-d19}, [%[out0]]! @ save remain to pre_out\n" - "vst1.32 {d20-d21}, [%[out0]]! @ save remain to pre_out\n" - "vst1.32 {d22-d23}, [%[out0]]! @ save remain to pre_out\n" - "vst1.32 {d24-d25}, [%[out0]]! @ save remain to pre_out\n" - "vst1.32 {d26-d27}, [%[out0]]! @ save remain to pre_out\n" - "vst1.32 {d28-d29}, [%[out0]]! @ save remain to pre_out\n" - "vst1.32 {d30-d31}, [%[out0]]! @ save remain to pre_out\n" - "3: \n" - : [r0] "+r"(inr0), [r1] "+r"(inr1), - [r2] "+r"(inr2), [r3] "+r"(inr3), - [out0] "+r"(out0), [wc0] "+r"(weight_c) - : [flag_mask] "r" (flag_mask), [outl] "r" (outl_ptr) - : "cc", "memory", - "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", - "q10", "q11", "q12", "q13","q14", "q15", "r0", "r1", "r2", "r3", "r4", "r5" - ); -#endif // __arch64__ - // clang-format on - outl[0] += 4; - outl[1] += 4; - outl[2] += 4; - outl[3] += 4; - outl[4] += 4; - outl[5] += 4; - outl[6] += 4; - outl[7] += 4; - if (flag_mask) { - memcpy(outl[0] - 4, pre_out, remain * sizeof(float)); - memcpy(outl[1] - 4, pre_out + 4, remain * sizeof(float)); - memcpy(outl[2] - 4, pre_out + 8, remain * sizeof(float)); - memcpy(outl[3] - 4, pre_out + 12, remain * sizeof(float)); - memcpy(outl[4] - 4, pre_out + 16, remain * sizeof(float)); - memcpy(outl[5] - 4, pre_out + 20, remain * sizeof(float)); - memcpy(outl[6] - 4, pre_out + 24, remain * sizeof(float)); - memcpy(outl[7] - 4, pre_out + 28, remain * sizeof(float)); - } - } - } - } - } -} - -} // namespace math -} // namespace arm -} // namespace lite -} // namespace paddle diff --git a/lite/backends/arm/math/conv3x3s1_direct_fp32.cc b/lite/backends/arm/math/conv3x3s1_direct_fp32.cc index 6a1fa37681585883280625a22c15aec43c6554af..5cee02b639af7e04a9184af765a5e96be4cb4cdb 100644 --- a/lite/backends/arm/math/conv3x3s1_direct_fp32.cc +++ b/lite/backends/arm/math/conv3x3s1_direct_fp32.cc @@ -35,9 +35,10 @@ size_t conv3x3s1_direct_workspace_size(const operators::ConvParam& param, auto dim_in = param.x->dims(); auto dim_out = param.output->dims(); const int threads = ctx->threads(); + auto paddings = *param.paddings; int llc_size = ctx->llc_size() / sizeof(float); - const int pad_w = param.paddings[1]; - const int pad_h = param.paddings[0]; + const int pad_w = paddings[2]; + const int pad_h = paddings[0]; int ow = dim_out[3]; int oh = dim_out[2]; int ic = dim_in[1]; @@ -74,9 +75,11 @@ void conv_3x3s1_direct_fp32(const float* i_data, ARMContext* ctx) { const int threads = ctx->threads(); int l2_size = ctx->llc_size() / sizeof(float); + auto paddings = *param.paddings; + auto act_param = param.activation_param; - const int pad_h = param.paddings[0]; - const int pad_w = param.paddings[1]; + const int pad_h = paddings[0]; + const int pad_w = paddings[2]; const int wout_round = ROUNDUP(ow, OUT_W_BLOCK); const int win_round = wout_round + 2; bool flag_relu = param.fuse_relu; @@ -467,7 +470,8 @@ void conv_3x3s1_direct_fp32(const float* i_data, oh, ow, flag_relu, - ptr_write); + ptr_write, + &act_param); } const float* weight_remain_ptr = weights + c_round_down * w_stride; #pragma omp parallel for num_threads(threads) @@ -778,7 +782,8 @@ void conv_3x3s1_direct_fp32(const float* i_data, oh, ow, flag_relu, - ptr_write); + ptr_write, + &act_param); } } } diff --git a/lite/backends/arm/math/conv3x3s1_direct_int8.cc b/lite/backends/arm/math/conv3x3s1_direct_int8.cc index f966313e118acf3f74124aca1d16aa3c50009bb8..64e72bc441bb93fa955e12ff53ce17f0e37b4830 100644 --- a/lite/backends/arm/math/conv3x3s1_direct_int8.cc +++ b/lite/backends/arm/math/conv3x3s1_direct_int8.cc @@ -41,10 +41,11 @@ void conv_3x3s1_direct_int8(const int8_t* din, const operators::ConvParam& param, Context* ctx, const float* scale) { + auto paddings = *param.paddings; bool flag_relu = param.fuse_relu; bool flag_bias = param.bias; - int pad_h = param.paddings[0]; - int pad_w = param.paddings[1]; + int pad_h = paddings[0]; + int pad_w = paddings[2]; const int threads = ctx->threads(); int llc_size = ctx->llc_size() / 4; diff --git a/lite/backends/arm/math/conv3x3s1p01_depthwise_fp32.cc b/lite/backends/arm/math/conv3x3s1p01_depthwise_fp32.cc new file mode 100644 index 0000000000000000000000000000000000000000..66d61413fc43fd518e0b34c7bc8d7b7bf5cc72a7 --- /dev/null +++ b/lite/backends/arm/math/conv3x3s1p01_depthwise_fp32.cc @@ -0,0 +1,4094 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/backends/arm/math/conv_depthwise.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +void conv_depthwise_3x3s1p0_bias(float *dout, + const float *din, + const float *weights, + const float *bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + const operators::ActivationParam act_param, + ARMContext *ctx); + +void conv_depthwise_3x3s1p0_bias_s(float *dout, + const float *din, + const float *weights, + const float *bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + const operators::ActivationParam act_param, + ARMContext *ctx); + +void conv_depthwise_3x3s1p1_bias(float *dout, + const float *din, + const float *weights, + const float *bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + const operators::ActivationParam act_param, + ARMContext *ctx); + +void conv_depthwise_3x3s1p1_bias_s(float *dout, + const float *din, + const float *weights, + const float *bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + const operators::ActivationParam act_param, + ARMContext *ctx); + +void conv_depthwise_3x3s1_fp32(const float *din, + float *dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const float *weights, + const float *bias, + int pad, + bool flag_bias, + const operators::ActivationParam act_param, + ARMContext *ctx) { + if (pad == 0) { + if (w_in > 5) { + conv_depthwise_3x3s1p0_bias(dout, + din, + weights, + bias, + flag_bias, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + act_param, + ctx); + } else { + conv_depthwise_3x3s1p0_bias_s(dout, + din, + weights, + bias, + flag_bias, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + act_param, + ctx); + } + } + if (pad == 1) { + if (w_in > 4) { + conv_depthwise_3x3s1p1_bias(dout, + din, + weights, + bias, + flag_bias, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + act_param, + ctx); + } else { + conv_depthwise_3x3s1p1_bias_s(dout, + din, + weights, + bias, + flag_bias, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + act_param, + ctx); + } + } +} + +#ifdef __aarch64__ +#define INIT_S1 \ + "PRFM PLDL1KEEP, [%[din_ptr0]] \n" \ + "PRFM PLDL1KEEP, [%[din_ptr1]] \n" \ + "PRFM PLDL1KEEP, [%[din_ptr2]] \n" \ + "PRFM PLDL1KEEP, [%[din_ptr3]] \n" \ + "PRFM PLDL1KEEP, [%[din_ptr4]] \n" \ + "PRFM PLDL1KEEP, [%[din_ptr5]] \n" \ + "movi v21.4s, #0x0\n" /* out0 = 0 */ \ + \ + "ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ + +#define LEFT_COMPUTE_S1 \ + "ext v16.16b, %[vzero].16b, v0.16b, #12 \n" /* v16 = 00123*/ \ + "ext v17.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */ /* r0 */ \ + "fmla v12.4s, v0.4s, %[w0].s[1]\n" /* outr00 += din0_0123 * w0[1]*/ \ + \ + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + "sub %[din_ptr0], %[din_ptr0], #4 \n" /* din_ptr0-- */ \ + "sub %[din_ptr1], %[din_ptr1], #4 \n" /* din_ptr0-- */ \ + \ + "fmla v12.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din0_0012 * w0[0]*/ \ + \ + "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \ + "sub %[din_ptr2], %[din_ptr2], #4 \n" /* din_ptr0-- */ \ + "sub %[din_ptr3], %[din_ptr3], #4 \n" /* din_ptr0-- */ \ + \ + "fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_1234 * w0[2]*/ \ + \ + "ext v16.16b, %[vzero].16b, v2.16b, #12 \n" /* v16 = 00123*/ \ + "ext v17.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234 */ /* r1 */ \ + "fmla v13.4s , v2.4s, %[w0].s[1]\n" /* outr00 += din1_0123 * w0[1]*/ \ + "fmla v12.4s , v2.4s, %[w1].s[1]\n" /* outr00 += din1_0123 * w1[1]*/ \ + "sub %[din_ptr4], %[din_ptr4], #4 \n" /* din_ptr0-- */ \ + "sub %[din_ptr5], %[din_ptr5], #4 \n" /* din_ptr0-- */ \ + \ + "fmla v13.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din1_0123 * w0[1]*/ \ + "fmla v12.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din1_0123 * w1[1]*/ \ + \ + "ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ + "fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * w1[1]*/ \ + \ + "ext v17.16b, v4.16b, v5.16b, #4 \n" /* v16=1234 */ \ + "ext v16.16b, %[vzero].16b, v4.16b, #12 \n" /* v16 = 00123*/ \ + \ + /* r2 */ \ + "fmla v14.4s , v4.4s, %[w0].s[1]\n" /* outr00 += din2_0123 * w0[1]*/ \ + "fmla v13.4s , v4.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \ + "fmla v12.4s , v4.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w2[1]*/ \ + \ + "ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v14.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ + "fmla v13.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ + "fmla v12.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w1[1]*/ \ + \ + "ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ + "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ + "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w1[1]*/ \ + \ + "ext v16.16b, %[vzero].16b, v6.16b, #12 \n" /* v16 = 00123*/ \ + "ext v17.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234 */ /* r3 */ \ + "fmla v15.4s , v6.4s, %[w0].s[1]\n" /*outr00 += din2_0123 * w0[1]*/ \ + "fmla v14.4s , v6.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \ + "fmla v13.4s , v6.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w2[1]*/ \ + \ + "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ + "fmla v13.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w1[1]*/ \ + \ + "ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ + "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ + "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w1[1]*/ \ + \ + "ext v16.16b, %[vzero].16b, v8.16b, #12 \n" /* v16 = 00123*/ \ + "ext v17.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234 */ /* r4 */ \ + "fmla v15.4s , v8.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \ + "fmla v14.4s , v8.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w2[1]*/ + +#define LEFT_RESULT_S1 \ + "st1 {v12.4s}, [%[doutr0]], #16 \n" /* vst1q_f32() */ \ + "st1 {v13.4s}, [%[doutr1]], #16 \n" /* vst1q_f32() */ \ + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w1[1]*/ \ + \ + "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + \ + "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ + "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w1[1]*/ \ + \ + "ext v16.16b, %[vzero].16b, v10.16b, #12 \n" /* v16 = 00123*/ \ + "ext v17.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234 */ /* r5 */ \ + "fmla v15.4s , v10.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \ + \ + "st1 {v14.4s}, [%[doutr2]], #16 \n" /* vst1q_f32() */ \ + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ + \ + "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + \ + "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ + \ + "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \ + \ + "st1 {v15.4s}, [%[doutr3]], #16 \n" /* vst1q_f32() */ \ + "cmp %w[cnt], #1 \n" \ + "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + \ + "blt 3f \n" + +#define MID_COMPUTE_S1 \ + "1: \n" /* r0 */ \ + "fmla v12.4s , v0.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + \ + "ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v12.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v2.16b, v3.16b, #8 \n" /* v16 = 2345 */ /* r1 */ \ + "fmla v13.4s , v2.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v12.4s , v2.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + \ + "ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v13.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v12.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v4.16b, v5.16b, #8 \n" /* v16 = 2345 */ /* r2 */ \ + "fmla v14.4s , v4.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v13.4s , v4.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v12.4s , v4.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + \ + "ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v14.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v13.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v12.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ + "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + \ + "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ + +#define MID_RESULT_S1 \ + "st1 {v12.4s}, [%[doutr0]], #16 \n" \ + \ + "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + \ + "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ + "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + \ + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + "st1 {v13.4s}, [%[doutr1]], #16 \n" \ + \ + "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + \ + "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ + "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + \ + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + "st1 {v14.4s}, [%[doutr2]], #16 \n" \ + \ + "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + \ + "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \ + \ + "subs %w[cnt], %w[cnt], #1 \n" \ + \ + "st1 {v15.4s}, [%[doutr3]], #16 \n" \ + "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + \ + "bne 1b \n" + +#define RIGHT_COMPUTE_S1 \ + "3: \n" \ + "movi v20.4s, #0 \n" \ + "ld1 {v18.4s, v19.4s}, [%[vmask]] \n" \ + "ld1 {v22.4s}, [%[doutr0]] \n" \ + "ld1 {v23.4s}, [%[doutr1]] \n" \ + "ld1 {v24.4s}, [%[doutr2]] \n" \ + "ld1 {v25.4s}, [%[doutr3]] \n" \ + \ + "bif v0.16b, v20.16b, v18.16b \n" \ + "bif v1.16b, v20.16b, v19.16b \n" \ + "bif v2.16b, v20.16b, v18.16b \n" \ + "bif v3.16b, v20.16b, v19.16b \n" \ + \ + "bif v4.16b, v20.16b, v18.16b \n" \ + "bif v5.16b, v20.16b, v19.16b \n" \ + "bif v6.16b, v20.16b, v18.16b \n" \ + "bif v7.16b, v20.16b, v19.16b \n" \ + \ + "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ /* r0 */ \ + "fmla v12.4s, v0.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + \ + "bif v8.16b, v20.16b, v18.16b \n" \ + "bif v9.16b, v20.16b, v19.16b \n" \ + "bif v10.16b, v20.16b, v18.16b \n" \ + "bif v11.16b, v20.16b, v19.16b \n" \ + \ + "fmla v12.4s, v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "ld1 {v18.4s}, [%[rmask]] \n" \ + \ + "fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v2.16b, v3.16b, #8 \n" /* v16 = 2345 */ /* r1 */ \ + "fmla v13.4s , v2.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v12.4s , v2.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + \ + "fmla v13.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v12.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v4.16b, v5.16b, #8 \n" /* v16 = 2345 */ /* r2 */ \ + "fmla v14.4s , v4.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v13.4s , v4.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v12.4s , v4.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + \ + "fmla v14.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v13.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v12.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ + "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ + +#define RIGHT_RESULT_S1 \ + "bif v12.16b, v22.16b, v18.16b \n" \ + \ + "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "st1 {v12.4s}, [%[doutr0]], #16 \n" \ + \ + "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ + "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + \ + "bif v13.16b, v23.16b, v18.16b \n" \ + \ + "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "st1 {v13.4s}, [%[doutr1]], #16 \n" \ + \ + "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ + "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + \ + "bif v14.16b, v24.16b, v18.16b \n" \ + \ + "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "st1 {v14.4s}, [%[doutr2]], #16 \n" \ + \ + "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "bif v15.16b, v25.16b, v18.16b \n" \ + \ + "st1 {v15.4s}, [%[doutr3]], #16 \n" + +#define LEFT_RESULT_S1_RELU \ + "fmax v12.4s, v12.4s, %[vzero].4s \n" /*relu*/ \ + "fmax v13.4s, v13.4s, %[vzero].4s \n" /*relu*/ \ + \ + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w1[1]*/ \ + \ + "st1 {v12.4s}, [%[doutr0]], #16 \n" /* vst1q_f32() */ \ + "st1 {v13.4s}, [%[doutr1]], #16 \n" /* vst1q_f32() */ \ + \ + "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ + "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w1[1]*/ \ + \ + "ext v16.16b, %[vzero].16b, v10.16b, #12 \n" /* v16 = 00123*/ \ + "ext v17.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234 */ \ + "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ /* r5*/ \ + "fmla v15.4s , v10.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \ + \ + "fmax v14.4s, v14.4s, %[vzero].4s \n" /*relu*/ \ + \ + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ + \ + "st1 {v14.4s}, [%[doutr2]], #16 \n" /* vst1q_f32() */ \ + \ + "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ + \ + "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + \ + "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \ + \ + "fmax v15.4s, v15.4s, %[vzero].4s \n" /*relu*/ \ + \ + "st1 {v15.4s}, [%[doutr3]], #16 \n" /* vst1q_f32() */ \ + "cmp %w[cnt], #1 \n" \ + "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + "blt 3f \n" + +#define LEFT_RESULT_S1_RELU6 \ + "fmax v12.4s, v12.4s, %[vzero].4s \n" /*relu*/ \ + "fmax v13.4s, v13.4s, %[vzero].4s \n" /*relu*/ \ + \ + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w1[1]*/ \ + \ + "fmin v12.4s, v12.4s, %[vsix].4s \n" /*relu6*/ \ + "fmin v13.4s, v13.4s, %[vsix].4s \n" /*relu6*/ \ + \ + "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ + "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w1[1]*/ \ + \ + "st1 {v12.4s}, [%[doutr0]], #16 \n" /* vst1q_f32() */ \ + "st1 {v13.4s}, [%[doutr1]], #16 \n" /* vst1q_f32() */ \ + "ext v16.16b, %[vzero].16b, v10.16b, #12 \n" /* v16 = 00123*/ \ + "ext v17.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234 */ \ + "fmla v15.4s , v10.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \ + "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ /* r5*/ \ + \ + "fmax v14.4s, v14.4s, %[vzero].4s \n" /*relu*/ \ + \ + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ + \ + "fmin v14.4s, v14.4s, %[vsix].4s \n" /*relu6*/ \ + \ + "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ + \ + "st1 {v14.4s}, [%[doutr2]], #16 \n" /* vst1q_f32() */ \ + \ + "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \ + \ + "fmax v15.4s, v15.4s, %[vzero].4s \n" /*relu*/ \ + "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + \ + "fmin v15.4s, v15.4s, %[vsix].4s \n" /*relu6*/ \ + "st1 {v15.4s}, [%[doutr3]], #16 \n" /* vst1q_f32() */ \ + "cmp %w[cnt], #1 \n" \ + "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + "blt 3f \n" + +#define LEFT_RESULT_S1_LEAKY_RELU \ + "fcmge v18.4s, v12.4s, %[vzero].4s \n" /* vcgeq_f32 */ \ + "fcmge v19.4s, v13.4s, %[vzero].4s \n" /* vcgeq_f32 */ \ + "fmul v20.4s, v12.4s, %[vscale].4s \n" /* mul */ \ + "fmul v21.4s, v13.4s, %[vscale].4s \n" /* mul */ \ + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w1[1]*/ \ + \ + "bif v12.16b, v20.16b, v18.16b \n" /* choose*/ \ + "bif v13.16b, v21.16b, v19.16b \n" /* choose*/ \ + "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ + "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w1[1]*/ \ + \ + "ext v16.16b, %[vzero].16b, v10.16b, #12 \n" /* v16 = 00123*/ \ + "ext v17.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234 */ \ + "st1 {v12.4s}, [%[doutr0]], #16 \n" /* vst1q_f32() */ \ + "st1 {v13.4s}, [%[doutr1]], #16 \n" /* vst1q_f32() */ \ + \ + "fmla v15.4s , v10.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \ + \ + "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ /* r5*/ \ + "fcmge v18.4s, v14.4s, %[vzero].4s \n" /* vcgeq_f32 */ \ + "fmul v20.4s, v14.4s, %[vscale].4s \n" /* mul */ \ + \ + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ + \ + "bif v14.16b, v20.16b, v18.16b \n" /* choose*/ \ + \ + "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ + \ + "st1 {v14.4s}, [%[doutr2]], #16 \n" /* vst1q_f32() */ \ + \ + "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \ + \ + "fcmge v18.4s, v15.4s, %[vzero].4s \n" /* vcgeq_f32 */ \ + "fmul v20.4s, v15.4s, %[vscale].4s \n" /* mul */ \ + "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + "bif v15.16b, v20.16b, v18.16b \n" /* choose*/ \ + "cmp %w[cnt], #1 \n" \ + "st1 {v15.4s}, [%[doutr3]], #16 \n" /* vst1q_f32() */ \ + "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + "blt 3f \n" + +#define MID_RESULT_S1_RELU \ + "movi v20.4s, #0 \n" \ + "fmax v12.4s, v12.4s, v20.4s \n" /*relu*/ \ + \ + "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "st1 {v12.4s}, [%[doutr0]], #16 \n" \ + \ + "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + \ + "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ + "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + \ + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + "fmax v13.4s, v13.4s, v20.4s \n" /*relu*/ \ + \ + "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "st1 {v13.4s}, [%[doutr1]], #16 \n" \ + \ + "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + \ + "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ \ + \ + /* r3 */ \ + "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + "fmax v14.4s, v14.4s, v20.4s \n" /*relu*/ \ + \ + "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "st1 {v14.4s}, [%[doutr2]], #16 \n" \ + \ + "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + \ + "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \ + \ + "subs %w[cnt], %w[cnt], #1 \n" \ + \ + "fmax v15.4s, v15.4s, v20.4s \n" /*relu*/ \ + \ + "st1 {v15.4s}, [%[doutr3]], #16 \n" \ + "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + \ + "bne 1b \n" + +#define MID_RESULT_S1_RELU6 \ + "movi v20.4s, #0 \n" \ + "fmax v12.4s, v12.4s, v20.4s \n" /*relu*/ \ + \ + "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "fmin v12.4s, v12.4s, %[vsix].4s \n" /*relu6*/ \ + \ + "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "st1 {v12.4s}, [%[doutr0]], #16 \n" \ + "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ + "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + \ + "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + "fmax v13.4s, v13.4s, v20.4s \n" /*relu*/ \ + \ + "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "fmin v13.4s, v13.4s, %[vsix].4s \n" /*relu6*/ \ + \ + "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ \ + "st1 {v13.4s}, [%[doutr1]], #16 \n" \ + \ + /* r3 */ \ + "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + "fmax v14.4s, v14.4s, v20.4s \n" /*relu*/ \ + \ + "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "fmin v14.4s, v14.4s, %[vsix].4s \n" /*relu6*/ \ + \ + "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \ + "st1 {v14.4s}, [%[doutr2]], #16 \n" \ + \ + "fmax v15.4s, v15.4s, v20.4s \n" /*relu*/ \ + "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + \ + "fmin v15.4s, v15.4s, %[vsix].4s \n" /*relu6*/ \ + "subs %w[cnt], %w[cnt], #1 \n" \ + \ + "st1 {v15.4s}, [%[doutr3]], #16 \n" \ + "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + \ + "bne 1b \n" + +#define MID_RESULT_S1_LEAKY_RELU \ + "movi v21.4s, #0 \n" \ + "fcmge v18.4s, v12.4s, v21.4s \n" /* vcgeq_f32 */ \ + "fmul v20.4s, v12.4s, %[vscale].4s \n" /* mul */ \ + \ + "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "bif v12.16b, v20.16b, v18.16b \n" /* choose*/ \ + \ + "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ + "st1 {v12.4s}, [%[doutr0]], #16 \n" \ + "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + \ + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + "fcmge v18.4s, v13.4s, v21.4s \n" /* vcgeq_f32 */ \ + "fmul v20.4s, v13.4s, %[vscale].4s \n" /* mul */ \ + \ + "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + "bif v13.16b, v20.16b, v18.16b \n" /* choose*/ \ + \ + "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ \ + "st1 {v13.4s}, [%[doutr1]], #16 \n" \ + \ + /* r3 */ \ + "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + "fcmge v18.4s, v14.4s, v21.4s \n" /* vcgeq_f32 */ \ + "fmul v20.4s, v14.4s, %[vscale].4s \n" /* mul */ \ + \ + "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "bif v14.16b, v20.16b, v18.16b \n" /* choose*/ \ + \ + "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \ + "st1 {v14.4s}, [%[doutr2]], #16 \n" \ + \ + "fcmge v18.4s, v15.4s, v21.4s \n" /* vcgeq_f32 */ \ + "fmul v20.4s, v15.4s, %[vscale].4s \n" /* mul */ \ + \ + "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + "bif v15.16b, v20.16b, v18.16b \n" /* choose*/ \ + "subs %w[cnt], %w[cnt], #1 \n" \ + \ + "st1 {v15.4s}, [%[doutr3]], #16 \n" \ + "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + \ + "bne 1b \n" + +#define RIGHT_RESULT_S1_RELU \ + "fmax v12.4s, v12.4s, v20.4s \n" /*relu*/ \ + \ + "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "bif v12.16b, v22.16b, v18.16b \n" \ + \ + "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ + "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + \ + "st1 {v12.4s}, [%[doutr0]], #16 \n" \ + "fmax v13.4s, v13.4s, v20.4s \n" /*relu*/ \ + \ + "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "bif v13.16b, v23.16b, v18.16b \n" \ + \ + "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ \ + \ + "st1 {v13.4s}, [%[doutr1]], #16 \n" /* r3 */ \ + "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + \ + "fmax v14.4s, v14.4s, v20.4s \n" /*relu*/ \ + \ + "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "bif v14.16b, v24.16b, v18.16b \n" \ + \ + "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "st1 {v14.4s}, [%[doutr2]], #16 \n" \ + \ + "fmax v15.4s, v15.4s, v20.4s \n" /*relu*/ \ + \ + "bif v15.16b, v25.16b, v18.16b \n" \ + \ + "st1 {v15.4s}, [%[doutr3]], #16 \n" + +#define RIGHT_RESULT_S1_RELU6 \ + "fmax v12.4s, v12.4s, v20.4s \n" /*relu*/ \ + \ + "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "fmin v12.4s, v12.4s, %[vsix].4s \n" /*relu6*/ \ + \ + "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ + "bif v12.16b, v22.16b, v18.16b \n" \ + "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmax v13.4s, v13.4s, v20.4s \n" /*relu*/ \ + \ + "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "st1 {v12.4s}, [%[doutr0]], #16 \n" \ + \ + "fmin v13.4s, v13.4s, %[vsix].4s \n" /*relu6*/ \ + \ + "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ \ + "bif v13.16b, v23.16b, v18.16b \n" \ + \ + "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + \ + "fmax v14.4s, v14.4s, v20.4s \n" /*relu*/ \ + "st1 {v13.4s}, [%[doutr1]], #16 \n" /* r3 */ \ + \ + "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "fmin v14.4s, v14.4s, %[vsix].4s \n" /*relu6*/ \ + \ + "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "bif v14.16b, v24.16b, v18.16b \n" \ + "fmax v15.4s, v15.4s, v20.4s \n" /*relu*/ \ + \ + "st1 {v14.4s}, [%[doutr2]], #16 \n" \ + \ + "fmin v15.4s, v15.4s, %[vsix].4s \n" /*relu6*/ \ + "bif v15.16b, v25.16b, v18.16b \n" \ + \ + "st1 {v15.4s}, [%[doutr3]], #16 \n" + +#define RIGHT_RESULT_S1_LEAKY_RELU \ + "movi v1.4s, #0 \n" \ + "fcmge v20.4s, v12.4s, v1.4s \n" /* vcgeq_f32 */ \ + "fmul v21.4s, v12.4s, %[vscale].4s \n" /* mul */ \ + \ + "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "bif v12.16b, v21.16b, v20.16b \n" /* choose*/ \ + \ + "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ + "bif v12.16b, v22.16b, v18.16b \n" \ + "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + \ + "fcmge v20.4s, v13.4s, v1.4s \n" /* vcgeq_f32 */ \ + "fmul v21.4s, v13.4s, %[vscale].4s \n" /* mul */ \ + "st1 {v12.4s}, [%[doutr0]], #16 \n" \ + \ + "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "bif v13.16b, v21.16b, v20.16b \n" \ + "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ \ + \ + "bif v13.16b, v23.16b, v18.16b \n" \ + \ + "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + \ + "fcmge v20.4s, v14.4s, v1.4s \n" /* vcgeq_f32 */ \ + "fmul v21.4s, v14.4s, %[vscale].4s \n" /* mul */ \ + "st1 {v13.4s}, [%[doutr1]], #16 \n" /* r3 */ \ + \ + "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "bif v14.16b, v21.16b, v20.16b \n" \ + "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "bif v14.16b, v24.16b, v18.16b \n" \ + \ + "fcmge v20.4s, v15.4s, v1.4s \n" /* vcgeq_f32 */ \ + "fmul v21.4s, v15.4s, %[vscale].4s \n" /* mul */ \ + \ + "st1 {v14.4s}, [%[doutr2]], #16 \n" \ + "bif v15.16b, v21.16b, v20.16b \n" \ + "bif v15.16b, v25.16b, v18.16b \n" \ + "st1 {v15.4s}, [%[doutr3]], #16 \n" + +#define COMPUTE_S_S1 \ + "prfm pldl1keep, [%[din0]]\n" \ + "prfm pldl1keep, [%[din1]]\n" \ + "prfm pldl1keep, [%[din2]]\n" \ + "prfm pldl1keep, [%[din3]]\n" \ + \ + "ld1 {v0.4s}, [%[din0]], #16\n" \ + "ld1 {v1.4s}, [%[din1]], #16\n" \ + "ld1 {v2.4s}, [%[din2]], #16\n" \ + "ld1 {v3.4s}, [%[din3]], #16\n" \ + \ + "bif v0.16b, %[vzero].16b, %[mask].16b\n" \ + "bif v1.16b, %[vzero].16b, %[mask].16b\n" \ + "bif v2.16b, %[vzero].16b, %[mask].16b\n" \ + "bif v3.16b, %[vzero].16b, %[mask].16b\n" \ + \ + "ext v4.16b, %[vzero].16b, v0.16b, #12\n" \ + "ext v5.16b, %[vzero].16b, v1.16b, #12\n" \ + "ext v6.16b, %[vzero].16b, v2.16b, #12\n" \ + "ext v7.16b, %[vzero].16b, v3.16b, #12\n" \ + \ + "ext v8.16b, v0.16b, %[vzero].16b, #4\n" \ + "ext v9.16b, v1.16b, %[vzero].16b, #4\n" \ + "ext v10.16b, v2.16b, %[vzero].16b, #4\n" \ + "ext v11.16b, v3.16b, %[vzero].16b, #4\n" \ + \ + "fmul v12.4s, v0.4s, %[wr0].s[1]\n" \ + "fmul v13.4s, v1.4s, %[wr0].s[1]\n" \ + \ + "fmul v14.4s, v1.4s, %[wr1].s[1]\n" \ + "fmul v15.4s, v2.4s, %[wr1].s[1]\n" \ + \ + "fmul v16.4s, v2.4s, %[wr2].s[1]\n" \ + "fmul v17.4s, v3.4s, %[wr2].s[1]\n" \ + \ + "fmla v12.4s, v4.4s, %[wr0].s[0]\n" \ + "fmla v13.4s, v5.4s, %[wr0].s[0]\n" \ + \ + "fmla v14.4s, v5.4s, %[wr1].s[0]\n" \ + "fmla v15.4s, v6.4s, %[wr1].s[0]\n" \ + \ + "fmla v16.4s, v6.4s, %[wr2].s[0]\n" \ + "fmla v17.4s, v7.4s, %[wr2].s[0]\n" \ + \ + "fmla v12.4s, v8.4s, %[wr0].s[2]\n" \ + "fmla v13.4s, v9.4s, %[wr0].s[2]\n" \ + \ + "fmla v14.4s, v9.4s, %[wr1].s[2]\n" \ + "fmla v15.4s, v10.4s, %[wr1].s[2]\n" \ + \ + "fmla v16.4s, v10.4s, %[wr2].s[2]\n" \ + "fmla v17.4s, v11.4s, %[wr2].s[2]\n" \ + \ + "fadd v12.4s, v12.4s, v14.4s\n" \ + "fadd v12.4s, v12.4s, v16.4s\n" \ + \ + "fadd v13.4s, v13.4s, v15.4s\n" \ + "fadd v13.4s, v13.4s, v17.4s\n" \ + \ + "fadd v12.4s, v12.4s, %[bias].4s\n" \ + "fadd v13.4s, v13.4s, %[bias].4s\n" + +#define RESULT_S_S1 \ + "prfm pldl1keep, [%[out1]]\n" \ + "prfm pldl1keep, [%[out2]]\n" \ + \ + "st1 {v12.4s}, [%[out1]]\n" \ + "st1 {v13.4s}, [%[out2]]\n" + +#define RESULT_S_S1_RELU \ + "prfm pldl1keep, [%[out1]]\n" \ + "prfm pldl1keep, [%[out2]]\n" \ + \ + "fmax v12.4s, v12.4s, %[vzero].4s\n" \ + "fmax v13.4s, v13.4s, %[vzero].4s\n" \ + \ + "st1 {v12.4s}, [%[out1]]\n" \ + "st1 {v13.4s}, [%[out2]]\n" + +#define RESULT_S_S1_RELU6 \ + "prfm pldl1keep, [%[out1]]\n" \ + "prfm pldl1keep, [%[out2]]\n" \ + \ + "fmax v12.4s, v12.4s, %[vzero].4s\n" \ + "fmax v13.4s, v13.4s, %[vzero].4s\n" \ + \ + "fmin v12.4s, v12.4s, %[vsix].4s\n" \ + "fmin v13.4s, v13.4s, %[vsix].4s\n" \ + \ + "st1 {v12.4s}, [%[out1]]\n" \ + "st1 {v13.4s}, [%[out2]]\n" + +#define RESULT_S_S1_LEAKY_RELU \ + "prfm pldl1keep, [%[out1]]\n" \ + "prfm pldl1keep, [%[out2]]\n" \ + \ + "fcmge v18.4s, v12.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ + "fcmge v19.4s, v13.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ + "fmul v20.4s, v12.4s, %[vscale].4s \n" /* mul */ \ + "fmul v21.4s, v13.4s, %[vscale].4s \n" /* mul */ \ + \ + "bif v12.16b, v20.16b, v18.16b \n" \ + "bif v13.16b, v21.16b, v19.16b \n" \ + "st1 {v12.4s}, [%[out1]]\n" \ + "st1 {v13.4s}, [%[out2]]\n" +#define COMPUTE_S_S1_P0 \ + "prfm pldl1keep, [%[din0]]\n" \ + "prfm pldl1keep, [%[din1]]\n" \ + "prfm pldl1keep, [%[din2]]\n" \ + "prfm pldl1keep, [%[din3]]\n" \ + \ + "ld1 {v0.4s, v1.4s}, [%[din0]]\n" \ + "ld1 {v2.4s, v3.4s}, [%[din1]]\n" \ + "ld1 {v4.4s, v5.4s}, [%[din2]]\n" \ + "ld1 {v6.4s, v7.4s}, [%[din3]]\n" \ + \ + "bif v0.16b, %[vzero].16b, %[mask1].16b\n" \ + "bif v1.16b, %[vzero].16b, %[mask2].16b\n" \ + \ + "bif v2.16b, %[vzero].16b, %[mask1].16b\n" \ + "bif v3.16b, %[vzero].16b, %[mask2].16b\n" \ + \ + "bif v4.16b, %[vzero].16b, %[mask1].16b\n" \ + "bif v5.16b, %[vzero].16b, %[mask2].16b\n" \ + \ + "bif v6.16b, %[vzero].16b, %[mask1].16b\n" \ + "bif v7.16b, %[vzero].16b, %[mask2].16b\n" \ + \ + "ext v8.16b, v0.16b, v1.16b, #4\n" \ + "ext v9.16b, v0.16b, v1.16b, #8\n" \ + \ + "and v12.16b, %[vbias].16b, %[vbias].16b \n" \ + "and v13.16b, %[vbias].16b, %[vbias].16b \n" /* r0 */ \ + "fmul v10.4s, v0.4s, %[wr0].s[0]\n" \ + "fmul v11.4s, v8.4s, %[wr0].s[1]\n" \ + "fmla v12.4s, v9.4s, %[wr0].s[2]\n" \ + \ + "ext v8.16b, v2.16b, v3.16b, #4\n" \ + "ext v9.16b, v2.16b, v3.16b, #8\n" /* r1 */ \ + "fmul v14.4s, v2.4s, %[wr0].s[0]\n" \ + "fmla v10.4s, v2.4s, %[wr1].s[0]\n" \ + \ + "fmul v15.4s, v8.4s, %[wr0].s[1]\n" \ + "fmla v11.4s, v8.4s, %[wr1].s[1]\n" \ + \ + "fmla v13.4s, v9.4s, %[wr0].s[2]\n" \ + "fmla v12.4s, v9.4s, %[wr1].s[2]\n" \ + \ + "ext v8.16b, v4.16b, v5.16b, #4\n" \ + "ext v9.16b, v4.16b, v5.16b, #8\n" /* r2 */ \ + "fmla v14.4s, v4.4s, %[wr1].s[0]\n" \ + "fmla v10.4s, v4.4s, %[wr2].s[0]\n" \ + \ + "fmla v15.4s, v8.4s, %[wr1].s[1]\n" \ + "fmla v11.4s, v8.4s, %[wr2].s[1]\n" \ + \ + "fmla v13.4s, v9.4s, %[wr1].s[2]\n" \ + "fmla v12.4s, v9.4s, %[wr2].s[2]\n" \ + \ + "ext v8.16b, v6.16b, v7.16b, #4\n" \ + "ext v9.16b, v6.16b, v7.16b, #8\n" \ + \ + "fmla v14.4s, v6.4s, %[wr2].s[0]\n" \ + \ + "fmla v15.4s, v8.4s, %[wr2].s[1]\n" \ + \ + "fadd v12.4s, v12.4s, v10.4s\n" \ + \ + "fmla v13.4s, v9.4s, %[wr2].s[2]\n" \ + \ + "fadd v12.4s, v12.4s, v11.4s\n" \ + "fadd v13.4s, v13.4s, v14.4s\n" \ + "fadd v13.4s, v13.4s, v15.4s\n" // \ + // "prfm pldl1keep, [%[out1]]\n" \ + // "prfm pldl1keep, [%[out2]]\n" \ + // \ + // "st1 {v12.4s}, [%[out1]]\n" \ + // "st1 {v13.4s}, [%[out2]]\n" \ + +#else +#define INIT_S1 \ + "pld [%[din0_ptr]] @ preload data\n" \ + "pld [%[din1_ptr]] @ preload data\n" \ + "pld [%[din2_ptr]] @ preload data\n" \ + "pld [%[din3_ptr]] @ preload data\n" \ + \ + "vld1.32 {d16-d18}, [%[din0_ptr]]! @ load din r0\n" \ + "vld1.32 {d20-d22}, [%[din1_ptr]]! @ load din r1\n" \ + "vld1.32 {d24-d26}, [%[din2_ptr]]! @ load din r2\n" \ + "vld1.32 {d28-d30}, [%[din3_ptr]]! @ load din r3\n" \ + \ + "vdup.32 q4, %[bias_val] @ and \n" \ + "vdup.32 q5, %[bias_val] @ and \n" + +#define LEFT_COMPUTE_S1 \ + "vext.32 q6, %q[vzero], q8, #3 @ 0012\n" \ + "vext.32 q7, q8, q9, #1 @ 1234\n" /* r0 */ \ + "vmla.f32 q4, q8, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "sub %[din0_ptr], #12 @ 1pad + 2 float data overlap\n" \ + "sub %[din1_ptr], #12 @ 1pad + 2 float data overlap\n" \ + "sub %[din2_ptr], #12 @ 1pad + 2 float data overlap\n" \ + "sub %[din3_ptr], #12 @ 1pad + 2 float data overlap\n" \ + \ + "vmla.f32 q4, q6, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" \ + \ + "pld [%[din0_ptr]] @ preload data\n" \ + "pld [%[din1_ptr]] @ preload data\n" \ + "pld [%[din2_ptr]] @ preload data\n" \ + "pld [%[din3_ptr]] @ preload data\n" \ + \ + "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 1234 * wr0[2]\n" \ + \ + "vext.32 q6, %q[vzero], q10, #3 @ 0012\n" \ + "vext.32 q7, q10, q11, #1 @ 1234\n" \ + \ + /* r1 */ \ + "vmla.f32 q5, q10, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" \ + "vmla.f32 q4, q10, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0\n" \ + "vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r0\n" \ + \ + "vmla.f32 q5, q6, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" \ + "vmla.f32 q4, q6, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" \ + \ + "vld1.32 {d18}, [%[din0_ptr]] @ load din r0\n" \ + "vld1.32 {d22}, [%[din1_ptr]] @ load din r0\n" \ + \ + "vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[2]\n" \ + "vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[2]\n" \ + \ + "vext.32 q6, %q[vzero], q12, #3 @ 0012\n" \ + "vext.32 q7, q12, q13, #1 @ 1234\n" \ + \ + /* r2 */ \ + "vmla.f32 q5, q12, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" \ + "vmla.f32 q4, q12, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r0\n" \ + \ + "vmla.f32 q5, q6, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" \ + "vmla.f32 q4, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]\n" \ + \ + "vld1.32 {d26}, [%[din2_ptr]] @ load din r0\n" \ + \ + "vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[2]\n" \ + "vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]\n" \ + \ + "vext.32 q6, %q[vzero], q14, #3 @ 0012\n" \ + "vext.32 q7, q14, q15, #1 @ 1234\n" + +#define LEFT_RESULT_S1 \ + /* r3 */ \ + "vmla.f32 q5, q14, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" \ + "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ + \ + "vmla.f32 q5, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]\n" \ + \ + "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" \ + "vdup.32 q4, %[bias_val] @ and \n" \ + \ + "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]\n" \ + \ + "vext.32 q6, q8, q9, #1 @ 1234\n" \ + "vext.32 q7, q8, q9, #2 @ 2345\n" \ + "cmp %[cnt], #1 @ check whether has mid cols\n" \ + \ + "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" \ + \ + "vdup.32 q5, %[bias_val] @ and \n" \ + "blt 3f @ jump to main loop start point\n" + +#define MID_COMPUTE_S1 \ + "1: @ right pad entry\n" /* r0 */ \ + "vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n" \ + \ + "pld [%[din0_ptr]] @ preload data\n" \ + "pld [%[din1_ptr]] @ preload data\n" \ + "pld [%[din2_ptr]] @ preload data\n" \ + "pld [%[din3_ptr]] @ preload data\n" \ + \ + "vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0\n" \ + \ + "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n" \ + \ + "vld1.32 {d18}, [%[din0_ptr]] @ load din r0\n" \ + \ + "vext.32 q6, q10, q11, #1 @ 1234\n" \ + "vext.32 q7, q10, q11, #2 @ 2345\n" /* r1 */ \ + "vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" \ + "vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n" \ + \ + "vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r0\n" \ + \ + "vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" \ + "vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vld1.32 {d22}, [%[din1_ptr]] @ load din r0\n" \ + \ + "vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n" \ + "vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" \ + \ + "vext.32 q6, q12, q13, #1 @ 1234\n" \ + "vext.32 q7, q12, q13, #2 @ 2345\n" /* r2 */ \ + "vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" \ + "vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n" \ + \ + "vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r0\n" \ + \ + "vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" \ + "vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vld1.32 {d26}, [%[din2_ptr]] @ load din r0\n" \ + \ + "vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" \ + "vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n" \ + \ + "vext.32 q6, q14, q15, #1 @ 1234\n" \ + "vext.32 q7, q14, q15, #2 @ 2345\n" + +#define MID_RESULT_S1 \ + /* r3 */ \ + "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \ + \ + "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" \ + "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ + \ + "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" \ + "vdup.32 q4, %[bias_val] @ and \n" \ + \ + "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \ + \ + "vext.32 q6, q8, q9, #1 @ 1234\n" \ + "vext.32 q7, q8, q9, #2 @ 2345\n" \ + \ + "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" \ + \ + "subs %[cnt], #1 @ loop count minus 1\n" \ + \ + "vdup.32 q5, %[bias_val] @ and \n" \ + \ + "bne 1b @ jump to main loop start point\n" + +#define RIGHT_COMPUTE_S1 \ + "3: @ right pad entry\n" \ + "vld1.32 {d19}, [%[vmask]]! @ load din r0\n" \ + "vld1.32 {d23}, [%[vmask]]! @ load din r0\n" \ + \ + "vld1.32 {d27}, [%[vmask]]! @ load din r0\n" \ + "vld1.32 {d31}, [%[vmask]]! @ load din r0\n" \ + \ + "vbif d16, %e[vzero], d19 @ bit select, deal with right pad\n" \ + "vbif d17, %e[vzero], d23 @ bit select, deal with right pad\n" \ + "vbif d18, %e[vzero], d27 @ bit select, deal with right pad\n" \ + \ + "vbif d20, %e[vzero], d19 @ bit select, deal with right pad\n" \ + "vbif d21, %e[vzero], d23 @ bit select, deal with right pad\n" \ + "vbif d22, %e[vzero], d27 @ bit select, deal with right pad\n" \ + \ + "vext.32 q6, q8, q9, #1 @ 1234\n" \ + "vext.32 q7, q8, q9, #2 @ 2345\n" /* r0 */ \ + "vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n" \ + \ + "vbif d24, %e[vzero], d19 @ bit select, deal with right pad\n" \ + "vbif d25, %e[vzero], d23 @ bit select, deal with right pad\n" \ + "vbif d26, %e[vzero], d27 @ bit select, deal with right pad\n" \ + \ + "vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vbif d28, %e[vzero], d19 @ bit select, deal with right pad\n" \ + "vbif d29, %e[vzero], d23 @ bit select, deal with right pad\n" \ + "vbif d30, %e[vzero], d27 @ bit select, deal with right pad\n" \ + \ + "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n" \ + \ + "vext.32 q6, q10, q11, #1 @ 1234\n" \ + "vext.32 q7, q10, q11, #2 @ 2345\n" /* r1 */ \ + "vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" \ + "vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n" \ + \ + "vld1.32 {d19}, [%[rmask]]! @ load din r0\n" \ + "vld1.32 {d23}, [%[rmask]]! @ load din r0\n" \ + \ + "vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" \ + "vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vld1.32 {d16-d17}, [%[dout_ptr1]] @ load din r0\n" \ + "vld1.32 {d20-d21}, [%[dout_ptr2]] @ load din r0\n" \ + \ + "vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n" \ + "vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" \ + \ + "vext.32 q6, q12, q13, #1 @ 1234\n" \ + "vext.32 q7, q12, q13, #2 @ 2345\n" /* r2 */ \ + "vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" \ + "vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n" \ + \ + "vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" \ + "vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" \ + "vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n" \ + \ + "vext.32 q6, q14, q15, #1 @ 1234\n" \ + "vext.32 q7, q14, q15, #2 @ 2345\n" + +#define RIGHT_RESULT_S1 \ + /* r3 */ \ + "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \ + \ + "vbif d8, d16, d19 @ bit select, deal with right pad\n" \ + "vbif d9, d17, d23 @ bit select, deal with right pad\n" \ + \ + "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ + \ + "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \ + \ + "vbif d10, d20, d19 @ bit select, deal with right pad\n" \ + "vbif d11, d21, d23 @ bit select, deal with right pad\n" \ + \ + "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" + +#define LEFT_RESULT_S1_RELU \ + /* r3 */ \ + "vmla.f32 q5, q14, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" \ + "vmax.f32 q4, q4, %q[vzero] @ relu \n" \ + \ + "vmla.f32 q5, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]\n" \ + \ + "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" \ + "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ + \ + "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]\n" \ + \ + "vext.32 q6, q8, q9, #1 @ 1234\n" \ + "vext.32 q7, q8, q9, #2 @ 2345\n" \ + "vdup.32 q4, %[bias_val] @ and \n" \ + \ + "vmax.f32 q5, q5, %q[vzero] @ relu \n" \ + \ + "cmp %[cnt], #1 @ check whether has mid cols\n" \ + \ + "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" \ + \ + "vdup.32 q5, %[bias_val] @ and \n" \ + "blt 3f @ jump to main loop start point\n" + +#define LEFT_RESULT_S1_RELU6 \ + /* r3 */ \ + "vmla.f32 q5, q14, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vld1.f32 {d28-d29}, [%[six_ptr]] @ load six \n" \ + "vmax.f32 q4, q4, %q[vzero] @ relu \n" \ + \ + "vmla.f32 q5, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]\n" \ + \ + "vmin.f32 q4, q4, q14 @ relu6 \n" \ + \ + "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]\n" \ + \ + "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ + "vext.32 q6, q8, q9, #1 @ 1234\n" \ + "vext.32 q7, q8, q9, #2 @ 2345\n" \ + \ + "vmax.f32 q5, q5, %q[vzero] @ relu \n" \ + "vdup.32 q4, %[bias_val] @ and \n" \ + "vmin.f32 q5, q5, q14 @ relu6 \n" \ + "cmp %[cnt], #1 @ check whether has mid cols\n" \ + \ + "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" \ + "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" \ + \ + "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" \ + "vdup.32 q5, %[bias_val] @ and \n" \ + "blt 3f @ jump to main loop start point\n" + +#define LEFT_RESULT_S1_LEAKY_RELU \ + /* r3 */ \ + "vmla.f32 q5, q14, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ + "vld1.f32 {d28-d29}, [%[scale_ptr]] @ load scale \n" \ + \ + "vmla.f32 q5, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]\n" \ + "vcge.f32 q15, q4, %q[vzero] @ q0 > 0 \n" \ + "vmul.f32 q6, q4, q14 \n" \ + \ + "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]\n" \ + \ + "vbif q4, q6, q15 @ choose \n" \ + "vcge.f32 q7, q5, %q[vzero] @ q0 > 0 \n" \ + "vmul.f32 q6, q5, q14 \n" \ + \ + "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ + "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" \ + "vbif q5, q6, q7 @ choose \n" \ + \ + "vext.32 q6, q8, q9, #1 @ 1234\n" \ + "vext.32 q7, q8, q9, #2 @ 2345\n" \ + "vdup.32 q4, %[bias_val] @ and \n" \ + \ + "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" \ + "cmp %[cnt], #1 @ check whether has mid cols\n" \ + \ + "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" \ + \ + "vdup.32 q5, %[bias_val] @ and \n" \ + "blt 3f @ jump to main loop start point\n" + +#define MID_RESULT_S1_RELU \ + /* r3 */ \ + "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \ + \ + "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" \ + "vmax.f32 q4, q4, %q[vzero] @ relu \n" \ + \ + "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" \ + "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ + \ + "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \ + \ + "vext.32 q6, q8, q9, #1 @ 1234\n" \ + "vext.32 q7, q8, q9, #2 @ 2345\n" \ + "vdup.32 q4, %[bias_val] @ and \n" \ + \ + "vmax.f32 q5, q5, %q[vzero] @ relu \n" \ + \ + "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" \ + \ + "subs %[cnt], #1 @ loop count minus 1\n" \ + \ + "vdup.32 q5, %[bias_val] @ and \n" \ + \ + "bne 1b @ jump to main loop start point\n" + +#define MID_RESULT_S1_RELU6 \ + /* r3 */ \ + "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \ + \ + "vld1.32 {d28-d29}, [%[six_ptr]] @ load din r0\n" \ + "vmax.f32 q4, q4, %q[vzero] @ relu \n" \ + \ + "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vmin.f32 q4, q4, q14 @ relu6 \n" \ + \ + "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \ + "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ + \ + "vext.32 q6, q8, q9, #1 @ 1234\n" \ + "vext.32 q7, q8, q9, #2 @ 2345\n" \ + \ + "vmax.f32 q5, q5, %q[vzero] @ relu \n" \ + "vdup.32 q4, %[bias_val] @ and \n" \ + \ + "vmin.f32 q5, q5, q14 @ relu6 \n" \ + "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" \ + "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" \ + \ + "subs %[cnt], #1 @ loop count minus 1\n" \ + "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" \ + \ + "vdup.32 q5, %[bias_val] @ and \n" \ + \ + "bne 1b @ jump to main loop start point\n" + +#define MID_RESULT_S1_LEAKY_RELU \ + /* r3 */ \ + "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \ + \ + "vld1.32 {d28-d29}, [%[scale_ptr]] @ load din r0\n" \ + \ + "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vcge.f32 q15, q4, %q[vzero] @ q0 > 0 \n" \ + "vmul.f32 q6, q4, q14 \n" \ + "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \ + \ + "vbif q4, q6, q15 @ choose \n" \ + "vcge.f32 q7, q5, %q[vzero] @ q0 > 0 \n" \ + "vmul.f32 q6, q5, q14 \n" \ + "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ + "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" \ + \ + "vbif q5, q6, q7 @ choose \n" \ + "vext.32 q6, q8, q9, #1 @ 1234\n" \ + "vext.32 q7, q8, q9, #2 @ 2345\n" \ + "vdup.32 q4, %[bias_val] @ and \n" \ + \ + "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" \ + \ + "subs %[cnt], #1 @ loop count minus 1\n" \ + \ + "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" \ + "vdup.32 q5, %[bias_val] @ and \n" \ + \ + "bne 1b @ jump to main loop start point\n" + +#define RIGHT_RESULT_S1_RELU \ + /* r3 */ \ + "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \ + \ + "vmax.f32 q4, q4, %q[vzero] @ relu \n" \ + \ + "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vbif d8, d16, d19 @ bit select, deal with right pad\n" \ + "vbif d9, d17, d23 @ bit select, deal with right pad\n" \ + \ + "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \ + "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ + \ + "vmax.f32 q5, q5, %q[vzero] @ relu \n" \ + \ + "vbif d10, d20, d19 @ bit select, deal with right pad\n" \ + "vbif d11, d21, d23 @ bit select, deal with right pad\n" \ + \ + "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" + +#define RIGHT_RESULT_S1_RELU6 \ + /* r3 */ \ + "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \ + \ + "vld1.32 {d28-d29}, [%[six_ptr]] @ load din r0\n" \ + "vmax.f32 q4, q4, %q[vzero] @ relu \n" \ + \ + "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vmin.f32 q4, q4, q14 @ relu6 \n" \ + \ + "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \ + "vbif d8, d16, d19 @ bit select, deal with right pad\n" \ + "vbif d9, d17, d23 @ bit select, deal with right pad\n" \ + \ + "vmax.f32 q5, q5, %q[vzero] @ relu \n" \ + "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ + \ + "vmin.f32 q5, q5, q14 @ relu6 \n" \ + "vbif d10, d20, d19 @ bit select, deal with right pad\n" \ + "vbif d11, d21, d23 @ bit select, deal with right pad\n" \ + \ + "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" + +#define RIGHT_RESULT_S1_LEAKY_RELU \ + /* r3 */ \ + "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \ + \ + "vld1.32 {d28-d29}, [%[scale_ptr]] @ load din r0\n" \ + \ + "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vcge.f32 q15, q4, %q[vzero] @ q0 > 0 \n" \ + "vmul.f32 q6, q4, q14 \n" \ + \ + "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \ + "vbif q4, q6, q15 @ choose \n" \ + \ + "vcge.f32 q7, q5, %q[vzero] @ q0 > 0 \n" \ + "vmul.f32 q6, q5, q14 \n" \ + \ + "vbif d8, d16, d19 @ bit select, deal with right pad\n" \ + "vbif d9, d17, d23 @ bit select, deal with right pad\n" \ + "vbif q5, q6, q7 @ choose \n" \ + \ + "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ + \ + "vbif d10, d20, d19 @ bit select, deal with right pad\n" \ + "vbif d11, d21, d23 @ bit select, deal with right pad\n" \ + \ + "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" + +#define COMPUTE_S_S1 \ + "pld [%[din0]]\n" \ + "pld [%[din1]]\n" \ + "pld [%[din2]]\n" \ + "pld [%[din3]]\n" \ + \ + "vld1.32 {d12-d13}, [%[din0]]!\n" \ + "vld1.32 {d14-d15}, [%[din1]]!\n" \ + "vld1.32 {d16-d17}, [%[din2]]!\n" \ + "vld1.32 {d18-d19}, [%[din3]]!\n" \ + \ + "vbif q6, %q[vzero], %q[mask]\n" \ + "vbif q7, %q[vzero], %q[mask]\n" \ + "vbif q8, %q[vzero], %q[mask]\n" \ + "vbif q9, %q[vzero], %q[mask]\n" \ + \ + "vmul.f32 q14, q6, %e[wr0][1]\n" \ + "vmul.f32 q15, q7, %e[wr0][1]\n" \ + \ + "vmla.f32 q14, q7, %e[wr1][1]\n" \ + "vmla.f32 q15, q8, %e[wr1][1]\n" \ + \ + "vmla.f32 q14, q8, %e[wr2][1]\n" \ + "vmla.f32 q15, q9, %e[wr2][1]\n" \ + \ + "vext.32 q10, %q[vzero], q6, #3\n" \ + "vext.32 q11, %q[vzero], q7, #3\n" \ + "vext.32 q12, %q[vzero], q8, #3\n" \ + "vext.32 q13, %q[vzero], q9, #3\n" \ + \ + "vmla.f32 q14, q10, %e[wr0][0]\n" \ + "vmla.f32 q15, q11, %e[wr0][0]\n" \ + \ + "vmla.f32 q14, q11, %e[wr1][0]\n" \ + "vmla.f32 q15, q12, %e[wr1][0]\n" \ + \ + "vmla.f32 q14, q12, %e[wr2][0]\n" \ + "vmla.f32 q15, q13, %e[wr2][0]\n" \ + \ + "vext.32 q10, q6, %q[vzero], #1\n" \ + "vext.32 q11, q7, %q[vzero], #1\n" \ + "vext.32 q12, q8, %q[vzero], #1\n" \ + "vext.32 q13, q9, %q[vzero], #1\n" \ + \ + "vmla.f32 q14, q10, %f[wr0][0]\n" \ + "vmla.f32 q15, q11, %f[wr0][0]\n" \ + \ + "vmla.f32 q14, q11, %f[wr1][0]\n" \ + "vmla.f32 q15, q12, %f[wr1][0]\n" \ + \ + "vmla.f32 q14, q12, %f[wr2][0]\n" \ + "vmla.f32 q15, q13, %f[wr2][0]\n" \ + \ + "vadd.f32 q14, q14, %q[bias]\n" \ + "vadd.f32 q15, q15, %q[bias]\n" + +#define RESULT_S_S1 \ + "pld [%[out1]]\n" \ + "pld [%[out2]]\n" \ + \ + "vst1.32 {d28-d29}, [%[out1]]\n" \ + "vst1.32 {d30-d31}, [%[out2]]\n" + +#define RESULT_S_S1_RELU \ + "pld [%[out1]]\n" \ + "pld [%[out2]]\n" \ + \ + "vmax.f32 q14, q14, %q[vzero]\n" \ + "vmax.f32 q15, q15, %q[vzero]\n" \ + \ + "vst1.32 {d28-d29}, [%[out1]]\n" \ + "vst1.32 {d30-d31}, [%[out2]]\n" + +#define RESULT_S_S1_RELU6 \ + "pld [%[out1]]\n" \ + "pld [%[out2]]\n" \ + \ + "vld1.32 {d20-d21}, [%[six_ptr]] \n" \ + "vmax.f32 q14, q14, %q[vzero]\n" \ + "vmax.f32 q15, q15, %q[vzero]\n" \ + \ + "vmin.f32 q14, q14, q10 \n" \ + "vmin.f32 q15, q15, q10 \n" \ + \ + "vst1.32 {d28-d29}, [%[out1]]\n" \ + "vst1.32 {d30-d31}, [%[out2]]\n" + +#define RESULT_S_S1_LEAKY_RELU \ + "pld [%[out1]]\n" \ + "pld [%[out2]]\n" \ + \ + "vld1.32 {d18-d19}, [%[scale_ptr]] \n" \ + "vcge.f32 q10, q14, %q[vzero] @ q0 > 0 \n" \ + "vcge.f32 q11, q15, %q[vzero] @ q0 > 0 \n" \ + "vmul.f32 q12, q14, q9 \n" \ + "vmul.f32 q13, q15, q9 \n" \ + \ + "vbif q14, q12, q10 \n" \ + "vbif q15, q13, q11 \n" \ + \ + "vst1.32 {d28-d29}, [%[out1]]\n" \ + "vst1.32 {d30-d31}, [%[out2]]\n" + +#define COMPUTE_S_S1_P0 \ + "pld [%[din0]]\n" \ + "pld [%[din1]]\n" \ + "pld [%[din2]]\n" \ + "pld [%[din3]]\n" \ + "vld1.32 {d16-d18}, [%[din0]] @ load din r0\n" \ + "vld1.32 {d20-d22}, [%[din1]] @ load din r1\n" \ + "vld1.32 {d24-d26}, [%[din2]] @ load din r2\n" \ + "vld1.32 {d28-d30}, [%[din3]] @ load din r3\n" \ + \ + "vdup.32 q4, %[bias_val] @ and \n" \ + "vdup.32 q5, %[bias_val] @ and \n" \ + \ + "vld1.32 {d19}, [%[vmask]]! @ load din r0\n" \ + "vld1.32 {d23}, [%[vmask]]! @ load din r0\n" \ + \ + "vld1.32 {d27}, [%[vmask]]! @ load din r0\n" \ + \ + "vbif d16, %e[vzero], d19 @ bit select, deal with right pad\n" \ + "vbif d20, %e[vzero], d19 @ bit select, deal with right pad\n" \ + \ + "vbif d17, %e[vzero], d23 @ bit select, deal with right pad\n" \ + "vbif d21, %e[vzero], d23 @ bit select, deal with right pad\n" \ + \ + "vbif d18, %e[vzero], d27 @ bit select, deal with right pad\n" \ + "vbif d22, %e[vzero], d27 @ bit select, deal with right pad\n" \ + \ + "vext.32 q6, q8, q9, #1 @ 1234\n" \ + "vext.32 q7, q8, q9, #2 @ 2345\n" /* r0 */ \ + "vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n" \ + \ + "vbif d24, %e[vzero], d19 @ bit select, deal with right pad\n" \ + "vbif d25, %e[vzero], d23 @ bit select, deal with right pad\n" \ + "vbif d26, %e[vzero], d27 @ bit select, deal with right pad\n" \ + \ + "vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vbif d28, %e[vzero], d19 @ bit select, deal with right pad\n" \ + "vbif d29, %e[vzero], d23 @ bit select, deal with right pad\n" \ + "vbif d30, %e[vzero], d27 @ bit select, deal with right pad\n" \ + \ + "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n" \ + \ + "vext.32 q6, q10, q11, #1 @ 1234\n" \ + "vext.32 q7, q10, q11, #2 @ 2345\n" /* r1 */ \ + "vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" \ + "vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n" \ + \ + "vmul.f32 q8, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" \ + "vmul.f32 q10, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vmul.f32 q9, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n" \ + "vmul.f32 q11, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" \ + \ + "vext.32 q6, q12, q13, #1 @ 1234\n" \ + "vext.32 q7, q12, q13, #2 @ 2345\n" /* r2 */ \ + "vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" \ + "vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n" \ + \ + "vmla.f32 q8, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" \ + "vmla.f32 q10, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vmla.f32 q9, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" \ + "vmla.f32 q11, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n" \ + \ + "vext.32 q6, q14, q15, #1 @ 1234\n" \ + "vext.32 q7, q14, q15, #2 @ 2345\n" /* r3 */ \ + "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \ + \ + "vmla.f32 q8, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ + "vadd.f32 q4, q4, q10 @ q4 += q10 \n" \ + \ + "pld [%[out1]]\n" \ + "pld [%[out2]]\n" \ + \ + "vmla.f32 q9, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \ + "vadd.f32 q14, q4, q11 @ q4 += q10 \n" \ + \ + "vadd.f32 q5, q5, q8 @ q4 += q10 \n" \ + "vadd.f32 q15, q5, q9 @ q4 += q10 \n" + +#endif + +#ifdef __aarch64__ +void act_switch_3x3s1p1(const float *din_ptr0, + const float *din_ptr1, + const float *din_ptr2, + const float *din_ptr3, + const float *din_ptr4, + const float *din_ptr5, + float *doutr0, + float *doutr1, + float *doutr2, + float *doutr3, + float32x4_t wr0, + float32x4_t wr1, + float32x4_t wr2, + unsigned int *vmask, + unsigned int *rmask, + float32x4_t vzero, + float *vbias, + int cnt, + const operators::ActivationParam act_param) { + bool has_active = act_param.has_active; + if (has_active) { + float32x4_t vsix = vdupq_n_f32(act_param.Relu_clipped_coef); + float32x4_t vscale = vdupq_n_f32(act_param.Leaky_relu_alpha); + + switch (act_param.active_type) { + case lite_api::ActivationType::kRelu: + asm volatile( + INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU MID_COMPUTE_S1 + MID_RESULT_S1_RELU RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [din_ptr4] "+r"(din_ptr4), + [din_ptr5] "+r"(din_ptr5), + [doutr0] "+r"(doutr0), + [doutr1] "+r"(doutr1), + [doutr2] "+r"(doutr2), + [doutr3] "+r"(doutr3) + : [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [bias_val] "r"(vbias), + [vmask] "r"(vmask), + [rmask] "r"(rmask), + [vzero] "w"(vzero) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25"); + break; + case lite_api::ActivationType::kRelu6: + /* 0 <= din <= 6 */ + asm volatile( + INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU6 MID_COMPUTE_S1 + MID_RESULT_S1_RELU6 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU6 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [din_ptr4] "+r"(din_ptr4), + [din_ptr5] "+r"(din_ptr5), + [doutr0] "+r"(doutr0), + [doutr1] "+r"(doutr1), + [doutr2] "+r"(doutr2), + [doutr3] "+r"(doutr3) + : [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [vsix] "w"(vsix), + [bias_val] "r"(vbias), + [vmask] "r"(vmask), + [rmask] "r"(rmask), + [vzero] "w"(vzero) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25"); + break; + case lite_api::ActivationType::kLeakyRelu: + /*din = din >= 0 ? din : din * scale*/ + asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_LEAKY_RELU + MID_COMPUTE_S1 MID_RESULT_S1_LEAKY_RELU + RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_LEAKY_RELU + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [din_ptr4] "+r"(din_ptr4), + [din_ptr5] "+r"(din_ptr5), + [doutr0] "+r"(doutr0), + [doutr1] "+r"(doutr1), + [doutr2] "+r"(doutr2), + [doutr3] "+r"(doutr3) + : [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [vscale] "w"(vscale), + [bias_val] "r"(vbias), + [vmask] "r"(vmask), + [rmask] "r"(rmask), + [vzero] "w"(vzero) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25"); + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param.active_type) + << " fuse not support"; + } + } else { + asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1 MID_COMPUTE_S1 + MID_RESULT_S1 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [din_ptr4] "+r"(din_ptr4), + [din_ptr5] "+r"(din_ptr5), + [doutr0] "+r"(doutr0), + [doutr1] "+r"(doutr1), + [doutr2] "+r"(doutr2), + [doutr3] "+r"(doutr3) + : [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [bias_val] "r"(vbias), + [vmask] "r"(vmask), + [rmask] "r"(rmask), + [vzero] "w"(vzero) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25"); + } +} +#else +void act_switch_3x3s1p1(const float *din_ptr0, + const float *din_ptr1, + const float *din_ptr2, + const float *din_ptr3, + float *doutr0, + float *doutr1, + float32x4_t wr0, + float32x4_t wr1, + float32x4_t wr2, + unsigned int *vmask_ptr, + unsigned int *rmask_ptr, + float32x4_t vzero, + float bias_val, + int cnt, + const operators::ActivationParam act_param) { + bool has_active = act_param.has_active; + if (has_active) { + float tmp = act_param.Relu_clipped_coef; + float ss = act_param.Leaky_relu_alpha; + float vsix[4] = {tmp, tmp, tmp, tmp}; + float vscale[4] = {ss, ss, ss, ss}; + + switch (act_param.active_type) { + case lite_api::ActivationType::kRelu: + asm volatile( + INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU MID_COMPUTE_S1 + MID_RESULT_S1_RELU RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU + : [dout_ptr1] "+r"(doutr0), + [dout_ptr2] "+r"(doutr1), + [din0_ptr] "+r"(din_ptr0), + [din1_ptr] "+r"(din_ptr1), + [din2_ptr] "+r"(din_ptr2), + [din3_ptr] "+r"(din_ptr3), + [cnt] "+r"(cnt), + [rmask] "+r"(rmask_ptr), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias_val] "r"(bias_val), + [vzero] "w"(vzero) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + break; + case lite_api::ActivationType::kRelu6: + /* 0 <= din <= 6 */ + asm volatile( + INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU6 MID_COMPUTE_S1 + MID_RESULT_S1_RELU6 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU6 + : [dout_ptr1] "+r"(doutr0), + [dout_ptr2] "+r"(doutr1), + [din0_ptr] "+r"(din_ptr0), + [din1_ptr] "+r"(din_ptr1), + [din2_ptr] "+r"(din_ptr2), + [din3_ptr] "+r"(din_ptr3), + [cnt] "+r"(cnt), + [rmask] "+r"(rmask_ptr), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias_val] "r"(bias_val), + [six_ptr] "r"(vsix), + [vzero] "w"(vzero) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + break; + case lite_api::ActivationType::kLeakyRelu: + /*din = din >= 0 ? din : din * scale*/ + asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_LEAKY_RELU + MID_COMPUTE_S1 MID_RESULT_S1_LEAKY_RELU + RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_LEAKY_RELU + : [dout_ptr1] "+r"(doutr0), + [dout_ptr2] "+r"(doutr1), + [din0_ptr] "+r"(din_ptr0), + [din1_ptr] "+r"(din_ptr1), + [din2_ptr] "+r"(din_ptr2), + [din3_ptr] "+r"(din_ptr3), + [cnt] "+r"(cnt), + [rmask] "+r"(rmask_ptr), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias_val] "r"(bias_val), + [scale_ptr] "r"(vscale), + [vzero] "w"(vzero) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param.active_type) + << " fuse not support"; + } + } else { + asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1 MID_COMPUTE_S1 + MID_RESULT_S1 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1 + : [dout_ptr1] "+r"(doutr0), + [dout_ptr2] "+r"(doutr1), + [din0_ptr] "+r"(din_ptr0), + [din1_ptr] "+r"(din_ptr1), + [din2_ptr] "+r"(din_ptr2), + [din3_ptr] "+r"(din_ptr3), + [cnt] "+r"(cnt), + [rmask] "+r"(rmask_ptr), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias_val] "r"(bias_val), + [vzero] "w"(vzero) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + } +} +#endif +/** + * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, + * width > 4 + */ +void conv_depthwise_3x3s1p1_bias(float *dout, + const float *din, + const float *weights, + const float *bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + const operators::ActivationParam act_param, + ARMContext *ctx) { + //! pad is done implicit + const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; + //! for 4x6 convolution window + const unsigned int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0}; + + float *zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(float)); + float *write_ptr = zero_ptr + w_in; + + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + int w_stride = 9; + + int tile_w = w_out >> 2; + int remain = w_out % 4; + int cnt_col = tile_w - 1; + + unsigned int size_pad_right = (unsigned int)(5 + (tile_w << 2) - w_in); + const unsigned int remian_idx[4] = {0, 1, 2, 3}; + + if (remain == 0 && size_pad_right == 5) { + size_pad_right = 1; + cnt_col -= 1; + remain = 4; + } else if (remain == 0 && size_pad_right == 6) { + size_pad_right = 2; + cnt_col -= 1; + remain = 4; + } + + uint32x4_t vmask_rp1 = + vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right)); + uint32x4_t vmask_rp2 = + vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right)); + uint32x4_t vmask_result = + vcgtq_u32(vdupq_n_u32(remain), vld1q_u32(remian_idx)); + + unsigned int vmask[8]; + vst1q_u32(vmask, vmask_rp1); + vst1q_u32(vmask + 4, vmask_rp2); + + unsigned int rmask[4]; + vst1q_u32(rmask, vmask_result); + + float32x4_t vzero = vdupq_n_f32(0.f); + + for (int n = 0; n < num; ++n) { + const float *din_batch = din + n * ch_in * size_in_channel; + float *dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int c = 0; c < ch_in; c++) { + float *dout_ptr = dout_batch + c * size_out_channel; + + const float *din_ch_ptr = din_batch + c * size_in_channel; + + float bias_val = flag_bias ? bias[c] : 0.f; + float vbias[4] = {bias_val, bias_val, bias_val, bias_val}; + + const float *wei_ptr = weights + c * w_stride; + + float32x4_t wr0 = vld1q_f32(wei_ptr); + float32x4_t wr1 = vld1q_f32(wei_ptr + 3); + float32x4_t wr2 = vld1q_f32(wei_ptr + 6); + + float *doutr0 = dout_ptr; + float *doutr1 = doutr0 + w_out; + float *doutr2 = doutr1 + w_out; + float *doutr3 = doutr2 + w_out; + + const float *dr0 = din_ch_ptr; + const float *dr1 = dr0 + w_in; + const float *dr2 = dr1 + w_in; + const float *dr3 = dr2 + w_in; + const float *dr4 = dr3 + w_in; + const float *dr5 = dr4 + w_in; + + const float *din_ptr0 = dr0; + const float *din_ptr1 = dr1; + const float *din_ptr2 = dr2; + const float *din_ptr3 = dr3; + const float *din_ptr4 = dr4; + const float *din_ptr5 = dr5; + float *ptr_zero = const_cast(zero); +#ifdef __aarch64__ + for (int i = 0; i < h_out; i += 4) { + //! process top pad pad_h = 1 + din_ptr0 = dr0; + din_ptr1 = dr1; + din_ptr2 = dr2; + din_ptr3 = dr3; + din_ptr4 = dr4; + din_ptr5 = dr5; + + doutr0 = dout_ptr; + doutr1 = doutr0 + w_out; + doutr2 = doutr1 + w_out; + doutr3 = doutr2 + w_out; + if (i == 0) { + din_ptr0 = zero_ptr; + din_ptr1 = dr0; + din_ptr2 = dr1; + din_ptr3 = dr2; + din_ptr4 = dr3; + din_ptr5 = dr4; + dr0 = dr3; + dr1 = dr4; + dr2 = dr5; + } else { + dr0 = dr4; + dr1 = dr5; + dr2 = dr1 + w_in; + } + dr3 = dr2 + w_in; + dr4 = dr3 + w_in; + dr5 = dr4 + w_in; + + //! process bottom pad + if (i + 5 > h_in) { + switch (i + 5 - h_in) { + case 5: + din_ptr1 = zero_ptr; + case 4: + din_ptr2 = zero_ptr; + case 3: + din_ptr3 = zero_ptr; + case 2: + din_ptr4 = zero_ptr; + case 1: + din_ptr5 = zero_ptr; + default: + break; + } + } + //! process bottom remain + if (i + 4 > h_out) { + switch (i + 4 - h_out) { + case 3: + doutr1 = write_ptr; + case 2: + doutr2 = write_ptr; + case 1: + doutr3 = write_ptr; + default: + break; + } + } + + int cnt = cnt_col; + act_switch_3x3s1p1(din_ptr0, + din_ptr1, + din_ptr2, + din_ptr3, + din_ptr4, + din_ptr5, + doutr0, + doutr1, + doutr2, + doutr3, + wr0, + wr1, + wr2, + vmask, + rmask, + vzero, + vbias, + cnt, + act_param); + dout_ptr = dout_ptr + 4 * w_out; + } +#else + for (int i = 0; i < h_out; i += 2) { + //! process top pad pad_h = 1 + din_ptr0 = dr0; + din_ptr1 = dr1; + din_ptr2 = dr2; + din_ptr3 = dr3; + + doutr0 = dout_ptr; + doutr1 = dout_ptr + w_out; + + if (i == 0) { + din_ptr0 = zero_ptr; + din_ptr1 = dr0; + din_ptr2 = dr1; + din_ptr3 = dr2; + dr0 = dr1; + dr1 = dr2; + dr2 = dr3; + dr3 = dr2 + w_in; + } else { + dr0 = dr2; + dr1 = dr3; + dr2 = dr1 + w_in; + dr3 = dr2 + w_in; + } + //! process bottom pad + if (i + 3 > h_in) { + switch (i + 3 - h_in) { + case 3: + din_ptr1 = zero_ptr; + case 2: + din_ptr2 = zero_ptr; + case 1: + din_ptr3 = zero_ptr; + default: + break; + } + } + //! process bottom remain + if (i + 2 > h_out) { + doutr1 = write_ptr; + } + int cnt = cnt_col; + unsigned int *rmask_ptr = rmask; + unsigned int *vmask_ptr = vmask; + act_switch_3x3s1p1(din_ptr0, + din_ptr1, + din_ptr2, + din_ptr3, + doutr0, + doutr1, + wr0, + wr1, + wr2, + vmask_ptr, + rmask_ptr, + vzero, + bias_val, + cnt, + act_param); + dout_ptr += 2 * w_out; + } //! end of processing mid rows +#endif + } + } +} +void act_switch_3x3s1p1_s(const float *din_ptr0, + const float *din_ptr1, + const float *din_ptr2, + const float *din_ptr3, + float *doutr0, + float *doutr1, + float32x4_t wr0, + float32x4_t wr1, + float32x4_t wr2, + uint32x4_t vmask_rp, + float32x4_t vzero, + float32x4_t wbias, + const operators::ActivationParam act_param) { + bool has_active = act_param.has_active; + if (has_active) { +#ifdef __aarch64__ + float32x4_t vsix = vdupq_n_f32(act_param.Relu_clipped_coef); + float32x4_t vscale = vdupq_n_f32(act_param.Leaky_relu_alpha); +#else + float tmp = act_param.Relu_clipped_coef; + float ss = act_param.Leaky_relu_alpha; + float vsix[4] = {tmp, tmp, tmp, tmp}; + float vscale[4] = {ss, ss, ss, ss}; +#endif + switch (act_param.active_type) { + case lite_api::ActivationType::kRelu: +#ifdef __aarch64__ + asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vzero] "w"(vzero), + [mask] "w"(vmask_rp), + [bias] "w"(wbias), + [out1] "r"(doutr0), + [out2] "r"(doutr1) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17"); + break; +#else + asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vzero] "w"(vzero), + [mask] "w"(vmask_rp), + [bias] "w"(wbias), + [out1] "r"(doutr0), + [out2] "r"(doutr1) + : "cc", + "memory", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + break; +#endif + case lite_api::ActivationType::kRelu6: +/* 0 <= din <= 6 */ +#ifdef __aarch64__ + asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU6 + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vzero] "w"(vzero), + [mask] "w"(vmask_rp), + [bias] "w"(wbias), + [vsix] "w"(vsix), + [out1] "r"(doutr0), + [out2] "r"(doutr1) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17"); + break; +#else + asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU6 + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vzero] "w"(vzero), + [mask] "w"(vmask_rp), + [bias] "w"(wbias), + [six_ptr] "r"(vsix), + [out1] "r"(doutr0), + [out2] "r"(doutr1) + : "cc", + "memory", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + break; +#endif + case lite_api::ActivationType::kLeakyRelu: +/*din = din >= 0 ? din : din * scale*/ +#ifdef __aarch64__ + asm volatile(COMPUTE_S_S1 RESULT_S_S1_LEAKY_RELU + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vzero] "w"(vzero), + [mask] "w"(vmask_rp), + [bias] "w"(wbias), + [vscale] "w"(vscale), + [out1] "r"(doutr0), + [out2] "r"(doutr1) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20"); + break; +#else + asm volatile(COMPUTE_S_S1 RESULT_S_S1_LEAKY_RELU + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vzero] "w"(vzero), + [mask] "w"(vmask_rp), + [bias] "w"(wbias), + [scale_ptr] "r"(vscale), + [out1] "r"(doutr0), + [out2] "r"(doutr1) + : "cc", + "memory", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + break; +#endif + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param.active_type) + << " fuse not support"; + } + } else { +#ifdef __aarch64__ + asm volatile(COMPUTE_S_S1 RESULT_S_S1 + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vzero] "w"(vzero), + [mask] "w"(vmask_rp), + [bias] "w"(wbias), + [out1] "r"(doutr0), + [out2] "r"(doutr1) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17"); +#else + asm volatile(COMPUTE_S_S1 RESULT_S_S1 + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vzero] "w"(vzero), + [mask] "w"(vmask_rp), + [bias] "w"(wbias), + [out1] "r"(doutr0), + [out2] "r"(doutr1) + : "cc", + "memory", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + } +} +/** + * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, + * width <= 4 + */ +void conv_depthwise_3x3s1p1_bias_s(float *dout, + const float *din, + const float *weights, + const float *bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + const operators::ActivationParam act_param, + ARMContext *ctx) { + //! 3x3s1 convolution, implemented by direct algorithm + //! pad is done implicit + //! for 4x6 convolution window + const int right_pad_idx[4] = {3, 2, 1, 0}; + const float zero[4] = {0.f, 0.f, 0.f, 0.f}; + + float32x4_t vzero = vdupq_n_f32(0.f); + uint32x4_t vmask_rp = + vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(4 - w_in)); + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + for (int n = 0; n < num; ++n) { + const float *din_batch = din + n * ch_in * size_in_channel; + float *dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int i = 0; i < ch_in; ++i) { + float *dout_channel = dout_batch + i * size_out_channel; + const float *din_channel = din_batch + i * size_in_channel; + const float *weight_ptr = weights + i * 9; + float32x4_t wr0 = vld1q_f32(weight_ptr); + float32x4_t wr1 = vld1q_f32(weight_ptr + 3); + float32x4_t wr2 = vld1q_f32(weight_ptr + 6); + float32x4_t wbias; + if (flag_bias) { + wbias = vdupq_n_f32(bias[i]); + } else { + wbias = vdupq_n_f32(0.f); + } + + float out_buf1[4]; + float out_buf2[4]; + float trash_buf[4]; + + float *doutr0 = dout_channel; + float *doutr1 = dout_channel + w_out; + + const float *dr0 = din_channel; + const float *dr1 = dr0 + w_in; + const float *dr2 = dr1 + w_in; + const float *dr3 = dr2 + w_in; + + for (int j = 0; j < h_out; j += 2) { + const float *dr0_ptr = dr0; + const float *dr1_ptr = dr1; + const float *dr2_ptr = dr2; + const float *dr3_ptr = dr3; + if (j == 0) { + dr0_ptr = zero; + dr1_ptr = dr0; + dr2_ptr = dr1; + dr3_ptr = dr2; + dr0 = dr1; + dr1 = dr2; + } else { + dr0 = dr2; + dr1 = dr3; + } + dr2 = dr1 + w_in; + dr3 = dr2 + w_in; + //! process bottom pad + if (j + 3 > h_in) { + switch (j + 3 - h_in) { + case 3: + dr1_ptr = zero; + case 2: + dr2_ptr = zero; + case 1: + dr3_ptr = zero; + default: + break; + } + } + //! process bottom remain + if (j + 2 > h_out) { + doutr1 = trash_buf; + } + act_switch_3x3s1p1_s(dr0_ptr, + dr1_ptr, + dr2_ptr, + dr3_ptr, + out_buf1, + out_buf2, + wr0, + wr1, + wr2, + vmask_rp, + vzero, + wbias, + act_param); + for (int w = 0; w < w_out; ++w) { + *doutr0++ = out_buf1[w]; + *doutr1++ = out_buf2[w]; + } + doutr0 = doutr1; + doutr1 += w_out; + } // end of processing heights + } // end of processing channels + } // end of processing batchs +} + +#ifdef __aarch64__ +void act_switch_3x3s1p0(const float *din_ptr0, + const float *din_ptr1, + const float *din_ptr2, + const float *din_ptr3, + const float *din_ptr4, + const float *din_ptr5, + float *doutr0, + float *doutr1, + float *doutr2, + float *doutr3, + float32x4_t wr0, + float32x4_t wr1, + float32x4_t wr2, + unsigned int *vmask, + unsigned int *rmask, + float32x4_t vzero, + float *vbias, + int cnt, + int remain, + const operators::ActivationParam act_param) { + bool has_active = act_param.has_active; + if (has_active) { + float32x4_t vsix = vdupq_n_f32(act_param.Relu_clipped_coef); + float32x4_t vscale = vdupq_n_f32(act_param.Leaky_relu_alpha); + + switch (act_param.active_type) { + case lite_api::ActivationType::kRelu: + asm volatile( + INIT_S1 + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */ + "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */ + "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ + MID_COMPUTE_S1 MID_RESULT_S1_RELU + "cmp %w[remain], #1 \n" + "blt 0f \n" RIGHT_COMPUTE_S1 + RIGHT_RESULT_S1_RELU "0: \n" + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [din_ptr4] "+r"(din_ptr4), + [din_ptr5] "+r"(din_ptr5), + [doutr0] "+r"(doutr0), + [doutr1] "+r"(doutr1), + [doutr2] "+r"(doutr2), + [doutr3] "+r"(doutr3) + : [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [bias_val] "r"(vbias), + [vmask] "r"(vmask), + [rmask] "r"(rmask), + [vzero] "w"(vzero), + [remain] "r"(remain) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25"); + break; + case lite_api::ActivationType::kRelu6: + /* 0 <= din <= 6 */ + asm volatile( + INIT_S1 + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */ + "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */ + "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ + MID_COMPUTE_S1 MID_RESULT_S1_RELU6 + "cmp %w[remain], #1 \n" + "blt 0f \n" RIGHT_COMPUTE_S1 + RIGHT_RESULT_S1_RELU6 "0: \n" + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [din_ptr4] "+r"(din_ptr4), + [din_ptr5] "+r"(din_ptr5), + [doutr0] "+r"(doutr0), + [doutr1] "+r"(doutr1), + [doutr2] "+r"(doutr2), + [doutr3] "+r"(doutr3) + : [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [vsix] "w"(vsix), + [bias_val] "r"(vbias), + [vmask] "r"(vmask), + [rmask] "r"(rmask), + [remain] "r"(remain) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25"); + break; + case lite_api::ActivationType::kLeakyRelu: + /*din = din >= 0 ? din : din * scale*/ + asm volatile( + INIT_S1 + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */ + "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */ + "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ + MID_COMPUTE_S1 MID_RESULT_S1_LEAKY_RELU + "cmp %w[remain], #1 \n" + "blt 0f \n" RIGHT_COMPUTE_S1 + RIGHT_RESULT_S1_LEAKY_RELU "0: \n" + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [din_ptr4] "+r"(din_ptr4), + [din_ptr5] "+r"(din_ptr5), + [doutr0] "+r"(doutr0), + [doutr1] "+r"(doutr1), + [doutr2] "+r"(doutr2), + [doutr3] "+r"(doutr3) + : [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [vscale] "w"(vscale), + [bias_val] "r"(vbias), + [vmask] "r"(vmask), + [rmask] "r"(rmask), + [remain] "r"(remain) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25"); + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param.active_type) + << " fuse not support"; + } + } else { + asm volatile( + INIT_S1 + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */ + "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */ + "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ + MID_COMPUTE_S1 MID_RESULT_S1 + "cmp %w[remain], #1 \n" + "blt 0f \n" RIGHT_COMPUTE_S1 RIGHT_RESULT_S1 + "0: \n" + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [din_ptr4] "+r"(din_ptr4), + [din_ptr5] "+r"(din_ptr5), + [doutr0] "+r"(doutr0), + [doutr1] "+r"(doutr1), + [doutr2] "+r"(doutr2), + [doutr3] "+r"(doutr3) + : [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [bias_val] "r"(vbias), + [vmask] "r"(vmask), + [rmask] "r"(rmask), + [vzero] "w"(vzero), + [remain] "r"(remain) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25"); + } +} +#else +void act_switch_3x3s1p0(const float *din_ptr0, + const float *din_ptr1, + const float *din_ptr2, + const float *din_ptr3, + float *doutr0, + float *doutr1, + float32x4_t wr0, + float32x4_t wr1, + float32x4_t wr2, + unsigned int *vmask_ptr, + unsigned int *rmask_ptr, + float32x4_t vzero, + float bias_val, + int cnt, + int remain, + const operators::ActivationParam act_param) { + bool has_active = act_param.has_active; + if (has_active) { + float tmp = act_param.Relu_clipped_coef; + float ss = act_param.Leaky_relu_alpha; + float vsix[4] = {tmp, tmp, tmp, tmp}; + float vscale[4] = {ss, ss, ss, ss}; + + switch (act_param.active_type) { + case lite_api::ActivationType::kRelu: + asm volatile(INIT_S1 + "sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n" + "sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n" + "sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n" + "sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n" + "vext.32 q6, q8, q9, #1 @ 0012\n" + "vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1 + MID_RESULT_S1_RELU + "cmp %[remain], #1 \n" + "blt 0f \n" RIGHT_COMPUTE_S1 + RIGHT_RESULT_S1_RELU "0: \n" + : [dout_ptr1] "+r"(doutr0), + [dout_ptr2] "+r"(doutr1), + [din0_ptr] "+r"(din_ptr0), + [din1_ptr] "+r"(din_ptr1), + [din2_ptr] "+r"(din_ptr2), + [din3_ptr] "+r"(din_ptr3), + [cnt] "+r"(cnt), + [rmask] "+r"(rmask_ptr), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias_val] "r"(bias_val), + [vzero] "w"(vzero), + [remain] "r"(remain) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + break; + case lite_api::ActivationType::kRelu6: + /* 0 <= din <= 6 */ + asm volatile(INIT_S1 + "sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n" + "sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n" + "sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n" + "sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n" + "vext.32 q6, q8, q9, #1 @ 0012\n" + "vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1 + MID_RESULT_S1_RELU6 + "cmp %[remain], #1 \n" + "blt 0f \n" RIGHT_COMPUTE_S1 + RIGHT_RESULT_S1_RELU6 "0: \n" + : [dout_ptr1] "+r"(doutr0), + [dout_ptr2] "+r"(doutr1), + [din0_ptr] "+r"(din_ptr0), + [din1_ptr] "+r"(din_ptr1), + [din2_ptr] "+r"(din_ptr2), + [din3_ptr] "+r"(din_ptr3), + [cnt] "+r"(cnt), + [rmask] "+r"(rmask_ptr), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [six_ptr] "r"(vsix), + [bias_val] "r"(bias_val), + [vzero] "w"(vzero), + [remain] "r"(remain) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + break; + case lite_api::ActivationType::kLeakyRelu: + /*din = din >= 0 ? din : din * scale*/ + asm volatile(INIT_S1 + "sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n" + "sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n" + "sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n" + "sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n" + "vext.32 q6, q8, q9, #1 @ 0012\n" + "vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1 + MID_RESULT_S1_LEAKY_RELU + "cmp %[remain], #1 \n" + "blt 0f \n" RIGHT_COMPUTE_S1 + RIGHT_RESULT_S1_LEAKY_RELU + "0: \n" + : [dout_ptr1] "+r"(doutr0), + [dout_ptr2] "+r"(doutr1), + [din0_ptr] "+r"(din_ptr0), + [din1_ptr] "+r"(din_ptr1), + [din2_ptr] "+r"(din_ptr2), + [din3_ptr] "+r"(din_ptr3), + [cnt] "+r"(cnt), + [rmask] "+r"(rmask_ptr), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [scale_ptr] "r"(vscale), + [bias_val] "r"(bias_val), + [vzero] "w"(vzero), + [remain] "r"(remain) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param.active_type) + << " fuse not support"; + } + } else { + asm volatile( + INIT_S1 + "sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n" + "sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n" + "sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n" + "sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n" + "vext.32 q6, q8, q9, #1 @ 0012\n" + "vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1 MID_RESULT_S1 + "cmp %[remain], #1 \n" + "blt 0f \n" RIGHT_COMPUTE_S1 RIGHT_RESULT_S1 + "0: \n" + : [dout_ptr1] "+r"(doutr0), + [dout_ptr2] "+r"(doutr1), + [din0_ptr] "+r"(din_ptr0), + [din1_ptr] "+r"(din_ptr1), + [din2_ptr] "+r"(din_ptr2), + [din3_ptr] "+r"(din_ptr3), + [cnt] "+r"(cnt), + [rmask] "+r"(rmask_ptr), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias_val] "r"(bias_val), + [vzero] "w"(vzero), + [remain] "r"(remain) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + } +} +#endif +/** + * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, + * width > 4 + */ +void conv_depthwise_3x3s1p0_bias(float *dout, + const float *din, + const float *weights, + const float *bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + const operators::ActivationParam act_param, + ARMContext *ctx) { + //! pad is done implicit + const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; + //! for 4x6 convolution window + const unsigned int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0}; + + float *zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(float)); + float *write_ptr = zero_ptr + w_in; + + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + int w_stride = 9; + + int tile_w = w_out >> 2; + int remain = w_out % 4; + + unsigned int size_pad_right = (unsigned int)(6 + (tile_w << 2) - w_in); + const int remian_idx[4] = {0, 1, 2, 3}; + + if (remain == 0 && size_pad_right == 6) { // w_in == w_out and w_out % 4 == 0 + tile_w -= 1; + remain = 4; + size_pad_right = 2; + } + + uint32x4_t vmask_rp1 = + vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right)); + uint32x4_t vmask_rp2 = + vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right)); + uint32x4_t vmask_result = + vcgtq_s32(vdupq_n_s32(remain), vld1q_s32(remian_idx)); + + unsigned int vmask[8]; + vst1q_u32(vmask, vmask_rp1); + vst1q_u32(vmask + 4, vmask_rp2); + + unsigned int rmask[4]; + vst1q_u32(rmask, vmask_result); + + float32x4_t vzero = vdupq_n_f32(0.f); + + for (int n = 0; n < num; ++n) { + const float *din_batch = din + n * ch_in * size_in_channel; + float *dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int c = 0; c < ch_in; c++) { + float *dout_ptr = dout_batch + c * size_out_channel; + + const float *din_ch_ptr = din_batch + c * size_in_channel; + + float bias_val = flag_bias ? bias[c] : 0.f; + float vbias[4] = {bias_val, bias_val, bias_val, bias_val}; + + const float *wei_ptr = weights + c * w_stride; + + float32x4_t wr0 = vld1q_f32(wei_ptr); + float32x4_t wr1 = vld1q_f32(wei_ptr + 3); + float32x4_t wr2 = vld1q_f32(wei_ptr + 6); + + float *doutr0 = dout_ptr; + float *doutr1 = doutr0 + w_out; + float *doutr2 = doutr1 + w_out; + float *doutr3 = doutr2 + w_out; + + const float *dr0 = din_ch_ptr; + const float *dr1 = dr0 + w_in; + const float *dr2 = dr1 + w_in; + const float *dr3 = dr2 + w_in; + const float *dr4 = dr3 + w_in; + const float *dr5 = dr4 + w_in; + + const float *din_ptr0 = dr0; + const float *din_ptr1 = dr1; + const float *din_ptr2 = dr2; + const float *din_ptr3 = dr3; + const float *din_ptr4 = dr4; + const float *din_ptr5 = dr5; + + float *ptr_zero = const_cast(zero); +#ifdef __aarch64__ + for (int i = 0; i < h_out; i += 4) { + //! process top pad pad_h = 1 + din_ptr0 = dr0; + din_ptr1 = dr1; + din_ptr2 = dr2; + din_ptr3 = dr3; + din_ptr4 = dr4; + din_ptr5 = dr5; + + doutr0 = dout_ptr; + doutr1 = doutr0 + w_out; + doutr2 = doutr1 + w_out; + doutr3 = doutr2 + w_out; + + dr0 = dr4; + dr1 = dr5; + dr2 = dr1 + w_in; + dr3 = dr2 + w_in; + dr4 = dr3 + w_in; + dr5 = dr4 + w_in; + + //! process bottom pad + if (i + 5 >= h_in) { + switch (i + 5 - h_in) { + case 4: + din_ptr1 = zero_ptr; + case 3: + din_ptr2 = zero_ptr; + case 2: + din_ptr3 = zero_ptr; + case 1: + din_ptr4 = zero_ptr; + case 0: + din_ptr5 = zero_ptr; + default: + break; + } + } + //! process bottom remain + if (i + 4 > h_out) { + switch (i + 4 - h_out) { + case 3: + doutr1 = write_ptr; + case 2: + doutr2 = write_ptr; + case 1: + doutr3 = write_ptr; + default: + break; + } + } + + int cnt = tile_w; + act_switch_3x3s1p0(din_ptr0, + din_ptr1, + din_ptr2, + din_ptr3, + din_ptr4, + din_ptr5, + doutr0, + doutr1, + doutr2, + doutr3, + wr0, + wr1, + wr2, + vmask, + rmask, + vzero, + vbias, + cnt, + remain, + act_param); + dout_ptr = dout_ptr + 4 * w_out; + } +#else + for (int i = 0; i < h_out; i += 2) { + din_ptr0 = dr0; + din_ptr1 = dr1; + din_ptr2 = dr2; + din_ptr3 = dr3; + + doutr0 = dout_ptr; + doutr1 = dout_ptr + w_out; + + dr0 = dr2; + dr1 = dr3; + dr2 = dr1 + w_in; + dr3 = dr2 + w_in; + //! process bottom pad + if (i + 4 > h_in) { + switch (i + 4 - h_in) { + case 3: + din_ptr1 = zero_ptr; + case 2: + din_ptr2 = zero_ptr; + case 1: + din_ptr3 = zero_ptr; + default: + break; + } + } + //! process bottom remain + if (i + 2 > h_out) { + doutr1 = write_ptr; + } + int cnt = tile_w; + unsigned int *rmask_ptr = rmask; + unsigned int *vmask_ptr = vmask; + act_switch_3x3s1p0(din_ptr0, + din_ptr1, + din_ptr2, + din_ptr3, + doutr0, + doutr1, + wr0, + wr1, + wr2, + vmask_ptr, + rmask_ptr, + vzero, + bias_val, + cnt, + remain, + act_param); + dout_ptr += 2 * w_out; + } //! end of processing mid rows +#endif + } + } +} +void act_switch_3x3s1p0_s(const float *din_ptr0, + const float *din_ptr1, + const float *din_ptr2, + const float *din_ptr3, + float *doutr0, + float *doutr1, + float32x4_t wr0, + float32x4_t wr1, + float32x4_t wr2, + uint32x4_t vmask_rp1, + uint32x4_t vmask_rp2, + float32x4_t vzero, + float32x4_t wbias, + unsigned int *vmask_ptr, + float bias_val, + const operators::ActivationParam act_param) { + bool has_active = act_param.has_active; + if (has_active) { +#ifdef __aarch64__ + float32x4_t vsix = vdupq_n_f32(act_param.Relu_clipped_coef); + float32x4_t vscale = vdupq_n_f32(act_param.Leaky_relu_alpha); +#else + float tmp = act_param.Relu_clipped_coef; + float ss = act_param.Leaky_relu_alpha; + float vsix[4] = {tmp, tmp, tmp, tmp}; + float vscale[4] = {ss, ss, ss, ss}; +#endif + switch (act_param.active_type) { + case lite_api::ActivationType::kRelu: +#ifdef __aarch64__ + asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vbias] "w"(wbias), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [vzero] "w"(vzero), + [out1] "r"(doutr0), + [out2] "r"(doutr1) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15"); + break; +#else + asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vzero] "w"(vzero), + [bias_val] "r"(bias_val), + [out1] "r"(doutr0), + [out2] "r"(doutr1) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + break; +#endif + case lite_api::ActivationType::kRelu6: +/* 0 <= din <= 6 */ +#ifdef __aarch64__ + asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU6 + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vbias] "w"(wbias), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [vzero] "w"(vzero), + [vsix] "w"(vsix), + [out1] "r"(doutr0), + [out2] "r"(doutr1) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15"); + break; +#else + asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU6 + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vzero] "w"(vzero), + [six_ptr] "r"(vsix), + [bias_val] "r"(bias_val), + [out1] "r"(doutr0), + [out2] "r"(doutr1) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + break; +#endif + case lite_api::ActivationType::kLeakyRelu: +/*din = din >= 0 ? din : din * scale*/ +#ifdef __aarch64__ + asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_LEAKY_RELU + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vbias] "w"(wbias), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [vzero] "w"(vzero), + [vscale] "w"(vscale), + [out1] "r"(doutr0), + [out2] "r"(doutr1) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15"); + break; +#else + asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_LEAKY_RELU + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vzero] "w"(vzero), + [scale_ptr] "r"(vscale), + [bias_val] "r"(bias_val), + [out1] "r"(doutr0), + [out2] "r"(doutr1) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + break; +#endif + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param.active_type) + << " fuse not support"; + } + } else { +#ifdef __aarch64__ + asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1 + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vbias] "w"(wbias), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [vzero] "w"(vzero), + [out1] "r"(doutr0), + [out2] "r"(doutr1) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15"); +#else + asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1 + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vzero] "w"(vzero), + [bias_val] "r"(bias_val), + [out1] "r"(doutr0), + [out2] "r"(doutr1) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + } +} +/** + * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, + * width <= 4 + */ +void conv_depthwise_3x3s1p0_bias_s(float *dout, + const float *din, + const float *weights, + const float *bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + const operators::ActivationParam act_param, + ARMContext *ctx) { + //! 3x3s1 convolution, implemented by direct algorithm + //! pad is done implicit + //! for 4x6 convolution window + const int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0}; + const float zero_ptr[4] = {0.f, 0.f, 0.f, 0.f}; + + float32x4_t vzero = vdupq_n_f32(0.f); + uint32x4_t vmask_rp1 = + vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(6 - w_in)); + uint32x4_t vmask_rp2 = + vcgeq_s32(vld1q_s32(right_pad_idx + 4), vdupq_n_s32(6 - w_in)); + + unsigned int vmask[8]; + vst1q_u32(vmask, vmask_rp1); + vst1q_u32(vmask + 4, vmask_rp2); + + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + for (int n = 0; n < num; ++n) { + const float *din_batch = din + n * ch_in * size_in_channel; + float *dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int i = 0; i < ch_in; ++i) { + float *dout_channel = dout_batch + i * size_out_channel; + const float *din_channel = din_batch + i * size_in_channel; + const float *weight_ptr = weights + i * 9; + float32x4_t wr0 = vld1q_f32(weight_ptr); + float32x4_t wr1 = vld1q_f32(weight_ptr + 3); + float32x4_t wr2 = vld1q_f32(weight_ptr + 6); + + float32x4_t wbias; + float bias_val = 0.f; + if (flag_bias) { + wbias = vdupq_n_f32(bias[i]); + bias_val = bias[i]; + } else { + wbias = vdupq_n_f32(0.f); + } + float out_buf1[4]; + float out_buf2[4]; + float trash_buf[4]; + + float *doutr0 = dout_channel; + float *doutr1 = dout_channel + w_out; + + for (int j = 0; j < h_out; j += 2) { + const float *dr0 = din_channel + j * w_in; + const float *dr1 = dr0 + w_in; + const float *dr2 = dr1 + w_in; + const float *dr3 = dr2 + w_in; + + doutr0 = dout_channel + j * w_out; + doutr1 = doutr0 + w_out; + + if (j + 4 > h_in) { + switch (j + 4 - h_in) { + case 3: + dr1 = zero_ptr; + case 2: + dr2 = zero_ptr; + case 1: + dr3 = zero_ptr; + default: + break; + } + } + if (j + 2 > h_out) { + doutr1 = trash_buf; + } + unsigned int *vmask_ptr = vmask; + act_switch_3x3s1p0_s(dr0, + dr1, + dr2, + dr3, + out_buf1, + out_buf2, + wr0, + wr1, + wr2, + vmask_rp1, + vmask_rp2, + vzero, + wbias, + vmask_ptr, + bias_val, + act_param); + for (int w = 0; w < w_out; ++w) { + *doutr0++ = out_buf1[w]; + *doutr1++ = out_buf2[w]; + } + } // end of processing heights + } // end of processing channels + } // end of processing batchs +} +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/backends/arm/math/conv3x3s1px_depthwise_fp32.cc b/lite/backends/arm/math/conv3x3s1px_depthwise_fp32.cc new file mode 100644 index 0000000000000000000000000000000000000000..55ea94949ba93396c97be5e3ea66d6e29ce95429 --- /dev/null +++ b/lite/backends/arm/math/conv3x3s1px_depthwise_fp32.cc @@ -0,0 +1,1043 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/backends/arm/math/conv_block_utils.h" +#include "lite/backends/arm/math/conv_impl.h" +#include "lite/core/context.h" +#include "lite/operators/op_params.h" +#ifdef ARM_WITH_OMP +#include +#endif + +namespace paddle { +namespace lite { +namespace arm { +namespace math { +// clang-format off +#ifdef __aarch64__ +#define COMPUTE \ + "ldp q0, q1, [%[inr0]], #32\n" /* load input r0*/ \ + "ldp q6, q7, [%[inr1]], #32\n" /* load input r1*/ \ + "ldp q2, q3, [%[inr0]], #32\n" /* load input r0*/ \ + "ldp q8, q9, [%[inr1]], #32\n" /* load input r1*/ \ + "ldp q4, q5, [%[inr0]]\n" /* load input r0*/ \ + "ldp q10, q11, [%[inr1]]\n" /* load input r1*/ \ + /* r0, r1, mul w0, get out r0, r1 */ \ + "fmul v15.4s , %[w0].4s, v0.4s\n" /* outr00 = w0 * r0, 0*/ \ + "fmul v16.4s , %[w0].4s, v1.4s\n" /* outr01 = w0 * r0, 1*/ \ + "fmul v17.4s , %[w0].4s, v2.4s\n" /* outr02 = w0 * r0, 2*/ \ + "fmul v18.4s , %[w0].4s, v3.4s\n" /* outr03 = w0 * r0, 3*/ \ + "fmul v19.4s , %[w0].4s, v6.4s\n" /* outr10 = w0 * r1, 0*/ \ + "fmul v20.4s , %[w0].4s, v7.4s\n" /* outr11 = w0 * r1, 1*/ \ + "fmul v21.4s , %[w0].4s, v8.4s\n" /* outr12 = w0 * r1, 2*/ \ + "fmul v22.4s , %[w0].4s, v9.4s\n" /* outr13 = w0 * r1, 3*/ \ + /* r0, r1, mul w1, get out r0, r1 */ \ + "fmla v15.4s , %[w1].4s, v1.4s\n" /* outr00 = w1 * r0[1]*/ \ + "ldp q0, q1, [%[inr2]], #32\n" /* load input r2*/ \ + "fmla v16.4s , %[w1].4s, v2.4s\n" /* outr01 = w1 * r0[2]*/ \ + "fmla v17.4s , %[w1].4s, v3.4s\n" /* outr02 = w1 * r0[3]*/ \ + "fmla v18.4s , %[w1].4s, v4.4s\n" /* outr03 = w1 * r0[4]*/ \ + "fmla v19.4s , %[w1].4s, v7.4s\n" /* outr10 = w1 * r1[1]*/ \ + "fmla v20.4s , %[w1].4s, v8.4s\n" /* outr11 = w1 * r1[2]*/ \ + "fmla v21.4s , %[w1].4s, v9.4s\n" /* outr12 = w1 * r1[3]*/ \ + "fmla v22.4s , %[w1].4s, v10.4s\n"/* outr13 = w1 * r1[4]*/ \ + /* r0, r1, mul w2, get out r0, r1 */ \ + "fmla v15.4s , %[w2].4s, v2.4s\n" /* outr00 = w2 * r0[2]*/ \ + "fmla v16.4s , %[w2].4s, v3.4s\n" /* outr01 = w2 * r0[3]*/ \ + "ldp q2, q3, [%[inr2]], #32\n" /* load input r2*/ \ + "fmla v17.4s , %[w2].4s, v4.4s\n" /* outr02 = w2 * r0[4]*/ \ + "fmla v18.4s , %[w2].4s, v5.4s\n" /* outr03 = w2 * r0[5]*/ \ + "ldp q4, q5, [%[inr2]]\n" /* load input r2*/ \ + "fmla v19.4s , %[w2].4s, v8.4s\n" /* outr10 = w2 * r1[2]*/ \ + "fmla v20.4s , %[w2].4s, v9.4s\n" /* outr11 = w2 * r1[3]*/ \ + "fmla v21.4s , %[w2].4s, v10.4s\n"/* outr12 = w2 * r1[4]*/ \ + "fmla v22.4s , %[w2].4s, v11.4s\n"/* outr13 = w2 * r1[5]*/ \ + /* r1, r2, mul w3, get out r0, r1 */ \ + "fmla v15.4s , %[w3].4s, v6.4s\n" /* outr00 = w3 * r1[0]*/ \ + "fmla v16.4s , %[w3].4s, v7.4s\n" /* outr01 = w3 * r1[1]*/ \ + "fmla v17.4s , %[w3].4s, v8.4s\n" /* outr02 = w3 * r1[2]*/ \ + "fmla v18.4s , %[w3].4s, v9.4s\n" /* outr03 = w3 * r1[3]*/ \ + "fmla v19.4s , %[w3].4s, v0.4s\n" /* outr10 = w3 * r2[0]*/ \ + "fmla v20.4s , %[w3].4s, v1.4s\n" /* outr11 = w3 * r2[1]*/ \ + "fmla v21.4s , %[w3].4s, v2.4s\n" /* outr12 = w3 * r2[2]*/ \ + "fmla v22.4s , %[w3].4s, v3.4s\n" /* outr13 = w3 * r2[3]*/ \ + /* r1, r2, mul w4, get out r0, r1 */ \ + "fmla v15.4s , %[w4].4s, v7.4s\n" /* outr00 = w4 * r1[1]*/ \ + "ldp q6, q7, [%[inr3]], #32\n" /* load input r3*/ \ + "fmla v16.4s , %[w4].4s, v8.4s\n" /* outr01 = w4 * r1[2]*/ \ + "fmla v17.4s , %[w4].4s, v9.4s\n" /* outr02 = w4 * r1[3]*/ \ + "fmla v18.4s , %[w4].4s, v10.4s\n"/* outr03 = w4 * r1[4]*/ \ + "ldp x0, x1, [%[outl]] \n" \ + "fmla v19.4s , %[w4].4s, v1.4s\n" /* outr10 = w4 * r2[1]*/ \ + "fmla v20.4s , %[w4].4s, v2.4s\n" /* outr11 = w4 * r2[2]*/ \ + "fmla v21.4s , %[w4].4s, v3.4s\n" /* outr12 = w4 * r2[3]*/ \ + "fmla v22.4s , %[w4].4s, v4.4s\n" /* outr13 = w4 * r2[4]*/ \ + /* r1, r2, mul w5, get out r0, r1 */ \ + "fmla v15.4s , %[w5].4s, v8.4s\n" /* outr00 = w5 * r1[2]*/ \ + "fmla v16.4s , %[w5].4s, v9.4s\n" /* outr01 = w5 * r1[3]*/ \ + "ldp q8, q9, [%[inr3]], #32\n" /* load input r3*/ \ + "fmla v17.4s , %[w5].4s, v10.4s\n"/* outr02 = w5 * r1[4]*/ \ + "fmla v18.4s , %[w5].4s, v11.4s\n"/* outr03 = w5 * r1[5]*/ \ + "ldp q10, q11, [%[inr3]]\n" /* load input r3*/ \ + "fmla v19.4s , %[w5].4s, v2.4s\n" /* outr10 = w5 * r2[2]*/ \ + "fmla v20.4s , %[w5].4s, v3.4s\n" /* outr11 = w5 * r2[3]*/ \ + "fmla v21.4s , %[w5].4s, v4.4s\n" /* outr12 = w5 * r2[4]*/ \ + "fmla v22.4s , %[w5].4s, v5.4s\n" /* outr13 = w5 * r2[5]*/ \ + /* r2, r3, mul w6, get out r0, r1 */ \ + "fmla v15.4s , %[w6].4s, v0.4s\n" /* outr00 = w6 * r2[0]*/ \ + "fmla v16.4s , %[w6].4s, v1.4s\n" /* outr01 = w6 * r2[1]*/ \ + "fmla v17.4s , %[w6].4s, v2.4s\n" /* outr02 = w6 * r2[2]*/ \ + "fmla v18.4s , %[w6].4s, v3.4s\n" /* outr03 = w6 * r2[3]*/ \ + "ldp x2, x3, [%[outl], #16] \n" \ + "fmla v19.4s , %[w6].4s, v6.4s\n" /* outr10 = w6 * r3[0]*/ \ + "fmla v20.4s , %[w6].4s, v7.4s\n" /* outr11 = w6 * r3[1]*/ \ + "fmla v21.4s , %[w6].4s, v8.4s\n" /* outr12 = w6 * r3[2]*/ \ + "fmla v22.4s , %[w6].4s, v9.4s\n" /* outr13 = w6 * r3[3]*/ \ + /* r2, r3, mul w7, get out r0, r1 */ \ + "fmla v15.4s , %[w7].4s, v1.4s\n" /* outr00 = w7 * r2[1]*/ \ + "fmla v16.4s , %[w7].4s, v2.4s\n" /* outr01 = w7 * r2[2]*/ \ + "fmla v17.4s , %[w7].4s, v3.4s\n" /* outr02 = w7 * r2[3]*/ \ + "fmla v18.4s , %[w7].4s, v4.4s\n" /* outr03 = w7 * r2[4]*/ \ + "ldp x4, x5, [%[outl], #32] \n" \ + "fmla v19.4s , %[w7].4s, v7.4s\n" /* outr10 = w7 * r3[1]*/ \ + "fmla v20.4s , %[w7].4s, v8.4s\n" /* outr11 = w7 * r3[2]*/ \ + "fmla v21.4s , %[w7].4s, v9.4s\n" /* outr12 = w7 * r3[3]*/ \ + "fmla v22.4s , %[w7].4s, v10.4s\n"/* outr13 = w7 * r3[4]*/ \ + /* r2, r3, mul w8, get out r0, r1 */ \ + "fmla v15.4s , %[w8].4s, v2.4s\n" /* outr00 = w8 * r2[2]*/ \ + "fmla v16.4s , %[w8].4s, v3.4s\n" /* outr01 = w8 * r2[3]*/ \ + "fmla v17.4s , %[w8].4s, v4.4s\n" /* outr02 = w8 * r2[0]*/ \ + "fmla v18.4s , %[w8].4s, v5.4s\n" /* outr03 = w8 * r2[1]*/ \ + "ldp x6, x7, [%[outl], #48] \n" \ + "fmla v19.4s , %[w8].4s, v8.4s\n" /* outr10 = w8 * r3[2]*/ \ + "fmla v20.4s , %[w8].4s, v9.4s\n" /* outr11 = w8 * r3[3]*/ \ + "fmla v21.4s , %[w8].4s, v10.4s\n"/* outr12 = w8 * r3[0]*/ \ + "fmla v22.4s , %[w8].4s, v11.4s\n"/* outr13 = w8 * r3[1]*/ \ + \ + "fadd v15.4s, v15.4s, %[vbias].4s\n"/* add bias */ \ + "fadd v16.4s, v16.4s, %[vbias].4s\n"/* add bias */ \ + "fadd v17.4s, v17.4s, %[vbias].4s\n"/* add bias */ \ + "fadd v18.4s, v18.4s, %[vbias].4s\n"/* add bias */ \ + "fadd v19.4s, v19.4s, %[vbias].4s\n"/* add bias */ \ + "fadd v20.4s, v20.4s, %[vbias].4s\n"/* add bias */ \ + "fadd v21.4s, v21.4s, %[vbias].4s\n"/* add bias */ \ + "fadd v22.4s, v22.4s, %[vbias].4s\n"/* add bias */ \ + /* transpose */ \ + "trn1 v0.4s, v15.4s, v16.4s\n" /* r0: a0a1c0c1*/ \ + "trn2 v1.4s, v15.4s, v16.4s\n" /* r0: b0b1d0d1*/ \ + "trn1 v2.4s, v17.4s, v18.4s\n" /* r0: a2a3c2c3*/ \ + "trn2 v3.4s, v17.4s, v18.4s\n" /* r0: b2b3d2d3*/ \ + "trn1 v4.4s, v19.4s, v20.4s\n" /* r1: a0a1c0c1*/ \ + "trn2 v5.4s, v19.4s, v20.4s\n" /* r1: b0b1d0d1*/ \ + "trn1 v6.4s, v21.4s, v22.4s\n" /* r1: a2a3c2c3*/ \ + "trn2 v7.4s, v21.4s, v22.4s\n" /* r1: b2b3d2d3*/ \ + "trn1 v15.2d, v0.2d, v2.2d\n" /* r0: a0a1a2a3*/ \ + "trn2 v19.2d, v0.2d, v2.2d\n" /* r0: c0c1c2c3*/ \ + "trn1 v17.2d, v1.2d, v3.2d\n" /* r0: b0b1b2b3*/ \ + "trn2 v21.2d, v1.2d, v3.2d\n" /* r0: d0d1d2d3*/ \ + "trn1 v16.2d, v4.2d, v6.2d\n" /* r1: a0a1a2a3*/ \ + "trn2 v20.2d, v4.2d, v6.2d\n" /* r1: c0c1c2c3*/ \ + "trn1 v18.2d, v5.2d, v7.2d\n" /* r1: b0b1b2b3*/ \ + "trn2 v22.2d, v5.2d, v7.2d\n" /* r1: d0d1d2d3*/ + +#define RELU \ + "movi v0.4s, #0\n" /* for relu */ \ + "ldr x0, [%[outl], #80]\n" \ + "fmax v15.4s, v15.4s, v0.4s\n" \ + "fmax v16.4s, v16.4s, v0.4s\n" \ + "fmax v17.4s, v17.4s, v0.4s\n" \ + "fmax v18.4s, v18.4s, v0.4s\n" \ + "ld1 {v1.4s}, [x0]\n" \ + "fmax v19.4s, v19.4s, v0.4s\n" \ + "fmax v20.4s, v20.4s, v0.4s\n" \ + "fmax v21.4s, v21.4s, v0.4s\n" \ + "fmax v22.4s, v22.4s, v0.4s\n" \ + "ldr x0, [%[outl]]\n" \ + +#define RELU6 \ + "fmin v15.4s, v15.4s, v1.4s\n" \ + "fmin v16.4s, v16.4s, v1.4s\n" \ + "fmin v17.4s, v17.4s, v1.4s\n" \ + "fmin v18.4s, v18.4s, v1.4s\n" \ + "fmin v19.4s, v19.4s, v1.4s\n" \ + "fmin v20.4s, v20.4s, v1.4s\n" \ + "fmin v21.4s, v21.4s, v1.4s\n" \ + "fmin v22.4s, v22.4s, v1.4s\n" + +#define LEAKY_RELU \ + "movi v0.4s, #0\n" /* for relu */ \ + "ldr x0, [%[outl], #88]\n" \ + "fcmge v1.4s, v15.4s, v0.4s \n" /* vcgeq_f32 */ \ + "fcmge v2.4s, v16.4s, v0.4s \n" /* vcgeq_f32 */ \ + "ld1 {v9.4s}, [x0] \n" \ + "fcmge v3.4s, v17.4s, v0.4s \n" /* vcgeq_f32 */ \ + "fcmge v4.4s, v18.4s, v0.4s \n" /* vcgeq_f32 */ \ + "ldr x0, [%[outl]] \n" \ + "fmul v5.4s, v15.4s, v9.4s \n" /* mul */ \ + "fmul v6.4s, v16.4s, v9.4s \n" /* mul */ \ + "fmul v7.4s, v17.4s, v9.4s \n" /* mul */ \ + "fmul v8.4s, v18.4s, v9.4s \n" /* mul */ \ + "bif v15.16b, v5.16b, v1.16b \n" /* choose*/ \ + "bif v16.16b, v6.16b, v2.16b \n" /* choose*/ \ + "bif v17.16b, v7.16b, v3.16b \n" /* choose*/ \ + "bif v18.16b, v8.16b, v4.16b \n" /* choose*/ \ + "fcmge v1.4s, v19.4s, v0.4s \n" /* vcgeq_f32 */ \ + "fcmge v2.4s, v20.4s, v0.4s \n" /* vcgeq_f32 */ \ + "fcmge v3.4s, v21.4s, v0.4s \n" /* vcgeq_f32 */ \ + "fcmge v4.4s, v22.4s, v0.4s \n" /* vcgeq_f32 */ \ + "fmul v5.4s, v19.4s, v9.4s \n" /* mul */ \ + "fmul v6.4s, v20.4s, v9.4s \n" /* mul */ \ + "fmul v7.4s, v21.4s, v9.4s \n" /* mul */ \ + "fmul v8.4s, v22.4s, v9.4s \n" /* mul */ \ + "bif v19.16b, v5.16b, v1.16b \n" /* choose*/ \ + "bif v20.16b, v6.16b, v2.16b \n" /* choose*/ \ + "bif v21.16b, v7.16b, v3.16b \n" /* choose*/ \ + "bif v22.16b, v8.16b, v4.16b \n" /* choose*/ + +#define STORE \ + "cbnz %w[flag_mask], 1f\n" \ + "str q15, [x0]\n" /* save outc00 */ \ + "str q16, [x4]\n" /* save outc01 */ \ + "str q17, [x1]\n" /* save outc10 */ \ + "str q18, [x5]\n" /* save outc11 */ \ + "str q19, [x2]\n" /* save outc20 */ \ + "str q20, [x6]\n" /* save outc21 */ \ + "str q21, [x3]\n" /* save outc30 */ \ + "str q22, [x7]\n" /* save outc31 */ \ + "b 2f\n" \ + "1:\n" \ + "str q15, [%[out]], #16 \n" /* save remain to pre_out */ \ + "str q17, [%[out]], #16 \n" /* save remain to pre_out */ \ + "str q19, [%[out]], #16 \n" /* save remain to pre_out */ \ + "str q21, [%[out]], #16 \n" /* save remain to pre_out */ \ + "str q16, [%[out]], #16 \n" /* save remain to pre_out */ \ + "str q18, [%[out]], #16 \n" /* save remain to pre_out */ \ + "str q20, [%[out]], #16 \n" /* save remain to pre_out */ \ + "str q22, [%[out]], #16 \n" /* save remain to pre_out */ \ + "2:\n" +#else +#define COMPUTE \ + /* load weights */ \ + "vld1.32 {d10-d13}, [%[wc0]]! @ load w0, w1, to q5, q6\n" \ + "vld1.32 {d14-d15}, [%[wc0]]! @ load w2, to q7\n" \ + /* load r0, r1 */ \ + "vld1.32 {d0-d3}, [%[r0]]! @ load r0, q0, q1\n" \ + "vld1.32 {d4-d7}, [%[r0]]! @ load r0, q2, q3\n" \ + /* main loop */ \ + "0: @ main loop\n" \ + /* mul r0 with w0, w1, w2, get out r0 */ \ + "vmul.f32 q8, q5, q0 @ w0 * inr00\n" \ + "vmul.f32 q9, q5, q1 @ w0 * inr01\n" \ + "vmul.f32 q10, q5, q2 @ w0 * inr02\n" \ + "vmul.f32 q11, q5, q3 @ w0 * inr03\n" \ + "vmla.f32 q8, q6, q1 @ w1 * inr01\n" \ + "vld1.32 {d0-d3}, [%[r0]] @ load r0, q0, q1\n" \ + "vmla.f32 q9, q6, q2 @ w1 * inr02\n" \ + "vmla.f32 q10, q6, q3 @ w1 * inr03\n" \ + "vmla.f32 q11, q6, q0 @ w1 * inr04\n" \ + "vmla.f32 q8, q7, q2 @ w2 * inr02\n" \ + "vmla.f32 q9, q7, q3 @ w2 * inr03\n" \ + "vld1.32 {d4-d7}, [%[r1]]! @ load r0, q2, q3\n" \ + "vmla.f32 q10, q7, q0 @ w2 * inr04\n" \ + "vmla.f32 q11, q7, q1 @ w2 * inr05\n" \ + "vld1.32 {d0-d3}, [%[r1]]! @ load r0, q0, q1\n" \ + "vld1.32 {d8-d9}, [%[wc0]]! @ load w3 to q4\n" \ + /* mul r1 with w0-w5, get out r0, r1 */ \ + "vmul.f32 q12, q5, q2 @ w0 * inr10\n" \ + "vmul.f32 q13, q5, q3 @ w0 * inr11\n" \ + "vmul.f32 q14, q5, q0 @ w0 * inr12\n" \ + "vmul.f32 q15, q5, q1 @ w0 * inr13\n" \ + "vld1.32 {d10-d11}, [%[wc0]]! @ load w4 to q5\n" \ + "vmla.f32 q8, q4, q2 @ w3 * inr10\n" \ + "vmla.f32 q9, q4, q3 @ w3 * inr11\n" \ + "vmla.f32 q10, q4, q0 @ w3 * inr12\n" \ + "vmla.f32 q11, q4, q1 @ w3 * inr13\n" \ + /* mul r1 with w1, w4, get out r1, r0 */ \ + "vmla.f32 q8, q5, q3 @ w4 * inr11\n" \ + "vmla.f32 q12, q6, q3 @ w1 * inr11\n" \ + "vld1.32 {d4-d7}, [%[r1]] @ load r1, q2, q3\n" \ + "vmla.f32 q9, q5, q0 @ w4 * inr12\n" \ + "vmla.f32 q13, q6, q0 @ w1 * inr12\n" \ + "vmla.f32 q10, q5, q1 @ w4 * inr13\n" \ + "vmla.f32 q14, q6, q1 @ w1 * inr13\n" \ + "vmla.f32 q11, q5, q2 @ w4 * inr14\n" \ + "vmla.f32 q15, q6, q2 @ w1 * inr14\n" \ + "vld1.32 {d12-d13}, [%[wc0]]! @ load w5 to q6\n" \ + /* mul r1 with w2, w5, get out r1, r0 */ \ + "vmla.f32 q12, q7, q0 @ w2 * inr12\n" \ + "vmla.f32 q13, q7, q1 @ w2 * inr13\n" \ + "vmla.f32 q8, q6, q0 @ w5 * inr12\n" \ + "vmla.f32 q9, q6, q1 @ w5 * inr13\n" \ + "vld1.32 {d0-d3}, [%[r2]]! @ load r2, q0, q1\n" \ + "vmla.f32 q14, q7, q2 @ w2 * inr14\n" \ + "vmla.f32 q15, q7, q3 @ w2 * inr15\n" \ + "vmla.f32 q10, q6, q2 @ w5 * inr14\n" \ + "vmla.f32 q11, q6, q3 @ w5 * inr15\n" \ + "vld1.32 {d4-d7}, [%[r2]]! @ load r2, q0, q1\n" \ + "vld1.32 {d14-d15}, [%[wc0]]! @ load w6, to q7\n" \ + /* mul r2 with w3-w8, get out r0, r1 */ \ + "vmla.f32 q12, q4, q0 @ w3 * inr20\n" \ + "vmla.f32 q13, q4, q1 @ w3 * inr21\n" \ + "vmla.f32 q14, q4, q2 @ w3 * inr22\n" \ + "vmla.f32 q15, q4, q3 @ w3 * inr23\n" \ + "vld1.32 {d8-d9}, [%[wc0]]! @ load w7, to q4\n" \ + "vmla.f32 q8, q7, q0 @ w6 * inr20\n" \ + "vmla.f32 q9, q7, q1 @ w6 * inr21\n" \ + "vmla.f32 q10, q7, q2 @ w6 * inr22\n" \ + "vmla.f32 q11, q7, q3 @ w6 * inr23\n" \ + /* mul r2 with w4, w7, get out r1, r0 */ \ + "vmla.f32 q8, q4, q1 @ w7 * inr21\n" \ + "vmla.f32 q12, q5, q1 @ w4 * inr21\n" \ + "vld1.32 {d0-d3}, [%[r2]] @ load r2, q0, q1\n" \ + "vmla.f32 q9, q4, q2 @ w7 * inr22\n" \ + "vmla.f32 q13, q5, q2 @ w4 * inr22\n" \ + "vmla.f32 q10, q4, q3 @ w7 * inr23\n" \ + "vmla.f32 q14, q5, q3 @ w4 * inr23\n" \ + "vmla.f32 q11, q4, q0 @ w7 * inr24\n" \ + "vmla.f32 q15, q5, q0 @ w4 * inr24\n" \ + "vld1.32 {d10-d11}, [%[wc0]]! @ load w8 to q5\n" \ + /* mul r1 with w5, w8, get out r1, r0 */ \ + "vmla.f32 q12, q6, q2 @ w5 * inr22\n" \ + "vmla.f32 q13, q6, q3 @ w5 * inr23\n" \ + "vmla.f32 q8, q5, q2 @ w8 * inr22\n" \ + "vmla.f32 q9, q5, q3 @ w8 * inr23\n" \ + "vld1.32 {d4-d7}, [%[r3]]! @ load r3, q2, q3\n" \ + "ldr r4, [%[outl], #32] @ load bias addr to r4\n" \ + "vmla.f32 q14, q6, q0 @ w5 * inr24\n" \ + "vmla.f32 q15, q6, q1 @ w5 * inr25\n" \ + "vmla.f32 q10, q5, q0 @ w8 * inr24\n" \ + "vmla.f32 q11, q5, q1 @ w8 * inr25\n" \ + "vld1.32 {d0-d3}, [%[r3]]! @ load r3, q0, q1\n" \ + "sub %[wc0], %[wc0], #144 @ wc0 - 144 to start address\n" \ + /* mul r3 with w6, w7, w8, get out r1 */ \ + "vmla.f32 q12, q7, q2 @ w6 * inr30\n" \ + "vmla.f32 q13, q7, q3 @ w6 * inr31\n" \ + "vmla.f32 q14, q7, q0 @ w6 * inr32\n" \ + "vmla.f32 q15, q7, q1 @ w6 * inr33\n" \ + "vmla.f32 q12, q4, q3 @ w7 * inr31\n" \ + "vld1.32 {d4-d7}, [%[r3]] @ load r3, q2, q3\n" \ + "vld1.32 {d12-d13}, [r4] @ load bias\n" \ + "vmla.f32 q13, q4, q0 @ w7 * inr32\n" \ + "vmla.f32 q14, q4, q1 @ w7 * inr33\n" \ + "vmla.f32 q15, q4, q2 @ w7 * inr34\n" \ + "ldr r0, [%[outl]] @ load outc00 to r0\n" \ + "vmla.f32 q12, q5, q0 @ w8 * inr32\n" \ + "vmla.f32 q13, q5, q1 @ w8 * inr33\n" \ + "ldr r5, [%[outl], #36] @ load flag_relu to r5\n" \ + "vmla.f32 q14, q5, q2 @ w8 * inr34\n" \ + "vmla.f32 q15, q5, q3 @ w8 * inr35\n" \ + "ldr r1, [%[outl], #4] @ load outc10 to r1\n" \ + "vadd.f32 q8, q8, q6 @ r00 add bias\n" \ + "vadd.f32 q9, q9, q6 @ r01 add bias\n" \ + "vadd.f32 q10, q10, q6 @ r02 add bias\n" \ + "vadd.f32 q11, q11, q6 @ r03 add bias\n" \ + "ldr r2, [%[outl], #8] @ load outc20 to r2\n" \ + "vadd.f32 q12, q12, q6 @ r10 add bias\n" \ + "vadd.f32 q13, q13, q6 @ r11 add bias\n" \ + "vadd.f32 q14, q14, q6 @ r12 add bias\n" \ + "vadd.f32 q15, q15, q6 @ r13 add bias\n" \ + "ldr r3, [%[outl], #12] @ load outc30 to r3\n" \ + "vmov.u32 q7, #0 @ mov zero to q7\n" +#define RELU \ + "vmax.f32 q8, q8, q7 @ r00 relu\n" \ + "vmax.f32 q9, q9, q7 @ r01 relu\n" \ + "vmax.f32 q10, q10, q7 @ r02 relu\n" \ + "vmax.f32 q11, q11, q7 @ r03 relu\n" \ + "vmax.f32 q12, q12, q7 @ r10 relu\n" \ + "vmax.f32 q13, q13, q7 @ r11 relu\n" \ + "vmax.f32 q14, q14, q7 @ r12 relu\n" \ + "vmax.f32 q15, q15, q7 @ r13 relu\n" + +#define RELU6 \ + "ldr r4, [%[outl], #40] @ load six to r4\n" \ + "vld1.32 {d12-d13}, [r4] @load data \n" \ + "vmin.f32 q8, q8, q6 @ r00 relu\n" \ + "vmin.f32 q9, q9, q6 @ r01 relu\n" \ + "vmin.f32 q10, q10, q6 @ r02 relu\n" \ + "vmin.f32 q11, q11, q6 @ r03 relu\n" \ + "vmin.f32 q12, q12, q6 @ r10 relu\n" \ + "vmin.f32 q13, q13, q6 @ r11 relu\n" \ + "vmin.f32 q14, q14, q6 @ r12 relu\n" \ + "vmin.f32 q15, q15, q6 @ r13 relu\n" + +#define LEAKY_RELU \ + "ldr r4, [%[outl], #44] @ load scale to r4\n" \ + "vld1.32 {d12-d13}, [r4] @load data \n" \ + "vcge.f32 q0, q8, q7 @ q0 > 0 \n" \ + "vcge.f32 q1, q9, q7 @ q0 > 0 \n" \ + "vmul.f32 q4, q8, q6 \n" \ + "vmul.f32 q5, q9, q6 \n" \ + "vcge.f32 q2, q10, q7 @ q0 > 0 \n" \ + "vcge.f32 q3, q11, q7 @ q0 > 0 \n" \ + "vbif q8, q4, q0 @ choose \n" \ + "vbif q9, q5, q1 @ choose \n" \ + "vmul.f32 q4, q10, q6 \n" \ + "vmul.f32 q5, q11, q6 \n" \ + "vbif q10, q4, q2 @ choose \n" \ + "vbif q11, q5, q3 @ choose \n" \ + "vcge.f32 q0, q12, q7 @ q0 > 0 \n" \ + "vcge.f32 q1, q13, q7 @ q0 > 0 \n" \ + "vmul.f32 q4, q12, q6 \n" \ + "vmul.f32 q5, q13, q6 \n" \ + "vcge.f32 q2, q14, q7 @ q0 > 0 \n" \ + "vcge.f32 q3, q15, q7 @ q0 > 0 \n" \ + "vbif q12, q4, q0 @ choose \n" \ + "vbif q13, q5, q1 @ choose \n" \ + "vmul.f32 q4, q14, q6 \n" \ + "vmul.f32 q5, q15, q6 \n" \ + "vbif q14, q4, q2 @ choose \n" \ + "vbif q15, q5, q3 @ choose \n" + +#define STORE \ + "ldr r4, [%[outl], #16] @ load outc01 to r4\n" \ + "vtrn.32 q8, q9 @ r0: q8 : a0a1c0c1, q9 : b0b1d0d1\n" \ + "vtrn.32 q10, q11 @ r0: q10: a2a3c2c3, q11: b2b3d2d3\n" \ + "vtrn.32 q12, q13 @ r1: q12: a0a1c0c1, q13: b0b1d0d1\n" \ + "vtrn.32 q14, q15 @ r1: q14: a2a3c2c3, q15: b2b3d2d3\n" \ + "ldr r5, [%[outl], #20] @ load outc11 to r5\n" \ + "vswp d17, d20 @ r0: q8 : a0a1a2a3, q10: c0c1c2c3 \n" \ + "vswp d19, d22 @ r0: q9 : b0b1b2b3, q11: d0d1d2d3 \n" \ + "vswp d25, d28 @ r1: q12: a0a1a2a3, q14: c0c1c2c3 \n" \ + "vswp d27, d30 @ r1: q13: b0b1b2b3, q15: d0d1d2d3 \n" \ + "cmp %[flag_mask], #0 @ cmp flag mask\n" \ + "bne 2f\n" \ + "vst1.32 {d16-d17}, [r0] @ save outc00\n" \ + "vst1.32 {d18-d19}, [r1] @ save outc10\n" \ + "vst1.32 {d20-d21}, [r2] @ save outc20\n" \ + "vst1.32 {d22-d23}, [r3] @ save outc30\n" \ + "vst1.32 {d24-d25}, [r4] @ save outc01\n" \ + "vst1.32 {d26-d27}, [r5] @ save outc11\n" \ + "ldr r0, [%[outl], #24] @ load outc21 to r0\n" \ + "ldr r1, [%[outl], #28] @ load outc31 to r1\n" \ + "vst1.32 {d28-d29}, [r0] @ save outc21\n" \ + "vst1.32 {d30-d31}, [r1] @ save outc31\n" \ + "b 3f @ branch end\n" \ + "2: \n" \ + "vst1.32 {d16-d17}, [%[out0]]! @ save remain to pre_out\n" \ + "vst1.32 {d18-d19}, [%[out0]]! @ save remain to pre_out\n" \ + "vst1.32 {d20-d21}, [%[out0]]! @ save remain to pre_out\n" \ + "vst1.32 {d22-d23}, [%[out0]]! @ save remain to pre_out\n" \ + "vst1.32 {d24-d25}, [%[out0]]! @ save remain to pre_out\n" \ + "vst1.32 {d26-d27}, [%[out0]]! @ save remain to pre_out\n" \ + "vst1.32 {d28-d29}, [%[out0]]! @ save remain to pre_out\n" \ + "vst1.32 {d30-d31}, [%[out0]]! @ save remain to pre_out\n" \ + "3: \n" +#endif +// clang-format on +void act_switch_3x3s1(const float* inr0, + const float* inr1, + const float* inr2, + const float* inr3, + float* out0, + const float* weight_c, + float flag_mask, + void* outl_ptr, + float32x4_t w0, + float32x4_t w1, + float32x4_t w2, + float32x4_t w3, + float32x4_t w4, + float32x4_t w5, + float32x4_t w6, + float32x4_t w7, + float32x4_t w8, + float32x4_t vbias, + const operators::ActivationParam act_param) { + bool has_active = act_param.has_active; + if (has_active) { + switch (act_param.active_type) { + case lite_api::ActivationType::kRelu: +#ifdef __aarch64__ + asm volatile(COMPUTE RELU STORE + : [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [inr3] "+r"(inr3), + [out] "+r"(out0) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [w5] "w"(w5), + [w6] "w"(w6), + [w7] "w"(w7), + [w8] "w"(w8), + [vbias] "w"(vbias), + [outl] "r"(outl_ptr), + [flag_mask] "r"(flag_mask) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "x0", + "x1", + "x2", + "x3", + "x4", + "x5", + "x6", + "x7"); +#else + asm volatile(COMPUTE RELU STORE + : [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [r3] "+r"(inr3), + [out0] "+r"(out0), + [wc0] "+r"(weight_c) + : [flag_mask] "r"(flag_mask), [outl] "r"(outl_ptr) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15", + "r0", + "r1", + "r2", + "r3", + "r4", + "r5"); +#endif + break; + case lite_api::ActivationType::kRelu6: +#ifdef __aarch64__ + asm volatile(COMPUTE RELU RELU6 STORE + : [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [inr3] "+r"(inr3), + [out] "+r"(out0) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [w5] "w"(w5), + [w6] "w"(w6), + [w7] "w"(w7), + [w8] "w"(w8), + [vbias] "w"(vbias), + [outl] "r"(outl_ptr), + [flag_mask] "r"(flag_mask) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "x0", + "x1", + "x2", + "x3", + "x4", + "x5", + "x6", + "x7"); +#else + asm volatile(COMPUTE RELU RELU6 STORE + : [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [r3] "+r"(inr3), + [out0] "+r"(out0), + [wc0] "+r"(weight_c) + : [flag_mask] "r"(flag_mask), [outl] "r"(outl_ptr) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15", + "r0", + "r1", + "r2", + "r3", + "r4", + "r5"); +#endif + break; + case lite_api::ActivationType::kLeakyRelu: +#ifdef __aarch64__ + asm volatile(COMPUTE LEAKY_RELU STORE + : [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [inr3] "+r"(inr3), + [out] "+r"(out0) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [w5] "w"(w5), + [w6] "w"(w6), + [w7] "w"(w7), + [w8] "w"(w8), + [vbias] "w"(vbias), + [outl] "r"(outl_ptr), + [flag_mask] "r"(flag_mask) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "x0", + "x1", + "x2", + "x3", + "x4", + "x5", + "x6", + "x7"); +#else + asm volatile(COMPUTE LEAKY_RELU STORE + : [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [r3] "+r"(inr3), + [out0] "+r"(out0), + [wc0] "+r"(weight_c) + : [flag_mask] "r"(flag_mask), [outl] "r"(outl_ptr) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15", + "r0", + "r1", + "r2", + "r3", + "r4", + "r5"); +#endif + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param.active_type) + << " fuse not support"; + } + } else { +#ifdef __aarch64__ + asm volatile(COMPUTE STORE + : [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [inr3] "+r"(inr3), + [out] "+r"(out0) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [w5] "w"(w5), + [w6] "w"(w6), + [w7] "w"(w7), + [w8] "w"(w8), + [vbias] "w"(vbias), + [outl] "r"(outl_ptr), + [flag_mask] "r"(flag_mask) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "x0", + "x1", + "x2", + "x3", + "x4", + "x5", + "x6", + "x7"); +#else + asm volatile(COMPUTE STORE + : [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [r3] "+r"(inr3), + [out0] "+r"(out0), + [wc0] "+r"(weight_c) + : [flag_mask] "r"(flag_mask), [outl] "r"(outl_ptr) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15", + "r0", + "r1", + "r2", + "r3", + "r4", + "r5"); +#endif + } +} +void conv_3x3s1_depthwise_fp32(const float* i_data, + float* o_data, + int bs, + int oc, + int oh, + int ow, + int ic, + int ih, + int win, + const float* weights, + const float* bias, + const operators::ConvParam& param, + const operators::ActivationParam act_param, + ARMContext* ctx) { + int threads = ctx->threads(); + + auto paddings = *param.paddings; + const int pad_h = paddings[0]; + const int pad_w = paddings[2]; + + const int out_c_block = 4; + const int out_h_kernel = 2; + const int out_w_kernel = 4; + const int win_ext = ow + 2; + const int ow_round = ROUNDUP(ow, 4); + const int win_round = ROUNDUP(win_ext, 4); + const int hin_round = oh + 2; + const int prein_size = win_round * hin_round * out_c_block; + auto workspace_size = + threads * prein_size + win_round /*tmp zero*/ + ow_round /*tmp writer*/; + ctx->ExtendWorkspace(sizeof(float) * workspace_size); + + bool flag_bias = param.bias != nullptr; + + /// get workspace + float* ptr_zero = ctx->workspace_data(); + memset(ptr_zero, 0, sizeof(float) * win_round); + float* ptr_write = ptr_zero + win_round; + + int size_in_channel = win * ih; + int size_out_channel = ow * oh; + + int ws = -pad_w; + int we = ws + win_round; + int hs = -pad_h; + int he = hs + hin_round; + int w_loop = ow_round / 4; + auto remain = w_loop * 4 - ow; + bool flag_remain = remain > 0; + remain = 4 - remain; + remain = remain > 0 ? remain : 0; + int row_len = win_round * out_c_block; + + float six_ptr[4] = {0.f, 0.f, 0.f, 0.f}; + float scale_ptr[4] = {1.f, 1.f, 1.f, 1.f}; + float relu_ptr[4] = {0.f, 0.f, 0.f, 0.f}; + if (act_param.has_active) { + switch (act_param.active_type) { + case lite_api::ActivationType::kRelu: + break; + case lite_api::ActivationType::kRelu6: + six_ptr[0] = act_param.Relu_clipped_coef; + six_ptr[1] = act_param.Relu_clipped_coef; + six_ptr[2] = act_param.Relu_clipped_coef; + six_ptr[3] = act_param.Relu_clipped_coef; + break; + case lite_api::ActivationType::kLeakyRelu: + scale_ptr[0] = act_param.Leaky_relu_alpha; + scale_ptr[1] = act_param.Leaky_relu_alpha; + scale_ptr[2] = act_param.Leaky_relu_alpha; + scale_ptr[3] = act_param.Leaky_relu_alpha; + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param.active_type) + << " fuse not support"; + } + } + for (int n = 0; n < bs; ++n) { + const float* din_batch = i_data + n * ic * size_in_channel; + float* dout_batch = o_data + n * oc * size_out_channel; +#pragma omp parallel for num_threads(threads) + for (int c = 0; c < oc; c += out_c_block) { +#ifdef ARM_WITH_OMP + float* pre_din = ptr_write + ow_round + omp_get_thread_num() * prein_size; +#else + float* pre_din = ptr_write + ow_round; +#endif + /// const array size + float pre_out[out_c_block * out_w_kernel * out_h_kernel]; // NOLINT + prepack_input_nxwc4_dw( + din_batch, pre_din, c, hs, he, ws, we, ic, win, ih, ptr_zero); + const float* weight_c = weights + c * 9; // kernel_w * kernel_h + float* dout_c00 = dout_batch + c * size_out_channel; + float bias_local[4] = {0, 0, 0, 0}; + if (flag_bias) { + bias_local[0] = bias[c]; + bias_local[1] = bias[c + 1]; + bias_local[2] = bias[c + 2]; + bias_local[3] = bias[c + 3]; + } + float32x4_t vbias = vld1q_f32(bias_local); +#ifdef __aarch64__ + float32x4_t w0 = vld1q_f32(weight_c); // w0, v23 + float32x4_t w1 = vld1q_f32(weight_c + 4); // w1, v24 + float32x4_t w2 = vld1q_f32(weight_c + 8); // w2, v25 + float32x4_t w3 = vld1q_f32(weight_c + 12); // w3, v26 + float32x4_t w4 = vld1q_f32(weight_c + 16); // w4, v27 + float32x4_t w5 = vld1q_f32(weight_c + 20); // w5, v28 + float32x4_t w6 = vld1q_f32(weight_c + 24); // w6, v29 + float32x4_t w7 = vld1q_f32(weight_c + 28); // w7, v30 + float32x4_t w8 = vld1q_f32(weight_c + 32); // w8, v31 +#endif + for (int h = 0; h < oh; h += out_h_kernel) { + float* outc00 = dout_c00 + h * ow; + float* outc01 = outc00 + ow; + float* outc10 = outc00 + size_out_channel; + float* outc11 = outc10 + ow; + float* outc20 = outc10 + size_out_channel; + float* outc21 = outc20 + ow; + float* outc30 = outc20 + size_out_channel; + float* outc31 = outc30 + ow; + const float* inr0 = pre_din + h * row_len; + const float* inr1 = inr0 + row_len; + const float* inr2 = inr1 + row_len; + const float* inr3 = inr2 + row_len; + if (c + out_c_block > oc) { + switch (c + out_c_block - oc) { + case 3: + outc10 = ptr_write; + outc11 = ptr_write; + case 2: + outc20 = ptr_write; + outc21 = ptr_write; + case 1: + outc30 = ptr_write; + outc31 = ptr_write; + default: + break; + } + } + if (h + out_h_kernel > oh) { + outc01 = ptr_write; + outc11 = ptr_write; + outc21 = ptr_write; + outc31 = ptr_write; + } + + float* outl[] = {outc00, + outc10, + outc20, + outc30, + outc01, + outc11, + outc21, + outc31, + reinterpret_cast(bias_local), + reinterpret_cast(relu_ptr), + reinterpret_cast(six_ptr), + reinterpret_cast(scale_ptr)}; + void* outl_ptr = reinterpret_cast(outl); + for (int w = 0; w < w_loop; ++w) { + bool flag_mask = (w == w_loop - 1) && flag_remain; + float* out0 = pre_out; +#ifdef __aarch64__ + act_switch_3x3s1(inr0, + inr1, + inr2, + inr3, + out0, + weight_c, + flag_mask, + outl_ptr, + w0, + w1, + w2, + w3, + w4, + w5, + w6, + w7, + w8, + vbias, + act_param); +#else + act_switch_3x3s1(inr0, + inr1, + inr2, + inr3, + out0, + weight_c, + flag_mask, + outl_ptr, + vbias, + vbias, + vbias, + vbias, + vbias, + vbias, + vbias, + vbias, + vbias, + vbias, + act_param); +#endif + outl[0] += 4; + outl[1] += 4; + outl[2] += 4; + outl[3] += 4; + outl[4] += 4; + outl[5] += 4; + outl[6] += 4; + outl[7] += 4; + inr0 += 16; + inr1 += 16; + inr2 += 16; + inr3 += 16; + if (flag_mask) { + memcpy(outl[0] - 4, pre_out, remain * sizeof(float)); + memcpy(outl[1] - 4, pre_out + 4, remain * sizeof(float)); + memcpy(outl[2] - 4, pre_out + 8, remain * sizeof(float)); + memcpy(outl[3] - 4, pre_out + 12, remain * sizeof(float)); + memcpy(outl[4] - 4, pre_out + 16, remain * sizeof(float)); + memcpy(outl[5] - 4, pre_out + 20, remain * sizeof(float)); + memcpy(outl[6] - 4, pre_out + 24, remain * sizeof(float)); + memcpy(outl[7] - 4, pre_out + 28, remain * sizeof(float)); + } + } + } + } + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/backends/arm/math/conv3x3s2_depthwise_fp32.cc b/lite/backends/arm/math/conv3x3s2_depthwise_fp32.cc deleted file mode 100644 index 2d75323a9677f1cfbed726a1a28920dd77131688..0000000000000000000000000000000000000000 --- a/lite/backends/arm/math/conv3x3s2_depthwise_fp32.cc +++ /dev/null @@ -1,361 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include "lite/backends/arm/math/conv_block_utils.h" -#include "lite/backends/arm/math/conv_impl.h" -#include "lite/core/context.h" -#include "lite/operators/op_params.h" -#ifdef ARM_WITH_OMP -#include -#endif - -namespace paddle { -namespace lite { -namespace arm { -namespace math { - -void conv_3x3s2_depthwise_fp32(const float* i_data, - float* o_data, - int bs, - int oc, - int oh, - int ow, - int ic, - int ih, - int win, - const float* weights, - const float* bias, - const operators::ConvParam& param, - ARMContext* ctx) { - int threads = ctx->threads(); - const int pad_h = param.paddings[0]; - const int pad_w = param.paddings[1]; - const int out_c_block = 4; - const int out_h_kernel = 1; - const int out_w_kernel = 4; - const int win_ext = ow * 2 + 1; - const int ow_round = ROUNDUP(ow, 4); - const int win_round = ROUNDUP(win_ext, 4); - const int hin_round = oh * 2 + 1; - const int prein_size = win_round * hin_round * out_c_block; - auto workspace_size = - threads * prein_size + win_round /*tmp zero*/ + ow_round /*tmp writer*/; - ctx->ExtendWorkspace(sizeof(float) * workspace_size); - - bool flag_relu = param.fuse_relu; - bool flag_bias = param.bias != nullptr; - - /// get workspace - auto ptr_zero = ctx->workspace_data(); - memset(ptr_zero, 0, sizeof(float) * win_round); - float* ptr_write = ptr_zero + win_round; - - int size_in_channel = win * ih; - int size_out_channel = ow * oh; - - int ws = -pad_w; - int we = ws + win_round; - int hs = -pad_h; - int he = hs + hin_round; - int w_loop = ow_round / 4; - auto remain = w_loop * 4 - ow; - bool flag_remain = remain > 0; - remain = 4 - remain; - remain = remain > 0 ? remain : 0; - int row_len = win_round * out_c_block; - - for (int n = 0; n < bs; ++n) { - const float* din_batch = i_data + n * ic * size_in_channel; - float* dout_batch = o_data + n * oc * size_out_channel; -#pragma omp parallel for num_threads(threads) - for (int c = 0; c < oc; c += out_c_block) { -#ifdef ARM_WITH_OMP - float* pre_din = ptr_write + ow_round + omp_get_thread_num() * prein_size; -#else - float* pre_din = ptr_write + ow_round; -#endif - /// const array size - prepack_input_nxwc4_dw( - din_batch, pre_din, c, hs, he, ws, we, ic, win, ih, ptr_zero); - const float* weight_c = weights + c * 9; // kernel_w * kernel_h - float* dout_c00 = dout_batch + c * size_out_channel; - float bias_local[4] = {0, 0, 0, 0}; - if (flag_bias) { - bias_local[0] = bias[c]; - bias_local[1] = bias[c + 1]; - bias_local[2] = bias[c + 2]; - bias_local[3] = bias[c + 3]; - } -#ifdef __aarch64__ - float32x4_t w0 = vld1q_f32(weight_c); // w0, v23 - float32x4_t w1 = vld1q_f32(weight_c + 4); // w1, v24 - float32x4_t w2 = vld1q_f32(weight_c + 8); // w2, v25 - float32x4_t w3 = vld1q_f32(weight_c + 12); // w3, v26 - float32x4_t w4 = vld1q_f32(weight_c + 16); // w4, v27 - float32x4_t w5 = vld1q_f32(weight_c + 20); // w5, v28 - float32x4_t w6 = vld1q_f32(weight_c + 24); // w6, v29 - float32x4_t w7 = vld1q_f32(weight_c + 28); // w7, v30 - float32x4_t w8 = vld1q_f32(weight_c + 32); // w8, v31 -#endif - for (int h = 0; h < oh; h += out_h_kernel) { - float* outc0 = dout_c00 + h * ow; - float* outc1 = outc0 + size_out_channel; - float* outc2 = outc1 + size_out_channel; - float* outc3 = outc2 + size_out_channel; - const float* inr0 = pre_din + h * 2 * row_len; - const float* inr1 = inr0 + row_len; - const float* inr2 = inr1 + row_len; - if (c + out_c_block > oc) { - switch (c + out_c_block - oc) { - case 3: - outc1 = ptr_write; - case 2: - outc2 = ptr_write; - case 1: - outc3 = ptr_write; - default: - break; - } - } - auto c0 = outc0; - auto c1 = outc1; - auto c2 = outc2; - auto c3 = outc3; - float pre_out[16]; - for (int w = 0; w < w_loop; ++w) { - bool flag_mask = (w == w_loop - 1) && flag_remain; - if (flag_mask) { - c0 = outc0; - c1 = outc1; - c2 = outc2; - c3 = outc3; - outc0 = pre_out; - outc1 = pre_out + 4; - outc2 = pre_out + 8; - outc3 = pre_out + 12; - } -// clang-format off -#ifdef __aarch64__ - asm volatile( - "ldr q8, [%[bias]]\n" /* load bias */ - "ldp q0, q1, [%[inr0]], #32\n" /* load input r0*/ - "and v19.16b, v8.16b, v8.16b\n" - "ldp q2, q3, [%[inr0]], #32\n" /* load input r0*/ - "and v20.16b, v8.16b, v8.16b\n" - "ldp q4, q5, [%[inr0]], #32\n" /* load input r0*/ - "and v21.16b, v8.16b, v8.16b\n" - "ldp q6, q7, [%[inr0]], #32\n" /* load input r0*/ - "and v22.16b, v8.16b, v8.16b\n" - "ldr q8, [%[inr0]]\n" /* load input r0*/ - /* r0 mul w0-w2, get out */ - "fmla v19.4s , %[w0].4s, v0.4s\n" /* outr0 = w0 * r0, 0*/ - "fmla v20.4s , %[w0].4s, v2.4s\n" /* outr1 = w0 * r0, 2*/ - "fmla v21.4s , %[w0].4s, v4.4s\n" /* outr2 = w0 * r0, 4*/ - "fmla v22.4s , %[w0].4s, v6.4s\n" /* outr3 = w0 * r0, 6*/ - "fmla v19.4s , %[w1].4s, v1.4s\n" /* outr0 = w1 * r0, 1*/ - "ldp q0, q1, [%[inr1]], #32\n" /* load input r1*/ - "fmla v20.4s , %[w1].4s, v3.4s\n" /* outr1 = w1 * r0, 3*/ - "fmla v21.4s , %[w1].4s, v5.4s\n" /* outr2 = w1 * r0, 5*/ - "fmla v22.4s , %[w1].4s, v7.4s\n" /* outr3 = w1 * r0, 7*/ - "fmla v19.4s , %[w2].4s, v2.4s\n" /* outr0 = w0 * r0, 2*/ - "ldp q2, q3, [%[inr1]], #32\n" /* load input r1*/ - "fmla v20.4s , %[w2].4s, v4.4s\n" /* outr1 = w0 * r0, 4*/ - "ldp q4, q5, [%[inr1]], #32\n" /* load input r1*/ - "fmla v21.4s , %[w2].4s, v6.4s\n" /* outr2 = w0 * r0, 6*/ - "ldp q6, q7, [%[inr1]], #32\n" /* load input r1*/ - "fmla v22.4s , %[w2].4s, v8.4s\n" /* outr3 = w0 * r0, 8*/ - "ldr q8, [%[inr1]]\n" /* load input r1*/ - /* r1, mul w3-w5, get out */ - "fmla v19.4s , %[w3].4s, v0.4s\n" /* outr0 = w3 * r1, 0*/ - "fmla v20.4s , %[w3].4s, v2.4s\n" /* outr1 = w3 * r1, 2*/ - "fmla v21.4s , %[w3].4s, v4.4s\n" /* outr2 = w3 * r1, 4*/ - "fmla v22.4s , %[w3].4s, v6.4s\n" /* outr3 = w3 * r1, 6*/ - "fmla v19.4s , %[w4].4s, v1.4s\n" /* outr0 = w4 * r1, 1*/ - "ldp q0, q1, [%[inr2]], #32\n" /* load input r2*/ - "fmla v20.4s , %[w4].4s, v3.4s\n" /* outr1 = w4 * r1, 3*/ - "fmla v21.4s , %[w4].4s, v5.4s\n" /* outr2 = w4 * r1, 5*/ - "fmla v22.4s , %[w4].4s, v7.4s\n" /* outr3 = w4 * r1, 7*/ - "fmla v19.4s , %[w5].4s, v2.4s\n" /* outr0 = w5 * r1, 2*/ - "ldp q2, q3, [%[inr2]], #32\n" /* load input r2*/ - "fmla v20.4s , %[w5].4s, v4.4s\n" /* outr1 = w5 * r1, 4*/ - "ldp q4, q5, [%[inr2]], #32\n" /* load input r2*/ - "fmla v21.4s , %[w5].4s, v6.4s\n" /* outr2 = w5 * r1, 6*/ - "ldp q6, q7, [%[inr2]], #32\n" /* load input r2*/ - "fmla v22.4s , %[w5].4s, v8.4s\n" /* outr3 = w5 * r1, 8*/ - "ldr q8, [%[inr2]]\n" /* load input r2*/ - /* r2, mul w6-w8, get out r0, r1 */ - "fmla v19.4s , %[w6].4s, v0.4s\n" /* outr0 = w6 * r2, 0*/ - "fmla v20.4s , %[w6].4s, v2.4s\n" /* outr1 = w6 * r2, 2*/ - "fmla v21.4s , %[w6].4s, v4.4s\n" /* outr2 = w6 * r2, 4*/ - "fmla v22.4s , %[w6].4s, v6.4s\n" /* outr3 = w6 * r2, 6*/ - "fmla v19.4s , %[w7].4s, v1.4s\n" /* outr0 = w7 * r2, 1*/ - "fmla v20.4s , %[w7].4s, v3.4s\n" /* outr1 = w7 * r2, 3*/ - "fmla v21.4s , %[w7].4s, v5.4s\n" /* outr2 = w7 * r2, 5*/ - "fmla v22.4s , %[w7].4s, v7.4s\n" /* outr3 = w7 * r2, 7*/ - "fmla v19.4s , %[w8].4s, v2.4s\n" /* outr0 = w8 * r2, 2*/ - "fmla v20.4s , %[w8].4s, v4.4s\n" /* outr1 = w8 * r2, 4*/ - "fmla v21.4s , %[w8].4s, v6.4s\n" /* outr2 = w8 * r2, 6*/ - "fmla v22.4s , %[w8].4s, v8.4s\n" /* outr3 = w8 * r2, 8*/ - /* transpose */ - "trn1 v0.4s, v19.4s, v20.4s\n" /* r0: a0a1c0c1*/ - "trn2 v1.4s, v19.4s, v20.4s\n" /* r0: b0b1d0d1*/ - "trn1 v2.4s, v21.4s, v22.4s\n" /* r0: a2a3c2c3*/ - "trn2 v3.4s, v21.4s, v22.4s\n" /* r0: b2b3d2d3*/ - "trn1 v19.2d, v0.2d, v2.2d\n" /* r0: a0a1a2a3*/ - "trn2 v21.2d, v0.2d, v2.2d\n" /* r0: c0c1c2c3*/ - "trn1 v20.2d, v1.2d, v3.2d\n" /* r0: b0b1b2b3*/ - "trn2 v22.2d, v1.2d, v3.2d\n" /* r0: d0d1d2d3*/ - /* relu */ - "cbz %w[flag_relu], 0f\n" /* skip relu*/ - "movi v0.4s, #0\n" /* for relu */ - "fmax v19.4s, v19.4s, v0.4s\n" - "fmax v20.4s, v20.4s, v0.4s\n" - "fmax v21.4s, v21.4s, v0.4s\n" - "fmax v22.4s, v22.4s, v0.4s\n" - /* save result */ - "0:\n" - "str q19, [%[outc0]], #16\n" - "str q20, [%[outc1]], #16\n" - "str q21, [%[outc2]], #16\n" - "str q22, [%[outc3]], #16\n" - :[inr0] "+r"(inr0), [inr1] "+r"(inr1), - [inr2] "+r"(inr2), - [outc0]"+r"(outc0), [outc1]"+r"(outc1), - [outc2]"+r"(outc2), [outc3]"+r"(outc3) - :[w0] "w"(w0), [w1] "w"(w1), [w2] "w"(w2), - [w3] "w"(w3), [w4] "w"(w4), [w5] "w"(w5), - [w6] "w"(w6), [w7] "w"(w7), [w8] "w"(w8), - [bias] "r" (bias_local), [flag_relu]"r"(flag_relu) - : "cc", "memory", - "v0","v1","v2","v3","v4","v5","v6","v7", - "v8", "v19","v20","v21","v22" - ); -#else - asm volatile( - /* fill with bias */ - "vld1.32 {d16-d17}, [%[bias]]\n" /* load bias */ - /* load weights */ - "vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w0-2, to q9-11 */ - "vld1.32 {d0-d3}, [%[r0]]!\n" /* load input r0, 0,1*/ - "vand.i32 q12, q8, q8\n" - "vld1.32 {d4-d7}, [%[r0]]!\n" /* load input r0, 2,3*/ - "vand.i32 q13, q8, q8\n" - "vld1.32 {d8-d11}, [%[r0]]!\n" /* load input r0, 4,5*/ - "vand.i32 q14, q8, q8\n" - "vld1.32 {d12-d15}, [%[r0]]!\n" /* load input r0, 6,7*/ - "vand.i32 q15, q8, q8\n" - "vld1.32 {d16-d17}, [%[r0]]\n" /* load input r0, 8*/ - /* mul r0 with w0, w1, w2 */ - "vmla.f32 q12, q9, q0 @ w0 * inr0\n" - "vmla.f32 q13, q9, q2 @ w0 * inr2\n" - "vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w2, to q11 */ - "vmla.f32 q14, q9, q4 @ w0 * inr4\n" - "vmla.f32 q15, q9, q6 @ w0 * inr6\n" - "vmla.f32 q12, q10, q1 @ w1 * inr1\n" - "vld1.32 {d0-d3}, [%[r1]]! @ load r1, 0, 1\n" - "vmla.f32 q13, q10, q3 @ w1 * inr3\n" - "vmla.f32 q14, q10, q5 @ w1 * inr5\n" - "vmla.f32 q15, q10, q7 @ w1 * inr7\n" - "vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w3-4, to q9-10 */ - "vmla.f32 q12, q11, q2 @ w2 * inr2\n" - "vld1.32 {d4-d7}, [%[r1]]! @ load r1, 2, 3\n" - "vmla.f32 q13, q11, q4 @ w2 * inr4\n" - "vld1.32 {d8-d11}, [%[r1]]! @ load r1, 4, 5\n" - "vmla.f32 q14, q11, q6 @ w2 * inr6\n" - "vld1.32 {d12-d15}, [%[r1]]! @ load r1, 6, 7\n" - "vmla.f32 q15, q11, q8 @ w2 * inr8\n" - /* mul r1 with w3, w4, w5 */ - "vmla.f32 q12, q9, q0 @ w3 * inr0\n" - "vmla.f32 q13, q9, q2 @ w3 * inr2\n" - "vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w5, to q11 */ - "vmla.f32 q14, q9, q4 @ w3 * inr4\n" - "vmla.f32 q15, q9, q6 @ w3 * inr6\n" - "vld1.32 {d16-d17}, [%[r1]]\n" /* load input r1, 8*/ - "vmla.f32 q12, q10, q1 @ w4 * inr1\n" - "vld1.32 {d0-d3}, [%[r2]]! @ load r2, 0, 1\n" - "vmla.f32 q13, q10, q3 @ w4 * inr3\n" - "vmla.f32 q14, q10, q5 @ w4 * inr5\n" - "vmla.f32 q15, q10, q7 @ w4 * inr7\n" - "vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w6-7, to q9-10 */ - "vmla.f32 q12, q11, q2 @ w5 * inr2\n" - "vld1.32 {d4-d7}, [%[r2]]! @ load r2, 2, 3\n" - "vmla.f32 q13, q11, q4 @ w5 * inr4\n" - "vld1.32 {d8-d11}, [%[r2]]! @ load r2, 4, 5\n" - "vmla.f32 q14, q11, q6 @ w5 * inr6\n" - "vld1.32 {d12-d15}, [%[r2]]! @ load r2, 6, 7\n" - "vmla.f32 q15, q11, q8 @ w5 * inr8\n" - /* mul r2 with w6, w7, w8 */ - "vmla.f32 q12, q9, q0 @ w6 * inr0\n" - "vmla.f32 q13, q9, q2 @ w6 * inr2\n" - "vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w8, to q11 */ - "vmla.f32 q14, q9, q4 @ w6 * inr4\n" - "vmla.f32 q15, q9, q6 @ w6 * inr6\n" - "vld1.32 {d16-d17}, [%[r2]]\n" /* load input r2, 8*/ - "vmla.f32 q12, q10, q1 @ w7 * inr1\n" - "vmla.f32 q13, q10, q3 @ w7 * inr3\n" - "vmla.f32 q14, q10, q5 @ w7 * inr5\n" - "vmla.f32 q15, q10, q7 @ w7 * inr7\n" - "sub %[wc0], %[wc0], #144 @ wc0 - 144 to start address\n" - "vmla.f32 q12, q11, q2 @ w8 * inr2\n" - "vmla.f32 q13, q11, q4 @ w8 * inr4\n" - "vmla.f32 q14, q11, q6 @ w8 * inr6\n" - "vmla.f32 q15, q11, q8 @ w8 * inr8\n" - /* transpose */ - "vtrn.32 q12, q13\n" /* a0a1c0c1, b0b1d0d1*/ - "vtrn.32 q14, q15\n" /* a2a3c2c3, b2b3d2d3*/ - "vswp d25, d28\n" /* a0a1a2a3, c0c1c2c3*/ - "vswp d27, d30\n" /* b0b1b2b3, d0d1d2d3*/ - "cmp %[flag_relu], #0\n" - "beq 0f\n" /* skip relu*/ - "vmov.u32 q0, #0\n" - "vmax.f32 q12, q12, q0\n" - "vmax.f32 q13, q13, q0\n" - "vmax.f32 q14, q14, q0\n" - "vmax.f32 q15, q15, q0\n" - "0:\n" - "vst1.32 {d24-d25}, [%[outc0]]!\n" /* save outc0*/ - "vst1.32 {d26-d27}, [%[outc1]]!\n" /* save outc1*/ - "vst1.32 {d28-d29}, [%[outc2]]!\n" /* save outc2*/ - "vst1.32 {d30-d31}, [%[outc3]]!\n" /* save outc3*/ - :[r0] "+r"(inr0), [r1] "+r"(inr1), - [r2] "+r"(inr2), [wc0] "+r" (weight_c), - [outc0]"+r"(outc0), [outc1]"+r"(outc1), - [outc2]"+r"(outc2), [outc3]"+r"(outc3) - :[bias] "r" (bias_local), - [flag_relu]"r"(flag_relu) - :"cc", "memory", - "q0","q1","q2","q3","q4","q5","q6","q7", - "q8", "q9","q10","q11","q12","q13","q14","q15" - ); -#endif // __arch64__ - // clang-format off - if (flag_mask) { - for (int i = 0; i < remain; ++i) { - c0[i] = pre_out[i]; - c1[i] = pre_out[i + 4]; - c2[i] = pre_out[i + 8]; - c3[i] = pre_out[i + 12]; - } - } - } - } - } - } -} - -} // namespace math -} // namespace arm -} // namespace lite -} // namespace paddle diff --git a/lite/backends/arm/math/conv3x3s2_direct_fp32.cc b/lite/backends/arm/math/conv3x3s2_direct_fp32.cc index 8260718a50f8e2fa8497d41d958e82a45ea0480d..f5b196efcca3f3f35367f2fea5e8f475b7147f48 100644 --- a/lite/backends/arm/math/conv3x3s2_direct_fp32.cc +++ b/lite/backends/arm/math/conv3x3s2_direct_fp32.cc @@ -32,10 +32,11 @@ size_t conv3x3s2_direct_workspace_size(const operators::ConvParam& param, ARMContext* ctx) { auto dim_in = param.x->dims(); auto dim_out = param.output->dims(); + auto paddings = *param.paddings; const int threads = ctx->threads(); int llc_size = ctx->llc_size() / sizeof(float); - const int pad_w = param.paddings[1]; - const int pad_h = param.paddings[0]; + const int pad_w = paddings[2]; + const int pad_h = paddings[0]; int ow = dim_out[3]; int oh = dim_out[2]; int ic = dim_in[1]; @@ -73,10 +74,12 @@ void conv_3x3s2_direct_fp32(const float* i_data, //! 3x3s2 convolution, implemented by direct algorithm //! prepack input to tmp buffer //! write output to tmp buffer + auto paddings = *param.paddings; + auto act_param = param.activation_param; const int threads = ctx->threads(); int l2_size = ctx->llc_size() / sizeof(float); - const int pad_w = param.paddings[1]; - const int pad_h = param.paddings[0]; + const int pad_w = paddings[2]; + const int pad_h = paddings[0]; const int wout_round = ROUNDUP(ow, OUT_W_BLOCK); const int win_round = wout_round * 2 /*stride_w*/ + 1; bool flag_relu = param.fuse_relu; @@ -508,7 +511,8 @@ void conv_3x3s2_direct_fp32(const float* i_data, oh, ow, flag_relu, - ptr_write); + ptr_write, + &act_param); } #pragma omp parallel for num_threads(threads) @@ -837,7 +841,8 @@ void conv_3x3s2_direct_fp32(const float* i_data, oh, ow, flag_relu, - ptr_write); + ptr_write, + &act_param); } } } diff --git a/lite/backends/arm/math/conv3x3s2_direct_int8.cc b/lite/backends/arm/math/conv3x3s2_direct_int8.cc index 01b7a812ebc05a054bb9952bf53605ce7aed135a..3d6f3dd743c3e46b6123f2c93dbfed586ad7b4c6 100644 --- a/lite/backends/arm/math/conv3x3s2_direct_int8.cc +++ b/lite/backends/arm/math/conv3x3s2_direct_int8.cc @@ -46,10 +46,11 @@ void conv_3x3s2_direct_int8(const int8_t* din, //! 3x3s2 int8 convolution, implemented by direct algorithm //! prepack input to tmp buffer //! write output to tmp buffer + auto paddings = *param.paddings; bool flag_relu = param.fuse_relu; bool flag_bias = param.bias; - int pad_h = param.paddings[0]; - int pad_w = param.paddings[1]; + int pad_h = paddings[0]; + int pad_w = paddings[2]; const int threads = ctx->threads(); int llc_size = ctx->llc_size() / 4; @@ -472,10 +473,11 @@ void conv_3x3s2_direct_int8(const int8_t* din, //! 3x3s2 int8 convolution, implemented by direct algorithm //! prepack input to tmp buffer //! write output to tmp buffer + auto paddings = *param.paddings; bool flag_relu = param.fuse_relu; bool flag_bias = param.bias; - int pad_h = param.paddings[0]; - int pad_w = param.paddings[1]; + int pad_h = paddings[0]; + int pad_w = paddings[2]; const int threads = ctx->threads(); //! set 1/4 l2 cache int llc_size = ctx->llc_size() / 4; diff --git a/lite/backends/arm/math/conv_depthwise_3x3s2.cc b/lite/backends/arm/math/conv3x3s2p01_depthwise_fp32.cc similarity index 60% rename from lite/backends/arm/math/conv_depthwise_3x3s2.cc rename to lite/backends/arm/math/conv3x3s2p01_depthwise_fp32.cc index ec039af98cb7e4fb037475dd4e5ee29204252165..3e5569365119b97397c6d42f48bacd2552b248e5 100644 --- a/lite/backends/arm/math/conv_depthwise_3x3s2.cc +++ b/lite/backends/arm/math/conv3x3s2p01_depthwise_fp32.cc @@ -12,8 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "lite/backends/arm/math/conv_depthwise.h" #include +#include "lite/backends/arm/math/conv_block_utils.h" +#include "lite/backends/arm/math/conv_depthwise.h" namespace paddle { namespace lite { @@ -24,13 +25,13 @@ void conv_depthwise_3x3s2p0_bias(float* dout, const float* weights, const float* bias, bool flag_bias, - bool flag_relu, const int num, const int ch_in, const int h_in, const int w_in, const int h_out, const int w_out, + const operators::ActivationParam act_param, ARMContext* ctx); void conv_depthwise_3x3s2p0_bias_s(float* dout, @@ -38,13 +39,13 @@ void conv_depthwise_3x3s2p0_bias_s(float* dout, const float* weights, const float* bias, bool flag_bias, - bool flag_relu, const int num, const int ch_in, const int h_in, const int w_in, const int h_out, const int w_out, + const operators::ActivationParam act_param, ARMContext* ctx); void conv_depthwise_3x3s2p1_bias(float* dout, @@ -52,13 +53,13 @@ void conv_depthwise_3x3s2p1_bias(float* dout, const float* weights, const float* bias, bool flag_bias, - bool flag_relu, const int num, const int ch_in, const int h_in, const int w_in, const int h_out, const int w_out, + const operators::ActivationParam act_param, ARMContext* ctx); void conv_depthwise_3x3s2p1_bias_s(float* dout, @@ -66,13 +67,13 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout, const float* weights, const float* bias, bool flag_bias, - bool flag_relu, const int num, const int ch_in, const int h_in, const int w_in, const int h_out, const int w_out, + const operators::ActivationParam act_param, ARMContext* ctx); void conv_depthwise_3x3s2_fp32(const float* din, @@ -88,7 +89,7 @@ void conv_depthwise_3x3s2_fp32(const float* din, const float* bias, int pad, bool flag_bias, - bool flag_relu, + const operators::ActivationParam act_param, ARMContext* ctx) { if (pad == 0) { if (w_in > 7) { @@ -97,13 +98,13 @@ void conv_depthwise_3x3s2_fp32(const float* din, weights, bias, flag_bias, - flag_relu, num, ch_in, h_in, w_in, h_out, w_out, + act_param, ctx); } else { conv_depthwise_3x3s2p0_bias_s(dout, @@ -111,13 +112,13 @@ void conv_depthwise_3x3s2_fp32(const float* din, weights, bias, flag_bias, - flag_relu, num, ch_in, h_in, w_in, h_out, w_out, + act_param, ctx); } } @@ -128,13 +129,13 @@ void conv_depthwise_3x3s2_fp32(const float* din, weights, bias, flag_bias, - flag_relu, num, ch_in, h_in, w_in, h_out, w_out, + act_param, ctx); } else { conv_depthwise_3x3s2p1_bias_s(dout, @@ -142,13 +143,13 @@ void conv_depthwise_3x3s2_fp32(const float* din, weights, bias, flag_bias, - flag_relu, num, ch_in, h_in, w_in, h_out, w_out, + act_param, ctx); } } @@ -205,14 +206,12 @@ void conv_depthwise_3x3s2_fp32(const float* din, \ "ext v10.16b, %[vzero].16b, v9.16b, #12 \n" \ "fadd v16.4s, v16.4s, v11.4s \n" \ - "fadd v16.4s, v16.4s, v12.4s \n" + "fadd v16.4s, v16.4s, v12.4s \n" /* r4 */ \ + "fmla v13.4s, v8.4s, %[w2].s[1] \n" \ + "fmla v14.4s, v9.4s, %[w2].s[2] \n" \ + "fmla v17.4s, v10.4s, %[w2].s[0] \n" #define LEFT_RESULT_S2 \ - /* r4 */ \ - "fmla v13.4s, v8.4s, %[w2].s[1] \n" \ - "fmla v14.4s, v9.4s, %[w2].s[2] \n" \ - "fmla v17.4s, v10.4s, %[w2].s[0] \n" \ - \ "st1 {v16.4s}, [%[outptr0]], #16 \n" \ \ "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" \ @@ -244,53 +243,52 @@ void conv_depthwise_3x3s2_fp32(const float* din, \ "blt 1f \n" -#define MID_COMPUTE_S2 \ - "2: \n" /* r0 */ \ - "fmul v11.4s, v0.4s, %[w0].s[0] \n" \ - "fmul v12.4s, v1.4s, %[w0].s[1] \n" \ - "fmla v16.4s, v10.4s, %[w0].s[2] \n" \ - \ - "ext v10.16b, v2.16b, v18.16b, #4 \n" \ - "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" /* r1 */ \ - "fmla v11.4s, v2.4s, %[w1].s[0] \n" \ - "fmla v12.4s, v3.4s, %[w1].s[1] \n" \ - "fmla v16.4s, v10.4s, %[w1].s[2] \n" \ - \ - "ext v10.16b, v4.16b, v19.16b, #4 \n" \ - \ - "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" /* r2 */ \ - "fmul v13.4s, v4.4s, %[w0].s[0] \n" \ - "fmla v11.4s, v4.4s, %[w2].s[0] \n" \ - \ - "fmul v14.4s, v5.4s, %[w0].s[1] \n" \ - "fmla v12.4s, v5.4s, %[w2].s[1] \n" \ - \ - "fmla v17.4s, v10.4s, %[w0].s[2] \n" \ - "fmla v16.4s, v10.4s, %[w2].s[2] \n" \ - \ - "ext v10.16b, v6.16b, v20.16b, #4 \n" \ - \ - "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" /* r3 */ \ - "fmla v13.4s, v6.4s, %[w1].s[0] \n" \ - "fmla v14.4s, v7.4s, %[w1].s[1] \n" \ - "fmla v17.4s, v10.4s, %[w1].s[2] \n" \ - \ - "ext v10.16b, v8.16b, v21.16b, #4 \n" \ - \ - "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" \ - \ - "fadd v16.4s, v16.4s, v11.4s \n" \ - "fadd v16.4s, v16.4s, v12.4s \n" +#define MID_COMPUTE_S2 \ + "2: \n" /* r0 */ \ + "fmul v11.4s, v0.4s, %[w0].s[0] \n" \ + "fmul v12.4s, v1.4s, %[w0].s[1] \n" \ + "fmla v16.4s, v10.4s, %[w0].s[2] \n" \ + \ + "ext v10.16b, v2.16b, v18.16b, #4 \n" \ + "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" /* r1 */ \ + "fmla v11.4s, v2.4s, %[w1].s[0] \n" \ + "fmla v12.4s, v3.4s, %[w1].s[1] \n" \ + "fmla v16.4s, v10.4s, %[w1].s[2] \n" \ + \ + "ext v10.16b, v4.16b, v19.16b, #4 \n" \ + \ + "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" /* r2 */ \ + "fmul v13.4s, v4.4s, %[w0].s[0] \n" \ + "fmla v11.4s, v4.4s, %[w2].s[0] \n" \ + \ + "fmul v14.4s, v5.4s, %[w0].s[1] \n" \ + "fmla v12.4s, v5.4s, %[w2].s[1] \n" \ + \ + "fmla v17.4s, v10.4s, %[w0].s[2] \n" \ + "fmla v16.4s, v10.4s, %[w2].s[2] \n" \ + \ + "ext v10.16b, v6.16b, v20.16b, #4 \n" \ + \ + "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" /* r3 */ \ + "fmla v13.4s, v6.4s, %[w1].s[0] \n" \ + "fmla v14.4s, v7.4s, %[w1].s[1] \n" \ + "fmla v17.4s, v10.4s, %[w1].s[2] \n" \ + \ + "ext v10.16b, v8.16b, v21.16b, #4 \n" \ + \ + "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" \ + \ + "fadd v16.4s, v16.4s, v11.4s \n" \ + "fadd v16.4s, v16.4s, v12.4s \n" /* r4 */ \ + "fmla v13.4s, v8.4s, %[w2].s[0] \n" \ + "fmla v14.4s, v9.4s, %[w2].s[1] \n" \ + "fmla v17.4s, v10.4s, %[w2].s[2] \n" \ + \ + "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \ + "ld1 {v15.4s}, [%[inptr0]] \n" \ + "ld1 {v18.4s}, [%[inptr1]] \n" #define MID_RESULT_S2 \ - /* r4 */ \ - "fmla v13.4s, v8.4s, %[w2].s[0] \n" \ - "fmla v14.4s, v9.4s, %[w2].s[1] \n" \ - "fmla v17.4s, v10.4s, %[w2].s[2] \n" \ - \ - "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \ - "ld1 {v15.4s}, [%[inptr0]] \n" \ - "ld1 {v18.4s}, [%[inptr1]] \n" \ "st1 {v16.4s}, [%[outptr0]], #16 \n" \ \ "fadd v17.4s, v17.4s, v13.4s \n" \ @@ -360,14 +358,12 @@ void conv_depthwise_3x3s2_fp32(const float* din, \ "fadd v16.4s, v16.4s, v11.4s \n" \ "fadd v16.4s, v16.4s, v12.4s \n" \ - "ld1 {v1.4s}, [%[outptr1]] \n" + "ld1 {v1.4s}, [%[outptr1]] \n" /* r4 */ \ + "fmla v13.4s, v8.4s, %[w2].s[0] \n" \ + "fmla v14.4s, v9.4s, %[w2].s[1] \n" \ + "fmla v17.4s, v10.4s, %[w2].s[2] \n" #define RIGHT_RESULT_S2 \ - /* r4 */ \ - "fmla v13.4s, v8.4s, %[w2].s[0] \n" \ - "fmla v14.4s, v9.4s, %[w2].s[1] \n" \ - "fmla v17.4s, v10.4s, %[w2].s[2] \n" \ - \ "bif v16.16b, v0.16b, %[wmask].16b \n" \ \ "fadd v17.4s, v17.4s, v13.4s \n" \ @@ -382,11 +378,6 @@ void conv_depthwise_3x3s2_fp32(const float* din, "4: \n" #define LEFT_RESULT_S2_RELU \ - /* r4 */ \ - "fmla v13.4s, v8.4s, %[w2].s[1] \n" \ - "fmla v14.4s, v9.4s, %[w2].s[2] \n" \ - "fmla v17.4s, v10.4s, %[w2].s[0] \n" \ - \ "fmax v16.4s, v16.4s, %[vzero].4s \n" \ \ "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" \ @@ -422,16 +413,85 @@ void conv_depthwise_3x3s2_fp32(const float* din, "and v17.16b, %[vbias].16b, %[vbias].16b \n" \ \ "blt 1f \n" +#define LEFT_RESULT_S2_RELU6 \ + "fmax v16.4s, v16.4s, %[vzero].4s \n" \ + "ld1 {v22.4s}, [%[six_ptr]] \n" \ + \ + "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" \ + "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" \ + "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" \ + \ + "fadd v17.4s, v17.4s, v13.4s \n" \ + "fmin v16.4s, v16.4s, v22.4s \n" \ + \ + "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" \ + "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \ + "ld1 {v15.4s}, [%[inptr0]] \n" \ + \ + "fadd v17.4s, v17.4s, v14.4s \n" \ + "st1 {v16.4s}, [%[outptr0]], #16 \n" \ + \ + "ld1 {v18.4s}, [%[inptr1]] \n" \ + "ld1 {v19.4s}, [%[inptr2]] \n" \ + \ + "ext v10.16b, v0.16b, v15.16b, #4 \n" \ + \ + "and v16.16b, %[vbias].16b, %[vbias].16b \n" \ + "fmax v17.4s, v17.4s, %[vzero].4s \n" \ + \ + "ld1 {v20.4s}, [%[inptr3]] \n" \ + "ld1 {v21.4s}, [%[inptr4]] \n" \ + \ + "fmin v17.4s, v17.4s, v22.4s \n" \ + \ + "cmp %w[cnt], #1 \n" \ + \ + "st1 {v17.4s}, [%[outptr1]], #16 \n" \ + "and v17.16b, %[vbias].16b, %[vbias].16b \n" \ + \ + "blt 1f \n" + +#define LEFT_RESULT_S2_LEAKY_RELU \ + "ld1 {v22.4s}, [%[scale_ptr]] \n" \ + "fcmge v11.4s, v16.4s, %[vzero].4s \n" /* vcgeq_f32 */ \ + \ + "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" \ + "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" \ + "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" \ + \ + "fmul v12.4s, v16.4s, v22.4s \n" \ + "fadd v17.4s, v17.4s, v13.4s \n" \ + \ + "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" \ + "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \ + "ld1 {v15.4s}, [%[inptr0]] \n" \ + \ + "fadd v17.4s, v17.4s, v14.4s \n" \ + "bif v16.16b, v12.16b, v11.16b \n" /* choose*/ \ + \ + "ld1 {v18.4s}, [%[inptr1]] \n" \ + "ld1 {v19.4s}, [%[inptr2]] \n" \ + \ + "ext v10.16b, v0.16b, v15.16b, #4 \n" \ + \ + "st1 {v16.4s}, [%[outptr0]], #16 \n" \ + "fcmge v11.4s, v17.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ + "fmul v12.4s, v16.4s, v22.4s \n" \ + \ + "ld1 {v20.4s}, [%[inptr3]] \n" \ + "ld1 {v21.4s}, [%[inptr4]] \n" \ + \ + "and v16.16b, %[vbias].16b, %[vbias].16b \n" \ + "bif v17.16b, v12.16b, v11.16b \n" /* choose*/ \ + \ + "cmp %w[cnt], #1 \n" \ + \ + "st1 {v17.4s}, [%[outptr1]], #16 \n" \ + "and v17.16b, %[vbias].16b, %[vbias].16b \n" \ + \ + "blt 1f \n" #define MID_RESULT_S2_RELU \ - /* r4 */ \ - "fmla v13.4s, v8.4s, %[w2].s[0] \n" \ - "fmla v14.4s, v9.4s, %[w2].s[1] \n" \ - "fmla v17.4s, v10.4s, %[w2].s[2] \n" \ - \ - "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \ - "ld1 {v15.4s}, [%[inptr0]] \n" \ - "ld1 {v18.4s}, [%[inptr1]] \n" \ "fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ \ \ "fadd v17.4s, v17.4s, v13.4s \n" \ @@ -456,12 +516,59 @@ void conv_depthwise_3x3s2_fp32(const float* din, \ "bne 2b \n" -#define RIGHT_RESULT_S2_RELU \ - /* r4 */ \ - "fmla v13.4s, v8.4s, %[w2].s[0] \n" \ - "fmla v14.4s, v9.4s, %[w2].s[1] \n" \ - "fmla v17.4s, v10.4s, %[w2].s[2] \n" \ +#define MID_RESULT_S2_RELU6 \ + "fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ \ + \ + "fadd v17.4s, v17.4s, v13.4s \n" \ + \ + "ld1 {v19.4s}, [%[inptr2]] \n" \ + "ld1 {v20.4s}, [%[inptr3]] \n" \ + "ld1 {v21.4s}, [%[inptr4]] \n" \ \ + "fmin v16.4s, v16.4s, v22.4s \n" \ + \ + "fadd v17.4s, v17.4s, v14.4s \n" \ + \ + "ext v10.16b, v0.16b, v15.16b, #4 \n" \ + "st1 {v16.4s}, [%[outptr0]], #16 \n" \ + "subs %w[cnt], %w[cnt], #1 \n" \ + \ + "fmax v17.4s, v17.4s, %[vzero].4s \n" /* relu */ \ + "and v16.16b, %[vbias].16b, %[vbias].16b \n" \ + "fmin v17.4s, v17.4s, v22.4s \n" \ + "st1 {v17.4s}, [%[outptr1]], #16 \n" \ + \ + "and v17.16b, %[vbias].16b, %[vbias].16b \n" \ + \ + "bne 2b \n" + +#define MID_RESULT_S2_LEAKY_RELU \ + "fcmge v11.4s, v16.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ + "fmul v12.4s, v16.4s, v22.4s \n" \ + \ + "fadd v17.4s, v17.4s, v13.4s \n" \ + \ + "ld1 {v19.4s}, [%[inptr2]] \n" \ + "ld1 {v20.4s}, [%[inptr3]] \n" \ + "ld1 {v21.4s}, [%[inptr4]] \n" \ + \ + "bif v16.16b, v12.16b, v11.16b \n" /* choose*/ \ + "ext v10.16b, v0.16b, v15.16b, #4 \n" \ + "fcmge v11.4s, v17.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ + "fmul v12.4s, v17.4s, v22.4s \n" \ + \ + "st1 {v16.4s}, [%[outptr0]], #16 \n" \ + "subs %w[cnt], %w[cnt], #1 \n" \ + \ + "and v16.16b, %[vbias].16b, %[vbias].16b \n" \ + "bif v17.16b, v12.16b, v11.16b \n" /* choose*/ \ + "st1 {v17.4s}, [%[outptr1]], #16 \n" \ + \ + "and v17.16b, %[vbias].16b, %[vbias].16b \n" \ + \ + "bne 2b \n" + +#define RIGHT_RESULT_S2_RELU \ "fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ \ \ "fadd v17.4s, v17.4s, v13.4s \n" \ @@ -479,6 +586,47 @@ void conv_depthwise_3x3s2_fp32(const float* din, "st1 {v17.4s}, [%[outptr1]], #16 \n" \ "4: \n" +#define RIGHT_RESULT_S2_RELU6 \ + "fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ \ + \ + "fadd v17.4s, v17.4s, v13.4s \n" \ + \ + "fmin v16.4s, v16.4s, v22.4s \n" \ + \ + "fadd v17.4s, v17.4s, v14.4s \n" \ + \ + "bif v16.16b, v0.16b, %[wmask].16b \n" \ + \ + "fmax v17.4s, v17.4s, %[vzero].4s \n" /* relu */ \ + \ + "st1 {v16.4s}, [%[outptr0]], #16 \n" \ + "fmin v17.4s, v17.4s, v22.4s \n" \ + "bif v17.16b, v1.16b, %[wmask].16b \n" \ + \ + "st1 {v17.4s}, [%[outptr1]], #16 \n" \ + "4: \n" + +#define RIGHT_RESULT_S2_LEAKY_RELU \ + "fcmge v11.4s, v16.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ + "fmul v12.4s, v16.4s, v22.4s \n" \ + "fadd v17.4s, v17.4s, v13.4s \n" \ + \ + "bif v16.16b, v12.16b, v11.16b \n" /* choose*/ \ + \ + "fadd v17.4s, v17.4s, v14.4s \n" \ + \ + "bif v16.16b, v0.16b, %[wmask].16b \n" \ + \ + "fcmge v11.4s, v17.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ + "fmul v12.4s, v17.4s, v22.4s \n" \ + \ + "st1 {v16.4s}, [%[outptr0]], #16 \n" \ + "bif v17.16b, v12.16b, v11.16b \n" /* choose*/ \ + "bif v17.16b, v1.16b, %[wmask].16b \n" \ + \ + "st1 {v17.4s}, [%[outptr1]], #16 \n" \ + "4: \n" + #define COMPUTE_S_S2 \ "movi v9.4s, #0 \n" \ "ld1 {v6.4s, v7.4s}, [%[mask_ptr]], #32 \n" \ @@ -523,7 +671,6 @@ void conv_depthwise_3x3s2_fp32(const float* din, "fmax v4.4s, v4.4s, v9.4s \n" \ \ "st1 {v4.4s}, [%[out]] \n" - #define COMPUTE_S_S2_P0 \ "movi v9.4s, #0 \n" \ "ld1 {v6.4s, v7.4s}, [%[mask_ptr]], #32 \n" \ @@ -560,7 +707,6 @@ void conv_depthwise_3x3s2_fp32(const float* din, "fadd v4.4s, v4.4s, v16.4s \n" #define RESULT_S_S2_P0 "st1 {v4.4s}, [%[out]] \n" - #define RESULT_S_S2_P0_RELU \ "fmax v4.4s, v4.4s, v9.4s \n" \ "st1 {v4.4s}, [%[out]] \n" @@ -705,7 +851,6 @@ void conv_depthwise_3x3s2_fp32(const float* din, "vst1.32 {d6-d7}, [%[outptr]]! \n" \ "cmp %[cnt], #1 \n" \ "blt 1f \n" - #define MID_RESULT_S2_RELU \ "vmax.f32 q3, q3, q9 @ relu \n" \ "subs %[cnt], #1 \n" \ @@ -762,7 +907,6 @@ void conv_depthwise_3x3s2_fp32(const float* din, "vadd.f32 q3, q3, q5 @ add \n" #define RESULT_S_S2 "vst1.32 {d6-d7}, [%[out]] \n" - #define RESULT_S_S2_RELU \ "vmax.f32 q3, q3, q9 @ relu\n" \ \ @@ -810,13 +954,233 @@ void conv_depthwise_3x3s2_fp32(const float* din, "vadd.f32 q3, q3, q5 @ add \n" #define RESULT_S_S2_P0 "vst1.32 {d6-d7}, [%[out]] \n" - #define RESULT_S_S2_P0_RELU \ "vmax.f32 q3, q3, q9 @ relu \n" \ "vst1.32 {d6-d7}, [%[out]] \n" - #endif - +#ifdef __aarch64__ +void act_switch_3x3s2p1(const float* din0_ptr, + const float* din1_ptr, + const float* din2_ptr, + const float* din3_ptr, + const float* din4_ptr, + float* doutr0_ptr, + float* doutr1_ptr, + float32x4_t wr0, + float32x4_t wr1, + float32x4_t wr2, + uint32x4_t vmask_rp1, + uint32x4_t vmask_rp2, + uint32x4_t wmask, + float32x4_t wbias, + float32x4_t vzero, + int cnt, + int cnt_remain, + const operators::ActivationParam act_param) { + bool has_active = act_param.has_active; + if (has_active) { + float tmp = act_param.Relu_clipped_coef; + float ss = act_param.Leaky_relu_alpha; + float vsix[4] = {tmp, tmp, tmp, tmp}; + float vscale[4] = {ss, ss, ss, ss}; + + switch (act_param.active_type) { + case lite_api::ActivationType::kRelu: + asm volatile( + INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_RELU MID_COMPUTE_S2 + MID_RESULT_S2_RELU RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU + : [inptr0] "+r"(din0_ptr), + [inptr1] "+r"(din1_ptr), + [inptr2] "+r"(din2_ptr), + [inptr3] "+r"(din3_ptr), + [inptr4] "+r"(din4_ptr), + [outptr0] "+r"(doutr0_ptr), + [outptr1] "+r"(doutr1_ptr), + [cnt] "+r"(cnt) + : [vzero] "w"(vzero), + [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [remain] "r"(cnt_remain), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [wmask] "w"(wmask), + [vbias] "w"(wbias) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21"); + break; + case lite_api::ActivationType::kRelu6: + /* 0 <= din <= 6 */ + asm volatile( + INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_RELU6 MID_COMPUTE_S2 + MID_RESULT_S2_RELU6 RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU6 + : [inptr0] "+r"(din0_ptr), + [inptr1] "+r"(din1_ptr), + [inptr2] "+r"(din2_ptr), + [inptr3] "+r"(din3_ptr), + [inptr4] "+r"(din4_ptr), + [outptr0] "+r"(doutr0_ptr), + [outptr1] "+r"(doutr1_ptr), + [cnt] "+r"(cnt) + : [vzero] "w"(vzero), + [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [remain] "r"(cnt_remain), + [six_ptr] "r"(vsix), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [wmask] "w"(wmask), + [vbias] "w"(wbias) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22"); + break; + case lite_api::ActivationType::kLeakyRelu: + /*din = din >= 0 ? din : din * scale*/ + asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_LEAKY_RELU + MID_COMPUTE_S2 MID_RESULT_S2_LEAKY_RELU + RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_LEAKY_RELU + : [inptr0] "+r"(din0_ptr), + [inptr1] "+r"(din1_ptr), + [inptr2] "+r"(din2_ptr), + [inptr3] "+r"(din3_ptr), + [inptr4] "+r"(din4_ptr), + [outptr0] "+r"(doutr0_ptr), + [outptr1] "+r"(doutr1_ptr), + [cnt] "+r"(cnt) + : [vzero] "w"(vzero), + [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [remain] "r"(cnt_remain), + [scale_ptr] "r"(vscale), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [wmask] "w"(wmask), + [vbias] "w"(wbias) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22"); + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param.active_type) + << " fuse not support"; + } + } else { + asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2 MID_COMPUTE_S2 + MID_RESULT_S2 RIGHT_COMPUTE_S2 RIGHT_RESULT_S2 + : [inptr0] "+r"(din0_ptr), + [inptr1] "+r"(din1_ptr), + [inptr2] "+r"(din2_ptr), + [inptr3] "+r"(din3_ptr), + [inptr4] "+r"(din4_ptr), + [outptr0] "+r"(doutr0_ptr), + [outptr1] "+r"(doutr1_ptr), + [cnt] "+r"(cnt) + : [vzero] "w"(vzero), + [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [remain] "r"(cnt_remain), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [wmask] "w"(wmask), + [vbias] "w"(wbias) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21"); + } +} +#endif /** * \brief depthwise convolution kernel 3x3, stride 2 * w_in > 7 @@ -826,27 +1190,29 @@ void conv_depthwise_3x3s2p1_bias(float* dout, const float* weights, const float* bias, bool flag_bias, - bool flag_relu, const int num, const int ch_in, const int h_in, const int w_in, const int h_out, const int w_out, + const operators::ActivationParam act_param, ARMContext* ctx) { int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; int out_pad_idx[4] = {0, 1, 2, 3}; int size_pad_bottom = h_out * 2 - h_in; - int cnt_col = (w_out >> 2) - 2; - int size_right_remain = w_in - (7 + cnt_col * 8); - if (size_right_remain >= 9) { - cnt_col++; - size_right_remain -= 8; - } - int cnt_remain = (size_right_remain == 8) ? 4 : (w_out % 4); // + int tile_w = w_out >> 2; + int cnt_remain = w_out % 4; + unsigned int size_right_remain = (unsigned int)(7 + (tile_w << 3) - w_in); + size_right_remain = 8 - size_right_remain; - int size_right_pad = w_out * 2 - w_in; + if (cnt_remain == 0 && size_right_remain == 0) { + cnt_remain = 4; + tile_w -= 1; + size_right_remain = 8; + } + int cnt_col = tile_w - 1; uint32x4_t vmask_rp1 = vcgtq_s32(vdupq_n_s32(size_right_remain), vld1q_s32(right_pad_idx)); // 0 2 4 6 @@ -912,7 +1278,7 @@ void conv_depthwise_3x3s2p1_bias(float* dout, float* doutr1_ptr = nullptr; #ifdef __aarch64__ - for (int i = 0; i < h_in; i += 4) { + for (int i = 0; i < h_out; i += 2) { din0_ptr = dr0; din1_ptr = dr1; din2_ptr = dr2; @@ -939,8 +1305,8 @@ void conv_depthwise_3x3s2p1_bias(float* dout, dr4 = dr3 + w_in; //! process bottom pad - if (i + 4 > h_in) { - switch (i + 4 - h_in) { + if (i * 2 + 4 > h_in) { + switch (i * 2 + 4 - h_in) { case 4: din1_ptr = zero_ptr; case 3: @@ -954,104 +1320,32 @@ void conv_depthwise_3x3s2p1_bias(float* dout, } } //! process output pad - if (i / 2 + 2 > h_out) { + if (i + 2 > h_out) { doutr1_ptr = write_ptr; } int cnt = cnt_col; - if (flag_relu) { - asm volatile( - INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_RELU MID_COMPUTE_S2 - MID_RESULT_S2_RELU RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU - : [inptr0] "+r"(din0_ptr), - [inptr1] "+r"(din1_ptr), - [inptr2] "+r"(din2_ptr), - [inptr3] "+r"(din3_ptr), - [inptr4] "+r"(din4_ptr), - [outptr0] "+r"(doutr0_ptr), - [outptr1] "+r"(doutr1_ptr), - [cnt] "+r"(cnt) - : [vzero] "w"(vzero), - [w0] "w"(wr0), - [w1] "w"(wr1), - [w2] "w"(wr2), - [remain] "r"(cnt_remain), - [mask1] "w"(vmask_rp1), - [mask2] "w"(vmask_rp2), - [wmask] "w"(wmask), - [vbias] "w"(wbias) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21"); - } else { - asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2 MID_COMPUTE_S2 - MID_RESULT_S2 RIGHT_COMPUTE_S2 RIGHT_RESULT_S2 - : [inptr0] "+r"(din0_ptr), - [inptr1] "+r"(din1_ptr), - [inptr2] "+r"(din2_ptr), - [inptr3] "+r"(din3_ptr), - [inptr4] "+r"(din4_ptr), - [outptr0] "+r"(doutr0_ptr), - [outptr1] "+r"(doutr1_ptr), - [cnt] "+r"(cnt) - : [vzero] "w"(vzero), - [w0] "w"(wr0), - [w1] "w"(wr1), - [w2] "w"(wr2), - [remain] "r"(cnt_remain), - [mask1] "w"(vmask_rp1), - [mask2] "w"(vmask_rp2), - [wmask] "w"(wmask), - [vbias] "w"(wbias) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21"); - } + act_switch_3x3s2p1(din0_ptr, + din1_ptr, + din2_ptr, + din3_ptr, + din4_ptr, + doutr0_ptr, + doutr1_ptr, + wr0, + wr1, + wr2, + vmask_rp1, + vmask_rp2, + wmask, + wbias, + vzero, + cnt, + cnt_remain, + act_param); doutr0 = doutr0 + 2 * w_out; } #else - for (int i = 0; i < h_in; i += 2) { + for (int i = 0; i < h_out; i++) { din0_ptr = dr0; din1_ptr = dr1; din2_ptr = dr2; @@ -1072,8 +1366,8 @@ void conv_depthwise_3x3s2p1_bias(float* dout, } //! process bottom pad - if (i + 2 > h_in) { - switch (i + 2 - h_in) { + if (i * 2 + 2 > h_in) { + switch (i * 2 + 2 - h_in) { case 2: din1_ptr = zero_ptr; case 1: @@ -1084,65 +1378,37 @@ void conv_depthwise_3x3s2p1_bias(float* dout, } int cnt = cnt_col; unsigned int* mask_ptr = dmask; - if (flag_relu) { - asm volatile( - INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_RELU MID_COMPUTE_S2 - MID_RESULT_S2_RELU RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU - : [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr), - [outptr] "+r"(doutr0_ptr), - [cnt] "+r"(cnt), - [mask_ptr] "+r"(mask_ptr) - : [remain] "r"(cnt_remain), - [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias] "r"(bias_c) - : "cc", - "memory", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - } else { - asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2 MID_COMPUTE_S2 - MID_RESULT_S2 RIGHT_COMPUTE_S2 RIGHT_RESULT_S2 - : [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr), - [outptr] "+r"(doutr0_ptr), - [cnt] "+r"(cnt), - [mask_ptr] "+r"(mask_ptr) - : [remain] "r"(cnt_remain), - [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias] "r"(bias_c) - : "cc", - "memory", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); + asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2 MID_COMPUTE_S2 + MID_RESULT_S2 RIGHT_COMPUTE_S2 RIGHT_RESULT_S2 + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [outptr] "+r"(doutr0_ptr), + [cnt] "+r"(cnt), + [mask_ptr] "+r"(mask_ptr) + : [remain] "r"(cnt_remain), + [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "r"(bias_c) + : "cc", + "memory", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + // do act + if (act_param.has_active) { + act_switch_process(doutr0, doutr0, w_out, &act_param); } doutr0 = doutr0 + w_out; } @@ -1159,13 +1425,13 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout, const float* weights, const float* bias, bool flag_bias, - bool flag_relu, const int num, const int ch_in, const int h_in, const int w_in, const int h_out, const int w_out, + const operators::ActivationParam act_param, ARMContext* ctx) { int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; int out_pad_idx[4] = {0, 1, 2, 3}; @@ -1221,108 +1487,59 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout, unsigned int* mask_ptr = dmask; #ifdef __aarch64__ - if (flag_relu) { - asm volatile(COMPUTE_S_S2 RESULT_S_S2_RELU - : [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr), - [mask_ptr] "+r"(mask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias] "w"(vbias), - [out] "r"(out_buf) - : "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15"); - } else { - asm volatile(COMPUTE_S_S2 RESULT_S_S2 - : [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr), - [mask_ptr] "+r"(mask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias] "w"(vbias), - [out] "r"(out_buf) - : "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15"); - } + asm volatile(COMPUTE_S_S2 RESULT_S_S2 + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [mask_ptr] "+r"(mask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "w"(vbias), + [out] "r"(out_buf) + : "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15"); #else - if (flag_relu) { - asm volatile(COMPUTE_S_S2 RESULT_S_S2_RELU - : [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr), - [mask_ptr] "+r"(mask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias] "r"(bias_c), - [out] "r"(out_buf) - : "cc", - "memory", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - } else { - asm volatile(COMPUTE_S_S2 RESULT_S_S2 - : [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr), - [mask_ptr] "+r"(mask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias] "r"(bias_c), - [out] "r"(out_buf) - : "cc", - "memory", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - } + asm volatile(COMPUTE_S_S2 RESULT_S_S2 + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [mask_ptr] "+r"(mask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "r"(bias_c), + [out] "r"(out_buf) + : "cc", + "memory", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); #endif + // do act + if (act_param.has_active) { + act_switch_process(out_buf, out_buf, w_out, &act_param); + } for (int w = 0; w < w_out; ++w) { *dout_channel++ = out_buf[w]; } @@ -1333,6 +1550,271 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout, } } +#ifdef __aarch64__ +void act_switch_3x3s2p0(const float* din0_ptr, + const float* din1_ptr, + const float* din2_ptr, + const float* din3_ptr, + const float* din4_ptr, + float* doutr0_ptr, + float* doutr1_ptr, + float32x4_t wr0, + float32x4_t wr1, + float32x4_t wr2, + uint32x4_t vmask_rp1, + uint32x4_t vmask_rp2, + uint32x4_t wmask, + float32x4_t wbias, + float32x4_t vzero, + int cnt, + int cnt_remain, + const operators::ActivationParam act_param) { + bool has_active = act_param.has_active; + if (has_active) { + float tmp = act_param.Relu_clipped_coef; + float ss = act_param.Leaky_relu_alpha; + float vsix[4] = {tmp, tmp, tmp, tmp}; + float vscale[4] = {ss, ss, ss, ss}; + + switch (act_param.active_type) { + case lite_api::ActivationType::kRelu: + asm volatile( + INIT_S2 + "ld1 {v15.4s}, [%[inptr0]] \n" + "ld1 {v18.4s}, [%[inptr1]] \n" + "ld1 {v19.4s}, [%[inptr2]] \n" + "ld1 {v20.4s}, [%[inptr3]] \n" + "ld1 {v21.4s}, [%[inptr4]] \n" + "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} + MID_COMPUTE_S2 MID_RESULT_S2_RELU + "cmp %w[remain], #1 \n" + "blt 4f \n" RIGHT_COMPUTE_S2 + RIGHT_RESULT_S2_RELU + "4: \n" + : [inptr0] "+r"(din0_ptr), + [inptr1] "+r"(din1_ptr), + [inptr2] "+r"(din2_ptr), + [inptr3] "+r"(din3_ptr), + [inptr4] "+r"(din4_ptr), + [outptr0] "+r"(doutr0_ptr), + [outptr1] "+r"(doutr1_ptr), + [cnt] "+r"(cnt) + : [vzero] "w"(vzero), + [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [remain] "r"(cnt_remain), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [wmask] "w"(wmask), + [vbias] "w"(wbias) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21"); + break; + case lite_api::ActivationType::kRelu6: + /* 0 <= din <= 6 */ + asm volatile( + INIT_S2 + "ld1 {v15.4s}, [%[inptr0]] \n" + "ld1 {v18.4s}, [%[inptr1]] \n" + "ld1 {v19.4s}, [%[inptr2]] \n" + "ld1 {v20.4s}, [%[inptr3]] \n" + "ld1 {v21.4s}, [%[inptr4]] \n" + "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} + "ld1 {v22.4s}, [%[six_ptr]] \n" MID_COMPUTE_S2 + MID_RESULT_S2_RELU6 + "cmp %w[remain], #1 \n" + "blt 4f \n" RIGHT_COMPUTE_S2 + RIGHT_RESULT_S2_RELU6 + "4: \n" + : [inptr0] "+r"(din0_ptr), + [inptr1] "+r"(din1_ptr), + [inptr2] "+r"(din2_ptr), + [inptr3] "+r"(din3_ptr), + [inptr4] "+r"(din4_ptr), + [outptr0] "+r"(doutr0_ptr), + [outptr1] "+r"(doutr1_ptr), + [cnt] "+r"(cnt) + : [vzero] "w"(vzero), + [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [remain] "r"(cnt_remain), + [six_ptr] "r"(vsix), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [wmask] "w"(wmask), + [vbias] "w"(wbias) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22"); + break; + case lite_api::ActivationType::kLeakyRelu: + /*din = din >= 0 ? din : din * scale*/ + asm volatile( + INIT_S2 + "ld1 {v15.4s}, [%[inptr0]] \n" + "ld1 {v18.4s}, [%[inptr1]] \n" + "ld1 {v19.4s}, [%[inptr2]] \n" + "ld1 {v20.4s}, [%[inptr3]] \n" + "ld1 {v21.4s}, [%[inptr4]] \n" + "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} + "ld1 {v22.4s}, [%[scale_ptr]] \n" MID_COMPUTE_S2 + MID_RESULT_S2_LEAKY_RELU + "cmp %w[remain], #1 \n" + "blt 4f \n" RIGHT_COMPUTE_S2 + RIGHT_RESULT_S2_LEAKY_RELU + "4: \n" + : [inptr0] "+r"(din0_ptr), + [inptr1] "+r"(din1_ptr), + [inptr2] "+r"(din2_ptr), + [inptr3] "+r"(din3_ptr), + [inptr4] "+r"(din4_ptr), + [outptr0] "+r"(doutr0_ptr), + [outptr1] "+r"(doutr1_ptr), + [cnt] "+r"(cnt) + : [vzero] "w"(vzero), + [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [remain] "r"(cnt_remain), + [scale_ptr] "r"(vscale), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [wmask] "w"(wmask), + [vbias] "w"(wbias) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22"); + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param.active_type) + << " fuse not support"; + } + } else { + asm volatile( + INIT_S2 + "ld1 {v15.4s}, [%[inptr0]] \n" + "ld1 {v18.4s}, [%[inptr1]] \n" + "ld1 {v19.4s}, [%[inptr2]] \n" + "ld1 {v20.4s}, [%[inptr3]] \n" + "ld1 {v21.4s}, [%[inptr4]] \n" + "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} + MID_COMPUTE_S2 MID_RESULT_S2 + "cmp %w[remain], #1 \n" + "blt 4f \n" RIGHT_COMPUTE_S2 + RIGHT_RESULT_S2 "4: \n" + : [inptr0] "+r"(din0_ptr), + [inptr1] "+r"(din1_ptr), + [inptr2] "+r"(din2_ptr), + [inptr3] "+r"(din3_ptr), + [inptr4] "+r"(din4_ptr), + [outptr0] "+r"(doutr0_ptr), + [outptr1] "+r"(doutr1_ptr), + [cnt] "+r"(cnt) + : [vzero] "w"(vzero), + [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [remain] "r"(cnt_remain), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [wmask] "w"(wmask), + [vbias] "w"(wbias) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21"); + } +} +#endif /** * \brief depthwise convolution kernel 3x3, stride 2 */ @@ -1342,13 +1824,13 @@ void conv_depthwise_3x3s2p0_bias(float* dout, const float* weights, const float* bias, bool flag_bias, - bool flag_relu, const int num, const int ch_in, const int h_in, const int w_in, const int h_out, const int w_out, + const operators::ActivationParam act_param, ARMContext* ctx) { int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; int out_pad_idx[4] = {0, 1, 2, 3}; @@ -1356,7 +1838,14 @@ void conv_depthwise_3x3s2p0_bias(float* dout, int tile_w = w_out >> 2; int cnt_remain = w_out % 4; - unsigned int size_right_remain = (unsigned int)(w_in - (tile_w << 3)); + unsigned int size_right_remain = (unsigned int)(8 + (tile_w << 3) - w_in); + size_right_remain = 8 - size_right_remain; + + if (cnt_remain == 0 && size_right_remain == 0) { + cnt_remain = 4; + tile_w -= 1; + size_right_remain = 8; + } uint32x4_t vmask_rp1 = vcgtq_s32(vdupq_n_s32(size_right_remain), vld1q_s32(right_pad_idx)); // 0 2 4 6 @@ -1461,117 +1950,24 @@ void conv_depthwise_3x3s2p0_bias(float* dout, doutr1_ptr = write_ptr; } int cnt = tile_w; - if (flag_relu) { - asm volatile( - INIT_S2 - "ld1 {v15.4s}, [%[inptr0]] \n" - "ld1 {v18.4s}, [%[inptr1]] \n" - "ld1 {v19.4s}, [%[inptr2]] \n" - "ld1 {v20.4s}, [%[inptr3]] \n" - "ld1 {v21.4s}, [%[inptr4]] \n" - "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} - MID_COMPUTE_S2 MID_RESULT_S2_RELU - "cmp %w[remain], #1 \n" - "blt 4f \n" RIGHT_COMPUTE_S2 - RIGHT_RESULT_S2_RELU - "4: \n" - : [inptr0] "+r"(din0_ptr), - [inptr1] "+r"(din1_ptr), - [inptr2] "+r"(din2_ptr), - [inptr3] "+r"(din3_ptr), - [inptr4] "+r"(din4_ptr), - [outptr0] "+r"(doutr0_ptr), - [outptr1] "+r"(doutr1_ptr), - [cnt] "+r"(cnt) - : [vzero] "w"(vzero), - [w0] "w"(wr0), - [w1] "w"(wr1), - [w2] "w"(wr2), - [remain] "r"(cnt_remain), - [mask1] "w"(vmask_rp1), - [mask2] "w"(vmask_rp2), - [wmask] "w"(wmask), - [vbias] "w"(wbias) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21"); - } else { - asm volatile( - INIT_S2 - "ld1 {v15.4s}, [%[inptr0]] \n" - "ld1 {v18.4s}, [%[inptr1]] \n" - "ld1 {v19.4s}, [%[inptr2]] \n" - "ld1 {v20.4s}, [%[inptr3]] \n" - "ld1 {v21.4s}, [%[inptr4]] \n" - "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} - MID_COMPUTE_S2 MID_RESULT_S2 - "cmp %w[remain], #1 \n" - "blt 4f \n" RIGHT_COMPUTE_S2 - RIGHT_RESULT_S2 - "4: \n" - : [inptr0] "+r"(din0_ptr), - [inptr1] "+r"(din1_ptr), - [inptr2] "+r"(din2_ptr), - [inptr3] "+r"(din3_ptr), - [inptr4] "+r"(din4_ptr), - [outptr0] "+r"(doutr0_ptr), - [outptr1] "+r"(doutr1_ptr), - [cnt] "+r"(cnt) - : [vzero] "w"(vzero), - [w0] "w"(wr0), - [w1] "w"(wr1), - [w2] "w"(wr2), - [remain] "r"(cnt_remain), - [mask1] "w"(vmask_rp1), - [mask2] "w"(vmask_rp2), - [wmask] "w"(wmask), - [vbias] "w"(wbias) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21"); - } + act_switch_3x3s2p0(din0_ptr, + din1_ptr, + din2_ptr, + din3_ptr, + din4_ptr, + doutr0_ptr, + doutr1_ptr, + wr0, + wr1, + wr2, + vmask_rp1, + vmask_rp2, + wmask, + wbias, + vzero, + cnt, + cnt_remain, + act_param); doutr0 = doutr0 + 2 * w_out; } #else @@ -1599,64 +1995,36 @@ void conv_depthwise_3x3s2p0_bias(float* dout, } int cnt = tile_w; unsigned int* mask_ptr = dmask; - if (flag_relu) { - asm volatile(INIT_S2 MID_COMPUTE_S2 MID_RESULT_S2_RELU - RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU - : [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr), - [outptr] "+r"(doutr0_ptr), - [cnt] "+r"(cnt), - [mask_ptr] "+r"(mask_ptr) - : [remain] "r"(cnt_remain), - [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias] "r"(bias_c) - : "cc", - "memory", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - } else { - asm volatile(INIT_S2 MID_COMPUTE_S2 MID_RESULT_S2 RIGHT_COMPUTE_S2 - RIGHT_RESULT_S2 - : [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr), - [outptr] "+r"(doutr0_ptr), - [cnt] "+r"(cnt), - [mask_ptr] "+r"(mask_ptr) - : [remain] "r"(cnt_remain), - [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias] "r"(bias_c) - : "cc", - "memory", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); + asm volatile(INIT_S2 MID_COMPUTE_S2 MID_RESULT_S2 RIGHT_COMPUTE_S2 + RIGHT_RESULT_S2 + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [outptr] "+r"(doutr0_ptr), + [cnt] "+r"(cnt), + [mask_ptr] "+r"(mask_ptr) + : [remain] "r"(cnt_remain), + [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "r"(bias_c) + : "cc", + "memory", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + if (act_param.has_active) { + act_switch_process(doutr0, doutr0, w_out, &act_param); } doutr0 = doutr0 + w_out; } @@ -1673,13 +2041,13 @@ void conv_depthwise_3x3s2p0_bias_s(float* dout, const float* weights, const float* bias, bool flag_bias, - bool flag_relu, const int num, const int ch_in, const int h_in, const int w_in, const int h_out, const int w_out, + const operators::ActivationParam act_param, ARMContext* ctx) { int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; int out_pad_idx[4] = {0, 1, 2, 3}; @@ -1741,114 +2109,62 @@ void conv_depthwise_3x3s2p0_bias_s(float* dout, unsigned int* mask_ptr = dmask; #ifdef __aarch64__ - if (flag_relu) { - asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0_RELU - : [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr), - [mask_ptr] "+r"(mask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias] "w"(vbias), - [out] "r"(out_buf) - : "cc", - "memory", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16"); - } else { - asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0 - : [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr), - [mask_ptr] "+r"(mask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias] "w"(vbias), - [out] "r"(out_buf) - : "cc", - "memory", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16"); - } + asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0 + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [mask_ptr] "+r"(mask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "w"(vbias), + [out] "r"(out_buf) + : "cc", + "memory", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16"); + #else - if (flag_relu) { - asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0_RELU - : [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias] "r"(bias_c), - [out] "r"(out_buf), - [mask_ptr] "r"(dmask) - : "cc", - "memory", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - } else { - asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0 - : [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias] "r"(bias_c), - [out] "r"(out_buf), - [mask_ptr] "r"(dmask) - : "cc", - "memory", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - } + asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0 + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "r"(bias_c), + [out] "r"(out_buf), + [mask_ptr] "r"(dmask) + : "cc", + "memory", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); #endif + if (act_param.has_active) { + act_switch_process(out_buf, out_buf, w_out, &act_param); + } for (int w = 0; w < w_out; ++w) { *dout_channel++ = out_buf[w]; } diff --git a/lite/backends/arm/math/conv3x3s2px_depthwise_fp32.cc b/lite/backends/arm/math/conv3x3s2px_depthwise_fp32.cc new file mode 100644 index 0000000000000000000000000000000000000000..4617d40f4372f6589f20b50205fb307cdc705808 --- /dev/null +++ b/lite/backends/arm/math/conv3x3s2px_depthwise_fp32.cc @@ -0,0 +1,721 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/backends/arm/math/conv_block_utils.h" +#include "lite/backends/arm/math/conv_impl.h" +#include "lite/core/context.h" +#include "lite/operators/op_params.h" +#ifdef ARM_WITH_OMP +#include +#endif + +namespace paddle { +namespace lite { +namespace arm { +namespace math { +#ifdef __aarch64__ +#define COMPUTE \ + "ldr q8, [%[bias]]\n" /* load bias */ \ + "ldp q0, q1, [%[inr0]], #32\n" /* load input r0*/ \ + "and v19.16b, v8.16b, v8.16b\n" \ + "ldp q2, q3, [%[inr0]], #32\n" /* load input r0*/ \ + "and v20.16b, v8.16b, v8.16b\n" \ + "ldp q4, q5, [%[inr0]], #32\n" /* load input r0*/ \ + "and v21.16b, v8.16b, v8.16b\n" \ + "ldp q6, q7, [%[inr0]], #32\n" /* load input r0*/ \ + "and v22.16b, v8.16b, v8.16b\n" \ + "ldr q8, [%[inr0]]\n" /* load input r0*/ \ + "fmla v19.4s , %[w0].4s, v0.4s\n" /* outr0 = w0 * r0, 0*/ \ + "fmla v20.4s , %[w0].4s, v2.4s\n" /* outr1 = w0 * r0, 2*/ \ + "fmla v21.4s , %[w0].4s, v4.4s\n" /* outr2 = w0 * r0, 4*/ \ + "fmla v22.4s , %[w0].4s, v6.4s\n" /* outr3 = w0 * r0, 6*/ \ + "fmla v19.4s , %[w1].4s, v1.4s\n" /* outr0 = w1 * r0, 1*/ \ + "ldp q0, q1, [%[inr1]], #32\n" /* load input r1*/ \ + "fmla v20.4s , %[w1].4s, v3.4s\n" /* outr1 = w1 * r0, 3*/ \ + "fmla v21.4s , %[w1].4s, v5.4s\n" /* outr2 = w1 * r0, 5*/ \ + "fmla v22.4s , %[w1].4s, v7.4s\n" /* outr3 = w1 * r0, 7*/ \ + "fmla v19.4s , %[w2].4s, v2.4s\n" /* outr0 = w0 * r0, 2*/ \ + "ldp q2, q3, [%[inr1]], #32\n" /* load input r1*/ \ + "fmla v20.4s , %[w2].4s, v4.4s\n" /* outr1 = w0 * r0, 4*/ \ + "ldp q4, q5, [%[inr1]], #32\n" /* load input r1*/ \ + "fmla v21.4s , %[w2].4s, v6.4s\n" /* outr2 = w0 * r0, 6*/ \ + "ldp q6, q7, [%[inr1]], #32\n" /* load input r1*/ \ + "fmla v22.4s , %[w2].4s, v8.4s\n" /* outr3 = w0 * r0, 8*/ \ + "ldr q8, [%[inr1]]\n" /* load input r1*/ \ + "fmla v19.4s , %[w3].4s, v0.4s\n" /* outr0 = w3 * r1, 0*/ \ + "fmla v20.4s , %[w3].4s, v2.4s\n" /* outr1 = w3 * r1, 2*/ \ + "fmla v21.4s , %[w3].4s, v4.4s\n" /* outr2 = w3 * r1, 4*/ \ + "fmla v22.4s , %[w3].4s, v6.4s\n" /* outr3 = w3 * r1, 6*/ \ + "fmla v19.4s , %[w4].4s, v1.4s\n" /* outr0 = w4 * r1, 1*/ \ + "ldp q0, q1, [%[inr2]], #32\n" /* load input r2*/ \ + "fmla v20.4s , %[w4].4s, v3.4s\n" /* outr1 = w4 * r1, 3*/ \ + "fmla v21.4s , %[w4].4s, v5.4s\n" /* outr2 = w4 * r1, 5*/ \ + "fmla v22.4s , %[w4].4s, v7.4s\n" /* outr3 = w4 * r1, 7*/ \ + "fmla v19.4s , %[w5].4s, v2.4s\n" /* outr0 = w5 * r1, 2*/ \ + "ldp q2, q3, [%[inr2]], #32\n" /* load input r2*/ \ + "fmla v20.4s , %[w5].4s, v4.4s\n" /* outr1 = w5 * r1, 4*/ \ + "ldp q4, q5, [%[inr2]], #32\n" /* load input r2*/ \ + "fmla v21.4s , %[w5].4s, v6.4s\n" /* outr2 = w5 * r1, 6*/ \ + "ldp q6, q7, [%[inr2]], #32\n" /* load input r2*/ \ + "fmla v22.4s , %[w5].4s, v8.4s\n" /* outr3 = w5 * r1, 8*/ \ + "ldr q8, [%[inr2]]\n" /* load input r2*/ \ + "fmla v19.4s , %[w6].4s, v0.4s\n" /* outr0 = w6 * r2, 0*/ \ + "fmla v20.4s , %[w6].4s, v2.4s\n" /* outr1 = w6 * r2, 2*/ \ + "fmla v21.4s , %[w6].4s, v4.4s\n" /* outr2 = w6 * r2, 4*/ \ + "fmla v22.4s , %[w6].4s, v6.4s\n" /* outr3 = w6 * r2, 6*/ \ + "fmla v19.4s , %[w7].4s, v1.4s\n" /* outr0 = w7 * r2, 1*/ \ + "fmla v20.4s , %[w7].4s, v3.4s\n" /* outr1 = w7 * r2, 3*/ \ + "fmla v21.4s , %[w7].4s, v5.4s\n" /* outr2 = w7 * r2, 5*/ \ + "fmla v22.4s , %[w7].4s, v7.4s\n" /* outr3 = w7 * r2, 7*/ \ + "fmla v19.4s , %[w8].4s, v2.4s\n" /* outr0 = w8 * r2, 2*/ \ + "fmla v20.4s , %[w8].4s, v4.4s\n" /* outr1 = w8 * r2, 4*/ \ + "fmla v21.4s , %[w8].4s, v6.4s\n" /* outr2 = w8 * r2, 6*/ \ + "fmla v22.4s , %[w8].4s, v8.4s\n" /* outr3 = w8 * r2, 8*/ \ + "trn1 v0.4s, v19.4s, v20.4s\n" /* r0: a0a1c0c1*/ \ + "trn2 v1.4s, v19.4s, v20.4s\n" /* r0: b0b1d0d1*/ \ + "trn1 v2.4s, v21.4s, v22.4s\n" /* r0: a2a3c2c3*/ \ + "trn2 v3.4s, v21.4s, v22.4s\n" /* r0: b2b3d2d3*/ \ + "trn1 v19.2d, v0.2d, v2.2d\n" /* r0: a0a1a2a3*/ \ + "trn2 v21.2d, v0.2d, v2.2d\n" /* r0: c0c1c2c3*/ \ + "trn1 v20.2d, v1.2d, v3.2d\n" /* r0: b0b1b2b3*/ \ + "trn2 v22.2d, v1.2d, v3.2d\n" /* r0: d0d1d2d3*/ +#define RELU /* relu */ \ + "movi v0.4s, #0\n" /* for relu */ \ + "fmax v19.4s, v19.4s, v0.4s\n" \ + "fmax v20.4s, v20.4s, v0.4s\n" \ + "fmax v21.4s, v21.4s, v0.4s\n" \ + "fmax v22.4s, v22.4s, v0.4s\n" +#define RELU6 /* relu6 */ \ + "fmin v19.4s, v19.4s, %[vsix].4s\n" \ + "fmin v20.4s, v20.4s, %[vsix].4s\n" \ + "fmin v21.4s, v21.4s, %[vsix].4s\n" \ + "fmin v22.4s, v22.4s, %[vsix].4s\n" +#define LEAKY_RELU /* LeakyRelu */ \ + "movi v0.4s, #0\n" /* for relu */ \ + "fcmge v1.4s, v19.4s, v0.4s \n" /* vcgeq_u32 */ \ + "fmul v2.4s, v19.4s, %[vscale].4s \n" /* mul */ \ + "fcmge v3.4s, v20.4s, v0.4s \n" /* vcgeq_u32 */ \ + "fmul v4.4s, v20.4s, %[vscale].4s \n" /* mul */ \ + "fcmge v5.4s, v21.4s, v0.4s \n" /* vcgeq_u32 */ \ + "fmul v6.4s, v21.4s, %[vscale].4s \n" /* mul */ \ + "fcmge v7.4s, v22.4s, v0.4s \n" /* vcgeq_u32 */ \ + "fmul v8.4s, v22.4s, %[vscale].4s \n" /* mul */ \ + "bif v19.16b, v2.16b, v1.16b \n" /* choose*/ \ + "bif v19.16b, v4.16b, v3.16b \n" /* choose*/ \ + "bif v19.16b, v6.16b, v5.16b \n" /* choose*/ \ + "bif v19.16b, v8.16b, v7.16b \n" /* choose*/ +#define STORE /* save result */ \ + "str q19, [%[outc0]], #16\n" \ + "str q20, [%[outc1]], #16\n" \ + "str q21, [%[outc2]], #16\n" \ + "str q22, [%[outc3]], #16\n" + +#else +#define COMPUTE \ + /* fill with bias */ \ + "vld1.32 {d16-d17}, [%[bias]]\n" /* load bias */ /* load weights */ \ + "vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w0-2, to q9-11 */ \ + "vld1.32 {d0-d3}, [%[r0]]!\n" /* load input r0, 0,1*/ \ + "vand.i32 q12, q8, q8\n" \ + "vld1.32 {d4-d7}, [%[r0]]!\n" /* load input r0, 2,3*/ \ + "vand.i32 q13, q8, q8\n" \ + "vld1.32 {d8-d11}, [%[r0]]!\n" /* load input r0, 4,5*/ \ + "vand.i32 q14, q8, q8\n" \ + "vld1.32 {d12-d15}, [%[r0]]!\n" /* load input r0, 6,7*/ \ + "vand.i32 q15, q8, q8\n" \ + "vld1.32 {d16-d17}, [%[r0]]\n" /* load input r0, 8*/ \ + "vmla.f32 q12, q9, q0 @ w0 * inr0\n" \ + "vmla.f32 q13, q9, q2 @ w0 * inr2\n" \ + "vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w2, to q11 */ \ + "vmla.f32 q14, q9, q4 @ w0 * inr4\n" \ + "vmla.f32 q15, q9, q6 @ w0 * inr6\n" \ + "vmla.f32 q12, q10, q1 @ w1 * inr1\n" \ + "vld1.32 {d0-d3}, [%[r1]]! @ load r1, 0, 1\n" \ + "vmla.f32 q13, q10, q3 @ w1 * inr3\n" \ + "vmla.f32 q14, q10, q5 @ w1 * inr5\n" \ + "vmla.f32 q15, q10, q7 @ w1 * inr7\n" \ + "vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w3-4, to q9-10 */ \ + "vmla.f32 q12, q11, q2 @ w2 * inr2\n" \ + "vld1.32 {d4-d7}, [%[r1]]! @ load r1, 2, 3\n" \ + "vmla.f32 q13, q11, q4 @ w2 * inr4\n" \ + "vld1.32 {d8-d11}, [%[r1]]! @ load r1, 4, 5\n" \ + "vmla.f32 q14, q11, q6 @ w2 * inr6\n" \ + "vld1.32 {d12-d15}, [%[r1]]! @ load r1, 6, 7\n" \ + "vmla.f32 q15, q11, q8 @ w2 * inr8\n" /* mul r1 with w3, w4*/ \ + "vmla.f32 q12, q9, q0 @ w3 * inr0\n" \ + "vmla.f32 q13, q9, q2 @ w3 * inr2\n" \ + "vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w5, to q11 */ \ + "vmla.f32 q14, q9, q4 @ w3 * inr4\n" \ + "vmla.f32 q15, q9, q6 @ w3 * inr6\n" \ + "vld1.32 {d16-d17}, [%[r1]]\n" /* load input r1, 8*/ \ + "vmla.f32 q12, q10, q1 @ w4 * inr1\n" \ + "vld1.32 {d0-d3}, [%[r2]]! @ load r2, 0, 1\n" \ + "vmla.f32 q13, q10, q3 @ w4 * inr3\n" \ + "vmla.f32 q14, q10, q5 @ w4 * inr5\n" \ + "vmla.f32 q15, q10, q7 @ w4 * inr7\n" \ + "vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w6-7, to q9-10 */ \ + "vmla.f32 q12, q11, q2 @ w5 * inr2\n" \ + "vld1.32 {d4-d7}, [%[r2]]! @ load r2, 2, 3\n" \ + "vmla.f32 q13, q11, q4 @ w5 * inr4\n" \ + "vld1.32 {d8-d11}, [%[r2]]! @ load r2, 4, 5\n" \ + "vmla.f32 q14, q11, q6 @ w5 * inr6\n" \ + "vld1.32 {d12-d15}, [%[r2]]! @ load r2, 6, 7\n" \ + "vmla.f32 q15, q11, q8 @ w5 * inr8\n" /* mul r2 with w6, w7*/ \ + "vmla.f32 q12, q9, q0 @ w6 * inr0\n" \ + "vmla.f32 q13, q9, q2 @ w6 * inr2\n" \ + "vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w8, to q11 */ \ + "vmla.f32 q14, q9, q4 @ w6 * inr4\n" \ + "vmla.f32 q15, q9, q6 @ w6 * inr6\n" \ + "vld1.32 {d16-d17}, [%[r2]]\n" /* load input r2, 8*/ \ + "vmla.f32 q12, q10, q1 @ w7 * inr1\n" \ + "vmla.f32 q13, q10, q3 @ w7 * inr3\n" \ + "vmla.f32 q14, q10, q5 @ w7 * inr5\n" \ + "vmla.f32 q15, q10, q7 @ w7 * inr7\n" \ + "sub %[wc0], %[wc0], #144 @ wc0 - 144 to start address\n" \ + "vmla.f32 q12, q11, q2 @ w8 * inr2\n" \ + "vmla.f32 q13, q11, q4 @ w8 * inr4\n" \ + "vmla.f32 q14, q11, q6 @ w8 * inr6\n" \ + "vmla.f32 q15, q11, q8 @ w8 * inr8\n" /* transpose */ \ + "vtrn.32 q12, q13\n" /* a0a1c0c1, b0b1d0d1*/ \ + "vtrn.32 q14, q15\n" /* a2a3c2c3, b2b3d2d3*/ \ + "vswp d25, d28\n" /* a0a1a2a3, c0c1c2c3*/ \ + "vswp d27, d30\n" /* b0b1b2b3, d0d1d2d3*/ +#define RELU /* relu */ \ + "vmov.u32 q0, #0\n" \ + "vld1.32 {d2-d3}, [%[six_ptr]]\n" \ + "vmax.f32 q12, q12, q0\n" \ + "vmax.f32 q13, q13, q0\n" \ + "vmax.f32 q14, q14, q0\n" \ + "vmax.f32 q15, q15, q0\n" +#define RELU6 /* relu6 */ \ + "vmin.f32 q12, q12, q1\n" \ + "vmin.f32 q13, q13, q1\n" \ + "vmin.f32 q14, q14, q1\n" \ + "vmin.f32 q15, q15, q1\n" +#define LEAKY_RELU /* LeakyRelu */ \ + "vmov.u32 q0, #0\n" \ + "vld1.32 {d2-d3}, [%[scale_ptr]]\n" \ + "vcge.f32 q2, q12, q0 @ q0 > 0 \n" \ + "vcge.f32 q4, q13, q0 @ q0 > 0 \n" \ + "vcge.f32 q6, q14, q0 @ q0 > 0 \n" \ + "vcge.f32 q8, q15, q0 @ q0 > 0 \n" \ + "vmul.f32 q3, q12, q1 @ mul \n" \ + "vmul.f32 q5, q13, q1 @ mul \n" \ + "vmul.f32 q7, q14, q1 @ mul \n" \ + "vmul.f32 q9, q15, q1 @ mul \n" \ + "vbif q12, q3, q2 @ choose \n" \ + "vbif q13, q5, q4 @ choose \n" \ + "vbif q14, q7, q6 @ choose \n" \ + "vbif q15, q9, q8 @ choose \n" +#define STORE /* save result */ \ + "vst1.32 {d24-d25}, [%[outc0]]!\n" /* save outc0*/ \ + "vst1.32 {d26-d27}, [%[outc1]]!\n" /* save outc1*/ \ + "vst1.32 {d28-d29}, [%[outc2]]!\n" /* save outc2*/ \ + "vst1.32 {d30-d31}, [%[outc3]]!\n" /* save outc3*/ + +#endif + +void act_switch_3x3s2(const float* inr0, + const float* inr1, + const float* inr2, + float* outc0, + float* outc1, + float* outc2, + float* outc3, + const float* weight_c, + float* bias_local, + float32x4_t w0, + float32x4_t w1, + float32x4_t w2, + float32x4_t w3, + float32x4_t w4, + float32x4_t w5, + float32x4_t w6, + float32x4_t w7, + float32x4_t w8, + const operators::ActivationParam act_param) { + bool has_active = act_param.has_active; + if (has_active) { + float tmp = act_param.Relu_clipped_coef; + float ss = act_param.Leaky_relu_alpha; +#ifdef __aarch64__ + float32x4_t vsix = vdupq_n_f32(tmp); + float32x4_t vscale = vdupq_n_f32(ss); +#else + float vsix[4] = {tmp, tmp, tmp, tmp}; + float vscale[4] = {ss, ss, ss, ss}; +#endif + switch (act_param.active_type) { + case lite_api::ActivationType::kRelu: +#ifdef __aarch64__ + asm volatile(COMPUTE RELU STORE + : [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [w5] "w"(w5), + [w6] "w"(w6), + [w7] "w"(w7), + [w8] "w"(w8), + [bias] "r"(bias_local) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v19", + "v20", + "v21", + "v22"); +#else + asm volatile(COMPUTE RELU STORE + : [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [wc0] "+r"(weight_c), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3) + : [bias] "r"(bias_local), [six_ptr] "r"(vsix) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + break; + case lite_api::ActivationType::kRelu6: +#ifdef __aarch64__ + asm volatile(COMPUTE RELU RELU6 STORE + : [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [w5] "w"(w5), + [w6] "w"(w6), + [w7] "w"(w7), + [w8] "w"(w8), + [bias] "r"(bias_local), + [vsix] "w"(vsix) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v19", + "v20", + "v21", + "v22"); +#else + asm volatile(COMPUTE RELU RELU6 STORE + : [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [wc0] "+r"(weight_c), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3) + : [bias] "r"(bias_local), [six_ptr] "r"(vsix) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + break; + case lite_api::ActivationType::kLeakyRelu: +#ifdef __aarch64__ + asm volatile(COMPUTE LEAKY_RELU STORE + : [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [w5] "w"(w5), + [w6] "w"(w6), + [w7] "w"(w7), + [w8] "w"(w8), + [bias] "r"(bias_local), + [vscale] "w"(vscale) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v19", + "v20", + "v21", + "v22"); +#else + asm volatile(COMPUTE LEAKY_RELU STORE + : [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [wc0] "+r"(weight_c), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3) + : [bias] "r"(bias_local), [scale_ptr] "r"(vscale) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param.active_type) + << " fuse not support"; + } + } else { +#ifdef __aarch64__ + asm volatile(COMPUTE STORE + : [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [w5] "w"(w5), + [w6] "w"(w6), + [w7] "w"(w7), + [w8] "w"(w8), + [bias] "r"(bias_local) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v19", + "v20", + "v21", + "v22"); +#else + asm volatile(COMPUTE STORE + : [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [wc0] "+r"(weight_c), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3) + : [bias] "r"(bias_local) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + } +} + +void conv_3x3s2_depthwise_fp32(const float* i_data, + float* o_data, + int bs, + int oc, + int oh, + int ow, + int ic, + int ih, + int win, + const float* weights, + const float* bias, + const operators::ConvParam& param, + const operators::ActivationParam act_param, + ARMContext* ctx) { + auto paddings = *param.paddings; + int threads = ctx->threads(); + const int pad_h = paddings[0]; + const int pad_w = paddings[2]; + const int out_c_block = 4; + const int out_h_kernel = 1; + const int out_w_kernel = 4; + const int win_ext = ow * 2 + 1; + const int ow_round = ROUNDUP(ow, 4); + const int win_round = ROUNDUP(win_ext, 4); + const int hin_round = oh * 2 + 1; + const int prein_size = win_round * hin_round * out_c_block; + auto workspace_size = threads * prein_size + win_round + ow_round; + ctx->ExtendWorkspace(sizeof(float) * workspace_size); + + bool flag_bias = param.bias != nullptr; + + /// get workspace + auto ptr_zero = ctx->workspace_data(); + memset(ptr_zero, 0, sizeof(float) * win_round); + float* ptr_write = ptr_zero + win_round; + + int size_in_channel = win * ih; + int size_out_channel = ow * oh; + + int ws = -pad_w; + int we = ws + win_round; + int hs = -pad_h; + int he = hs + hin_round; + int w_loop = ow_round / 4; + auto remain = w_loop * 4 - ow; + bool flag_remain = remain > 0; + remain = 4 - remain; + remain = remain > 0 ? remain : 0; + int row_len = win_round * out_c_block; + + float32x4_t vzero = vdupq_n_f32(0.f); + + for (int n = 0; n < bs; ++n) { + const float* din_batch = i_data + n * ic * size_in_channel; + float* dout_batch = o_data + n * oc * size_out_channel; +#pragma omp parallel for num_threads(threads) + for (int c = 0; c < oc; c += out_c_block) { +#ifdef ARM_WITH_OMP + float* pre_din = ptr_write + ow_round + omp_get_thread_num() * prein_size; +#else + float* pre_din = ptr_write + ow_round; +#endif + /// const array size + prepack_input_nxwc4_dw( + din_batch, pre_din, c, hs, he, ws, we, ic, win, ih, ptr_zero); + const float* weight_c = weights + c * 9; // kernel_w * kernel_h + float* dout_c00 = dout_batch + c * size_out_channel; + float bias_local[4] = {0, 0, 0, 0}; + if (flag_bias) { + bias_local[0] = bias[c]; + bias_local[1] = bias[c + 1]; + bias_local[2] = bias[c + 2]; + bias_local[3] = bias[c + 3]; + } +#ifdef __aarch64__ + float32x4_t w0 = vld1q_f32(weight_c); // w0, v23 + float32x4_t w1 = vld1q_f32(weight_c + 4); // w1, v24 + float32x4_t w2 = vld1q_f32(weight_c + 8); // w2, v25 + float32x4_t w3 = vld1q_f32(weight_c + 12); // w3, v26 + float32x4_t w4 = vld1q_f32(weight_c + 16); // w4, v27 + float32x4_t w5 = vld1q_f32(weight_c + 20); // w5, v28 + float32x4_t w6 = vld1q_f32(weight_c + 24); // w6, v29 + float32x4_t w7 = vld1q_f32(weight_c + 28); // w7, v30 + float32x4_t w8 = vld1q_f32(weight_c + 32); // w8, v31 +#endif + for (int h = 0; h < oh; h += out_h_kernel) { + float* outc0 = dout_c00 + h * ow; + float* outc1 = outc0 + size_out_channel; + float* outc2 = outc1 + size_out_channel; + float* outc3 = outc2 + size_out_channel; + const float* inr0 = pre_din + h * 2 * row_len; + const float* inr1 = inr0 + row_len; + const float* inr2 = inr1 + row_len; + if (c + out_c_block > oc) { + switch (c + out_c_block - oc) { + case 3: + outc1 = ptr_write; + case 2: + outc2 = ptr_write; + case 1: + outc3 = ptr_write; + default: + break; + } + } + auto c0 = outc0; + auto c1 = outc1; + auto c2 = outc2; + auto c3 = outc3; + float pre_out[16]; + for (int w = 0; w < w_loop; ++w) { + bool flag_mask = (w == w_loop - 1) && flag_remain; + if (flag_mask) { + c0 = outc0; + c1 = outc1; + c2 = outc2; + c3 = outc3; + outc0 = pre_out; + outc1 = pre_out + 4; + outc2 = pre_out + 8; + outc3 = pre_out + 12; + } +#ifdef __aarch64__ + act_switch_3x3s2(inr0, + inr1, + inr2, + outc0, + outc1, + outc2, + outc3, + weight_c, + bias_local, + w0, + w1, + w2, + w3, + w4, + w5, + w6, + w7, + w8, + act_param); +#else + act_switch_3x3s2(inr0, + inr1, + inr2, + outc0, + outc1, + outc2, + outc3, + weight_c, + bias_local, + vzero, + vzero, + vzero, + vzero, + vzero, + vzero, + vzero, + vzero, + vzero, + act_param); +#endif + if (flag_mask) { + for (int i = 0; i < remain; ++i) { + c0[i] = pre_out[i]; + c1[i] = pre_out[i + 4]; + c2[i] = pre_out[i + 8]; + c3[i] = pre_out[i + 12]; + } + } + inr0 += 32; + inr1 += 32; + inr2 += 32; + outc0 += 4; + outc1 += 4; + outc2 += 4; + outc3 += 4; + } + } + } + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/backends/arm/math/conv5x5s1_depthwise_fp32.cc b/lite/backends/arm/math/conv5x5s1_depthwise_fp32.cc index 1a2e42e0a9ca4193be84a21247112de8cdc144a1..6125547b8ba611d016d5d85359a4138b0ede7607 100644 --- a/lite/backends/arm/math/conv5x5s1_depthwise_fp32.cc +++ b/lite/backends/arm/math/conv5x5s1_depthwise_fp32.cc @@ -13,9602 +13,750 @@ // limitations under the License. #include +#include "lite/backends/arm/math/conv_block_utils.h" #include "lite/backends/arm/math/conv_depthwise.h" +#include "lite/core/context.h" +#include "lite/operators/op_params.h" +#ifdef ARM_WITH_OMP +#include +#endif namespace paddle { namespace lite { namespace arm { namespace math { -//! weights layout -//! *-----------------------*-----* -//! w0 <-- | W0 W1 W2 W3 | W4 | -//! *-----------------------* | -//! w1 <-- | W5 W6 W7 W8 | W9 | -//! *-----------------------* | --> w5 -//! w2 <-- | W10 W11 W12 W13 | W14 | -//! *-----------------------* | -//! w3 <-- | W15 W16 W17 W18 | W19 | -//! *-----------------------*-----* -//! w4 <-- | W20 W21 W22 W23 | W24 | --> w6[0] -//! *-----------------------*-----* - -void conv_depthwise_5x5s1_impl(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - int pad, - bool flag_bias, - bool flag_relu, - ARMContext* ctx); - -void conv_depthwise_5x5s1_small_impl(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - int pad, - bool flag_bias, - bool flag_relu, - ARMContext* ctx); - -void conv_depthwise_5x5s1_relu_impl(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - int pad, - bool flag_bias, - bool flag_relu, - ARMContext* ctx); - -void conv_depthwise_5x5s1_small_relu_impl(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - int pad, - bool flag_bias, - bool flag_relu, - ARMContext* ctx); - -static float* prepad_input( - const float* input, int num, int ch_in, int h_in, int w_in, int pad) { - int h_new = h_in + 2 * pad; - int w_new = w_in + 2 * pad; - float* new_input = - static_cast(malloc(h_new * w_new * ch_in * num * sizeof(float))); - float* new_input_ptr = new_input; - for (int c = 0; c < num * ch_in; ++c) { - memset(new_input_ptr, 0x00, w_new * pad * sizeof(float)); - new_input_ptr += w_new * pad; - for (int i = 0; i < h_in; ++i) { - memset(new_input_ptr, 0x00, pad * sizeof(float)); - new_input_ptr += pad; - memcpy(new_input_ptr, input, w_in * sizeof(float)); - new_input_ptr += w_in; - input += w_in; - memset(new_input_ptr, 0x00, pad * sizeof(float)); - new_input_ptr += pad; - } - memset(new_input_ptr, 0x00, w_new * pad * sizeof(float)); - new_input_ptr += w_new * pad; - } - return new_input; -} - -#ifdef __aarch64__ - -//! kernel for one out without extracting data mid -//! deal with four lines out -void compute_one_out_without_extract(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - const float* din6, - const float* din7, - float* dout0, - float* dout1, - float* dout2, - float* dout3, - float32x4_t w0, - float32x4_t w1, - float32x4_t w2, - float32x4_t w3, - float32x4_t w4, - float32x4_t w5, - float32x4_t w6, - const float* bias) { - //! din0 - din7: 0-4 v8-v15 - //! din0 - din7: 5 v20, v21 - //! dout0 - dout3: v16-v19 - asm volatile( - "ld1 {v8.4s}, [%[din0]], #16 \n" - "ld1 {v9.4s}, [%[din1]], #16 \n" - "ld1 {v10.4s}, [%[din2]], #16 \n" - "ld1 {v11.4s}, [%[din3]], #16 \n" - "ld1 {v12.4s}, [%[din4]], #16 \n" - "ld1 {v13.4s}, [%[din5]], #16 \n" - - // in row0 - "fmul v16.4s, %[w0].4s, v8.4s \n" - "fmul v17.4s, %[w0].4s, v9.4s \n" - "fmul v18.4s, %[w0].4s, v10.4s \n" - "fmul v19.4s, %[w0].4s, v11.4s \n" - - "ld1 {v14.4s}, [%[din6]], #16 \n" - "ld1 {v15.4s}, [%[din7]], #16 \n" - - // in row1 - "fmla v16.4s, %[w1].4s, v9.4s \n" - "fmla v17.4s, %[w1].4s, v10.4s \n" - "fmla v18.4s, %[w1].4s, v11.4s \n" - "fmla v19.4s, %[w1].4s, v12.4s \n" - - "ld1 {v20.s}[0], [%[din0]] \n" - "ld1 {v21.s}[0], [%[din4]] \n" - "ld1 {v20.s}[1], [%[din1]] \n" - "ld1 {v21.s}[1], [%[din5]] \n" - - // in row2 - "fmla v16.4s, %[w2].4s, v10.4s \n" - "fmla v17.4s, %[w2].4s, v11.4s \n" - "fmla v18.4s, %[w2].4s, v12.4s \n" - "fmla v19.4s, %[w2].4s, v13.4s \n" - - "ld1 {v20.s}[2], [%[din2]] \n" - "ld1 {v21.s}[2], [%[din6]] \n" - "ld1 {v20.s}[3], [%[din3]] \n" - "ld1 {v21.s}[3], [%[din7]] \n" - - // in row3 - "fmla v16.4s, %[w3].4s, v11.4s \n" - "fmla v17.4s, %[w3].4s, v12.4s \n" - "fmla v18.4s, %[w3].4s, v13.4s \n" - "fmla v19.4s, %[w3].4s, v14.4s \n" - - // in row4 - "fmla v16.4s, %[w4].4s, v12.4s \n" - "fmla v17.4s, %[w4].4s, v13.4s \n" - "fmla v18.4s, %[w4].4s, v14.4s \n" - "fmla v19.4s, %[w4].4s, v15.4s \n" - - // ext - "ext v22.16b, v20.16b, v21.16b, #4 \n" // 1 2 3 4 - "ext v23.16b, v20.16b, v21.16b, #8 \n" // 2 3 4 5 - "ext v24.16b, v20.16b, v21.16b, #12 \n" // 3 4 5 6 - - // in col5 - "fmla v16.4s, %[w5].4s, v20.4s \n" - "fmla v17.4s, %[w5].4s, v22.4s \n" - "fmla v18.4s, %[w5].4s, v23.4s \n" - "fmla v19.4s, %[w5].4s, v24.4s \n" - - "ld1 {v31.4s}, [%[bias]] \n" - - // add to out register v25 - "faddp v25.4s, v16.4s, v17.4s \n" - "faddp v26.4s, v18.4s, v19.4s \n" - "faddp v25.4s, v25.4s, v26.4s \n" - - // in[24] * w6[0] - "fmla v25.4s, v21.4s, %[w6].s[0]\n" - "fadd v25.4s, v25.4s, v31.4s \n" - - // write output - "st1 {v25.s}[0], [%[dout0]] \n" - "st1 {v25.s}[1], [%[dout1]] \n" - "st1 {v25.s}[2], [%[dout2]] \n" - "st1 {v25.s}[3], [%[dout3]] \n" - : [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [din6] "+r"(din6), - [din7] "+r"(din7) - : [dout0] "r"(dout0), - [dout1] "r"(dout1), - [dout2] "r"(dout2), - [dout3] "r"(dout3), - [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [w5] "w"(w5), - [w6] "w"(w6), - [bias] "r"(bias) - : "memory", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25", - "v26", - "v31"); -} - -//! kernel for one out without extracting data mid -//! deal with four lines out -void compute_one_out_without_extract_relu(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - const float* din6, - const float* din7, - float* dout0, - float* dout1, - float* dout2, - float* dout3, - float32x4_t w0, - float32x4_t w1, - float32x4_t w2, - float32x4_t w3, - float32x4_t w4, - float32x4_t w5, - float32x4_t w6, - const float* bias) { - //! din0 - din7: 0-4 v8-v15 - //! din0 - din7: 5 v20, v21 - //! dout0 - dout3: v16-v19 - asm volatile( - "ld1 {v8.4s}, [%[din0]], #16 \n" - "ld1 {v9.4s}, [%[din1]], #16 \n" - "ld1 {v10.4s}, [%[din2]], #16 \n" - "ld1 {v11.4s}, [%[din3]], #16 \n" - "ld1 {v12.4s}, [%[din4]], #16 \n" - "ld1 {v13.4s}, [%[din5]], #16 \n" - - // in row0 - "fmul v16.4s, %[w0].4s, v8.4s \n" - "fmul v17.4s, %[w0].4s, v9.4s \n" - "fmul v18.4s, %[w0].4s, v10.4s \n" - "fmul v19.4s, %[w0].4s, v11.4s \n" - - "ld1 {v14.4s}, [%[din6]], #16 \n" - "ld1 {v15.4s}, [%[din7]], #16 \n" - - // in row1 - "fmla v16.4s, %[w1].4s, v9.4s \n" - "fmla v17.4s, %[w1].4s, v10.4s \n" - "fmla v18.4s, %[w1].4s, v11.4s \n" - "fmla v19.4s, %[w1].4s, v12.4s \n" - - "ld1 {v20.s}[0], [%[din0]] \n" - "ld1 {v21.s}[0], [%[din4]] \n" - "ld1 {v20.s}[1], [%[din1]] \n" - "ld1 {v21.s}[1], [%[din5]] \n" - - // in row2 - "fmla v16.4s, %[w2].4s, v10.4s \n" - "fmla v17.4s, %[w2].4s, v11.4s \n" - "fmla v18.4s, %[w2].4s, v12.4s \n" - "fmla v19.4s, %[w2].4s, v13.4s \n" - - "ld1 {v20.s}[2], [%[din2]] \n" - "ld1 {v21.s}[2], [%[din6]] \n" - "ld1 {v20.s}[3], [%[din3]] \n" - "ld1 {v21.s}[3], [%[din7]] \n" - - // in row3 - "fmla v16.4s, %[w3].4s, v11.4s \n" - "fmla v17.4s, %[w3].4s, v12.4s \n" - "fmla v18.4s, %[w3].4s, v13.4s \n" - "fmla v19.4s, %[w3].4s, v14.4s \n" - - // in row4 - "fmla v16.4s, %[w4].4s, v12.4s \n" - "fmla v17.4s, %[w4].4s, v13.4s \n" - "fmla v18.4s, %[w4].4s, v14.4s \n" - "fmla v19.4s, %[w4].4s, v15.4s \n" - - // ext - "ext v22.16b, v20.16b, v21.16b, #4 \n" // 1 2 3 4 - "ext v23.16b, v20.16b, v21.16b, #8 \n" // 2 3 4 5 - "ext v24.16b, v20.16b, v21.16b, #12 \n" // 3 4 5 6 - - // in col5 - "fmla v16.4s, %[w5].4s, v20.4s \n" - "fmla v17.4s, %[w5].4s, v22.4s \n" - "fmla v18.4s, %[w5].4s, v23.4s \n" - "fmla v19.4s, %[w5].4s, v24.4s \n" - - "ld1 {v31.4s}, [%[bias]] \n" - "movi v30.4s, #0 \n" - - // add to out register v25 - "faddp v25.4s, v16.4s, v17.4s \n" - "faddp v26.4s, v18.4s, v19.4s \n" - "faddp v25.4s, v25.4s, v26.4s \n" - - // in[24] * w6[0] - "fmla v25.4s, v21.4s, %[w6].s[0] \n" - "fadd v25.4s, v25.4s, v31.4s \n" - "fmax v25.4s, v25.4s, v30.4s \n" - - // write output - "st1 {v25.s}[0], [%[dout0]] \n" - "st1 {v25.s}[1], [%[dout1]] \n" - "st1 {v25.s}[2], [%[dout2]] \n" - "st1 {v25.s}[3], [%[dout3]] \n" - - : [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [din6] "+r"(din6), - [din7] "+r"(din7) - : [dout0] "r"(dout0), - [dout1] "r"(dout1), - [dout2] "r"(dout2), - [dout3] "r"(dout3), - [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [w5] "w"(w5), - [w6] "w"(w6), - [bias] "r"(bias) - : "memory", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25", - "v26", - "v30", - "v31"); -} - -//! kernel for one out with extracting data pre -//! deal with four lines out -//! need extra load weights -void compute_one_out_extract_pre(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - const float* din6, - const float* din7, - float* dout0, - float* dout1, - float* dout2, - float* dout3, - const float* weights, - const float* bias) { - //! din0 - din7: 0-4 v8-v15 - //! dout0 - dout3: v16-v19 - //! weights: v0-v4 - asm volatile( - // load weights - "add %[wh], %[wh], #4 \n" - "ldr q0, [%[wh]], #20 \n" - "ldr q1, [%[wh]], #20 \n" - "ldr q2, [%[wh]], #20 \n" - "ldr q3, [%[wh]], #20 \n" - "ldr q4, [%[wh]], #20 \n" - - "ld1 {v31.4s}, [%[bias]] \n" - "ld1 {v8.4s}, [%[din0]], #16 \n" - "ld1 {v9.4s}, [%[din1]], #16 \n" - "ld1 {v10.4s}, [%[din2]], #16 \n" - "ld1 {v11.4s}, [%[din3]], #16 \n" - "ld1 {v12.4s}, [%[din4]], #16 \n" - "ld1 {v13.4s}, [%[din5]], #16 \n" - - // in row0 - "fmul v16.4s, v0.4s, v8.4s \n" - "fmul v17.4s, v0.4s, v9.4s \n" - "fmul v18.4s, v0.4s, v10.4s \n" - "fmul v19.4s, v0.4s, v11.4s \n" - - "ld1 {v14.4s}, [%[din6]], #16 \n" - "ld1 {v15.4s}, [%[din7]], #16 \n" - - // in row1 - "fmla v16.4s, v1.4s, v9.4s \n" - "fmla v17.4s, v1.4s, v10.4s \n" - "fmla v18.4s, v1.4s, v11.4s \n" - "fmla v19.4s, v1.4s, v12.4s \n" - - // in row2 - "fmla v16.4s, v2.4s, v10.4s \n" - "fmla v17.4s, v2.4s, v11.4s \n" - "fmla v18.4s, v2.4s, v12.4s \n" - "fmla v19.4s, v2.4s, v13.4s \n" - - // in row3 - "fmla v16.4s, v3.4s, v11.4s \n" - "fmla v17.4s, v3.4s, v12.4s \n" - "fmla v18.4s, v3.4s, v13.4s \n" - "fmla v19.4s, v3.4s, v14.4s \n" - - // in row4 - "fmla v16.4s, v4.4s, v12.4s \n" - "fmla v17.4s, v4.4s, v13.4s \n" - "fmla v18.4s, v4.4s, v14.4s \n" - "fmla v19.4s, v4.4s, v15.4s \n" - - // add to out register v25 - "faddp v25.4s, v16.4s, v17.4s \n" - "faddp v26.4s, v18.4s, v19.4s \n" - "faddp v25.4s, v25.4s, v26.4s \n" - "fadd v25.4s, v25.4s, v31.4s \n" - - // write output - "st1 {v25.s}[0], [%[dout0]] \n" - "st1 {v25.s}[1], [%[dout1]] \n" - "st1 {v25.s}[2], [%[dout2]] \n" - "st1 {v25.s}[3], [%[dout3]] \n" - - : [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [din6] "+r"(din6), - [din7] "+r"(din7), - [wh] "+r"(weights) - : [dout0] "r"(dout0), - [dout1] "r"(dout1), - [dout2] "r"(dout2), - [dout3] "r"(dout3), - [bias] "r"(bias) - : "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v25", - "v26", - "v31"); -} - -//! kernel for one out with extracting data pre -//! deal with four lines out -//! need extra load weights -void compute_one_out_extract_pre_relu(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - const float* din6, - const float* din7, - float* dout0, - float* dout1, - float* dout2, - float* dout3, - const float* weights, - const float* bias) { - //! din0 - din7: 0-4 v8-v15 - //! dout0 - dout3: v16-v19 - //! weights: v0-v4 - asm volatile( - // load weights - "add %[wh], %[wh], #4 \n" - "ldr q0, [%[wh]], #20 \n" - "ldr q1, [%[wh]], #20 \n" - "ldr q2, [%[wh]], #20 \n" - "ldr q3, [%[wh]], #20 \n" - "ldr q4, [%[wh]], #20 \n" - - "ld1 {v8.4s}, [%[din0]], #16 \n" - "ld1 {v9.4s}, [%[din1]], #16 \n" - "ld1 {v10.4s}, [%[din2]], #16 \n" - "ld1 {v11.4s}, [%[din3]], #16 \n" - "ld1 {v12.4s}, [%[din4]], #16 \n" - "ld1 {v13.4s}, [%[din5]], #16 \n" - - // in row0 - "fmul v16.4s, v0.4s, v8.4s \n" - "fmul v17.4s, v0.4s, v9.4s \n" - "fmul v18.4s, v0.4s, v10.4s \n" - "fmul v19.4s, v0.4s, v11.4s \n" - - "ld1 {v14.4s}, [%[din6]], #16 \n" - "ld1 {v15.4s}, [%[din7]], #16 \n" - - // in row1 - "fmla v16.4s, v1.4s, v9.4s \n" - "fmla v17.4s, v1.4s, v10.4s \n" - "fmla v18.4s, v1.4s, v11.4s \n" - "fmla v19.4s, v1.4s, v12.4s \n" - - // in row2 - "fmla v16.4s, v2.4s, v10.4s \n" - "fmla v17.4s, v2.4s, v11.4s \n" - "fmla v18.4s, v2.4s, v12.4s \n" - "fmla v19.4s, v2.4s, v13.4s \n" - - // in row3 - "fmla v16.4s, v3.4s, v11.4s \n" - "fmla v17.4s, v3.4s, v12.4s \n" - "fmla v18.4s, v3.4s, v13.4s \n" - "fmla v19.4s, v3.4s, v14.4s \n" - - "ld1 {v31.4s}, [%[bias]] \n" - "movi v30.4s, #0 \n" - - // in row4 - "fmla v16.4s, v4.4s, v12.4s \n" - "fmla v17.4s, v4.4s, v13.4s \n" - "fmla v18.4s, v4.4s, v14.4s \n" - "fmla v19.4s, v4.4s, v15.4s \n" - - // add to out register v25 - "faddp v25.4s, v16.4s, v17.4s \n" - "faddp v26.4s, v18.4s, v19.4s \n" - "faddp v25.4s, v25.4s, v26.4s \n" - "fadd v25.4s, v25.4s, v31.4s \n" - "fmax v25.4s, v25.4s, v30.4s \n" - - // write output - "st1 {v25.s}[0], [%[dout0]] \n" - "st1 {v25.s}[1], [%[dout1]] \n" - "st1 {v25.s}[2], [%[dout2]] \n" - "st1 {v25.s}[3], [%[dout3]] \n" - - : [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [din6] "+r"(din6), - [din7] "+r"(din7), - [wh] "+r"(weights) - : [dout0] "r"(dout0), - [dout1] "r"(dout1), - [dout2] "r"(dout2), - [dout3] "r"(dout3), - [bias] "r"(bias) - : "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v25", - "v26", - "v30", - "v31"); -} - -//! kernel for one out with extracting data post -//! deal with four lines out -void compute_one_out_extract_post(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - const float* din6, - const float* din7, - float* dout0, - float* dout1, - float* dout2, - float* dout3, - float32x4_t w0, - float32x4_t w1, - float32x4_t w2, - float32x4_t w3, - float32x4_t w4, - const float* bias) { - //! din0 - din7: 0-4 v8-v15 - //! dout0 - dout3: v16-v19 - asm volatile( - "ld1 {v31.4s}, [%[bias]] \n" - "ld1 {v8.4s}, [%[din0]], #16 \n" - "ld1 {v9.4s}, [%[din1]], #16 \n" - "ld1 {v10.4s}, [%[din2]], #16 \n" - "ld1 {v11.4s}, [%[din3]], #16 \n" - "ld1 {v12.4s}, [%[din4]], #16 \n" - "ld1 {v13.4s}, [%[din5]], #16 \n" - - // in row0 - "fmul v16.4s, %[w0].4s, v8.4s \n" - "fmul v17.4s, %[w0].4s, v9.4s \n" - "fmul v18.4s, %[w0].4s, v10.4s \n" - "fmul v19.4s, %[w0].4s, v11.4s \n" - - "ld1 {v14.4s}, [%[din6]], #16 \n" - "ld1 {v15.4s}, [%[din7]], #16 \n" - - // in row1 - "fmla v16.4s, %[w1].4s, v9.4s \n" - "fmla v17.4s, %[w1].4s, v10.4s \n" - "fmla v18.4s, %[w1].4s, v11.4s \n" - "fmla v19.4s, %[w1].4s, v12.4s \n" - - // in row2 - "fmla v16.4s, %[w2].4s, v10.4s \n" - "fmla v17.4s, %[w2].4s, v11.4s \n" - "fmla v18.4s, %[w2].4s, v12.4s \n" - "fmla v19.4s, %[w2].4s, v13.4s \n" - - // in row3 - "fmla v16.4s, %[w3].4s, v11.4s \n" - "fmla v17.4s, %[w3].4s, v12.4s \n" - "fmla v18.4s, %[w3].4s, v13.4s \n" - "fmla v19.4s, %[w3].4s, v14.4s \n" - - // in row4 - "fmla v16.4s, %[w4].4s, v12.4s \n" - "fmla v17.4s, %[w4].4s, v13.4s \n" - "fmla v18.4s, %[w4].4s, v14.4s \n" - "fmla v19.4s, %[w4].4s, v15.4s \n" - - // add to out register v25 - "faddp v25.4s, v16.4s, v17.4s \n" - "faddp v26.4s, v18.4s, v19.4s \n" - "faddp v25.4s, v25.4s, v26.4s \n" - "fadd v25.4s, v25.4s, v31.4s \n" - - // write output - "st1 {v25.s}[0], [%[dout0]] \n" - "st1 {v25.s}[1], [%[dout1]] \n" - "st1 {v25.s}[2], [%[dout2]] \n" - "st1 {v25.s}[3], [%[dout3]] \n" - - : [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [din6] "+r"(din6), - [din7] "+r"(din7) - : [dout0] "r"(dout0), - [dout1] "r"(dout1), - [dout2] "r"(dout2), - [dout3] "r"(dout3), - [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [bias] "r"(bias) - : "memory", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v25", - "v26", - "v31"); -} - -//! kernel for one out with extracting data post -//! deal with four lines out -void compute_one_out_extract_post_relu(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - const float* din6, - const float* din7, - float* dout0, - float* dout1, - float* dout2, - float* dout3, - float32x4_t w0, - float32x4_t w1, - float32x4_t w2, - float32x4_t w3, - float32x4_t w4, - const float* bias) { - //! din0 - din7: 0-4 v8-v15 - //! dout0 - dout3: v16-v19 - asm volatile( - "ld1 {v8.4s}, [%[din0]], #16 \n" - "ld1 {v9.4s}, [%[din1]], #16 \n" - "ld1 {v10.4s}, [%[din2]], #16 \n" - "ld1 {v11.4s}, [%[din3]], #16 \n" - "ld1 {v12.4s}, [%[din4]], #16 \n" - "ld1 {v13.4s}, [%[din5]], #16 \n" - - // in row0 - "fmul v16.4s, %[w0].4s, v8.4s \n" - "fmul v17.4s, %[w0].4s, v9.4s \n" - "fmul v18.4s, %[w0].4s, v10.4s \n" - "fmul v19.4s, %[w0].4s, v11.4s \n" - - "ld1 {v14.4s}, [%[din6]], #16 \n" - "ld1 {v15.4s}, [%[din7]], #16 \n" - - // in row1 - "fmla v16.4s, %[w1].4s, v9.4s \n" - "fmla v17.4s, %[w1].4s, v10.4s \n" - "fmla v18.4s, %[w1].4s, v11.4s \n" - "fmla v19.4s, %[w1].4s, v12.4s \n" - - // in row2 - "fmla v16.4s, %[w2].4s, v10.4s \n" - "fmla v17.4s, %[w2].4s, v11.4s \n" - "fmla v18.4s, %[w2].4s, v12.4s \n" - "fmla v19.4s, %[w2].4s, v13.4s \n" - - // in row3 - "fmla v16.4s, %[w3].4s, v11.4s \n" - "fmla v17.4s, %[w3].4s, v12.4s \n" - "fmla v18.4s, %[w3].4s, v13.4s \n" - "fmla v19.4s, %[w3].4s, v14.4s \n" - - "ld1 {v31.4s}, [%[bias]] \n" - "movi v30.4s, #0 \n" - - // in row4 - "fmla v16.4s, %[w4].4s, v12.4s \n" - "fmla v17.4s, %[w4].4s, v13.4s \n" - "fmla v18.4s, %[w4].4s, v14.4s \n" - "fmla v19.4s, %[w4].4s, v15.4s \n" - - // add to out register v25 - "faddp v25.4s, v16.4s, v17.4s \n" - "faddp v26.4s, v18.4s, v19.4s \n" - "faddp v25.4s, v25.4s, v26.4s \n" - "fadd v25.4s, v25.4s, v31.4s \n" - "fmax v25.4s, v25.4s, v30.4s \n" - - // write output - "st1 {v25.s}[0], [%[dout0]] \n" - "st1 {v25.s}[1], [%[dout1]] \n" - "st1 {v25.s}[2], [%[dout2]] \n" - "st1 {v25.s}[3], [%[dout3]] \n" - - : [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [din6] "+r"(din6), - [din7] "+r"(din7) - : [dout0] "r"(dout0), - [dout1] "r"(dout1), - [dout2] "r"(dout2), - [dout3] "r"(dout3), - [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [bias] "r"(bias) - : "memory", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v25", - "v26", - "v30", - "v31"); -} - -//! kernel for two out with extracting data pre -//! deal with four lines out -//! need extra load weights -void compute_two_out_extract_pre(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - const float* din6, - const float* din7, - float* dout0, - float* dout1, - float* dout2, - float* dout3, - const float* weights, - const float* bias) { - //! din0 - din7: 0-4 v8-v15 - //! dout0 - dout3: v16-v19 - //! weights: v0-v4 - asm volatile( - // load weights - "movi v31.4s, #0 \n" - "add %[wh], %[wh], #4 \n" - "ldr q0, [%[wh]], #20 \n" // 1, 2, 3, 4 - "ldr q1, [%[wh]], #20 \n" // 6, 7, 8, 9 - "ldr q2, [%[wh]], #20 \n" // 11, 12, 13, 14 - "ldr q3, [%[wh]], #20 \n" // 16, 17, 18, 19 - "ldr q4, [%[wh]], #20 \n" // 21, 22, 23, 24 - - // load inputs - "ld1 {v20.4s}, [%[bias]] \n" - "ld1 {v8.4s}, [%[din0]], #16 \n" - "ld1 {v9.4s}, [%[din1]], #16 \n" - "ld1 {v10.4s}, [%[din2]], #16 \n" - "ld1 {v11.4s}, [%[din3]], #16 \n" - "ld1 {v12.4s}, [%[din4]], #16 \n" - "ld1 {v13.4s}, [%[din5]], #16 \n" - - // in row0 - "fmul v16.4s, v0.4s, v8.4s \n" - "fmul v17.4s, v0.4s, v9.4s \n" - "fmul v18.4s, v0.4s, v10.4s \n" - "fmul v19.4s, v0.4s, v11.4s \n" - - "ld1 {v14.4s}, [%[din6]], #16 \n" - "ld1 {v15.4s}, [%[din7]], #16 \n" - - // in row1 - "fmla v16.4s, v1.4s, v9.4s \n" - "fmla v17.4s, v1.4s, v10.4s \n" - "fmla v18.4s, v1.4s, v11.4s \n" - "fmla v19.4s, v1.4s, v12.4s \n" - - // in row2 - "fmla v16.4s, v2.4s, v10.4s \n" - "fmla v17.4s, v2.4s, v11.4s \n" - "fmla v18.4s, v2.4s, v12.4s \n" - "fmla v19.4s, v2.4s, v13.4s \n" - - // in row3 - "fmla v16.4s, v3.4s, v11.4s \n" - "fmla v17.4s, v3.4s, v12.4s \n" - "fmla v18.4s, v3.4s, v13.4s \n" - "fmla v19.4s, v3.4s, v14.4s \n" - - // in row4 - "fmla v16.4s, v4.4s, v12.4s \n" - "fmla v17.4s, v4.4s, v13.4s \n" - "fmla v18.4s, v4.4s, v14.4s \n" - "fmla v19.4s, v4.4s, v15.4s \n" - - // add to out register v5 - "faddp v5.4s, v16.4s, v17.4s \n" - "faddp v6.4s, v18.4s, v19.4s \n" - "faddp v5.4s, v5.4s, v6.4s \n" - - // ext weights - "ext v0.16b, v0.16b, v31.16b, #4 \n" // 2, 3, 4 - "ext v1.16b, v1.16b, v31.16b, #4 \n" // 7, 8, 9 - "ext v2.16b, v2.16b, v31.16b, #4 \n" // 12, 13, 14 - "ext v3.16b, v3.16b, v31.16b, #4 \n" // 17, 18, 19 - "ext v4.16b, v4.16b, v31.16b, #4 \n" // 22, 23, 24 - - // in row0 - "fmul v16.4s, v0.4s, v8.4s \n" - "fmul v17.4s, v0.4s, v9.4s \n" - "fmul v18.4s, v0.4s, v10.4s \n" - "fmul v19.4s, v0.4s, v11.4s \n" - - // in row1 - "fmla v16.4s, v1.4s, v9.4s \n" - "fmla v17.4s, v1.4s, v10.4s \n" - "fmla v18.4s, v1.4s, v11.4s \n" - "fmla v19.4s, v1.4s, v12.4s \n" - - // in row2 - "fmla v16.4s, v2.4s, v10.4s \n" - "fmla v17.4s, v2.4s, v11.4s \n" - "fmla v18.4s, v2.4s, v12.4s \n" - "fmla v19.4s, v2.4s, v13.4s \n" - - // in row3 - "fmla v16.4s, v3.4s, v11.4s \n" - "fmla v17.4s, v3.4s, v12.4s \n" - "fmla v18.4s, v3.4s, v13.4s \n" - "fmla v19.4s, v3.4s, v14.4s \n" - - // in row4 - "fmla v16.4s, v4.4s, v12.4s \n" - "fmla v17.4s, v4.4s, v13.4s \n" - "fmla v18.4s, v4.4s, v14.4s \n" - "fmla v19.4s, v4.4s, v15.4s \n" - - // add to out register v7 - "faddp v7.4s, v16.4s, v17.4s \n" - "faddp v8.4s, v18.4s, v19.4s \n" - "faddp v7.4s, v7.4s, v8.4s \n" - - // zip - "zip1 v6.4s, v7.4s, v5.4s \n" - "zip2 v8.4s, v7.4s, v5.4s \n" - "fadd v6.4s, v6.4s, v20.4s \n" - "fadd v8.4s, v8.4s, v20.4s \n" - "ext v7.16b, v6.16b, v31.16b, #8 \n" - "ext v9.16b, v8.16b, v31.16b, #8 \n" - - // write output - "str d6, [%[dout0]] \n" - "str d7, [%[dout1]] \n" - "str d8, [%[dout2]] \n" - "str d9, [%[dout3]] \n" - - : [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [din6] "+r"(din6), - [din7] "+r"(din7), - [wh] "+r"(weights) - : [dout0] "r"(dout0), - [dout1] "r"(dout1), - [dout2] "r"(dout2), - [dout3] "r"(dout3), - [bias] "r"(bias) - : "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v31"); -} - -//! kernel for two out with extracting data pre -//! deal with four lines out -//! need extra load weights -void compute_two_out_extract_pre_relu(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - const float* din6, - const float* din7, - float* dout0, - float* dout1, - float* dout2, - float* dout3, - const float* weights, - const float* bias) { - //! din0 - din7: 0-4 v8-v15 - //! dout0 - dout3: v16-v19 - //! weights: v0-v4 - asm volatile( - // load weights - "movi v31.4s, #0 \n" - "add %[wh], %[wh], #4 \n" - "ldr q0, [%[wh]], #20 \n" // 1, 2, 3, 4 - "ldr q1, [%[wh]], #20 \n" // 6, 7, 8, 9 - "ldr q2, [%[wh]], #20 \n" // 11, 12, 13, 14 - "ldr q3, [%[wh]], #20 \n" // 16, 17, 18, 19 - "ldr q4, [%[wh]], #20 \n" // 21, 22, 23, 24 - - // load inputs - "ld1 {v20.4s}, [%[bias]] \n" - "ld1 {v8.4s}, [%[din0]], #16 \n" - "ld1 {v9.4s}, [%[din1]], #16 \n" - "ld1 {v10.4s}, [%[din2]], #16 \n" - "ld1 {v11.4s}, [%[din3]], #16 \n" - "ld1 {v12.4s}, [%[din4]], #16 \n" - "ld1 {v13.4s}, [%[din5]], #16 \n" - - // in row0 - "fmul v16.4s, v0.4s, v8.4s \n" - "fmul v17.4s, v0.4s, v9.4s \n" - "fmul v18.4s, v0.4s, v10.4s \n" - "fmul v19.4s, v0.4s, v11.4s \n" - - "ld1 {v14.4s}, [%[din6]], #16 \n" - "ld1 {v15.4s}, [%[din7]], #16 \n" - - // in row1 - "fmla v16.4s, v1.4s, v9.4s \n" - "fmla v17.4s, v1.4s, v10.4s \n" - "fmla v18.4s, v1.4s, v11.4s \n" - "fmla v19.4s, v1.4s, v12.4s \n" - - // in row2 - "fmla v16.4s, v2.4s, v10.4s \n" - "fmla v17.4s, v2.4s, v11.4s \n" - "fmla v18.4s, v2.4s, v12.4s \n" - "fmla v19.4s, v2.4s, v13.4s \n" - - // in row3 - "fmla v16.4s, v3.4s, v11.4s \n" - "fmla v17.4s, v3.4s, v12.4s \n" - "fmla v18.4s, v3.4s, v13.4s \n" - "fmla v19.4s, v3.4s, v14.4s \n" - - // in row4 - "fmla v16.4s, v4.4s, v12.4s \n" - "fmla v17.4s, v4.4s, v13.4s \n" - "fmla v18.4s, v4.4s, v14.4s \n" - "fmla v19.4s, v4.4s, v15.4s \n" - - // add to out register v5 - "faddp v5.4s, v16.4s, v17.4s \n" - "faddp v6.4s, v18.4s, v19.4s \n" - "faddp v5.4s, v5.4s, v6.4s \n" - - // ext weights - "ext v0.16b, v0.16b, v31.16b, #4 \n" // 2, 3, 4 - "ext v1.16b, v1.16b, v31.16b, #4 \n" // 7, 8, 9 - "ext v2.16b, v2.16b, v31.16b, #4 \n" // 12, 13, 14 - "ext v3.16b, v3.16b, v31.16b, #4 \n" // 17, 18, 19 - "ext v4.16b, v4.16b, v31.16b, #4 \n" // 22, 23, 24 - - // in row0 - "fmul v16.4s, v0.4s, v8.4s \n" - "fmul v17.4s, v0.4s, v9.4s \n" - "fmul v18.4s, v0.4s, v10.4s \n" - "fmul v19.4s, v0.4s, v11.4s \n" - - // in row1 - "fmla v16.4s, v1.4s, v9.4s \n" - "fmla v17.4s, v1.4s, v10.4s \n" - "fmla v18.4s, v1.4s, v11.4s \n" - "fmla v19.4s, v1.4s, v12.4s \n" - - // in row2 - "fmla v16.4s, v2.4s, v10.4s \n" - "fmla v17.4s, v2.4s, v11.4s \n" - "fmla v18.4s, v2.4s, v12.4s \n" - "fmla v19.4s, v2.4s, v13.4s \n" - - // in row3 - "fmla v16.4s, v3.4s, v11.4s \n" - "fmla v17.4s, v3.4s, v12.4s \n" - "fmla v18.4s, v3.4s, v13.4s \n" - "fmla v19.4s, v3.4s, v14.4s \n" - - // in row4 - "fmla v16.4s, v4.4s, v12.4s \n" - "fmla v17.4s, v4.4s, v13.4s \n" - "fmla v18.4s, v4.4s, v14.4s \n" - "fmla v19.4s, v4.4s, v15.4s \n" - - // add to out register v7 - "faddp v7.4s, v16.4s, v17.4s \n" - "faddp v8.4s, v18.4s, v19.4s \n" - "faddp v7.4s, v7.4s, v8.4s \n" - - // zip - "zip1 v6.4s, v7.4s, v5.4s \n" - "zip2 v8.4s, v7.4s, v5.4s \n" - - // add bias - "fadd v6.4s, v6.4s, v20.4s \n" - "fadd v8.4s, v8.4s, v20.4s \n" - - // relu - "fmax v6.4s, v6.4s, v31.4s \n" - "fmax v8.4s, v8.4s, v31.4s \n" - - "ext v7.16b, v6.16b, v31.16b, #8 \n" - "ext v9.16b, v8.16b, v31.16b, #8 \n" - - // write output - "str d6, [%[dout0]] \n" - "str d7, [%[dout1]] \n" - "str d8, [%[dout2]] \n" - "str d9, [%[dout3]] \n" - - : [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [din6] "+r"(din6), - [din7] "+r"(din7), - [wh] "+r"(weights) - : [dout0] "r"(dout0), - [dout1] "r"(dout1), - [dout2] "r"(dout2), - [dout3] "r"(dout3), - [bias] "r"(bias) - : "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v31"); -} - -//! kernel for two out with extracting data post -//! deal with four lines out -void compute_two_out_extract_post(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - const float* din6, - const float* din7, - float* dout0, - float* dout1, - float* dout2, - float* dout3, - float32x4_t w0, - float32x4_t w1, - float32x4_t w2, - float32x4_t w3, - float32x4_t w4, - const float* bias) { - //! din0 - din7: 0-4 v8-v15 - //! dout0 - dout3: v16-v19 - asm volatile( - "movi v31.4s, #0 \n" - - // load inputs - "ld1 {v20.4s}, [%[bias]] \n" - "ld1 {v8.4s}, [%[din0]], #16 \n" - "ld1 {v9.4s}, [%[din1]], #16 \n" - "ld1 {v10.4s}, [%[din2]], #16 \n" - "ld1 {v11.4s}, [%[din3]], #16 \n" - "ld1 {v12.4s}, [%[din4]], #16 \n" - "ld1 {v13.4s}, [%[din5]], #16 \n" - - // in row0 - "fmul v16.4s, %[w0].4s, v8.4s \n" - "fmul v17.4s, %[w0].4s, v9.4s \n" - "fmul v18.4s, %[w0].4s, v10.4s \n" - "fmul v19.4s, %[w0].4s, v11.4s \n" - - "ld1 {v14.4s}, [%[din6]], #16 \n" - "ld1 {v15.4s}, [%[din7]], #16 \n" - - // in row1 - "fmla v16.4s, %[w1].4s, v9.4s \n" - "fmla v17.4s, %[w1].4s, v10.4s \n" - "fmla v18.4s, %[w1].4s, v11.4s \n" - "fmla v19.4s, %[w1].4s, v12.4s \n" - - // in row2 - "fmla v16.4s, %[w2].4s, v10.4s \n" - "fmla v17.4s, %[w2].4s, v11.4s \n" - "fmla v18.4s, %[w2].4s, v12.4s \n" - "fmla v19.4s, %[w2].4s, v13.4s \n" - - // in row3 - "fmla v16.4s, %[w3].4s, v11.4s \n" - "fmla v17.4s, %[w3].4s, v12.4s \n" - "fmla v18.4s, %[w3].4s, v13.4s \n" - "fmla v19.4s, %[w3].4s, v14.4s \n" - - // in row4 - "fmla v16.4s, %[w4].4s, v12.4s \n" - "fmla v17.4s, %[w4].4s, v13.4s \n" - "fmla v18.4s, %[w4].4s, v14.4s \n" - "fmla v19.4s, %[w4].4s, v15.4s \n" - - // add to out register v5 - "faddp v5.4s, v16.4s, v17.4s \n" - "faddp v6.4s, v18.4s, v19.4s \n" - "faddp v5.4s, v5.4s, v6.4s \n" - - // ext input - "ext v8.16b, v8.16b, v31.16b, #4 \n" - "ext v9.16b, v9.16b, v31.16b, #4 \n" - "ext v10.16b, v10.16b, v31.16b, #4 \n" - "ext v11.16b, v11.16b, v31.16b, #4 \n" - "ext v12.16b, v12.16b, v31.16b, #4 \n" - "ext v13.16b, v13.16b, v31.16b, #4 \n" - "ext v14.16b, v14.16b, v31.16b, #4 \n" - "ext v15.16b, v15.16b, v31.16b, #4 \n" - - // in row0 - "fmul v16.4s, %[w0].4s, v8.4s \n" - "fmul v17.4s, %[w0].4s, v9.4s \n" - "fmul v18.4s, %[w0].4s, v10.4s \n" - "fmul v19.4s, %[w0].4s, v11.4s \n" - - // in row1 - "fmla v16.4s, %[w1].4s, v9.4s \n" - "fmla v17.4s, %[w1].4s, v10.4s \n" - "fmla v18.4s, %[w1].4s, v11.4s \n" - "fmla v19.4s, %[w1].4s, v12.4s \n" - - // in row2 - "fmla v16.4s, %[w2].4s, v10.4s \n" - "fmla v17.4s, %[w2].4s, v11.4s \n" - "fmla v18.4s, %[w2].4s, v12.4s \n" - "fmla v19.4s, %[w2].4s, v13.4s \n" - - // in row3 - "fmla v16.4s, %[w3].4s, v11.4s \n" - "fmla v17.4s, %[w3].4s, v12.4s \n" - "fmla v18.4s, %[w3].4s, v13.4s \n" - "fmla v19.4s, %[w3].4s, v14.4s \n" - - // in row4 - "fmla v16.4s, %[w4].4s, v12.4s \n" - "fmla v17.4s, %[w4].4s, v13.4s \n" - "fmla v18.4s, %[w4].4s, v14.4s \n" - "fmla v19.4s, %[w4].4s, v15.4s \n" - - // add to out register v7 - "faddp v7.4s, v16.4s, v17.4s \n" - "faddp v8.4s, v18.4s, v19.4s \n" - "faddp v7.4s, v7.4s, v8.4s \n" - - // zip - "zip1 v6.4s, v5.4s, v7.4s \n" - "zip2 v8.4s, v5.4s, v7.4s \n" - "fadd v6.4s, v6.4s, v20.4s \n" - "fadd v8.4s, v8.4s, v20.4s \n" - "ext v7.16b, v6.16b, v31.16b, #8 \n" - "ext v9.16b, v8.16b, v31.16b, #8 \n" - - // write output - "str d6, [%[dout0]] \n" - "str d7, [%[dout1]] \n" - "str d8, [%[dout2]] \n" - "str d9, [%[dout3]] \n" - - : [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [din6] "+r"(din6), - [din7] "+r"(din7) - : [dout0] "r"(dout0), - [dout1] "r"(dout1), - [dout2] "r"(dout2), - [dout3] "r"(dout3), - [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [bias] "r"(bias) - : "memory", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v31"); -} - -//! kernel for two out with extracting data post -//! deal with four lines out -void compute_two_out_extract_post_relu(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - const float* din6, - const float* din7, - float* dout0, - float* dout1, - float* dout2, - float* dout3, - float32x4_t w0, - float32x4_t w1, - float32x4_t w2, - float32x4_t w3, - float32x4_t w4, - const float* bias) { - //! din0 - din7: 0-4 v8-v15 - //! dout0 - dout3: v16-v19 - asm volatile( - "movi v31.4s, #0 \n" - - // load inputs - "ld1 {v20.4s}, [%[bias]] \n" - "ld1 {v8.4s}, [%[din0]], #16 \n" - "ld1 {v9.4s}, [%[din1]], #16 \n" - "ld1 {v10.4s}, [%[din2]], #16 \n" - "ld1 {v11.4s}, [%[din3]], #16 \n" - "ld1 {v12.4s}, [%[din4]], #16 \n" - "ld1 {v13.4s}, [%[din5]], #16 \n" - - // in row0 - "fmul v16.4s, %[w0].4s, v8.4s \n" - "fmul v17.4s, %[w0].4s, v9.4s \n" - "fmul v18.4s, %[w0].4s, v10.4s \n" - "fmul v19.4s, %[w0].4s, v11.4s \n" - - "ld1 {v14.4s}, [%[din6]], #16 \n" - "ld1 {v15.4s}, [%[din7]], #16 \n" - - // in row1 - "fmla v16.4s, %[w1].4s, v9.4s \n" - "fmla v17.4s, %[w1].4s, v10.4s \n" - "fmla v18.4s, %[w1].4s, v11.4s \n" - "fmla v19.4s, %[w1].4s, v12.4s \n" - - // in row2 - "fmla v16.4s, %[w2].4s, v10.4s \n" - "fmla v17.4s, %[w2].4s, v11.4s \n" - "fmla v18.4s, %[w2].4s, v12.4s \n" - "fmla v19.4s, %[w2].4s, v13.4s \n" - - // in row3 - "fmla v16.4s, %[w3].4s, v11.4s \n" - "fmla v17.4s, %[w3].4s, v12.4s \n" - "fmla v18.4s, %[w3].4s, v13.4s \n" - "fmla v19.4s, %[w3].4s, v14.4s \n" - - // in row4 - "fmla v16.4s, %[w4].4s, v12.4s \n" - "fmla v17.4s, %[w4].4s, v13.4s \n" - "fmla v18.4s, %[w4].4s, v14.4s \n" - "fmla v19.4s, %[w4].4s, v15.4s \n" - - // add to out register v5 - "faddp v5.4s, v16.4s, v17.4s \n" - "faddp v6.4s, v18.4s, v19.4s \n" - "faddp v5.4s, v5.4s, v6.4s \n" - - // ext input - "ext v8.16b, v8.16b, v31.16b, #4 \n" - "ext v9.16b, v9.16b, v31.16b, #4 \n" - "ext v10.16b, v10.16b, v31.16b, #4 \n" - "ext v11.16b, v11.16b, v31.16b, #4 \n" - "ext v12.16b, v12.16b, v31.16b, #4 \n" - "ext v13.16b, v13.16b, v31.16b, #4 \n" - "ext v14.16b, v14.16b, v31.16b, #4 \n" - "ext v15.16b, v15.16b, v31.16b, #4 \n" - - // in row0 - "fmul v16.4s, %[w0].4s, v8.4s \n" - "fmul v17.4s, %[w0].4s, v9.4s \n" - "fmul v18.4s, %[w0].4s, v10.4s \n" - "fmul v19.4s, %[w0].4s, v11.4s \n" - - // in row1 - "fmla v16.4s, %[w1].4s, v9.4s \n" - "fmla v17.4s, %[w1].4s, v10.4s \n" - "fmla v18.4s, %[w1].4s, v11.4s \n" - "fmla v19.4s, %[w1].4s, v12.4s \n" - - // in row2 - "fmla v16.4s, %[w2].4s, v10.4s \n" - "fmla v17.4s, %[w2].4s, v11.4s \n" - "fmla v18.4s, %[w2].4s, v12.4s \n" - "fmla v19.4s, %[w2].4s, v13.4s \n" - - // in row3 - "fmla v16.4s, %[w3].4s, v11.4s \n" - "fmla v17.4s, %[w3].4s, v12.4s \n" - "fmla v18.4s, %[w3].4s, v13.4s \n" - "fmla v19.4s, %[w3].4s, v14.4s \n" - - // in row4 - "fmla v16.4s, %[w4].4s, v12.4s \n" - "fmla v17.4s, %[w4].4s, v13.4s \n" - "fmla v18.4s, %[w4].4s, v14.4s \n" - "fmla v19.4s, %[w4].4s, v15.4s \n" - - // add to out register v7 - "faddp v7.4s, v16.4s, v17.4s \n" - "faddp v8.4s, v18.4s, v19.4s \n" - "faddp v7.4s, v7.4s, v8.4s \n" - - // zip - "zip1 v6.4s, v5.4s, v7.4s \n" - "zip2 v8.4s, v5.4s, v7.4s \n" - - // add bias - "fadd v6.4s, v6.4s, v20.4s \n" - "fadd v8.4s, v8.4s, v20.4s \n" - - // relu - "fmax v6.4s, v6.4s, v31.4s \n" - "fmax v8.4s, v8.4s, v31.4s \n" - "ext v7.16b, v6.16b, v31.16b, #8 \n" - "ext v9.16b, v8.16b, v31.16b, #8 \n" - - // write output - "str d6, [%[dout0]] \n" - "str d7, [%[dout1]] \n" - "str d8, [%[dout2]] \n" - "str d9, [%[dout3]] \n" - - : [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [din6] "+r"(din6), - [din7] "+r"(din7) - : [dout0] "r"(dout0), - [dout1] "r"(dout1), - [dout2] "r"(dout2), - [dout3] "r"(dout3), - [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [bias] "r"(bias) - : "memory", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v31"); -} - -//! kernel for three out with extracting data pre -//! deal with four lines out -//! need extra load weights -void compute_three_out_extract_pre(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - const float* din6, - const float* din7, - float* dout0, - float* dout1, - float* dout2, - float* dout3, - const float* weights, - const float* bias) { - //! din0 - din7: 0-4 v8-v15 - //! dout0 - dout3: v16-v19 - //! weights: v0-v4 - asm volatile( - // load weights - "movi v31.4s, #0 \n" - "add %[wh], %[wh], #4 \n" - "ldr q0, [%[wh]], #20 \n" // 1, 2, 3, 4 - "ldr q1, [%[wh]], #20 \n" // 6, 7, 8, 9 - "ldr q2, [%[wh]], #20 \n" // 11, 12, 13, 14 - "ldr q3, [%[wh]], #20 \n" // 16, 17, 18, 19 - "ldr q4, [%[wh]], #20 \n" // 21, 22, 23, 24 - - // load inputs - "ld1 {v20.4s}, [%[bias]] \n" - "ld1 {v8.4s}, [%[din0]], #16 \n" - "ld1 {v9.4s}, [%[din1]], #16 \n" - "ld1 {v10.4s}, [%[din2]], #16 \n" - "ld1 {v11.4s}, [%[din3]], #16 \n" - "ld1 {v12.4s}, [%[din4]], #16 \n" - "ld1 {v13.4s}, [%[din5]], #16 \n" - - // in row0 - "fmul v16.4s, v0.4s, v8.4s \n" - "fmul v17.4s, v0.4s, v9.4s \n" - "fmul v18.4s, v0.4s, v10.4s \n" - "fmul v19.4s, v0.4s, v11.4s \n" - - "ld1 {v14.4s}, [%[din6]], #16 \n" - "ld1 {v15.4s}, [%[din7]], #16 \n" - - // in row1 - "fmla v16.4s, v1.4s, v9.4s \n" - "fmla v17.4s, v1.4s, v10.4s \n" - "fmla v18.4s, v1.4s, v11.4s \n" - "fmla v19.4s, v1.4s, v12.4s \n" - - // in row2 - "fmla v16.4s, v2.4s, v10.4s \n" - "fmla v17.4s, v2.4s, v11.4s \n" - "fmla v18.4s, v2.4s, v12.4s \n" - "fmla v19.4s, v2.4s, v13.4s \n" - - // in row3 - "fmla v16.4s, v3.4s, v11.4s \n" - "fmla v17.4s, v3.4s, v12.4s \n" - "fmla v18.4s, v3.4s, v13.4s \n" - "fmla v19.4s, v3.4s, v14.4s \n" - - // in row4 - "fmla v16.4s, v4.4s, v12.4s \n" - "fmla v17.4s, v4.4s, v13.4s \n" - "fmla v18.4s, v4.4s, v14.4s \n" - "fmla v19.4s, v4.4s, v15.4s \n" - - // add to out register v5 - "faddp v5.4s, v16.4s, v17.4s \n" - "faddp v6.4s, v18.4s, v19.4s \n" - "faddp v5.4s, v5.4s, v6.4s \n" - - // ext weights - "ext v0.16b, v0.16b, v31.16b, #4 \n" // 2, 3, 4 - "ext v1.16b, v1.16b, v31.16b, #4 \n" // 7, 8, 9 - "ext v2.16b, v2.16b, v31.16b, #4 \n" // 12, 13, 14 - "ext v3.16b, v3.16b, v31.16b, #4 \n" // 17, 18, 19 - "ext v4.16b, v4.16b, v31.16b, #4 \n" // 22, 23, 24 - - // in row0 - "fmul v16.4s, v0.4s, v8.4s \n" - "fmul v17.4s, v0.4s, v9.4s \n" - "fmul v18.4s, v0.4s, v10.4s \n" - "fmul v19.4s, v0.4s, v11.4s \n" - - // in row1 - "fmla v16.4s, v1.4s, v9.4s \n" - "fmla v17.4s, v1.4s, v10.4s \n" - "fmla v18.4s, v1.4s, v11.4s \n" - "fmla v19.4s, v1.4s, v12.4s \n" - - // in row2 - "fmla v16.4s, v2.4s, v10.4s \n" - "fmla v17.4s, v2.4s, v11.4s \n" - "fmla v18.4s, v2.4s, v12.4s \n" - "fmla v19.4s, v2.4s, v13.4s \n" - - // in row3 - "fmla v16.4s, v3.4s, v11.4s \n" - "fmla v17.4s, v3.4s, v12.4s \n" - "fmla v18.4s, v3.4s, v13.4s \n" - "fmla v19.4s, v3.4s, v14.4s \n" - - // in row4 - "fmla v16.4s, v4.4s, v12.4s \n" - "fmla v17.4s, v4.4s, v13.4s \n" - "fmla v18.4s, v4.4s, v14.4s \n" - "fmla v19.4s, v4.4s, v15.4s \n" - - // add to out register v7 - "faddp v7.4s, v16.4s, v17.4s \n" - "faddp v6.4s, v18.4s, v19.4s \n" - "faddp v7.4s, v7.4s, v6.4s \n" - - // ext weights - "ext v0.16b, v0.16b, v31.16b, #4 \n" // 3, 4 - "ext v1.16b, v1.16b, v31.16b, #4 \n" // 8, 9 - "ext v2.16b, v2.16b, v31.16b, #4 \n" // 13, 14 - "ext v3.16b, v3.16b, v31.16b, #4 \n" // 18, 19 - "ext v4.16b, v4.16b, v31.16b, #4 \n" // 23, 24 - - // in row0 - "fmul v16.4s, v0.4s, v8.4s \n" - "fmul v17.4s, v0.4s, v9.4s \n" - "fmul v18.4s, v0.4s, v10.4s \n" - "fmul v19.4s, v0.4s, v11.4s \n" - - // in row1 - "fmla v16.4s, v1.4s, v9.4s \n" - "fmla v17.4s, v1.4s, v10.4s \n" - "fmla v18.4s, v1.4s, v11.4s \n" - "fmla v19.4s, v1.4s, v12.4s \n" - - // in row2 - "fmla v16.4s, v2.4s, v10.4s \n" - "fmla v17.4s, v2.4s, v11.4s \n" - "fmla v18.4s, v2.4s, v12.4s \n" - "fmla v19.4s, v2.4s, v13.4s \n" - - // in row3 - "fmla v16.4s, v3.4s, v11.4s \n" - "fmla v17.4s, v3.4s, v12.4s \n" - "fmla v18.4s, v3.4s, v13.4s \n" - "fmla v19.4s, v3.4s, v14.4s \n" - - // in row4 - "fmla v16.4s, v4.4s, v12.4s \n" - "fmla v17.4s, v4.4s, v13.4s \n" - "fmla v18.4s, v4.4s, v14.4s \n" - "fmla v19.4s, v4.4s, v15.4s \n" - - // add to out register v25 - "faddp v25.4s, v16.4s, v17.4s \n" - "faddp v26.4s, v18.4s, v19.4s \n" - "faddp v25.4s, v25.4s, v26.4s \n" - "fadd v25.4s, v25.4s, v20.4s \n" - - // zip - "zip1 v6.4s, v7.4s, v5.4s \n" - "zip2 v8.4s, v7.4s, v5.4s \n" - "fadd v6.4s, v6.4s, v20.4s \n" - "fadd v8.4s, v8.4s, v20.4s \n" - "ext v7.16b, v6.16b, v31.16b, #8 \n" - "ext v9.16b, v8.16b, v31.16b, #8 \n" - - // write output - "st1 {v25.s}[0], [%[dout0]], #4 \n" - "st1 {v25.s}[1], [%[dout1]], #4 \n" - "st1 {v25.s}[2], [%[dout2]], #4 \n" - "st1 {v25.s}[3], [%[dout3]], #4 \n" - - "str d6, [%[dout0]] \n" - "str d7, [%[dout1]] \n" - "str d8, [%[dout2]] \n" - "str d9, [%[dout3]] \n" - - : [dout0] "+r"(dout0), - [dout1] "+r"(dout1), - [dout2] "+r"(dout2), - [dout3] "+r"(dout3), - [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [din6] "+r"(din6), - [din7] "+r"(din7), - [wh] "+r"(weights) - : [bias] "r"(bias) - : "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v25", - "v26", - "v31"); -} - -//! kernel for three out with extracting data pre -//! deal with four lines out -//! need extra load weights -void compute_three_out_extract_pre_relu(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - const float* din6, - const float* din7, - float* dout0, - float* dout1, - float* dout2, - float* dout3, - const float* weights, - const float* bias) { - //! din0 - din7: 0-4 v8-v15 - //! dout0 - dout3: v16-v19 - //! weights: v0-v4 - asm volatile( - // load weights - "movi v31.4s, #0 \n" - "add %[wh], %[wh], #4 \n" - "ldr q0, [%[wh]], #20 \n" // 1, 2, 3, 4 - "ldr q1, [%[wh]], #20 \n" // 6, 7, 8, 9 - "ldr q2, [%[wh]], #20 \n" // 11, 12, 13, 14 - "ldr q3, [%[wh]], #20 \n" // 16, 17, 18, 19 - "ldr q4, [%[wh]], #20 \n" // 21, 22, 23, 24 - - // load inputs - "ld1 {v20.4s}, [%[bias]] \n" - "ld1 {v8.4s}, [%[din0]], #16 \n" - "ld1 {v9.4s}, [%[din1]], #16 \n" - "ld1 {v10.4s}, [%[din2]], #16 \n" - "ld1 {v11.4s}, [%[din3]], #16 \n" - "ld1 {v12.4s}, [%[din4]], #16 \n" - "ld1 {v13.4s}, [%[din5]], #16 \n" - - // in row0 - "fmul v16.4s, v0.4s, v8.4s \n" - "fmul v17.4s, v0.4s, v9.4s \n" - "fmul v18.4s, v0.4s, v10.4s \n" - "fmul v19.4s, v0.4s, v11.4s \n" - - "ld1 {v14.4s}, [%[din6]], #16 \n" - "ld1 {v15.4s}, [%[din7]], #16 \n" - - // in row1 - "fmla v16.4s, v1.4s, v9.4s \n" - "fmla v17.4s, v1.4s, v10.4s \n" - "fmla v18.4s, v1.4s, v11.4s \n" - "fmla v19.4s, v1.4s, v12.4s \n" - - // in row2 - "fmla v16.4s, v2.4s, v10.4s \n" - "fmla v17.4s, v2.4s, v11.4s \n" - "fmla v18.4s, v2.4s, v12.4s \n" - "fmla v19.4s, v2.4s, v13.4s \n" - - // in row3 - "fmla v16.4s, v3.4s, v11.4s \n" - "fmla v17.4s, v3.4s, v12.4s \n" - "fmla v18.4s, v3.4s, v13.4s \n" - "fmla v19.4s, v3.4s, v14.4s \n" - - // in row4 - "fmla v16.4s, v4.4s, v12.4s \n" - "fmla v17.4s, v4.4s, v13.4s \n" - "fmla v18.4s, v4.4s, v14.4s \n" - "fmla v19.4s, v4.4s, v15.4s \n" - - // add to out register v5 - "faddp v5.4s, v16.4s, v17.4s \n" - "faddp v6.4s, v18.4s, v19.4s \n" - "faddp v5.4s, v5.4s, v6.4s \n" - - // ext weights - "ext v0.16b, v0.16b, v31.16b, #4 \n" // 2, 3, 4 - "ext v1.16b, v1.16b, v31.16b, #4 \n" // 7, 8, 9 - "ext v2.16b, v2.16b, v31.16b, #4 \n" // 12, 13, 14 - "ext v3.16b, v3.16b, v31.16b, #4 \n" // 17, 18, 19 - "ext v4.16b, v4.16b, v31.16b, #4 \n" // 22, 23, 24 - - // in row0 - "fmul v16.4s, v0.4s, v8.4s \n" - "fmul v17.4s, v0.4s, v9.4s \n" - "fmul v18.4s, v0.4s, v10.4s \n" - "fmul v19.4s, v0.4s, v11.4s \n" - - // in row1 - "fmla v16.4s, v1.4s, v9.4s \n" - "fmla v17.4s, v1.4s, v10.4s \n" - "fmla v18.4s, v1.4s, v11.4s \n" - "fmla v19.4s, v1.4s, v12.4s \n" - - // in row2 - "fmla v16.4s, v2.4s, v10.4s \n" - "fmla v17.4s, v2.4s, v11.4s \n" - "fmla v18.4s, v2.4s, v12.4s \n" - "fmla v19.4s, v2.4s, v13.4s \n" - - // in row3 - "fmla v16.4s, v3.4s, v11.4s \n" - "fmla v17.4s, v3.4s, v12.4s \n" - "fmla v18.4s, v3.4s, v13.4s \n" - "fmla v19.4s, v3.4s, v14.4s \n" - - // in row4 - "fmla v16.4s, v4.4s, v12.4s \n" - "fmla v17.4s, v4.4s, v13.4s \n" - "fmla v18.4s, v4.4s, v14.4s \n" - "fmla v19.4s, v4.4s, v15.4s \n" - - // add to out register v7 - "faddp v7.4s, v16.4s, v17.4s \n" - "faddp v6.4s, v18.4s, v19.4s \n" - "faddp v7.4s, v7.4s, v6.4s \n" - - // ext weights - "ext v0.16b, v0.16b, v31.16b, #4 \n" // 3, 4 - "ext v1.16b, v1.16b, v31.16b, #4 \n" // 8, 9 - "ext v2.16b, v2.16b, v31.16b, #4 \n" // 13, 14 - "ext v3.16b, v3.16b, v31.16b, #4 \n" // 18, 19 - "ext v4.16b, v4.16b, v31.16b, #4 \n" // 23, 24 - - // in row0 - "fmul v16.4s, v0.4s, v8.4s \n" - "fmul v17.4s, v0.4s, v9.4s \n" - "fmul v18.4s, v0.4s, v10.4s \n" - "fmul v19.4s, v0.4s, v11.4s \n" - - // in row1 - "fmla v16.4s, v1.4s, v9.4s \n" - "fmla v17.4s, v1.4s, v10.4s \n" - "fmla v18.4s, v1.4s, v11.4s \n" - "fmla v19.4s, v1.4s, v12.4s \n" - - // in row2 - "fmla v16.4s, v2.4s, v10.4s \n" - "fmla v17.4s, v2.4s, v11.4s \n" - "fmla v18.4s, v2.4s, v12.4s \n" - "fmla v19.4s, v2.4s, v13.4s \n" - - // in row3 - "fmla v16.4s, v3.4s, v11.4s \n" - "fmla v17.4s, v3.4s, v12.4s \n" - "fmla v18.4s, v3.4s, v13.4s \n" - "fmla v19.4s, v3.4s, v14.4s \n" - - // in row4 - "fmla v16.4s, v4.4s, v12.4s \n" - "fmla v17.4s, v4.4s, v13.4s \n" - "fmla v18.4s, v4.4s, v14.4s \n" - "fmla v19.4s, v4.4s, v15.4s \n" - - // add to out register v25 - "faddp v25.4s, v16.4s, v17.4s \n" - "faddp v26.4s, v18.4s, v19.4s \n" - "faddp v25.4s, v25.4s, v26.4s \n" - "fadd v25.4s, v25.4s, v20.4s \n" - "fmax v25.4s, v25.4s, v31.4s \n" - - // zip - "zip1 v6.4s, v7.4s, v5.4s \n" - "zip2 v8.4s, v7.4s, v5.4s \n" - - // add bias - "fadd v6.4s, v6.4s, v20.4s \n" - "fadd v8.4s, v8.4s, v20.4s \n" - - // relu - "fmax v6.4s, v6.4s, v31.4s \n" - "fmax v8.4s, v8.4s, v31.4s \n" - - "ext v7.16b, v6.16b, v31.16b, #8 \n" - "ext v9.16b, v8.16b, v31.16b, #8 \n" - - // write output - "st1 {v25.s}[0], [%[dout0]], #4 \n" - "st1 {v25.s}[1], [%[dout1]], #4 \n" - "st1 {v25.s}[2], [%[dout2]], #4 \n" - "st1 {v25.s}[3], [%[dout3]], #4 \n" - - "str d6, [%[dout0]] \n" - "str d7, [%[dout1]] \n" - "str d8, [%[dout2]] \n" - "str d9, [%[dout3]] \n" - - : [dout0] "+r"(dout0), - [dout1] "+r"(dout1), - [dout2] "+r"(dout2), - [dout3] "+r"(dout3), - [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [din6] "+r"(din6), - [din7] "+r"(din7), - [wh] "+r"(weights) - : [bias] "r"(bias) - : "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v25", - "v26", - "v31"); -} - -//! kernel for three out with extracting data post -//! deal with four lines out -void compute_three_out_extract_post(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - const float* din6, - const float* din7, - float* dout0, - float* dout1, - float* dout2, - float* dout3, - float32x4_t w0, - float32x4_t w1, - float32x4_t w2, - float32x4_t w3, - float32x4_t w4, - const float* bias) { - //! din0 - din7: 0-4 v8-v15 - //! dout0 - dout3: v6, v8, v25 - asm volatile( - "movi v31.4s, #0 \n" - // load inputs - "ld1 {v20.4s}, [%[bias]] \n" - "ld1 {v8.4s}, [%[din0]], #16 \n" - "ld1 {v9.4s}, [%[din1]], #16 \n" - "ld1 {v10.4s}, [%[din2]], #16 \n" - "ld1 {v11.4s}, [%[din3]], #16 \n" - "ld1 {v12.4s}, [%[din4]], #16 \n" - "ld1 {v13.4s}, [%[din5]], #16 \n" - - // in row0 - "fmul v16.4s, %[w0].4s, v8.4s \n" - "fmul v17.4s, %[w0].4s, v9.4s \n" - "fmul v18.4s, %[w0].4s, v10.4s \n" - "fmul v19.4s, %[w0].4s, v11.4s \n" - - "ld1 {v14.4s}, [%[din6]], #16 \n" - "ld1 {v15.4s}, [%[din7]], #16 \n" - - // in row1 - "fmla v16.4s, %[w1].4s, v9.4s \n" - "fmla v17.4s, %[w1].4s, v10.4s \n" - "fmla v18.4s, %[w1].4s, v11.4s \n" - "fmla v19.4s, %[w1].4s, v12.4s \n" - - // in row2 - "fmla v16.4s, %[w2].4s, v10.4s \n" - "fmla v17.4s, %[w2].4s, v11.4s \n" - "fmla v18.4s, %[w2].4s, v12.4s \n" - "fmla v19.4s, %[w2].4s, v13.4s \n" - - // in row3 - "fmla v16.4s, %[w3].4s, v11.4s \n" - "fmla v17.4s, %[w3].4s, v12.4s \n" - "fmla v18.4s, %[w3].4s, v13.4s \n" - "fmla v19.4s, %[w3].4s, v14.4s \n" - - // in row4 - "fmla v16.4s, %[w4].4s, v12.4s \n" - "fmla v17.4s, %[w4].4s, v13.4s \n" - "fmla v18.4s, %[w4].4s, v14.4s \n" - "fmla v19.4s, %[w4].4s, v15.4s \n" - - // add to out register v5 - "faddp v5.4s, v16.4s, v17.4s \n" - "faddp v6.4s, v18.4s, v19.4s \n" - "faddp v5.4s, v5.4s, v6.4s \n" - - // ext input - "ext v8.16b, v8.16b, v31.16b, #4 \n" - "ext v9.16b, v9.16b, v31.16b, #4 \n" - "ext v10.16b, v10.16b, v31.16b, #4 \n" - "ext v11.16b, v11.16b, v31.16b, #4 \n" - "ext v12.16b, v12.16b, v31.16b, #4 \n" - "ext v13.16b, v13.16b, v31.16b, #4 \n" - "ext v14.16b, v14.16b, v31.16b, #4 \n" - "ext v15.16b, v15.16b, v31.16b, #4 \n" - - // in row0 - "fmul v16.4s, %[w0].4s, v8.4s \n" - "fmul v17.4s, %[w0].4s, v9.4s \n" - "fmul v18.4s, %[w0].4s, v10.4s \n" - "fmul v19.4s, %[w0].4s, v11.4s \n" - - // in row1 - "fmla v16.4s, %[w1].4s, v9.4s \n" - "fmla v17.4s, %[w1].4s, v10.4s \n" - "fmla v18.4s, %[w1].4s, v11.4s \n" - "fmla v19.4s, %[w1].4s, v12.4s \n" - - // in row2 - "fmla v16.4s, %[w2].4s, v10.4s \n" - "fmla v17.4s, %[w2].4s, v11.4s \n" - "fmla v18.4s, %[w2].4s, v12.4s \n" - "fmla v19.4s, %[w2].4s, v13.4s \n" - - // in row3 - "fmla v16.4s, %[w3].4s, v11.4s \n" - "fmla v17.4s, %[w3].4s, v12.4s \n" - "fmla v18.4s, %[w3].4s, v13.4s \n" - "fmla v19.4s, %[w3].4s, v14.4s \n" - - // in row4 - "fmla v16.4s, %[w4].4s, v12.4s \n" - "fmla v17.4s, %[w4].4s, v13.4s \n" - "fmla v18.4s, %[w4].4s, v14.4s \n" - "fmla v19.4s, %[w4].4s, v15.4s \n" - - // add to out register v7 - "faddp v7.4s, v16.4s, v17.4s \n" - "faddp v6.4s, v18.4s, v19.4s \n" - "faddp v7.4s, v7.4s, v6.4s \n" - - // ext input - "ext v8.16b, v8.16b, v31.16b, #4 \n" - "ext v9.16b, v9.16b, v31.16b, #4 \n" - "ext v10.16b, v10.16b, v31.16b, #4 \n" - "ext v11.16b, v11.16b, v31.16b, #4 \n" - "ext v12.16b, v12.16b, v31.16b, #4 \n" - "ext v13.16b, v13.16b, v31.16b, #4 \n" - "ext v14.16b, v14.16b, v31.16b, #4 \n" - "ext v15.16b, v15.16b, v31.16b, #4 \n" - - // in row0 - "fmul v16.4s, %[w0].4s, v8.4s \n" - "fmul v17.4s, %[w0].4s, v9.4s \n" - "fmul v18.4s, %[w0].4s, v10.4s \n" - "fmul v19.4s, %[w0].4s, v11.4s \n" - - // in row1 - "fmla v16.4s, %[w1].4s, v9.4s \n" - "fmla v17.4s, %[w1].4s, v10.4s \n" - "fmla v18.4s, %[w1].4s, v11.4s \n" - "fmla v19.4s, %[w1].4s, v12.4s \n" - - // in row2 - "fmla v16.4s, %[w2].4s, v10.4s \n" - "fmla v17.4s, %[w2].4s, v11.4s \n" - "fmla v18.4s, %[w2].4s, v12.4s \n" - "fmla v19.4s, %[w2].4s, v13.4s \n" - - // in row3 - "fmla v16.4s, %[w3].4s, v11.4s \n" - "fmla v17.4s, %[w3].4s, v12.4s \n" - "fmla v18.4s, %[w3].4s, v13.4s \n" - "fmla v19.4s, %[w3].4s, v14.4s \n" - - // in row4 - "fmla v16.4s, %[w4].4s, v12.4s \n" - "fmla v17.4s, %[w4].4s, v13.4s \n" - "fmla v18.4s, %[w4].4s, v14.4s \n" - "fmla v19.4s, %[w4].4s, v15.4s \n" - - // add to out register v25 - "faddp v25.4s, v16.4s, v17.4s \n" - "faddp v26.4s, v18.4s, v19.4s \n" - "faddp v25.4s, v25.4s, v26.4s \n" - "fadd v25.4s, v25.4s, v20.4s \n" - - // zip - "zip1 v6.4s, v5.4s, v7.4s \n" - "zip2 v8.4s, v5.4s, v7.4s \n" - "fadd v6.4s, v6.4s, v20.4s \n" - "fadd v8.4s, v8.4s, v20.4s \n" - "ext v7.16b, v6.16b, v31.16b, #8 \n" - "ext v9.16b, v8.16b, v31.16b, #8 \n" - - // write output - "str d6, [%[dout0]], #8 \n" - "str d7, [%[dout1]], #8 \n" - "str d8, [%[dout2]], #8 \n" - "str d9, [%[dout3]], #8 \n" - - "st1 {v25.s}[0], [%[dout0]] \n" - "st1 {v25.s}[1], [%[dout1]] \n" - "st1 {v25.s}[2], [%[dout2]] \n" - "st1 {v25.s}[3], [%[dout3]] \n" - - : [dout0] "+r"(dout0), - [dout1] "+r"(dout1), - [dout2] "+r"(dout2), - [dout3] "+r"(dout3), - [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [din6] "+r"(din6), - [din7] "+r"(din7) - : [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [bias] "r"(bias) - : "memory", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v25", - "v26", - "v31"); -} - -//! kernel for three out with extracting data post -//! deal with four lines out -void compute_three_out_extract_post_relu(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - const float* din6, - const float* din7, - float* dout0, - float* dout1, - float* dout2, - float* dout3, - float32x4_t w0, - float32x4_t w1, - float32x4_t w2, - float32x4_t w3, - float32x4_t w4, - const float* bias) { - //! din0 - din7: 0-4 v8-v15 - //! dout0 - dout3: v6, v8, v25 - asm volatile( - "movi v31.4s, #0 \n" - - // load inputs - "ld1 {v20.4s}, [%[bias]] \n" - "ld1 {v8.4s}, [%[din0]], #16 \n" - "ld1 {v9.4s}, [%[din1]], #16 \n" - "ld1 {v10.4s}, [%[din2]], #16 \n" - "ld1 {v11.4s}, [%[din3]], #16 \n" - "ld1 {v12.4s}, [%[din4]], #16 \n" - "ld1 {v13.4s}, [%[din5]], #16 \n" - - // in row0 - "fmul v16.4s, %[w0].4s, v8.4s \n" - "fmul v17.4s, %[w0].4s, v9.4s \n" - "fmul v18.4s, %[w0].4s, v10.4s \n" - "fmul v19.4s, %[w0].4s, v11.4s \n" - - "ld1 {v14.4s}, [%[din6]], #16 \n" - "ld1 {v15.4s}, [%[din7]], #16 \n" - - // in row1 - "fmla v16.4s, %[w1].4s, v9.4s \n" - "fmla v17.4s, %[w1].4s, v10.4s \n" - "fmla v18.4s, %[w1].4s, v11.4s \n" - "fmla v19.4s, %[w1].4s, v12.4s \n" - - // in row2 - "fmla v16.4s, %[w2].4s, v10.4s \n" - "fmla v17.4s, %[w2].4s, v11.4s \n" - "fmla v18.4s, %[w2].4s, v12.4s \n" - "fmla v19.4s, %[w2].4s, v13.4s \n" - - // in row3 - "fmla v16.4s, %[w3].4s, v11.4s \n" - "fmla v17.4s, %[w3].4s, v12.4s \n" - "fmla v18.4s, %[w3].4s, v13.4s \n" - "fmla v19.4s, %[w3].4s, v14.4s \n" - - // in row4 - "fmla v16.4s, %[w4].4s, v12.4s \n" - "fmla v17.4s, %[w4].4s, v13.4s \n" - "fmla v18.4s, %[w4].4s, v14.4s \n" - "fmla v19.4s, %[w4].4s, v15.4s \n" - - // add to out register v5 - "faddp v5.4s, v16.4s, v17.4s \n" - "faddp v6.4s, v18.4s, v19.4s \n" - "faddp v5.4s, v5.4s, v6.4s \n" - - // ext input - "ext v8.16b, v8.16b, v31.16b, #4 \n" - "ext v9.16b, v9.16b, v31.16b, #4 \n" - "ext v10.16b, v10.16b, v31.16b, #4 \n" - "ext v11.16b, v11.16b, v31.16b, #4 \n" - "ext v12.16b, v12.16b, v31.16b, #4 \n" - "ext v13.16b, v13.16b, v31.16b, #4 \n" - "ext v14.16b, v14.16b, v31.16b, #4 \n" - "ext v15.16b, v15.16b, v31.16b, #4 \n" - - // in row0 - "fmul v16.4s, %[w0].4s, v8.4s \n" - "fmul v17.4s, %[w0].4s, v9.4s \n" - "fmul v18.4s, %[w0].4s, v10.4s \n" - "fmul v19.4s, %[w0].4s, v11.4s \n" - - // in row1 - "fmla v16.4s, %[w1].4s, v9.4s \n" - "fmla v17.4s, %[w1].4s, v10.4s \n" - "fmla v18.4s, %[w1].4s, v11.4s \n" - "fmla v19.4s, %[w1].4s, v12.4s \n" - - // in row2 - "fmla v16.4s, %[w2].4s, v10.4s \n" - "fmla v17.4s, %[w2].4s, v11.4s \n" - "fmla v18.4s, %[w2].4s, v12.4s \n" - "fmla v19.4s, %[w2].4s, v13.4s \n" - - // in row3 - "fmla v16.4s, %[w3].4s, v11.4s \n" - "fmla v17.4s, %[w3].4s, v12.4s \n" - "fmla v18.4s, %[w3].4s, v13.4s \n" - "fmla v19.4s, %[w3].4s, v14.4s \n" - - // in row4 - "fmla v16.4s, %[w4].4s, v12.4s \n" - "fmla v17.4s, %[w4].4s, v13.4s \n" - "fmla v18.4s, %[w4].4s, v14.4s \n" - "fmla v19.4s, %[w4].4s, v15.4s \n" - - // add to out register v7 - "faddp v7.4s, v16.4s, v17.4s \n" - "faddp v6.4s, v18.4s, v19.4s \n" - "faddp v7.4s, v7.4s, v6.4s \n" - - // ext input - "ext v8.16b, v8.16b, v31.16b, #4 \n" - "ext v9.16b, v9.16b, v31.16b, #4 \n" - "ext v10.16b, v10.16b, v31.16b, #4 \n" - "ext v11.16b, v11.16b, v31.16b, #4 \n" - "ext v12.16b, v12.16b, v31.16b, #4 \n" - "ext v13.16b, v13.16b, v31.16b, #4 \n" - "ext v14.16b, v14.16b, v31.16b, #4 \n" - "ext v15.16b, v15.16b, v31.16b, #4 \n" - - // in row0 - "fmul v16.4s, %[w0].4s, v8.4s \n" - "fmul v17.4s, %[w0].4s, v9.4s \n" - "fmul v18.4s, %[w0].4s, v10.4s \n" - "fmul v19.4s, %[w0].4s, v11.4s \n" - - // in row1 - "fmla v16.4s, %[w1].4s, v9.4s \n" - "fmla v17.4s, %[w1].4s, v10.4s \n" - "fmla v18.4s, %[w1].4s, v11.4s \n" - "fmla v19.4s, %[w1].4s, v12.4s \n" - - // in row2 - "fmla v16.4s, %[w2].4s, v10.4s \n" - "fmla v17.4s, %[w2].4s, v11.4s \n" - "fmla v18.4s, %[w2].4s, v12.4s \n" - "fmla v19.4s, %[w2].4s, v13.4s \n" - - // in row3 - "fmla v16.4s, %[w3].4s, v11.4s \n" - "fmla v17.4s, %[w3].4s, v12.4s \n" - "fmla v18.4s, %[w3].4s, v13.4s \n" - "fmla v19.4s, %[w3].4s, v14.4s \n" - - // in row4 - "fmla v16.4s, %[w4].4s, v12.4s \n" - "fmla v17.4s, %[w4].4s, v13.4s \n" - "fmla v18.4s, %[w4].4s, v14.4s \n" - "fmla v19.4s, %[w4].4s, v15.4s \n" - - // add to out register v25 - "faddp v25.4s, v16.4s, v17.4s \n" - "faddp v26.4s, v18.4s, v19.4s \n" - "faddp v25.4s, v25.4s, v26.4s \n" - "fadd v25.4s, v25.4s, v20.4s \n" - "fmax v25.4s, v25.4s, v31.4s \n" - - // zip - "zip1 v6.4s, v5.4s, v7.4s \n" - "zip2 v8.4s, v5.4s, v7.4s \n" - - // add bias - "fadd v6.4s, v6.4s, v20.4s \n" - "fadd v8.4s, v8.4s, v20.4s \n" - - // relu - "fmax v6.4s, v6.4s, v31.4s \n" - "fmax v8.4s, v8.4s, v31.4s \n" - - "ext v7.16b, v6.16b, v31.16b, #8 \n" - "ext v9.16b, v8.16b, v31.16b, #8 \n" - - // write output - "str d6, [%[dout0]], #8 \n" - "str d7, [%[dout1]], #8 \n" - "str d8, [%[dout2]], #8 \n" - "str d9, [%[dout3]], #8 \n" - - "st1 {v25.s}[0], [%[dout0]] \n" - "st1 {v25.s}[1], [%[dout1]] \n" - "st1 {v25.s}[2], [%[dout2]] \n" - "st1 {v25.s}[3], [%[dout3]] \n" - - : [dout0] "+r"(dout0), - [dout1] "+r"(dout1), - [dout2] "+r"(dout2), - [dout3] "+r"(dout3), - [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [din6] "+r"(din6), - [din7] "+r"(din7) - : [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [bias] "r"(bias) - : "memory", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v25", - "v26", - "v31"); -} - -//! kernel for four out with extracting data pre -//! deal with four lines out -//! need extra load weights -void compute_four_out_extract_pre(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - const float* din6, - const float* din7, - float* dout0, - float* dout1, - float* dout2, - float* dout3, - const float* weights, - const float* bias) { - //! din0 - din7: 0-4 v8-v15 - //! dout0 - dout3: v0-v3 - //! weights: v0-v4, v5, v6 - asm volatile( - // load weights - "movi v31.4s, #0 \n" - "mov x0, #20 \n" - "add %[wh], %[wh], #4 \n" - "ldr q0, [%[wh]], #20 \n" // 1, 2, 3, 4 - "ldr q1, [%[wh]], #20 \n" // 6, 7, 8, 9 - "ldr q2, [%[wh]], #20 \n" // 11, 12, 13, 14 - "ldr q3, [%[wh]], #20 \n" // 16, 17, 18, 19 - "ldr q4, [%[wh]] \n" // 21, 22, 23, 24 - "sub %[wh], %[wh], #68 \n" - - // load inputs - "ld1 {v8.4s}, [%[din0]] \n" - "ld1 {v9.4s}, [%[din1]] \n" - "ld1 {v10.4s}, [%[din2]] \n" - "ld1 {v11.4s}, [%[din3]] \n" - "ld1 {v12.4s}, [%[din4]] \n" - "ld1 {v13.4s}, [%[din5]] \n" - - // in row0 - "fmul v16.4s, v0.4s, v8.4s \n" - "fmul v17.4s, v0.4s, v9.4s \n" - "fmul v18.4s, v0.4s, v10.4s \n" - "fmul v19.4s, v0.4s, v11.4s \n" - - "ld1 {v14.4s}, [%[din6]] \n" - "ld1 {v15.4s}, [%[din7]] \n" - - // in row1 - "fmla v16.4s, v1.4s, v9.4s \n" - "fmla v17.4s, v1.4s, v10.4s \n" - "fmla v18.4s, v1.4s, v11.4s \n" - "fmla v19.4s, v1.4s, v12.4s \n" - - // in row2 - "fmla v16.4s, v2.4s, v10.4s \n" - "fmla v17.4s, v2.4s, v11.4s \n" - "fmla v18.4s, v2.4s, v12.4s \n" - "fmla v19.4s, v2.4s, v13.4s \n" - - // in row3 - "fmla v16.4s, v3.4s, v11.4s \n" - "fmla v17.4s, v3.4s, v12.4s \n" - "fmla v18.4s, v3.4s, v13.4s \n" - "fmla v19.4s, v3.4s, v14.4s \n" - - // in row4 - "fmla v16.4s, v4.4s, v12.4s \n" - "fmla v17.4s, v4.4s, v13.4s \n" - "fmla v18.4s, v4.4s, v14.4s \n" - "fmla v19.4s, v4.4s, v15.4s \n" - - // add to out register v25 - "faddp v25.4s, v16.4s, v17.4s \n" - "faddp v26.4s, v18.4s, v19.4s \n" - "faddp v25.4s, v25.4s, v26.4s \n" - - // load weights col5 - "ld1 {v5.s}[0], [%[wh]], x0 \n" - "ld1 {v5.s}[1], [%[wh]], x0 \n" - "ld1 {v5.s}[2], [%[wh]], x0 \n" - "ld1 {v5.s}[3], [%[wh]], x0 \n" - "ld1 {v6.s}[0], [%[wh]] \n" - - // ext weights - "ext v0.16b, v0.16b, v31.16b, #4 \n" // 2, 3, 4 - "ext v1.16b, v1.16b, v31.16b, #4 \n" // 7, 8, 9 - "ext v2.16b, v2.16b, v31.16b, #4 \n" // 12, 13, 14 - "ext v3.16b, v3.16b, v31.16b, #4 \n" // 17, 18, 19 - "ext v4.16b, v4.16b, v31.16b, #4 \n" // 22, 23, 24 - - // in row0 - "fmul v16.4s, v0.4s, v8.4s \n" - "fmul v17.4s, v0.4s, v9.4s \n" - "fmul v18.4s, v0.4s, v10.4s \n" - "fmul v19.4s, v0.4s, v11.4s \n" - - // in row1 - "fmla v16.4s, v1.4s, v9.4s \n" - "fmla v17.4s, v1.4s, v10.4s \n" - "fmla v18.4s, v1.4s, v11.4s \n" - "fmla v19.4s, v1.4s, v12.4s \n" - - // in row2 - "fmla v16.4s, v2.4s, v10.4s \n" - "fmla v17.4s, v2.4s, v11.4s \n" - "fmla v18.4s, v2.4s, v12.4s \n" - "fmla v19.4s, v2.4s, v13.4s \n" - - // in row3 - "fmla v16.4s, v3.4s, v11.4s \n" - "fmla v17.4s, v3.4s, v12.4s \n" - "fmla v18.4s, v3.4s, v13.4s \n" - "fmla v19.4s, v3.4s, v14.4s \n" - - // in row4 - "fmla v16.4s, v4.4s, v12.4s \n" - "fmla v17.4s, v4.4s, v13.4s \n" - "fmla v18.4s, v4.4s, v14.4s \n" - "fmla v19.4s, v4.4s, v15.4s \n" - - // add to out register v27 - "faddp v27.4s, v16.4s, v17.4s \n" - "faddp v26.4s, v18.4s, v19.4s \n" - "faddp v27.4s, v27.4s, v26.4s \n" - - // load in col5 - "ld1 {v20.s}[0], [%[din0]] \n" - "ld1 {v20.s}[1], [%[din1]] \n" - "ld1 {v20.s}[2], [%[din2]] \n" - "ld1 {v20.s}[3], [%[din3]] \n" - - // ext weights - "ext v0.16b, v0.16b, v31.16b, #4 \n" // 3, 4 - "ext v1.16b, v1.16b, v31.16b, #4 \n" // 8, 9 - "ext v2.16b, v2.16b, v31.16b, #4 \n" // 13, 14 - "ext v3.16b, v3.16b, v31.16b, #4 \n" // 18, 19 - "ext v4.16b, v4.16b, v31.16b, #4 \n" // 23, 24 - - "ld1 {v21.s}[0], [%[din4]] \n" - "ld1 {v21.s}[1], [%[din5]] \n" - "ld1 {v21.s}[2], [%[din6]] \n" - "ld1 {v21.s}[3], [%[din7]] \n" - - // in row0 - "fmul v16.4s, v0.4s, v8.4s \n" - "fmul v17.4s, v0.4s, v9.4s \n" - "fmul v18.4s, v0.4s, v10.4s \n" - "fmul v19.4s, v0.4s, v11.4s \n" - - // in row1 - "fmla v16.4s, v1.4s, v9.4s \n" - "fmla v17.4s, v1.4s, v10.4s \n" - "fmla v18.4s, v1.4s, v11.4s \n" - "fmla v19.4s, v1.4s, v12.4s \n" - - // in row2 - "fmla v16.4s, v2.4s, v10.4s \n" - "fmla v17.4s, v2.4s, v11.4s \n" - "fmla v18.4s, v2.4s, v12.4s \n" - "fmla v19.4s, v2.4s, v13.4s \n" - - // in row3 - "fmla v16.4s, v3.4s, v11.4s \n" - "fmla v17.4s, v3.4s, v12.4s \n" - "fmla v18.4s, v3.4s, v13.4s \n" - "fmla v19.4s, v3.4s, v14.4s \n" - - // in row4 - "fmla v16.4s, v4.4s, v12.4s \n" - "fmla v17.4s, v4.4s, v13.4s \n" - "fmla v18.4s, v4.4s, v14.4s \n" - "fmla v19.4s, v4.4s, v15.4s \n" - - // add to out register v26 - "faddp v26.4s, v16.4s, v17.4s \n" - "faddp v28.4s, v18.4s, v19.4s \n" - "faddp v26.4s, v26.4s, v28.4s \n" - - // ext input col5 - "ext v22.16b, v20.16b, v21.16b, #4 \n" - "ext v23.16b, v20.16b, v21.16b, #8 \n" - "ext v24.16b, v20.16b, v21.16b, #12 \n" - - // in col5 - "fmul v16.4s, v5.4s, v20.4s \n" - "fmul v17.4s, v5.4s, v22.4s \n" - "fmul v18.4s, v5.4s, v23.4s \n" - "fmul v19.4s, v5.4s, v24.4s \n" - - // add to out register v28 - "faddp v28.4s, v16.4s, v17.4s \n" - "faddp v29.4s, v18.4s, v19.4s \n" - "faddp v28.4s, v28.4s, v29.4s \n" - "fmla v28.4s, v21.4s, v6.s[0] \n" - - "ld1 {v8.4s}, [%[bias]] \n" - - // zip - "zip1 v0.4s, v28.4s, v26.4s \n" - "zip2 v2.4s, v28.4s, v26.4s \n" - "zip1 v4.4s, v27.4s, v25.4s \n" - "zip2 v6.4s, v27.4s, v25.4s \n" - - "fadd v0.4s, v0.4s, v8.4s \n" - "fadd v2.4s, v2.4s, v8.4s \n" - "fadd v4.4s, v4.4s, v8.4s \n" - "fadd v6.4s, v6.4s, v8.4s \n" - - "ext v1.16b, v0.16b, v31.16b, #8 \n" - "ext v3.16b, v2.16b, v31.16b, #8 \n" - "ext v5.16b, v4.16b, v31.16b, #8 \n" - "ext v7.16b, v6.16b, v31.16b, #8 \n" - - // write output - "str d0, [%[dout0]], #8 \n" - "str d1, [%[dout1]], #8 \n" - "str d2, [%[dout2]], #8 \n" - "str d3, [%[dout3]], #8 \n" - - "str d4, [%[dout0]] \n" - "str d5, [%[dout1]] \n" - "str d6, [%[dout2]] \n" - "str d7, [%[dout3]] \n" - - : [dout0] "+r"(dout0), - [dout1] "+r"(dout1), - [dout2] "+r"(dout2), - [dout3] "+r"(dout3), - [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [din6] "+r"(din6), - [din7] "+r"(din7), - [wh] "+r"(weights) - : [bias] "r"(bias) - : "memory", - "x0", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25", - "v26", - "v27", - "v28", - "v29", - "v31"); -} - -//! kernel for four out with extracting data pre -//! deal with four lines out -//! need extra load weights -void compute_four_out_extract_pre_relu(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - const float* din6, - const float* din7, - float* dout0, - float* dout1, - float* dout2, - float* dout3, - const float* weights, - const float* bias) { - //! din0 - din7: 0-4 v8-v15 - //! dout0 - dout3: v0-v3 - //! weights: v0-v4, v5, v6 - asm volatile( - // load weights - "movi v31.4s, #0 \n" - "mov x0, #20 \n" - "add %[wh], %[wh], #4 \n" - "ldr q0, [%[wh]], #20 \n" // 1, 2, 3, 4 - "ldr q1, [%[wh]], #20 \n" // 6, 7, 8, 9 - "ldr q2, [%[wh]], #20 \n" // 11, 12, 13, 14 - "ldr q3, [%[wh]], #20 \n" // 16, 17, 18, 19 - "ldr q4, [%[wh]] \n" // 21, 22, 23, 24 - "sub %[wh], %[wh], #68 \n" - - // load inputs - "ld1 {v8.4s}, [%[din0]] \n" - "ld1 {v9.4s}, [%[din1]] \n" - "ld1 {v10.4s}, [%[din2]] \n" - "ld1 {v11.4s}, [%[din3]] \n" - "ld1 {v12.4s}, [%[din4]] \n" - "ld1 {v13.4s}, [%[din5]] \n" - - // in row0 - "fmul v16.4s, v0.4s, v8.4s \n" - "fmul v17.4s, v0.4s, v9.4s \n" - "fmul v18.4s, v0.4s, v10.4s \n" - "fmul v19.4s, v0.4s, v11.4s \n" - - "ld1 {v14.4s}, [%[din6]] \n" - "ld1 {v15.4s}, [%[din7]] \n" - - // in row1 - "fmla v16.4s, v1.4s, v9.4s \n" - "fmla v17.4s, v1.4s, v10.4s \n" - "fmla v18.4s, v1.4s, v11.4s \n" - "fmla v19.4s, v1.4s, v12.4s \n" - - // in row2 - "fmla v16.4s, v2.4s, v10.4s \n" - "fmla v17.4s, v2.4s, v11.4s \n" - "fmla v18.4s, v2.4s, v12.4s \n" - "fmla v19.4s, v2.4s, v13.4s \n" - - // in row3 - "fmla v16.4s, v3.4s, v11.4s \n" - "fmla v17.4s, v3.4s, v12.4s \n" - "fmla v18.4s, v3.4s, v13.4s \n" - "fmla v19.4s, v3.4s, v14.4s \n" - - // in row4 - "fmla v16.4s, v4.4s, v12.4s \n" - "fmla v17.4s, v4.4s, v13.4s \n" - "fmla v18.4s, v4.4s, v14.4s \n" - "fmla v19.4s, v4.4s, v15.4s \n" - - // add to out register v25 - "faddp v25.4s, v16.4s, v17.4s \n" - "faddp v26.4s, v18.4s, v19.4s \n" - "faddp v25.4s, v25.4s, v26.4s \n" - - // load weights col5 - "ld1 {v5.s}[0], [%[wh]], x0 \n" - "ld1 {v5.s}[1], [%[wh]], x0 \n" - "ld1 {v5.s}[2], [%[wh]], x0 \n" - "ld1 {v5.s}[3], [%[wh]], x0 \n" - "ld1 {v6.s}[0], [%[wh]] \n" - - // ext weights - "ext v0.16b, v0.16b, v31.16b, #4 \n" // 2, 3, 4 - "ext v1.16b, v1.16b, v31.16b, #4 \n" // 7, 8, 9 - "ext v2.16b, v2.16b, v31.16b, #4 \n" // 12, 13, 14 - "ext v3.16b, v3.16b, v31.16b, #4 \n" // 17, 18, 19 - "ext v4.16b, v4.16b, v31.16b, #4 \n" // 22, 23, 24 - - // in row0 - "fmul v16.4s, v0.4s, v8.4s \n" - "fmul v17.4s, v0.4s, v9.4s \n" - "fmul v18.4s, v0.4s, v10.4s \n" - "fmul v19.4s, v0.4s, v11.4s \n" - - // in row1 - "fmla v16.4s, v1.4s, v9.4s \n" - "fmla v17.4s, v1.4s, v10.4s \n" - "fmla v18.4s, v1.4s, v11.4s \n" - "fmla v19.4s, v1.4s, v12.4s \n" - - // in row2 - "fmla v16.4s, v2.4s, v10.4s \n" - "fmla v17.4s, v2.4s, v11.4s \n" - "fmla v18.4s, v2.4s, v12.4s \n" - "fmla v19.4s, v2.4s, v13.4s \n" - - // in row3 - "fmla v16.4s, v3.4s, v11.4s \n" - "fmla v17.4s, v3.4s, v12.4s \n" - "fmla v18.4s, v3.4s, v13.4s \n" - "fmla v19.4s, v3.4s, v14.4s \n" - - // in row4 - "fmla v16.4s, v4.4s, v12.4s \n" - "fmla v17.4s, v4.4s, v13.4s \n" - "fmla v18.4s, v4.4s, v14.4s \n" - "fmla v19.4s, v4.4s, v15.4s \n" - - // add to out register v27 - "faddp v27.4s, v16.4s, v17.4s \n" - "faddp v26.4s, v18.4s, v19.4s \n" - "faddp v27.4s, v27.4s, v26.4s \n" - - // load in col5 - "ld1 {v20.s}[0], [%[din0]] \n" - "ld1 {v20.s}[1], [%[din1]] \n" - "ld1 {v20.s}[2], [%[din2]] \n" - "ld1 {v20.s}[3], [%[din3]] \n" - - // ext weights - "ext v0.16b, v0.16b, v31.16b, #4 \n" // 3, 4 - "ext v1.16b, v1.16b, v31.16b, #4 \n" // 8, 9 - "ext v2.16b, v2.16b, v31.16b, #4 \n" // 13, 14 - "ext v3.16b, v3.16b, v31.16b, #4 \n" // 18, 19 - "ext v4.16b, v4.16b, v31.16b, #4 \n" // 23, 24 - - "ld1 {v21.s}[0], [%[din4]] \n" - "ld1 {v21.s}[1], [%[din5]] \n" - "ld1 {v21.s}[2], [%[din6]] \n" - "ld1 {v21.s}[3], [%[din7]] \n" - - // in row0 - "fmul v16.4s, v0.4s, v8.4s \n" - "fmul v17.4s, v0.4s, v9.4s \n" - "fmul v18.4s, v0.4s, v10.4s \n" - "fmul v19.4s, v0.4s, v11.4s \n" - - // in row1 - "fmla v16.4s, v1.4s, v9.4s \n" - "fmla v17.4s, v1.4s, v10.4s \n" - "fmla v18.4s, v1.4s, v11.4s \n" - "fmla v19.4s, v1.4s, v12.4s \n" - - // in row2 - "fmla v16.4s, v2.4s, v10.4s \n" - "fmla v17.4s, v2.4s, v11.4s \n" - "fmla v18.4s, v2.4s, v12.4s \n" - "fmla v19.4s, v2.4s, v13.4s \n" - - // in row3 - "fmla v16.4s, v3.4s, v11.4s \n" - "fmla v17.4s, v3.4s, v12.4s \n" - "fmla v18.4s, v3.4s, v13.4s \n" - "fmla v19.4s, v3.4s, v14.4s \n" - - // in row4 - "fmla v16.4s, v4.4s, v12.4s \n" - "fmla v17.4s, v4.4s, v13.4s \n" - "fmla v18.4s, v4.4s, v14.4s \n" - "fmla v19.4s, v4.4s, v15.4s \n" - - // add to out register v26 - "faddp v26.4s, v16.4s, v17.4s \n" - "faddp v28.4s, v18.4s, v19.4s \n" - "faddp v26.4s, v26.4s, v28.4s \n" - - // ext input col5 - "ext v22.16b, v20.16b, v21.16b, #4 \n" - "ext v23.16b, v20.16b, v21.16b, #8 \n" - "ext v24.16b, v20.16b, v21.16b, #12 \n" - - // in col5 - "fmul v16.4s, v5.4s, v20.4s \n" - "fmul v17.4s, v5.4s, v22.4s \n" - "fmul v18.4s, v5.4s, v23.4s \n" - "fmul v19.4s, v5.4s, v24.4s \n" - - // add to out register v28 - "faddp v28.4s, v16.4s, v17.4s \n" - "faddp v29.4s, v18.4s, v19.4s \n" - "faddp v28.4s, v28.4s, v29.4s \n" - "fmla v28.4s, v21.4s, v6.s[0] \n" - - "ld1 {v8.4s}, [%[bias]] \n" - - // zip - "zip1 v0.4s, v28.4s, v26.4s \n" - "zip2 v2.4s, v28.4s, v26.4s \n" - "zip1 v4.4s, v27.4s, v25.4s \n" - "zip2 v6.4s, v27.4s, v25.4s \n" - - // add bias - "fadd v0.4s, v0.4s, v8.4s \n" - "fadd v2.4s, v2.4s, v8.4s \n" - "fadd v4.4s, v4.4s, v8.4s \n" - "fadd v6.4s, v6.4s, v8.4s \n" - - // relu - "fmax v0.4s, v0.4s, v31.4s \n" - "fmax v2.4s, v2.4s, v31.4s \n" - "fmax v4.4s, v4.4s, v31.4s \n" - "fmax v6.4s, v6.4s, v31.4s \n" - - "ext v1.16b, v0.16b, v31.16b, #8 \n" - "ext v3.16b, v2.16b, v31.16b, #8 \n" - "ext v5.16b, v4.16b, v31.16b, #8 \n" - "ext v7.16b, v6.16b, v31.16b, #8 \n" - - // write output - "str d0, [%[dout0]], #8 \n" - "str d1, [%[dout1]], #8 \n" - "str d2, [%[dout2]], #8 \n" - "str d3, [%[dout3]], #8 \n" - - "str d4, [%[dout0]] \n" - "str d5, [%[dout1]] \n" - "str d6, [%[dout2]] \n" - "str d7, [%[dout3]] \n" - - : [dout0] "+r"(dout0), - [dout1] "+r"(dout1), - [dout2] "+r"(dout2), - [dout3] "+r"(dout3), - [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [din6] "+r"(din6), - [din7] "+r"(din7), - [wh] "+r"(weights) - : [bias] "r"(bias) - : "memory", - "x0", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25", - "v26", - "v27", - "v28", - "v29", - "v31"); -} - -//! kernel for four out with extracting data post -//! deal with four lines out -void compute_four_out_extract_post(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - const float* din6, - const float* din7, - float* dout0, - float* dout1, - float* dout2, - float* dout3, - float32x4_t w0, - float32x4_t w1, - float32x4_t w2, - float32x4_t w3, - float32x4_t w4, - const float* bias) { - //! din0 - din7: 0-4 v8-v15 - //! dout0 - dout3: v0-v3 - const int64_t s_12 = 12; - const float* doutl[4] = {dout0, dout1, dout2, dout3}; - void* doutl_ptr = reinterpret_cast(doutl); - asm volatile( - "movi v31.4s, #0 \n" - "ldp x0, x1, [%[doutl]], #16 \n" - "ldp x2, x3, [%[doutl]] \n" - - // load inputs - "ld1 {v8.4s}, [%[din0]], %[s_12] \n" - "ld1 {v9.4s}, [%[din1]], %[s_12] \n" - "ld1 {v10.4s}, [%[din2]], %[s_12] \n" - "ld1 {v11.4s}, [%[din3]], %[s_12] \n" - "ld1 {v12.4s}, [%[din4]], %[s_12] \n" - "ld1 {v13.4s}, [%[din5]], %[s_12] \n" - - // in row0 - "fmul v16.4s, %[w0].4s, v8.4s \n" - "fmul v17.4s, %[w0].4s, v9.4s \n" - "fmul v18.4s, %[w0].4s, v10.4s \n" - "fmul v19.4s, %[w0].4s, v11.4s \n" - - "ld1 {v14.4s}, [%[din6]], %[s_12] \n" - "ld1 {v15.4s}, [%[din7]], %[s_12] \n" - - // in row1 - "fmla v16.4s, %[w1].4s, v9.4s \n" - "fmla v17.4s, %[w1].4s, v10.4s \n" - "fmla v18.4s, %[w1].4s, v11.4s \n" - "fmla v19.4s, %[w1].4s, v12.4s \n" - - // in row2 - "fmla v16.4s, %[w2].4s, v10.4s \n" - "fmla v17.4s, %[w2].4s, v11.4s \n" - "fmla v18.4s, %[w2].4s, v12.4s \n" - "fmla v19.4s, %[w2].4s, v13.4s \n" - - // in row3 - "fmla v16.4s, %[w3].4s, v11.4s \n" - "fmla v17.4s, %[w3].4s, v12.4s \n" - "fmla v18.4s, %[w3].4s, v13.4s \n" - "fmla v19.4s, %[w3].4s, v14.4s \n" - - // in row4 - "fmla v16.4s, %[w4].4s, v12.4s \n" - "fmla v17.4s, %[w4].4s, v13.4s \n" - "fmla v18.4s, %[w4].4s, v14.4s \n" - "fmla v19.4s, %[w4].4s, v15.4s \n" - - // add to out register v25 - "faddp v25.4s, v16.4s, v17.4s \n" - "faddp v26.4s, v18.4s, v19.4s \n" - "faddp v25.4s, v25.4s, v26.4s \n" - - // load input col5 - "ld1 {v20.s}[0], [%[din0]] \n" - "ld1 {v20.s}[1], [%[din1]] \n" - "ld1 {v20.s}[2], [%[din2]] \n" - "ld1 {v20.s}[3], [%[din3]] \n" - - // ext input - "ext v8.16b, v8.16b, v31.16b, #4 \n" - "ext v9.16b, v9.16b, v31.16b, #4 \n" - "ext v10.16b, v10.16b, v31.16b, #4 \n" - "ext v11.16b, v11.16b, v31.16b, #4 \n" - "ext v12.16b, v12.16b, v31.16b, #4 \n" - "ext v13.16b, v13.16b, v31.16b, #4 \n" - "ext v14.16b, v14.16b, v31.16b, #4 \n" - "ext v15.16b, v15.16b, v31.16b, #4 \n" - - // load input col5 - "ld1 {v21.s}[0], [%[din4]] \n" - "ld1 {v21.s}[1], [%[din5]] \n" - "ld1 {v21.s}[2], [%[din6]] \n" - "ld1 {v21.s}[3], [%[din7]] \n" - - // in row0 - "fmul v16.4s, %[w0].4s, v8.4s \n" - "fmul v17.4s, %[w0].4s, v9.4s \n" - "fmul v18.4s, %[w0].4s, v10.4s \n" - "fmul v19.4s, %[w0].4s, v11.4s \n" - - // in row1 - "fmla v16.4s, %[w1].4s, v9.4s \n" - "fmla v17.4s, %[w1].4s, v10.4s \n" - "fmla v18.4s, %[w1].4s, v11.4s \n" - "fmla v19.4s, %[w1].4s, v12.4s \n" - - // in row2 - "fmla v16.4s, %[w2].4s, v10.4s \n" - "fmla v17.4s, %[w2].4s, v11.4s \n" - "fmla v18.4s, %[w2].4s, v12.4s \n" - "fmla v19.4s, %[w2].4s, v13.4s \n" - - // in row3 - "fmla v16.4s, %[w3].4s, v11.4s \n" - "fmla v17.4s, %[w3].4s, v12.4s \n" - "fmla v18.4s, %[w3].4s, v13.4s \n" - "fmla v19.4s, %[w3].4s, v14.4s \n" - - // in row4 - "fmla v16.4s, %[w4].4s, v12.4s \n" - "fmla v17.4s, %[w4].4s, v13.4s \n" - "fmla v18.4s, %[w4].4s, v14.4s \n" - "fmla v19.4s, %[w4].4s, v15.4s \n" - - // add to out register v27 - "faddp v27.4s, v16.4s, v17.4s \n" - "faddp v26.4s, v18.4s, v19.4s \n" - "faddp v27.4s, v27.4s, v26.4s \n" - - // ext input - "ext v8.16b, v8.16b, v31.16b, #4 \n" - "ext v9.16b, v9.16b, v31.16b, #4 \n" - "ext v10.16b, v10.16b, v31.16b, #4 \n" - "ext v11.16b, v11.16b, v31.16b, #4 \n" - "ext v12.16b, v12.16b, v31.16b, #4 \n" - "ext v13.16b, v13.16b, v31.16b, #4 \n" - "ext v14.16b, v14.16b, v31.16b, #4 \n" - "ext v15.16b, v15.16b, v31.16b, #4 \n" - - // in row0 - "fmul v16.4s, %[w0].4s, v8.4s \n" - "fmul v17.4s, %[w0].4s, v9.4s \n" - "fmul v18.4s, %[w0].4s, v10.4s \n" - "fmul v19.4s, %[w0].4s, v11.4s \n" - - // in row1 - "fmla v16.4s, %[w1].4s, v9.4s \n" - "fmla v17.4s, %[w1].4s, v10.4s \n" - "fmla v18.4s, %[w1].4s, v11.4s \n" - "fmla v19.4s, %[w1].4s, v12.4s \n" - - // in row2 - "fmla v16.4s, %[w2].4s, v10.4s \n" - "fmla v17.4s, %[w2].4s, v11.4s \n" - "fmla v18.4s, %[w2].4s, v12.4s \n" - "fmla v19.4s, %[w2].4s, v13.4s \n" - - // in row3 - "fmla v16.4s, %[w3].4s, v11.4s \n" - "fmla v17.4s, %[w3].4s, v12.4s \n" - "fmla v18.4s, %[w3].4s, v13.4s \n" - "fmla v19.4s, %[w3].4s, v14.4s \n" - - // in row4 - "fmla v16.4s, %[w4].4s, v12.4s \n" - "fmla v17.4s, %[w4].4s, v13.4s \n" - "fmla v18.4s, %[w4].4s, v14.4s \n" - "fmla v19.4s, %[w4].4s, v15.4s \n" - - // add to out register v26 - "faddp v26.4s, v16.4s, v17.4s \n" - "faddp v28.4s, v18.4s, v19.4s \n" - "faddp v26.4s, v26.4s, v28.4s \n" - - // ext input col5 - "ext v8.16b, v20.16b, v21.16b, #4 \n" - "ext v9.16b, v20.16b, v21.16b, #8 \n" - "ext v10.16b, v20.16b, v21.16b, #12 \n" - - // ext weights col0 - "ins v5.s[0], %[w0].s[0] \n" - "ins v5.s[1], %[w1].s[0] \n" - "ins v5.s[2], %[w2].s[0] \n" - "ins v5.s[3], %[w3].s[0] \n" - - // in col5 - "fmul v16.4s, v5.4s, v20.4s \n" - "fmul v17.4s, v5.4s, v8.4s \n" - "fmul v18.4s, v5.4s, v9.4s \n" - "fmul v19.4s, v5.4s, v10.4s \n" - - // add to out register v28 - "faddp v28.4s, v16.4s, v17.4s \n" - "faddp v29.4s, v18.4s, v19.4s \n" - "faddp v28.4s, v28.4s, v29.4s \n" - "fmla v28.4s, v21.4s, %[w4].s[0] \n" - - "ld1 {v8.4s}, [%[bias]] \n" - - // zip - "zip1 v0.4s, v25.4s, v27.4s \n" - "zip2 v2.4s, v25.4s, v27.4s \n" - "zip1 v4.4s, v26.4s, v28.4s \n" - "zip2 v6.4s, v26.4s, v28.4s \n" - - "fadd v0.4s, v0.4s, v8.4s \n" - "fadd v2.4s, v2.4s, v8.4s \n" - "fadd v4.4s, v4.4s, v8.4s \n" - "fadd v6.4s, v6.4s, v8.4s \n" - - "ext v1.16b, v0.16b, v31.16b, #8 \n" - "ext v3.16b, v2.16b, v31.16b, #8 \n" - "ext v5.16b, v4.16b, v31.16b, #8 \n" - "ext v7.16b, v6.16b, v31.16b, #8 \n" - - // write output - "str d0, [x0], #8 \n" - "str d1, [x1], #8 \n" - "str d2, [x2], #8 \n" - "str d3, [x3], #8 \n" - - "str d4, [x0] \n" - "str d5, [x1] \n" - "str d6, [x2] \n" - "str d7, [x3] \n" - - : [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [din6] "+r"(din6), - [din7] "+r"(din7), - [doutl] "+r"(doutl_ptr) - : [s_12] "r"(s_12), - [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [bias] "r"(bias) - : "memory", - "x0", - "x1", - "x2", - "x3", - "v0", - "v1", - "v2", - "v3", - "v5", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v25", - "v26", - "v27", - "v28", - "v29", - "v31"); -} - -//! kernel for four out with extracting data post -//! deal with four lines out -void compute_four_out_extract_post_relu(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - const float* din6, - const float* din7, - float* dout0, - float* dout1, - float* dout2, - float* dout3, - float32x4_t w0, - float32x4_t w1, - float32x4_t w2, - float32x4_t w3, - float32x4_t w4, - const float* bias) { - //! din0 - din7: 0-4 v8-v15 - //! dout0 - dout3: v0-v3 - const int64_t s_12 = 12; - const float* doutl[4] = {dout0, dout1, dout2, dout3}; - void* doutl_ptr = reinterpret_cast(doutl); - asm volatile( - "movi v31.4s, #0 \n" - "ldp x0, x1, [%[doutl]], #16 \n" - "ldp x2, x3, [%[doutl]] \n" - - // load inputs - "ld1 {v8.4s}, [%[din0]], %[s_12] \n" - "ld1 {v9.4s}, [%[din1]], %[s_12] \n" - "ld1 {v10.4s}, [%[din2]], %[s_12] \n" - "ld1 {v11.4s}, [%[din3]], %[s_12] \n" - "ld1 {v12.4s}, [%[din4]], %[s_12] \n" - "ld1 {v13.4s}, [%[din5]], %[s_12] \n" - - // in row0 - "fmul v16.4s, %[w0].4s, v8.4s \n" - "fmul v17.4s, %[w0].4s, v9.4s \n" - "fmul v18.4s, %[w0].4s, v10.4s \n" - "fmul v19.4s, %[w0].4s, v11.4s \n" - - "ld1 {v14.4s}, [%[din6]], %[s_12] \n" - "ld1 {v15.4s}, [%[din7]], %[s_12] \n" - - // in row1 - "fmla v16.4s, %[w1].4s, v9.4s \n" - "fmla v17.4s, %[w1].4s, v10.4s \n" - "fmla v18.4s, %[w1].4s, v11.4s \n" - "fmla v19.4s, %[w1].4s, v12.4s \n" - - // in row2 - "fmla v16.4s, %[w2].4s, v10.4s \n" - "fmla v17.4s, %[w2].4s, v11.4s \n" - "fmla v18.4s, %[w2].4s, v12.4s \n" - "fmla v19.4s, %[w2].4s, v13.4s \n" - - // in row3 - "fmla v16.4s, %[w3].4s, v11.4s \n" - "fmla v17.4s, %[w3].4s, v12.4s \n" - "fmla v18.4s, %[w3].4s, v13.4s \n" - "fmla v19.4s, %[w3].4s, v14.4s \n" - - // in row4 - "fmla v16.4s, %[w4].4s, v12.4s \n" - "fmla v17.4s, %[w4].4s, v13.4s \n" - "fmla v18.4s, %[w4].4s, v14.4s \n" - "fmla v19.4s, %[w4].4s, v15.4s \n" - - // add to out register v25 - "faddp v25.4s, v16.4s, v17.4s \n" - "faddp v26.4s, v18.4s, v19.4s \n" - "faddp v25.4s, v25.4s, v26.4s \n" - - // load input col5 - "ld1 {v20.s}[0], [%[din0]] \n" - "ld1 {v20.s}[1], [%[din1]] \n" - "ld1 {v20.s}[2], [%[din2]] \n" - "ld1 {v20.s}[3], [%[din3]] \n" - - // ext input - "ext v8.16b, v8.16b, v31.16b, #4 \n" - "ext v9.16b, v9.16b, v31.16b, #4 \n" - "ext v10.16b, v10.16b, v31.16b, #4 \n" - "ext v11.16b, v11.16b, v31.16b, #4 \n" - "ext v12.16b, v12.16b, v31.16b, #4 \n" - "ext v13.16b, v13.16b, v31.16b, #4 \n" - "ext v14.16b, v14.16b, v31.16b, #4 \n" - "ext v15.16b, v15.16b, v31.16b, #4 \n" - - // load input col5 - "ld1 {v21.s}[0], [%[din4]] \n" - "ld1 {v21.s}[1], [%[din5]] \n" - "ld1 {v21.s}[2], [%[din6]] \n" - "ld1 {v21.s}[3], [%[din7]] \n" - - // in row0 - "fmul v16.4s, %[w0].4s, v8.4s \n" - "fmul v17.4s, %[w0].4s, v9.4s \n" - "fmul v18.4s, %[w0].4s, v10.4s \n" - "fmul v19.4s, %[w0].4s, v11.4s \n" - - // in row1 - "fmla v16.4s, %[w1].4s, v9.4s \n" - "fmla v17.4s, %[w1].4s, v10.4s \n" - "fmla v18.4s, %[w1].4s, v11.4s \n" - "fmla v19.4s, %[w1].4s, v12.4s \n" - - // in row2 - "fmla v16.4s, %[w2].4s, v10.4s \n" - "fmla v17.4s, %[w2].4s, v11.4s \n" - "fmla v18.4s, %[w2].4s, v12.4s \n" - "fmla v19.4s, %[w2].4s, v13.4s \n" - - // in row3 - "fmla v16.4s, %[w3].4s, v11.4s \n" - "fmla v17.4s, %[w3].4s, v12.4s \n" - "fmla v18.4s, %[w3].4s, v13.4s \n" - "fmla v19.4s, %[w3].4s, v14.4s \n" - - // in row4 - "fmla v16.4s, %[w4].4s, v12.4s \n" - "fmla v17.4s, %[w4].4s, v13.4s \n" - "fmla v18.4s, %[w4].4s, v14.4s \n" - "fmla v19.4s, %[w4].4s, v15.4s \n" - - // add to out register v27 - "faddp v27.4s, v16.4s, v17.4s \n" - "faddp v26.4s, v18.4s, v19.4s \n" - "faddp v27.4s, v27.4s, v26.4s \n" - - // ext input - "ext v8.16b, v8.16b, v31.16b, #4 \n" - "ext v9.16b, v9.16b, v31.16b, #4 \n" - "ext v10.16b, v10.16b, v31.16b, #4 \n" - "ext v11.16b, v11.16b, v31.16b, #4 \n" - "ext v12.16b, v12.16b, v31.16b, #4 \n" - "ext v13.16b, v13.16b, v31.16b, #4 \n" - "ext v14.16b, v14.16b, v31.16b, #4 \n" - "ext v15.16b, v15.16b, v31.16b, #4 \n" - - // in row0 - "fmul v16.4s, %[w0].4s, v8.4s \n" - "fmul v17.4s, %[w0].4s, v9.4s \n" - "fmul v18.4s, %[w0].4s, v10.4s \n" - "fmul v19.4s, %[w0].4s, v11.4s \n" - - // in row1 - "fmla v16.4s, %[w1].4s, v9.4s \n" - "fmla v17.4s, %[w1].4s, v10.4s \n" - "fmla v18.4s, %[w1].4s, v11.4s \n" - "fmla v19.4s, %[w1].4s, v12.4s \n" - - // in row2 - "fmla v16.4s, %[w2].4s, v10.4s \n" - "fmla v17.4s, %[w2].4s, v11.4s \n" - "fmla v18.4s, %[w2].4s, v12.4s \n" - "fmla v19.4s, %[w2].4s, v13.4s \n" - - // in row3 - "fmla v16.4s, %[w3].4s, v11.4s \n" - "fmla v17.4s, %[w3].4s, v12.4s \n" - "fmla v18.4s, %[w3].4s, v13.4s \n" - "fmla v19.4s, %[w3].4s, v14.4s \n" - - // in row4 - "fmla v16.4s, %[w4].4s, v12.4s \n" - "fmla v17.4s, %[w4].4s, v13.4s \n" - "fmla v18.4s, %[w4].4s, v14.4s \n" - "fmla v19.4s, %[w4].4s, v15.4s \n" - - // add to out register v26 - "faddp v26.4s, v16.4s, v17.4s \n" - "faddp v28.4s, v18.4s, v19.4s \n" - "faddp v26.4s, v26.4s, v28.4s \n" - - // ext input col5 - "ext v8.16b, v20.16b, v21.16b, #4 \n" - "ext v9.16b, v20.16b, v21.16b, #8 \n" - "ext v10.16b, v20.16b, v21.16b, #12 \n" - - // ext weights col0 - "ins v5.s[0], %[w0].s[0] \n" - "ins v5.s[1], %[w1].s[0] \n" - "ins v5.s[2], %[w2].s[0] \n" - "ins v5.s[3], %[w3].s[0] \n" - - // in col5 - "fmul v16.4s, v5.4s, v20.4s \n" - "fmul v17.4s, v5.4s, v8.4s \n" - "fmul v18.4s, v5.4s, v9.4s \n" - "fmul v19.4s, v5.4s, v10.4s \n" - - // add to out register v28 - "faddp v28.4s, v16.4s, v17.4s \n" - "faddp v29.4s, v18.4s, v19.4s \n" - "faddp v28.4s, v28.4s, v29.4s \n" - "fmla v28.4s, v21.4s, %[w4].s[0] \n" - - "ld1 {v8.4s}, [%[bias]] \n" - - // zip - "zip1 v0.4s, v25.4s, v27.4s \n" - "zip2 v2.4s, v25.4s, v27.4s \n" - "zip1 v4.4s, v26.4s, v28.4s \n" - "zip2 v6.4s, v26.4s, v28.4s \n" - - // add bias - "fadd v0.4s, v0.4s, v8.4s \n" - "fadd v2.4s, v2.4s, v8.4s \n" - "fadd v4.4s, v4.4s, v8.4s \n" - "fadd v6.4s, v6.4s, v8.4s \n" - - // relu - "fmax v0.4s, v0.4s, v31.4s \n" - "fmax v2.4s, v2.4s, v31.4s \n" - "fmax v4.4s, v4.4s, v31.4s \n" - "fmax v6.4s, v6.4s, v31.4s \n" - - "ext v1.16b, v0.16b, v31.16b, #8 \n" - "ext v3.16b, v2.16b, v31.16b, #8 \n" - "ext v5.16b, v4.16b, v31.16b, #8 \n" - "ext v7.16b, v6.16b, v31.16b, #8 \n" - - // write output - "str d0, [x0], #8 \n" - "str d1, [x1], #8 \n" - "str d2, [x2], #8 \n" - "str d3, [x3], #8 \n" - - "str d4, [x0] \n" - "str d5, [x1] \n" - "str d6, [x2] \n" - "str d7, [x3] \n" - - : [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [din6] "+r"(din6), - [din7] "+r"(din7), - [doutl] "+r"(doutl_ptr) - : [s_12] "r"(s_12), - [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [bias] "r"(bias) - : "memory", - "x0", - "x1", - "x2", - "x3", - "v0", - "v1", - "v2", - "v3", - "v5", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v25", - "v26", - "v27", - "v28", - "v29", - "v31"); -} - -void conv_depthwise_5x5s1_impl(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - int pad, - bool flag_bias, - bool flag_relu, - ARMContext* ctx) { - float* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - float* write_ptr = zero_ptr + w_in; - int pad_new = pad > 4 ? 4 : pad; - int pad_0 = pad - pad_new; - int h_out_new = h_out - 2 * pad_0; - int mid_out = w_out - 2 * pad; - int mid_cnt = mid_out >> 2; - int mid_remain = mid_out - (mid_cnt << 2); - int pad_cnt = pad_0 >> 2; - int pad_remain = pad_0 - (pad_cnt << 2); - int bias_cnt = (w_out * pad_0) >> 2; - int bias_remain = (w_out * pad_0) - (bias_cnt << 2); - int in_spatial_size = w_in * h_in; - int out_spatial_size = w_out * h_out; - int weights_saptial_size = 25; - - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * in_spatial_size * ch_in; - float* dout_batch = dout + n * out_spatial_size * ch_out; -#pragma omp parallel for - for (int c = 0; c < ch_in; ++c) { - const float* din_ch = din_batch + c * in_spatial_size; - float* dout_ch = dout_batch + c * out_spatial_size; - float bias_c = flag_bias ? bias[c] : 0.f; - float vbias[4] = {bias_c, bias_c, bias_c, bias_c}; - float32x4_t vbias_c = vdupq_n_f32(bias_c); - if (flag_bias) { - //! deal with h_out pad_0 line with bias - for (int i = 0; i < bias_cnt; ++i) { - vst1q_f32(dout_ch, vbias_c); - dout_ch += 4; - } - for (int i = 0; i < bias_remain; ++i) { - *dout_ch++ = bias_c; - } - } else { - //! deal with h_out pad_0 line without bias - for (int i = 0; i < pad_0; ++i) { - memset(dout_ch, 0x00, w_out * sizeof(float)); - dout_ch += w_out; - } - } - const float* din_list[8]; - const float* dinl[8]; - //! set din ptr with zero buffer - for (int i = 0; i < pad_new; ++i) { - din_list[i] = zero_ptr; - } - //! set din ptr with input data - for (int i = pad_new; i < 8; ++i) { - din_list[i] = din_ch; - din_ch += w_in; - } - - //! every h loop, deal with 4 line output - float* dout0 = dout_ch; - float* dout1 = dout0 + w_out; - float* dout2 = dout1 + w_out; - float* dout3 = dout2 + w_out; - - //! load weights to neon register - const float* weights_c = weights + c * weights_saptial_size; - - float32x4_t w5; - float32x4_t w6; - float32x4_t w0 = vld1q_f32(weights_c); - float32x4_t w1 = vld1q_f32(weights_c + 5); - float32x4_t w2 = vld1q_f32(weights_c + 10); - float32x4_t w3 = vld1q_f32(weights_c + 15); - float32x4_t w4 = vld1q_f32(weights_c + 20); - w5 = vsetq_lane_f32(weights_c[4], w5, 0); - w5 = vsetq_lane_f32(weights_c[9], w5, 1); - w5 = vsetq_lane_f32(weights_c[14], w5, 2); - w5 = vsetq_lane_f32(weights_c[19], w5, 3); - w6 = vsetq_lane_f32(weights_c[24], w6, 0); - - //! h loop - for (int h = 0; h < h_out_new; h += 4) { - //! (h - pad_new) + 7 > h_in - 1 - if (h + 8 - pad_new > h_in) { - switch (h + 8 - pad_new - h_in) { - case 7: - din_list[1] = zero_ptr; - case 6: - din_list[2] = zero_ptr; - case 5: - din_list[3] = zero_ptr; - case 4: - din_list[4] = zero_ptr; - case 3: - din_list[5] = zero_ptr; - case 2: - din_list[6] = zero_ptr; - case 1: - din_list[7] = zero_ptr; - default: - break; - } - } - if (h + 4 > h_out_new) { - switch (h + 4 - h_out_new) { - case 3: - dout1 = write_ptr; - case 2: - dout2 = write_ptr; - case 1: - dout3 = write_ptr; - default: - break; - } - } - - //! every h loop, deal with 8 line input - dinl[0] = din_list[0]; - dinl[1] = din_list[1]; - dinl[2] = din_list[2]; - dinl[3] = din_list[3]; - dinl[4] = din_list[4]; - dinl[5] = din_list[5]; - dinl[6] = din_list[6]; - dinl[7] = din_list[7]; - - const float* weights_ptr = weights_c; - float* dout_ptr0 = dout0; - float* dout_ptr1 = dout1; - float* dout_ptr2 = dout2; - float* dout_ptr3 = dout3; - if (flag_bias) { - //! deal with w_out pad_0 column pre with bias - for (int i = 0; i < pad_cnt; i++) { - vst1q_f32(dout_ptr0, vbias_c); - vst1q_f32(dout_ptr1, vbias_c); - vst1q_f32(dout_ptr2, vbias_c); - vst1q_f32(dout_ptr3, vbias_c); - dout_ptr0 += 4; - dout_ptr1 += 4; - dout_ptr2 += 4; - dout_ptr3 += 4; - } - for (int i = 0; i < pad_remain; ++i) { - *dout_ptr0++ = bias_c; - *dout_ptr1++ = bias_c; - *dout_ptr2++ = bias_c; - *dout_ptr3++ = bias_c; - } - } else { - //! deal with w_out pad_0 column pre without bias - memset(dout_ptr0, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr1, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr2, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr3, 0x00, pad_0 * sizeof(float)); - dout_ptr0 += pad_0; - dout_ptr1 += pad_0; - dout_ptr2 += pad_0; - dout_ptr3 += pad_0; - } - //! deal with w_out pad_new column pre - switch (pad_new) { - case 4: - compute_four_out_extract_pre(dinl[0], - dinl[1], - dinl[2], - dinl[3], - dinl[4], - dinl[5], - dinl[6], - dinl[7], - dout_ptr0, - dout_ptr1, - dout_ptr2, - dout_ptr3, - weights_ptr, - vbias); - dout_ptr0 += 4; - dout_ptr1 += 4; - dout_ptr2 += 4; - dout_ptr3 += 4; - break; - case 3: - compute_three_out_extract_pre(dinl[0], - dinl[1], - dinl[2], - dinl[3], - dinl[4], - dinl[5], - dinl[6], - dinl[7], - dout_ptr0, - dout_ptr1, - dout_ptr2, - dout_ptr3, - weights_ptr, - vbias); - dout_ptr0 += 3; - dout_ptr1 += 3; - dout_ptr2 += 3; - dout_ptr3 += 3; - break; - case 2: - compute_two_out_extract_pre(dinl[0], - dinl[1], - dinl[2], - dinl[3], - dinl[4], - dinl[5], - dinl[6], - dinl[7], - dout_ptr0, - dout_ptr1, - dout_ptr2, - dout_ptr3, - weights_ptr, - vbias); - dout_ptr0 += 2; - dout_ptr1 += 2; - dout_ptr2 += 2; - dout_ptr3 += 2; - break; - case 1: - compute_one_out_extract_pre(dinl[0], - dinl[1], - dinl[2], - dinl[3], - dinl[4], - dinl[5], - dinl[6], - dinl[7], - dout_ptr0, - dout_ptr1, - dout_ptr2, - dout_ptr3, - weights_ptr, - vbias); - dout_ptr0 += 1; - dout_ptr1 += 1; - dout_ptr2 += 1; - dout_ptr3 += 1; - break; - } - //! mid loop - if (mid_cnt > 0) { - void* dinl_ptr = reinterpret_cast(dinl); - int mid_loop = mid_cnt; - asm volatile( - //! din: v7-v14 - //! dout: v15-v18 - "mov x0, #0 \n" - "mov x1, #4 \n" - "ldp x2, x3, [%[dinl]], #16 \n" - "ldp x4, x5, [%[dinl]], #16 \n" - "ldp x6, x7, [%[dinl]], #16 \n" - "ldp x8, x9, [%[dinl]], #16 \n" - - "ld1 {v7.4s} , [x2], x1 \n" - "ld1 {v8.4s} , [x3], x1 \n" - "ld1 {v9.4s} , [x4], x1 \n" - "ld1 {v10.4s}, [x5], x1 \n" - "ld1 {v11.4s}, [x6], x1 \n" - "ld1 {v12.4s}, [x7], x1 \n" - "ld1 {v13.4s}, [x8], x1 \n" - "ld1 {v14.4s}, [x9], x1 \n" - - //! load bias - "ld1 {v19.4s}, [%[bias]] \n" - - "1: \n" - //! add bias to output - "mov v15.16b, v19.16b \n" - "mov v16.16b, v19.16b \n" - "mov v17.16b, v19.16b \n" - "mov v18.16b, v19.16b \n" - - //! loop cnt is even, prefetch 64 Byte to l1 cache - "cmp x0, #1 \n" - "bne 2f \n" - "mov x0, #0 \n" - "prfm pldl1keep, [x2] \n" - "prfm pldl1keep, [x3] \n" - "prfm pldl1keep, [x4] \n" - "prfm pldl1keep, [x5] \n" - "prfm pldl1keep, [x6] \n" - "prfm pldl1keep, [x7] \n" - "prfm pldl1keep, [x8] \n" - "prfm pldl1keep, [x9] \n" - - "2: \n" - // weights col 0 - "fmla v15.4s, v7.4s , %[w0].s[0] \n" - "fmla v16.4s, v8.4s , %[w0].s[0] \n" - "fmla v17.4s, v9.4s , %[w0].s[0] \n" - "fmla v18.4s, v10.4s, %[w0].s[0] \n" - - "fmla v15.4s, v8.4s , %[w1].s[0] \n" - "fmla v16.4s, v9.4s , %[w1].s[0] \n" - "fmla v17.4s, v10.4s, %[w1].s[0] \n" - "fmla v18.4s, v11.4s, %[w1].s[0] \n" - - "ld1 {v7.4s}, [x2], x1 \n" - "ld1 {v8.4s}, [x3], x1 \n" - - "fmla v15.4s, v9.4s , %[w2].s[0] \n" - "fmla v16.4s, v10.4s, %[w2].s[0] \n" - "fmla v17.4s, v11.4s, %[w2].s[0] \n" - "fmla v18.4s, v12.4s, %[w2].s[0] \n" - - "fmla v15.4s, v10.4s, %[w3].s[0] \n" - "fmla v16.4s, v11.4s, %[w3].s[0] \n" - "fmla v17.4s, v12.4s, %[w3].s[0] \n" - "fmla v18.4s, v13.4s, %[w3].s[0] \n" - - "ld1 {v9.4s} , [x4], x1 \n" - "ld1 {v10.4s}, [x5], x1 \n" - - "fmla v15.4s, v11.4s, %[w4].s[0] \n" - "fmla v16.4s, v12.4s, %[w4].s[0] \n" - "fmla v17.4s, v13.4s, %[w4].s[0] \n" - "fmla v18.4s, v14.4s, %[w4].s[0] \n" - - "ld1 {v11.4s}, [x6], x1 \n" - "ld1 {v12.4s}, [x7], x1 \n" - - // weights col 1 - "fmla v15.4s, v7.4s , %[w0].s[1] \n" - "fmla v16.4s, v8.4s , %[w0].s[1] \n" - "fmla v17.4s, v9.4s , %[w0].s[1] \n" - "fmla v18.4s, v10.4s, %[w0].s[1] \n" - - "ld1 {v13.4s}, [x8], x1 \n" - "ld1 {v14.4s}, [x9], x1 \n" - - "fmla v15.4s, v8.4s , %[w1].s[1] \n" - "fmla v16.4s, v9.4s , %[w1].s[1] \n" - "fmla v17.4s, v10.4s, %[w1].s[1] \n" - "fmla v18.4s, v11.4s, %[w1].s[1] \n" - - "ld1 {v7.4s}, [x2], x1 \n" - "ld1 {v8.4s}, [x3], x1 \n" - - "fmla v15.4s, v9.4s , %[w2].s[1] \n" - "fmla v16.4s, v10.4s, %[w2].s[1] \n" - "fmla v17.4s, v11.4s, %[w2].s[1] \n" - "fmla v18.4s, v12.4s, %[w2].s[1] \n" - - "fmla v15.4s, v10.4s, %[w3].s[1] \n" - "fmla v16.4s, v11.4s, %[w3].s[1] \n" - "fmla v17.4s, v12.4s, %[w3].s[1] \n" - "fmla v18.4s, v13.4s, %[w3].s[1] \n" - - "ld1 {v9.4s} , [x4], x1 \n" - "ld1 {v10.4s}, [x5], x1 \n" - - "fmla v15.4s, v11.4s, %[w4].s[1] \n" - "fmla v16.4s, v12.4s, %[w4].s[1] \n" - "fmla v17.4s, v13.4s, %[w4].s[1] \n" - "fmla v18.4s, v14.4s, %[w4].s[1] \n" - - "ld1 {v11.4s}, [x6], x1 \n" - "ld1 {v12.4s}, [x7], x1 \n" - - // weights col 2 - "fmla v15.4s, v7.4s , %[w0].s[2] \n" - "fmla v16.4s, v8.4s , %[w0].s[2] \n" - "fmla v17.4s, v9.4s , %[w0].s[2] \n" - "fmla v18.4s, v10.4s, %[w0].s[2] \n" - - "ld1 {v13.4s}, [x8], x1 \n" - "ld1 {v14.4s}, [x9], x1 \n" - - "fmla v15.4s, v8.4s , %[w1].s[2] \n" - "fmla v16.4s, v9.4s , %[w1].s[2] \n" - "fmla v17.4s, v10.4s, %[w1].s[2] \n" - "fmla v18.4s, v11.4s, %[w1].s[2] \n" - - "ld1 {v7.4s}, [x2], x1 \n" - "ld1 {v8.4s}, [x3], x1 \n" - - "fmla v15.4s, v9.4s , %[w2].s[2] \n" - "fmla v16.4s, v10.4s, %[w2].s[2] \n" - "fmla v17.4s, v11.4s, %[w2].s[2] \n" - "fmla v18.4s, v12.4s, %[w2].s[2] \n" - - "fmla v15.4s, v10.4s, %[w3].s[2] \n" - "fmla v16.4s, v11.4s, %[w3].s[2] \n" - "fmla v17.4s, v12.4s, %[w3].s[2] \n" - "fmla v18.4s, v13.4s, %[w3].s[2] \n" - - "ld1 {v9.4s} , [x4], x1 \n" - "ld1 {v10.4s}, [x5], x1 \n" - - "fmla v15.4s, v11.4s, %[w4].s[2] \n" - "fmla v16.4s, v12.4s, %[w4].s[2] \n" - "fmla v17.4s, v13.4s, %[w4].s[2] \n" - "fmla v18.4s, v14.4s, %[w4].s[2] \n" - - "ld1 {v11.4s}, [x6], x1 \n" - "ld1 {v12.4s}, [x7], x1 \n" - - // weights col 3 - "fmla v15.4s, v7.4s , %[w0].s[3] \n" - "fmla v16.4s, v8.4s , %[w0].s[3] \n" - "fmla v17.4s, v9.4s , %[w0].s[3] \n" - "fmla v18.4s, v10.4s, %[w0].s[3] \n" - - "ld1 {v13.4s}, [x8], x1 \n" - "ld1 {v14.4s}, [x9], x1 \n" - - "fmla v15.4s, v8.4s , %[w1].s[3] \n" - "fmla v16.4s, v9.4s , %[w1].s[3] \n" - "fmla v17.4s, v10.4s, %[w1].s[3] \n" - "fmla v18.4s, v11.4s, %[w1].s[3] \n" - - "ld1 {v7.4s}, [x2], x1 \n" - "ld1 {v8.4s}, [x3], x1 \n" - - "fmla v15.4s, v9.4s , %[w2].s[3] \n" - "fmla v16.4s, v10.4s, %[w2].s[3] \n" - "fmla v17.4s, v11.4s, %[w2].s[3] \n" - "fmla v18.4s, v12.4s, %[w2].s[3] \n" - - "fmla v15.4s, v10.4s, %[w3].s[3] \n" - "fmla v16.4s, v11.4s, %[w3].s[3] \n" - "fmla v17.4s, v12.4s, %[w3].s[3] \n" - "fmla v18.4s, v13.4s, %[w3].s[3] \n" - - "ld1 {v9.4s} , [x4], x1 \n" - "ld1 {v10.4s}, [x5], x1 \n" - - "fmla v15.4s, v11.4s, %[w4].s[3] \n" - "fmla v16.4s, v12.4s, %[w4].s[3] \n" - "fmla v17.4s, v13.4s, %[w4].s[3] \n" - "fmla v18.4s, v14.4s, %[w4].s[3] \n" - - "ld1 {v11.4s}, [x6], x1 \n" - "ld1 {v12.4s}, [x7], x1 \n" - - // weights col 4 - "fmla v15.4s, v7.4s, %[w5].s[0] \n" - "fmla v16.4s, v8.4s, %[w5].s[0] \n" - "fmla v17.4s, v9.4s, %[w5].s[0] \n" - "fmla v18.4s, v10.4s, %[w5].s[0] \n" - - "ld1 {v13.4s}, [x8], x1 \n" - "ld1 {v14.4s}, [x9], x1 \n" - - "fmla v15.4s, v8.4s, %[w5].s[1] \n" - "fmla v16.4s, v9.4s, %[w5].s[1] \n" - "fmla v17.4s, v10.4s, %[w5].s[1] \n" - "fmla v18.4s, v11.4s, %[w5].s[1] \n" - - "fmla v15.4s, v9.4s , %[w5].s[2] \n" - "fmla v16.4s, v10.4s, %[w5].s[2] \n" - "fmla v17.4s, v11.4s, %[w5].s[2] \n" - "fmla v18.4s, v12.4s, %[w5].s[2] \n" - - "fmla v15.4s, v10.4s, %[w5].s[3] \n" - "fmla v16.4s, v11.4s, %[w5].s[3] \n" - "fmla v17.4s, v12.4s, %[w5].s[3] \n" - "fmla v18.4s, v13.4s, %[w5].s[3] \n" - - "fmla v15.4s, v11.4s, %[w6].s[0] \n" - "fmla v16.4s, v12.4s, %[w6].s[0] \n" - "fmla v17.4s, v13.4s, %[w6].s[0] \n" - "fmla v18.4s, v14.4s, %[w6].s[0] \n" - - "st1 {v15.4s}, [%[dout0]], #16 \n" - "st1 {v16.4s}, [%[dout1]], #16 \n" - "st1 {v17.4s}, [%[dout2]], #16 \n" - "st1 {v18.4s}, [%[dout3]], #16 \n" - - "subs %w[cnt], %w[cnt], #1 \n" - "add x0, x0, #1 \n" - "bne 1b \n" - - : [dout0] "+r"(dout_ptr0), - [dout1] "+r"(dout_ptr1), - [dout2] "+r"(dout_ptr2), - [dout3] "+r"(dout_ptr3), - [cnt] "+r"(mid_loop), - [dinl] "+r"(dinl_ptr) - : [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [w5] "w"(w5), - [w6] "w"(w6), - [bias] "r"(vbias) - : "cc", - "memory", - "x0", - "x1", - "x2", - "x3", - "x4", - "x5", - "x6", - "x7", - "x8", - "x9", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19"); - } - dinl[0] += 4 * mid_cnt; - dinl[1] += 4 * mid_cnt; - dinl[2] += 4 * mid_cnt; - dinl[3] += 4 * mid_cnt; - dinl[4] += 4 * mid_cnt; - dinl[5] += 4 * mid_cnt; - dinl[6] += 4 * mid_cnt; - dinl[7] += 4 * mid_cnt; - //! deal with mid remain - for (int i = 0; i < mid_remain; ++i) { - compute_one_out_without_extract(dinl[0], - dinl[1], - dinl[2], - dinl[3], - dinl[4], - dinl[5], - dinl[6], - dinl[7], - dout_ptr0, - dout_ptr1, - dout_ptr2, - dout_ptr3, - w0, - w1, - w2, - w3, - w4, - w5, - w6, - vbias); - dinl[0]++; - dinl[1]++; - dinl[2]++; - dinl[3]++; - dinl[4]++; - dinl[5]++; - dinl[6]++; - dinl[7]++; - - dout_ptr0++; - dout_ptr1++; - dout_ptr2++; - dout_ptr3++; - } - //! deal with w_out pad_new column post - switch (pad_new) { - case 4: - compute_four_out_extract_post(dinl[0], - dinl[1], - dinl[2], - dinl[3], - dinl[4], - dinl[5], - dinl[6], - dinl[7], - dout_ptr0, - dout_ptr1, - dout_ptr2, - dout_ptr3, - w0, - w1, - w2, - w3, - w4, - vbias); - dout_ptr0 += 4; - dout_ptr1 += 4; - dout_ptr2 += 4; - dout_ptr3 += 4; - break; - case 3: - compute_three_out_extract_post(dinl[0], - dinl[1], - dinl[2], - dinl[3], - dinl[4], - dinl[5], - dinl[6], - dinl[7], - dout_ptr0, - dout_ptr1, - dout_ptr2, - dout_ptr3, - w0, - w1, - w2, - w3, - w4, - vbias); - dout_ptr0 += 3; - dout_ptr1 += 3; - dout_ptr2 += 3; - dout_ptr3 += 3; - break; - case 2: - compute_two_out_extract_post(dinl[0], - dinl[1], - dinl[2], - dinl[3], - dinl[4], - dinl[5], - dinl[6], - dinl[7], - dout_ptr0, - dout_ptr1, - dout_ptr2, - dout_ptr3, - w0, - w1, - w2, - w3, - w4, - vbias); - dout_ptr0 += 2; - dout_ptr1 += 2; - dout_ptr2 += 2; - dout_ptr3 += 2; - break; - case 1: - compute_one_out_extract_post(dinl[0], - dinl[1], - dinl[2], - dinl[3], - dinl[4], - dinl[5], - dinl[6], - dinl[7], - dout_ptr0, - dout_ptr1, - dout_ptr2, - dout_ptr3, - w0, - w1, - w2, - w3, - w4, - vbias); - dout_ptr0 += 1; - dout_ptr1 += 1; - dout_ptr2 += 1; - dout_ptr3 += 1; - break; - } - - if (flag_bias) { - //! deal with w_out pad_0 column post with bias - memcpy(dout_ptr0, dout0, pad_0 * sizeof(float)); - memcpy(dout_ptr1, dout1, pad_0 * sizeof(float)); - memcpy(dout_ptr2, dout2, pad_0 * sizeof(float)); - memcpy(dout_ptr3, dout3, pad_0 * sizeof(float)); - } else { - //! deal with w_out pad_0 column post without bias - memset(dout_ptr0, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr1, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr2, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr3, 0x00, pad_0 * sizeof(float)); - } - - din_list[0] = din_list[4]; - din_list[1] = din_list[5]; - din_list[2] = din_list[6]; - din_list[3] = din_list[7]; - din_list[4] = din_list[3] + w_in; - din_list[5] = din_list[4] + w_in; - din_list[6] = din_list[5] + w_in; - din_list[7] = din_list[6] + w_in; - - dout0 = dout3 + w_out; - dout1 = dout0 + w_out; - dout2 = dout1 + w_out; - dout3 = dout2 + w_out; - } - float* dout_pad_end = dout_ch + h_out_new * w_out; - if (flag_bias) { - //! deal with h_out pad_0 line with bias - memcpy(reinterpret_cast(dout_pad_end), - dout_ch - pad_0 * w_out, - pad_0 * w_out * sizeof(float)); - } else { - //! deal with h_out pad_0 line without bias - memset(reinterpret_cast(dout_pad_end), - 0x00, - pad_0 * w_out * sizeof(float)); - } - } - } -} - -void conv_depthwise_5x5s1_relu_impl(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - int pad, - bool flag_bias, - bool flag_relu, - ARMContext* ctx) { - float* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - float* write_ptr = zero_ptr + w_in; - int pad_new = pad > 4 ? 4 : pad; - int pad_0 = pad - pad_new; - int h_out_new = h_out - 2 * pad_0; - int mid_out = w_out - 2 * pad; - int mid_cnt = mid_out >> 2; - int mid_remain = mid_out - (mid_cnt << 2); - int pad_cnt = pad_0 >> 2; - int pad_remain = pad_0 - (pad_cnt << 2); - int bias_cnt = (w_out * pad_0) >> 2; - int bias_remain = (w_out * pad_0) - (bias_cnt << 2); - int in_spatial_size = w_in * h_in; - int out_spatial_size = w_out * h_out; - int weights_saptial_size = 25; - - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * in_spatial_size * ch_in; - float* dout_batch = dout + n * out_spatial_size * ch_out; -#pragma omp parallel for - for (int c = 0; c < ch_in; ++c) { - const float* din_ch = din_batch + c * in_spatial_size; - float* dout_ch = dout_batch + c * out_spatial_size; - float bias_c = flag_bias ? bias[c] : 0.f; - float bias_relu = bias_c > 0.f ? bias_c : 0.f; - float vbias[4] = {bias_c, bias_c, bias_c, bias_c}; - float32x4_t vbias_c = vdupq_n_f32(bias_relu); - if (flag_bias) { - //! deal with h_out pad_0 line with bias - for (int i = 0; i < bias_cnt; ++i) { - vst1q_f32(dout_ch, vbias_c); - dout_ch += 4; - } - for (int i = 0; i < bias_remain; ++i) { - *dout_ch++ = bias_relu; - } - } else { - //! deal with h_out pad_0 line without bias - for (int i = 0; i < pad_0; ++i) { - memset(dout_ch, 0x00, w_out * sizeof(float)); - dout_ch += w_out; - } - } - const float* din_list[8]; - const float* dinl[8]; - //! set din ptr with zero buffer - for (int i = 0; i < pad_new; ++i) { - din_list[i] = zero_ptr; - } - //! set din ptr with input data - for (int i = pad_new; i < 8; ++i) { - din_list[i] = din_ch; - din_ch += w_in; - } - - //! every h loop, deal with 4 line output - float* dout0 = dout_ch; - float* dout1 = dout0 + w_out; - float* dout2 = dout1 + w_out; - float* dout3 = dout2 + w_out; - - //! load weights to neon register - const float* weights_c = weights + c * weights_saptial_size; - - float32x4_t w5; - float32x4_t w6; - float32x4_t w0 = vld1q_f32(weights_c); - float32x4_t w1 = vld1q_f32(weights_c + 5); - float32x4_t w2 = vld1q_f32(weights_c + 10); - float32x4_t w3 = vld1q_f32(weights_c + 15); - float32x4_t w4 = vld1q_f32(weights_c + 20); - w5 = vsetq_lane_f32(weights_c[4], w5, 0); - w5 = vsetq_lane_f32(weights_c[9], w5, 1); - w5 = vsetq_lane_f32(weights_c[14], w5, 2); - w5 = vsetq_lane_f32(weights_c[19], w5, 3); - w6 = vsetq_lane_f32(weights_c[24], w6, 0); - - //! h loop - for (int h = 0; h < h_out_new; h += 4) { - //! (h - pad_new) + 7 > h_in - 1 - if (h + 8 - pad_new > h_in) { - switch (h + 8 - pad_new - h_in) { - case 7: - din_list[1] = zero_ptr; - case 6: - din_list[2] = zero_ptr; - case 5: - din_list[3] = zero_ptr; - case 4: - din_list[4] = zero_ptr; - case 3: - din_list[5] = zero_ptr; - case 2: - din_list[6] = zero_ptr; - case 1: - din_list[7] = zero_ptr; - default: - break; - } - } - if (h + 4 > h_out_new) { - switch (h + 4 - h_out_new) { - case 3: - dout1 = write_ptr; - case 2: - dout2 = write_ptr; - case 1: - dout3 = write_ptr; - default: - break; - } - } - - //! every h loop, deal with 8 line input - dinl[0] = din_list[0]; - dinl[1] = din_list[1]; - dinl[2] = din_list[2]; - dinl[3] = din_list[3]; - dinl[4] = din_list[4]; - dinl[5] = din_list[5]; - dinl[6] = din_list[6]; - dinl[7] = din_list[7]; - - const float* weights_ptr = weights_c; - float* dout_ptr0 = dout0; - float* dout_ptr1 = dout1; - float* dout_ptr2 = dout2; - float* dout_ptr3 = dout3; - if (flag_bias) { - //! deal with w_out pad_0 column pre with bias - for (int i = 0; i < pad_cnt; i++) { - vst1q_f32(dout_ptr0, vbias_c); - vst1q_f32(dout_ptr1, vbias_c); - vst1q_f32(dout_ptr2, vbias_c); - vst1q_f32(dout_ptr3, vbias_c); - dout_ptr0 += 4; - dout_ptr1 += 4; - dout_ptr2 += 4; - dout_ptr3 += 4; - } - for (int i = 0; i < pad_remain; ++i) { - *dout_ptr0++ = bias_relu; - *dout_ptr1++ = bias_relu; - *dout_ptr2++ = bias_relu; - *dout_ptr3++ = bias_relu; - } - } else { - //! deal with w_out pad_0 column pre without bias - memset(dout_ptr0, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr1, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr2, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr3, 0x00, pad_0 * sizeof(float)); - dout_ptr0 += pad_0; - dout_ptr1 += pad_0; - dout_ptr2 += pad_0; - dout_ptr3 += pad_0; - } - //! deal with w_out pad_new column pre - switch (pad_new) { - case 4: - compute_four_out_extract_pre_relu(dinl[0], - dinl[1], - dinl[2], - dinl[3], - dinl[4], - dinl[5], - dinl[6], - dinl[7], - dout_ptr0, - dout_ptr1, - dout_ptr2, - dout_ptr3, - weights_ptr, - vbias); - dout_ptr0 += 4; - dout_ptr1 += 4; - dout_ptr2 += 4; - dout_ptr3 += 4; - break; - case 3: - compute_three_out_extract_pre_relu(dinl[0], - dinl[1], - dinl[2], - dinl[3], - dinl[4], - dinl[5], - dinl[6], - dinl[7], - dout_ptr0, - dout_ptr1, - dout_ptr2, - dout_ptr3, - weights_ptr, - vbias); - dout_ptr0 += 3; - dout_ptr1 += 3; - dout_ptr2 += 3; - dout_ptr3 += 3; - break; - case 2: - compute_two_out_extract_pre_relu(dinl[0], - dinl[1], - dinl[2], - dinl[3], - dinl[4], - dinl[5], - dinl[6], - dinl[7], - dout_ptr0, - dout_ptr1, - dout_ptr2, - dout_ptr3, - weights_ptr, - vbias); - dout_ptr0 += 2; - dout_ptr1 += 2; - dout_ptr2 += 2; - dout_ptr3 += 2; - break; - case 1: - compute_one_out_extract_pre_relu(dinl[0], - dinl[1], - dinl[2], - dinl[3], - dinl[4], - dinl[5], - dinl[6], - dinl[7], - dout_ptr0, - dout_ptr1, - dout_ptr2, - dout_ptr3, - weights_ptr, - vbias); - dout_ptr0 += 1; - dout_ptr1 += 1; - dout_ptr2 += 1; - dout_ptr3 += 1; - break; - } - //! mid loop - if (mid_cnt > 0) { - void* dinl_ptr = reinterpret_cast(dinl); - int mid_loop = mid_cnt; - asm volatile( - //! din: v7-v14 - //! dout: v15-v18 - "mov x0, #0 \n" - "mov x1, #4 \n" - "movi v31.4s, #0 \n" - "ldp x2, x3, [%[dinl]], #16 \n" - "ldp x4, x5, [%[dinl]], #16 \n" - "ldp x6, x7, [%[dinl]], #16 \n" - "ldp x8, x9, [%[dinl]], #16 \n" - - "ld1 {v7.4s} , [x2], x1 \n" - "ld1 {v8.4s} , [x3], x1 \n" - "ld1 {v9.4s} , [x4], x1 \n" - "ld1 {v10.4s}, [x5], x1 \n" - "ld1 {v11.4s}, [x6], x1 \n" - "ld1 {v12.4s}, [x7], x1 \n" - "ld1 {v13.4s}, [x8], x1 \n" - "ld1 {v14.4s}, [x9], x1 \n" - - //! load bias - "ld1 {v19.4s}, [%[bias]] \n" - - "1: \n" - //! add bias to output - "mov v15.16b, v19.16b \n" - "mov v16.16b, v19.16b \n" - "mov v17.16b, v19.16b \n" - "mov v18.16b, v19.16b \n" - - //! loop cnt is even, prefetch 64 Byte to l1 cache - "cmp x0, #1 \n" - "bne 2f \n" - "mov x0, #0 \n" - "prfm pldl1keep, [x2] \n" - "prfm pldl1keep, [x3] \n" - "prfm pldl1keep, [x4] \n" - "prfm pldl1keep, [x5] \n" - "prfm pldl1keep, [x6] \n" - "prfm pldl1keep, [x7] \n" - "prfm pldl1keep, [x8] \n" - "prfm pldl1keep, [x9] \n" - - "2: \n" - // weights col 0 - "fmla v15.4s, v7.4s , %[w0].s[0] \n" - "fmla v16.4s, v8.4s , %[w0].s[0] \n" - "fmla v17.4s, v9.4s , %[w0].s[0] \n" - "fmla v18.4s, v10.4s, %[w0].s[0] \n" - - "fmla v15.4s, v8.4s , %[w1].s[0] \n" - "fmla v16.4s, v9.4s , %[w1].s[0] \n" - "fmla v17.4s, v10.4s, %[w1].s[0] \n" - "fmla v18.4s, v11.4s, %[w1].s[0] \n" - - "ld1 {v7.4s}, [x2], x1 \n" - "ld1 {v8.4s}, [x3], x1 \n" - - "fmla v15.4s, v9.4s , %[w2].s[0] \n" - "fmla v16.4s, v10.4s, %[w2].s[0] \n" - "fmla v17.4s, v11.4s, %[w2].s[0] \n" - "fmla v18.4s, v12.4s, %[w2].s[0] \n" - - "fmla v15.4s, v10.4s, %[w3].s[0] \n" - "fmla v16.4s, v11.4s, %[w3].s[0] \n" - "fmla v17.4s, v12.4s, %[w3].s[0] \n" - "fmla v18.4s, v13.4s, %[w3].s[0] \n" - - "ld1 {v9.4s} , [x4], x1 \n" - "ld1 {v10.4s}, [x5], x1 \n" - - "fmla v15.4s, v11.4s, %[w4].s[0] \n" - "fmla v16.4s, v12.4s, %[w4].s[0] \n" - "fmla v17.4s, v13.4s, %[w4].s[0] \n" - "fmla v18.4s, v14.4s, %[w4].s[0] \n" - - "ld1 {v11.4s}, [x6], x1 \n" - "ld1 {v12.4s}, [x7], x1 \n" - - // weights col 1 - "fmla v15.4s, v7.4s , %[w0].s[1] \n" - "fmla v16.4s, v8.4s , %[w0].s[1] \n" - "fmla v17.4s, v9.4s , %[w0].s[1] \n" - "fmla v18.4s, v10.4s, %[w0].s[1] \n" - - "ld1 {v13.4s}, [x8], x1 \n" - "ld1 {v14.4s}, [x9], x1 \n" - - "fmla v15.4s, v8.4s , %[w1].s[1] \n" - "fmla v16.4s, v9.4s , %[w1].s[1] \n" - "fmla v17.4s, v10.4s, %[w1].s[1] \n" - "fmla v18.4s, v11.4s, %[w1].s[1] \n" - - "ld1 {v7.4s}, [x2], x1 \n" - "ld1 {v8.4s}, [x3], x1 \n" - - "fmla v15.4s, v9.4s , %[w2].s[1] \n" - "fmla v16.4s, v10.4s, %[w2].s[1] \n" - "fmla v17.4s, v11.4s, %[w2].s[1] \n" - "fmla v18.4s, v12.4s, %[w2].s[1] \n" - - "fmla v15.4s, v10.4s, %[w3].s[1] \n" - "fmla v16.4s, v11.4s, %[w3].s[1] \n" - "fmla v17.4s, v12.4s, %[w3].s[1] \n" - "fmla v18.4s, v13.4s, %[w3].s[1] \n" - - "ld1 {v9.4s} , [x4], x1 \n" - "ld1 {v10.4s}, [x5], x1 \n" - - "fmla v15.4s, v11.4s, %[w4].s[1] \n" - "fmla v16.4s, v12.4s, %[w4].s[1] \n" - "fmla v17.4s, v13.4s, %[w4].s[1] \n" - "fmla v18.4s, v14.4s, %[w4].s[1] \n" - - "ld1 {v11.4s}, [x6], x1 \n" - "ld1 {v12.4s}, [x7], x1 \n" - - // weights col 2 - "fmla v15.4s, v7.4s , %[w0].s[2] \n" - "fmla v16.4s, v8.4s , %[w0].s[2] \n" - "fmla v17.4s, v9.4s , %[w0].s[2] \n" - "fmla v18.4s, v10.4s, %[w0].s[2] \n" - - "ld1 {v13.4s}, [x8], x1 \n" - "ld1 {v14.4s}, [x9], x1 \n" - - "fmla v15.4s, v8.4s , %[w1].s[2] \n" - "fmla v16.4s, v9.4s , %[w1].s[2] \n" - "fmla v17.4s, v10.4s, %[w1].s[2] \n" - "fmla v18.4s, v11.4s, %[w1].s[2] \n" - - "ld1 {v7.4s}, [x2], x1 \n" - "ld1 {v8.4s}, [x3], x1 \n" - - "fmla v15.4s, v9.4s , %[w2].s[2] \n" - "fmla v16.4s, v10.4s, %[w2].s[2] \n" - "fmla v17.4s, v11.4s, %[w2].s[2] \n" - "fmla v18.4s, v12.4s, %[w2].s[2] \n" - - "fmla v15.4s, v10.4s, %[w3].s[2] \n" - "fmla v16.4s, v11.4s, %[w3].s[2] \n" - "fmla v17.4s, v12.4s, %[w3].s[2] \n" - "fmla v18.4s, v13.4s, %[w3].s[2] \n" - - "ld1 {v9.4s} , [x4], x1 \n" - "ld1 {v10.4s}, [x5], x1 \n" - - "fmla v15.4s, v11.4s, %[w4].s[2] \n" - "fmla v16.4s, v12.4s, %[w4].s[2] \n" - "fmla v17.4s, v13.4s, %[w4].s[2] \n" - "fmla v18.4s, v14.4s, %[w4].s[2] \n" - - "ld1 {v11.4s}, [x6], x1 \n" - "ld1 {v12.4s}, [x7], x1 \n" - - // weights col 3 - "fmla v15.4s, v7.4s , %[w0].s[3] \n" - "fmla v16.4s, v8.4s , %[w0].s[3] \n" - "fmla v17.4s, v9.4s , %[w0].s[3] \n" - "fmla v18.4s, v10.4s, %[w0].s[3] \n" - - "ld1 {v13.4s}, [x8], x1 \n" - "ld1 {v14.4s}, [x9], x1 \n" - - "fmla v15.4s, v8.4s , %[w1].s[3] \n" - "fmla v16.4s, v9.4s , %[w1].s[3] \n" - "fmla v17.4s, v10.4s, %[w1].s[3] \n" - "fmla v18.4s, v11.4s, %[w1].s[3] \n" - - "ld1 {v7.4s}, [x2], x1 \n" - "ld1 {v8.4s}, [x3], x1 \n" - - "fmla v15.4s, v9.4s , %[w2].s[3] \n" - "fmla v16.4s, v10.4s, %[w2].s[3] \n" - "fmla v17.4s, v11.4s, %[w2].s[3] \n" - "fmla v18.4s, v12.4s, %[w2].s[3] \n" - - "fmla v15.4s, v10.4s, %[w3].s[3] \n" - "fmla v16.4s, v11.4s, %[w3].s[3] \n" - "fmla v17.4s, v12.4s, %[w3].s[3] \n" - "fmla v18.4s, v13.4s, %[w3].s[3] \n" - - "ld1 {v9.4s} , [x4], x1 \n" - "ld1 {v10.4s}, [x5], x1 \n" - - "fmla v15.4s, v11.4s, %[w4].s[3] \n" - "fmla v16.4s, v12.4s, %[w4].s[3] \n" - "fmla v17.4s, v13.4s, %[w4].s[3] \n" - "fmla v18.4s, v14.4s, %[w4].s[3] \n" - - "ld1 {v11.4s}, [x6], x1 \n" - "ld1 {v12.4s}, [x7], x1 \n" - - // weights col 4 - "fmla v15.4s, v7.4s, %[w5].s[0] \n" - "fmla v16.4s, v8.4s, %[w5].s[0] \n" - "fmla v17.4s, v9.4s, %[w5].s[0] \n" - "fmla v18.4s, v10.4s, %[w5].s[0] \n" - - "ld1 {v13.4s}, [x8], x1 \n" - "ld1 {v14.4s}, [x9], x1 \n" - - "fmla v15.4s, v8.4s, %[w5].s[1] \n" - "fmla v16.4s, v9.4s, %[w5].s[1] \n" - "fmla v17.4s, v10.4s, %[w5].s[1] \n" - "fmla v18.4s, v11.4s, %[w5].s[1] \n" - - "fmla v15.4s, v9.4s , %[w5].s[2] \n" - "fmla v16.4s, v10.4s, %[w5].s[2] \n" - "fmla v17.4s, v11.4s, %[w5].s[2] \n" - "fmla v18.4s, v12.4s, %[w5].s[2] \n" - - "fmla v15.4s, v10.4s, %[w5].s[3] \n" - "fmla v16.4s, v11.4s, %[w5].s[3] \n" - "fmla v17.4s, v12.4s, %[w5].s[3] \n" - "fmla v18.4s, v13.4s, %[w5].s[3] \n" - - "fmla v15.4s, v11.4s, %[w6].s[0] \n" - "fmla v16.4s, v12.4s, %[w6].s[0] \n" - "fmla v17.4s, v13.4s, %[w6].s[0] \n" - "fmla v18.4s, v14.4s, %[w6].s[0] \n" - - "fmax v15.4s, v15.4s, v31.4s \n" - "fmax v16.4s, v16.4s, v31.4s \n" - "fmax v17.4s, v17.4s, v31.4s \n" - "fmax v18.4s, v18.4s, v31.4s \n" - - "st1 {v15.4s}, [%[dout0]], #16 \n" - "st1 {v16.4s}, [%[dout1]], #16 \n" - "st1 {v17.4s}, [%[dout2]], #16 \n" - "st1 {v18.4s}, [%[dout3]], #16 \n" - - "subs %w[cnt], %w[cnt], #1 \n" - "add x0, x0, #1 \n" - "bne 1b \n" - - : [dout0] "+r"(dout_ptr0), - [dout1] "+r"(dout_ptr1), - [dout2] "+r"(dout_ptr2), - [dout3] "+r"(dout_ptr3), - [cnt] "+r"(mid_loop), - [dinl] "+r"(dinl_ptr) - : [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [w5] "w"(w5), - [w6] "w"(w6), - [bias] "r"(vbias) - : "cc", - "memory", - "x0", - "x1", - "x2", - "x3", - "x4", - "x5", - "x6", - "x7", - "x8", - "x9", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v31"); - } - dinl[0] += 4 * mid_cnt; - dinl[1] += 4 * mid_cnt; - dinl[2] += 4 * mid_cnt; - dinl[3] += 4 * mid_cnt; - dinl[4] += 4 * mid_cnt; - dinl[5] += 4 * mid_cnt; - dinl[6] += 4 * mid_cnt; - dinl[7] += 4 * mid_cnt; - //! deal with mid remain - for (int i = 0; i < mid_remain; ++i) { - compute_one_out_without_extract_relu(dinl[0], - dinl[1], - dinl[2], - dinl[3], - dinl[4], - dinl[5], - dinl[6], - dinl[7], - dout_ptr0, - dout_ptr1, - dout_ptr2, - dout_ptr3, - w0, - w1, - w2, - w3, - w4, - w5, - w6, - vbias); - dinl[0]++; - dinl[1]++; - dinl[2]++; - dinl[3]++; - dinl[4]++; - dinl[5]++; - dinl[6]++; - dinl[7]++; - - dout_ptr0++; - dout_ptr1++; - dout_ptr2++; - dout_ptr3++; - } - //! deal with w_out pad_new column post - switch (pad_new) { - case 4: - compute_four_out_extract_post_relu(dinl[0], - dinl[1], - dinl[2], - dinl[3], - dinl[4], - dinl[5], - dinl[6], - dinl[7], - dout_ptr0, - dout_ptr1, - dout_ptr2, - dout_ptr3, - w0, - w1, - w2, - w3, - w4, - vbias); - dout_ptr0 += 4; - dout_ptr1 += 4; - dout_ptr2 += 4; - dout_ptr3 += 4; - break; - case 3: - compute_three_out_extract_post_relu(dinl[0], - dinl[1], - dinl[2], - dinl[3], - dinl[4], - dinl[5], - dinl[6], - dinl[7], - dout_ptr0, - dout_ptr1, - dout_ptr2, - dout_ptr3, - w0, - w1, - w2, - w3, - w4, - vbias); - dout_ptr0 += 3; - dout_ptr1 += 3; - dout_ptr2 += 3; - dout_ptr3 += 3; - break; - case 2: - compute_two_out_extract_post_relu(dinl[0], - dinl[1], - dinl[2], - dinl[3], - dinl[4], - dinl[5], - dinl[6], - dinl[7], - dout_ptr0, - dout_ptr1, - dout_ptr2, - dout_ptr3, - w0, - w1, - w2, - w3, - w4, - vbias); - dout_ptr0 += 2; - dout_ptr1 += 2; - dout_ptr2 += 2; - dout_ptr3 += 2; - break; - case 1: - compute_one_out_extract_post_relu(dinl[0], - dinl[1], - dinl[2], - dinl[3], - dinl[4], - dinl[5], - dinl[6], - dinl[7], - dout_ptr0, - dout_ptr1, - dout_ptr2, - dout_ptr3, - w0, - w1, - w2, - w3, - w4, - vbias); - dout_ptr0 += 1; - dout_ptr1 += 1; - dout_ptr2 += 1; - dout_ptr3 += 1; - break; - } - - if (flag_bias) { - //! deal with w_out pad_0 column post with bias - memcpy(dout_ptr0, dout0, pad_0 * sizeof(float)); - memcpy(dout_ptr1, dout1, pad_0 * sizeof(float)); - memcpy(dout_ptr2, dout2, pad_0 * sizeof(float)); - memcpy(dout_ptr3, dout3, pad_0 * sizeof(float)); - } else { - //! deal with w_out pad_0 column post without bias - memset(dout_ptr0, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr1, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr2, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr3, 0x00, pad_0 * sizeof(float)); - } - - din_list[0] = din_list[4]; - din_list[1] = din_list[5]; - din_list[2] = din_list[6]; - din_list[3] = din_list[7]; - din_list[4] = din_list[3] + w_in; - din_list[5] = din_list[4] + w_in; - din_list[6] = din_list[5] + w_in; - din_list[7] = din_list[6] + w_in; - - dout0 = dout3 + w_out; - dout1 = dout0 + w_out; - dout2 = dout1 + w_out; - dout3 = dout2 + w_out; - } - float* dout_pad_end = dout_ch + h_out_new * w_out; - if (flag_bias) { - //! deal with h_out pad_0 line with bias - memcpy(reinterpret_cast(dout_pad_end), - dout_ch - pad_0 * w_out, - pad_0 * w_out * sizeof(float)); - } else { - //! deal with h_out pad_0 line without bias - memset(reinterpret_cast(dout_pad_end), - 0x00, - pad_0 * w_out * sizeof(float)); - } - } - } -} - -void conv_depthwise_5x5s1_small_impl(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - int pad, - bool flag_bias, - bool flag_relu, - ARMContext* ctx) { - int pad_new = pad > 4 ? 4 : pad; - int pad_0 = pad - pad_new; - int h_in_new = h_in + 2 * pad_new; - int w_in_new = w_in + 2 * pad_new; - int h_out_new = h_out - 2 * pad_0; - int w_out_new = w_out - 2 * pad_0; - float zero_ptr[w_in_new + w_out]; // NOLINT - memset(zero_ptr, 0, w_in_new * sizeof(float)); - float* write_ptr = zero_ptr + w_in_new; - int pad_cnt = pad_0 >> 2; - int pad_remain = pad_0 - (pad_cnt << 2); - int bias_cnt = (w_out * pad_0) >> 2; - int bias_remain = (w_out * pad_0) - (bias_cnt << 2); - int in_spatial_size = w_in_new * h_in_new; - int out_spatial_size = w_out * h_out; - int weights_saptial_size = 25; - - float* din_new = prepad_input(din, num, ch_in, h_in, w_in, pad_new); - for (int n = 0; n < num; ++n) { - const float* din_batch = din_new + n * in_spatial_size * ch_in; - float* dout_batch = dout + n * out_spatial_size * ch_out; -#pragma omp parallel for - for (int c = 0; c < ch_in; ++c) { - const float* din_ch = din_batch + c * in_spatial_size; - float* dout_ch = dout_batch + c * out_spatial_size; - float bias_c = flag_bias ? bias[c] : 0.f; - float vbias[4] = {bias_c, bias_c, bias_c, bias_c}; - float32x4_t vbias_c = vdupq_n_f32(bias_c); - if (flag_bias) { - //! deal with h_out pad_0 line with bias - for (int i = 0; i < bias_cnt; ++i) { - vst1q_f32(dout_ch, vbias_c); - dout_ch += 4; - } - for (int i = 0; i < bias_remain; ++i) { - *dout_ch++ = bias_c; - } - } else { - //! deal with h_out pad_0 line without bias - for (int i = 0; i < pad_0; ++i) { - memset(dout_ch, 0x00, w_out * sizeof(float)); - dout_ch += w_out; - } - } - //! every h loop, deal with 8 line input - const float* din0 = din_ch; - const float* din1 = din0 + w_in_new; - const float* din2 = din1 + w_in_new; - const float* din3 = din2 + w_in_new; - const float* din4 = din3 + w_in_new; - const float* din5 = din4 + w_in_new; - const float* din6 = din5 + w_in_new; - const float* din7 = din6 + w_in_new; - //! every h loop, deal with 4 line output - float* dout0 = dout_ch; - float* dout1 = dout0 + w_out; - float* dout2 = dout1 + w_out; - float* dout3 = dout2 + w_out; - - //! load weights to neon register - const float* weights_c = weights + c * weights_saptial_size; - - float32x4_t w5; - float32x4_t w6; - float32x4_t w0 = vld1q_f32(weights_c); - float32x4_t w1 = vld1q_f32(weights_c + 5); - float32x4_t w2 = vld1q_f32(weights_c + 10); - float32x4_t w3 = vld1q_f32(weights_c + 15); - float32x4_t w4 = vld1q_f32(weights_c + 20); - w5 = vsetq_lane_f32(weights_c[4], w5, 0); - w5 = vsetq_lane_f32(weights_c[9], w5, 1); - w5 = vsetq_lane_f32(weights_c[14], w5, 2); - w5 = vsetq_lane_f32(weights_c[19], w5, 3); - w6 = vsetq_lane_f32(weights_c[24], w6, 0); - //! h loop - for (int h = 0; h < h_out_new; h += 4) { - //! (h - pad_new) + 7 > h_in - 1 - if (h + 8 > h_in_new) { - switch (h + 8 - h_in_new) { - case 7: - din1 = zero_ptr; - case 6: - din2 = zero_ptr; - case 5: - din3 = zero_ptr; - case 4: - din4 = zero_ptr; - case 3: - din5 = zero_ptr; - case 2: - din6 = zero_ptr; - case 1: - din7 = zero_ptr; - default: - break; - } - } - if (h + 4 > h_out_new) { - switch (h + 4 - h_out_new) { - case 3: - dout1 = write_ptr; - case 2: - dout2 = write_ptr; - case 1: - dout3 = write_ptr; - default: - break; - } - } - const float* din_ptr0 = din0; - const float* din_ptr1 = din1; - const float* din_ptr2 = din2; - const float* din_ptr3 = din3; - const float* din_ptr4 = din4; - const float* din_ptr5 = din5; - const float* din_ptr6 = din6; - const float* din_ptr7 = din7; - - const float* weights_ptr = weights_c; - float* dout_ptr0 = dout0; - float* dout_ptr1 = dout1; - float* dout_ptr2 = dout2; - float* dout_ptr3 = dout3; - - if (flag_bias) { - //! deal with w_out pad_0 column pre with bias - for (int i = 0; i < pad_cnt; i++) { - vst1q_f32(dout_ptr0, vbias_c); - vst1q_f32(dout_ptr1, vbias_c); - vst1q_f32(dout_ptr2, vbias_c); - vst1q_f32(dout_ptr3, vbias_c); - dout_ptr0 += 4; - dout_ptr1 += 4; - dout_ptr2 += 4; - dout_ptr3 += 4; - } - for (int i = 0; i < pad_remain; ++i) { - *dout_ptr0++ = bias_c; - *dout_ptr1++ = bias_c; - *dout_ptr2++ = bias_c; - *dout_ptr3++ = bias_c; - } - } else { - //! deal with w_out pad_0 column pre without bias - memset(dout_ptr0, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr1, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr2, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr3, 0x00, pad_0 * sizeof(float)); - dout_ptr0 += pad_0; - dout_ptr1 += pad_0; - dout_ptr2 += pad_0; - dout_ptr3 += pad_0; - } - //! mid loop - for (int i = 0; i < w_out_new; ++i) { - compute_one_out_without_extract(din_ptr0, - din_ptr1, - din_ptr2, - din_ptr3, - din_ptr4, - din_ptr5, - din_ptr6, - din_ptr7, - dout_ptr0, - dout_ptr1, - dout_ptr2, - dout_ptr3, - w0, - w1, - w2, - w3, - w4, - w5, - w6, - vbias); - din_ptr0++; - din_ptr1++; - din_ptr2++; - din_ptr3++; - din_ptr4++; - din_ptr5++; - din_ptr6++; - din_ptr7++; - - dout_ptr0++; - dout_ptr1++; - dout_ptr2++; - dout_ptr3++; - } - if (flag_bias) { - //! deal with w_out pad_0 column post with bias - memcpy(dout_ptr0, dout0, pad_0 * sizeof(float)); - memcpy(dout_ptr1, dout1, pad_0 * sizeof(float)); - memcpy(dout_ptr2, dout2, pad_0 * sizeof(float)); - memcpy(dout_ptr3, dout3, pad_0 * sizeof(float)); - } else { - //! deal with w_out pad_0 column post without bias - memset(dout_ptr0, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr1, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr2, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr3, 0x00, pad_0 * sizeof(float)); - } - - din0 = din4; - din1 = din5; - din2 = din6; - din3 = din7; - din4 = din3 + w_in_new; - din5 = din4 + w_in_new; - din6 = din5 + w_in_new; - din7 = din6 + w_in_new; - - dout0 = dout3 + w_out; - dout1 = dout0 + w_out; - dout2 = dout1 + w_out; - dout3 = dout2 + w_out; - } - float* dout_pad_end = dout_ch + h_out_new * w_out; - if (flag_bias) { - //! deal with h_out pad_0 line with bias - memcpy(reinterpret_cast(dout_pad_end), - dout_ch - pad_0 * w_out, - pad_0 * w_out * sizeof(float)); - } else { - //! deal with h_out pad_0 line without bias - memset(reinterpret_cast(dout_pad_end), - 0x00, - pad_0 * w_out * sizeof(float)); - } - } - } - free(din_new); -} - -void conv_depthwise_5x5s1_small_relu_impl(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - int pad, - bool flag_bias, - bool flag_relu, - ARMContext* ctx) { - int pad_new = pad > 4 ? 4 : pad; - int pad_0 = pad - pad_new; - int h_in_new = h_in + 2 * pad_new; - int w_in_new = w_in + 2 * pad_new; - float zero_ptr[w_in_new + w_out]; // NOLINT - memset(zero_ptr, 0, w_in_new * sizeof(float)); - float* write_ptr = zero_ptr + w_in_new; - int h_out_new = h_out - 2 * pad_0; - int w_out_new = w_out - 2 * pad_0; - int pad_cnt = pad_0 >> 2; - int pad_remain = pad_0 - (pad_cnt << 2); - int bias_cnt = (w_out * pad_0) >> 2; - int bias_remain = (w_out * pad_0) - (bias_cnt << 2); - int in_spatial_size = w_in_new * h_in_new; - int out_spatial_size = w_out * h_out; - int weights_saptial_size = 25; - - float* din_new = prepad_input(din, num, ch_in, h_in, w_in, pad_new); - for (int n = 0; n < num; ++n) { - const float* din_batch = din_new + n * in_spatial_size * ch_in; - float* dout_batch = dout + n * out_spatial_size * ch_out; -#pragma omp parallel for - for (int c = 0; c < ch_in; ++c) { - const float* din_ch = din_batch + c * in_spatial_size; - float* dout_ch = dout_batch + c * out_spatial_size; - float bias_c = flag_bias ? bias[c] : 0.f; - float bias_relu = bias_c > 0.f ? bias_c : 0.f; - float vbias[4] = {bias_c, bias_c, bias_c, bias_c}; - float32x4_t vbias_c = vdupq_n_f32(bias_relu); - if (flag_bias) { - //! deal with h_out pad_0 line with bias - for (int i = 0; i < bias_cnt; ++i) { - vst1q_f32(dout_ch, vbias_c); - dout_ch += 4; - } - for (int i = 0; i < bias_remain; ++i) { - *dout_ch++ = bias_relu; - } - } else { - //! deal with h_out pad_0 line without bias - for (int i = 0; i < pad_0; ++i) { - memset(dout_ch, 0x00, w_out * sizeof(float)); - dout_ch += w_out; - } - } - - //! every h loop, deal with 8 line input - const float* din0 = din_ch; - const float* din1 = din0 + w_in_new; - const float* din2 = din1 + w_in_new; - const float* din3 = din2 + w_in_new; - const float* din4 = din3 + w_in_new; - const float* din5 = din4 + w_in_new; - const float* din6 = din5 + w_in_new; - const float* din7 = din6 + w_in_new; - //! every h loop, deal with 4 line output - float* dout0 = dout_ch; - float* dout1 = dout0 + w_out; - float* dout2 = dout1 + w_out; - float* dout3 = dout2 + w_out; - - //! load weights to neon register - const float* weights_c = weights + c * weights_saptial_size; - - float32x4_t w5; - float32x4_t w6; - float32x4_t w0 = vld1q_f32(weights_c); - float32x4_t w1 = vld1q_f32(weights_c + 5); - float32x4_t w2 = vld1q_f32(weights_c + 10); - float32x4_t w3 = vld1q_f32(weights_c + 15); - float32x4_t w4 = vld1q_f32(weights_c + 20); - w5 = vsetq_lane_f32(weights_c[4], w5, 0); - w5 = vsetq_lane_f32(weights_c[9], w5, 1); - w5 = vsetq_lane_f32(weights_c[14], w5, 2); - w5 = vsetq_lane_f32(weights_c[19], w5, 3); - w6 = vsetq_lane_f32(weights_c[24], w6, 0); - - //! h loop - for (int h = 0; h < h_out_new; h += 4) { - //! (h - pad_new) + 7 > h_in - 1 - if (h + 8 > h_in_new) { - switch (h + 8 - h_in_new) { - case 7: - din1 = zero_ptr; - case 6: - din2 = zero_ptr; - case 5: - din3 = zero_ptr; - case 4: - din4 = zero_ptr; - case 3: - din5 = zero_ptr; - case 2: - din6 = zero_ptr; - case 1: - din7 = zero_ptr; - default: - break; - } - } - if (h + 4 > h_out_new) { - switch (h + 4 - h_out_new) { - case 3: - dout1 = write_ptr; - case 2: - dout2 = write_ptr; - case 1: - dout3 = write_ptr; - default: - break; - } - } - const float* din_ptr0 = din0; - const float* din_ptr1 = din1; - const float* din_ptr2 = din2; - const float* din_ptr3 = din3; - const float* din_ptr4 = din4; - const float* din_ptr5 = din5; - const float* din_ptr6 = din6; - const float* din_ptr7 = din7; - - float* dout_ptr0 = dout0; - float* dout_ptr1 = dout1; - float* dout_ptr2 = dout2; - float* dout_ptr3 = dout3; - - if (flag_bias) { - //! deal with w_out pad_0 column pre with bias - for (int i = 0; i < pad_cnt; i++) { - vst1q_f32(dout_ptr0, vbias_c); - vst1q_f32(dout_ptr1, vbias_c); - vst1q_f32(dout_ptr2, vbias_c); - vst1q_f32(dout_ptr3, vbias_c); - dout_ptr0 += 4; - dout_ptr1 += 4; - dout_ptr2 += 4; - dout_ptr3 += 4; - } - for (int i = 0; i < pad_remain; ++i) { - *dout_ptr0++ = bias_relu; - *dout_ptr1++ = bias_relu; - *dout_ptr2++ = bias_relu; - *dout_ptr3++ = bias_relu; - } - } else { - //! deal with w_out pad_0 column pre without bias - memset(dout_ptr0, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr1, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr2, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr3, 0x00, pad_0 * sizeof(float)); - dout_ptr0 += pad_0; - dout_ptr1 += pad_0; - dout_ptr2 += pad_0; - dout_ptr3 += pad_0; - } - - //! mid loop - for (int i = 0; i < w_out_new; ++i) { - compute_one_out_without_extract_relu(din_ptr0, - din_ptr1, - din_ptr2, - din_ptr3, - din_ptr4, - din_ptr5, - din_ptr6, - din_ptr7, - dout_ptr0, - dout_ptr1, - dout_ptr2, - dout_ptr3, - w0, - w1, - w2, - w3, - w4, - w5, - w6, - vbias); - din_ptr0++; - din_ptr1++; - din_ptr2++; - din_ptr3++; - din_ptr4++; - din_ptr5++; - din_ptr6++; - din_ptr7++; - - dout_ptr0++; - dout_ptr1++; - dout_ptr2++; - dout_ptr3++; - } - - if (flag_bias) { - //! deal with w_out pad_0 column post with bias - memcpy(dout_ptr0, dout0, pad_0 * sizeof(float)); - memcpy(dout_ptr1, dout1, pad_0 * sizeof(float)); - memcpy(dout_ptr2, dout2, pad_0 * sizeof(float)); - memcpy(dout_ptr3, dout3, pad_0 * sizeof(float)); - } else { - //! deal with w_out pad_0 column post without bias - memset(dout_ptr0, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr1, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr2, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr3, 0x00, pad_0 * sizeof(float)); - } - - din0 = din4; - din1 = din5; - din2 = din6; - din3 = din7; - din4 = din3 + w_in_new; - din5 = din4 + w_in_new; - din6 = din5 + w_in_new; - din7 = din6 + w_in_new; - - dout0 = dout3 + w_out; - dout1 = dout0 + w_out; - dout2 = dout1 + w_out; - dout3 = dout2 + w_out; - } - float* dout_pad_end = dout_ch + h_out_new * w_out; - if (flag_bias) { - //! deal with h_out pad_0 line with bias - memcpy(reinterpret_cast(dout_pad_end), - dout_ch - pad_0 * w_out, - pad_0 * w_out * sizeof(float)); - } else { - //! deal with h_out pad_0 line without bias - memset(reinterpret_cast(dout_pad_end), - 0x00, - pad_0 * w_out * sizeof(float)); - } - } - } - free(din_new); -} - -#else - -//! kernel for one out without extracting data mid -//! deal with two lines out -void compute_one_out_without_extract(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - float* dout0, - float* dout1, - const float* weights, - const float* bias) { - asm volatile( - "mov r0, #20 \n" - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - - "vld1.32 {d4-d5}, [%[din0]]! \n" - "vld1.32 {d6-d7}, [%[din1]]! \n" - "vld1.32 {d8-d9}, [%[din2]]! \n" - "vld1.32 {d10-d11}, [%[din3]]! \n" - "vld1.32 {d12-d13}, [%[din4]]! \n" - "vld1.32 {d14-d15}, [%[din5]]! \n" - - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - "vld1.32 {d6[0]}, [%[din0]] \n" - "vld1.32 {d6[1]}, [%[din1]] \n" - "vld1.32 {d7[0]}, [%[din2]] \n" - "vld1.32 {d7[1]}, [%[din3]] \n" - - // weights r2 - "vmla.f32 q9, q0, q4 \n" - "vmla.f32 q10, q0, q5 \n" - - "vld1.32 {d8[0]}, [%[din4]] \n" - "vld1.32 {d8[1]}, [%[din5]] \n" - - "vld1.32 {d0-d1}, [%[wh]] \n" - - // weights r3 - "vmla.f32 q9, q1, q5 \n" - "vmla.f32 q10, q1, q6 \n" - - // weights col4 - "sub %[wh], #64 \n" - "vld1.32 {d4[0]}, [%[wh]], r0 \n" - "vld1.32 {d4[1]}, [%[wh]], r0 \n" - "vld1.32 {d5[0]}, [%[wh]], r0 \n" - "vld1.32 {d5[1]}, [%[wh]], r0 \n" - - // weights r4 - "vmla.f32 q9, q0, q6 \n" - "vmla.f32 q10, q0, q7 \n" - - "vext.32 q5, q3, q4, #1 \n" - - "vmla.f32 q9, q2, q3 \n" - "vmla.f32 q10, q2, q5 \n" - - "vld1.32 {d4[0]}, [%[wh]] \n" - "vld1.32 {d6}, [%[bias]] \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - "vpadd.f32 d18, d18, d19 \n" - - "vmla.f32 d18, d8, d4[0] \n" - - // add bias - "vadd.f32 d18, d18, d6 \n" - - "vst1.32 {d18[0]}, [%[dout0]] \n" - "vst1.32 {d18[1]}, [%[dout1]] \n" - - : [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [wh] "+r"(weights) - : [dout0] "r"(dout0), [dout1] "r"(dout1), [bias] "r"(bias) - : "memory", - "r0", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11"); -} - -//! kernel for one out without extracting data mid -//! deal with two lines out -void compute_one_out_without_extract_relu(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - float* dout0, - float* dout1, - const float* weights, - const float* bias) { - asm volatile( - "mov r0, #20 \n" - "vmov.i32 q15, #0x0 \n" - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - - "vld1.32 {d4-d5}, [%[din0]]! \n" - "vld1.32 {d6-d7}, [%[din1]]! \n" - "vld1.32 {d8-d9}, [%[din2]]! \n" - "vld1.32 {d10-d11}, [%[din3]]! \n" - "vld1.32 {d12-d13}, [%[din4]]! \n" - "vld1.32 {d14-d15}, [%[din5]]! \n" - - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - "vld1.32 {d6[0]}, [%[din0]] \n" - "vld1.32 {d6[1]}, [%[din1]] \n" - "vld1.32 {d7[0]}, [%[din2]] \n" - "vld1.32 {d7[1]}, [%[din3]] \n" - - // weights r2 - "vmla.f32 q9, q0, q4 \n" - "vmla.f32 q10, q0, q5 \n" - - "vld1.32 {d8[0]}, [%[din4]] \n" - "vld1.32 {d8[1]}, [%[din5]] \n" - - "vld1.32 {d0-d1}, [%[wh]] \n" - - // weights r3 - "vmla.f32 q9, q1, q5 \n" - "vmla.f32 q10, q1, q6 \n" - - // weights col4 - "sub %[wh], #64 \n" - "vld1.32 {d4[0]}, [%[wh]], r0 \n" - "vld1.32 {d4[1]}, [%[wh]], r0 \n" - "vld1.32 {d5[0]}, [%[wh]], r0 \n" - "vld1.32 {d5[1]}, [%[wh]], r0 \n" - - // weights r4 - "vmla.f32 q9, q0, q6 \n" - "vmla.f32 q10, q0, q7 \n" - - "vext.32 q5, q3, q4, #1 \n" - - "vmla.f32 q9, q2, q3 \n" - "vmla.f32 q10, q2, q5 \n" - - "vld1.32 {d4[0]}, [%[wh]] \n" - "vld1.32 {d6}, [%[bias]] \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - "vpadd.f32 d18, d18, d19 \n" - - "vmla.f32 d18, d8, d4[0] \n" - - // add bias - "vadd.f32 d18, d18, d6 \n" - - // relu - "vmax.f32 d18, d18, d30 \n" - - "vst1.32 {d18[0]}, [%[dout0]] \n" - "vst1.32 {d18[1]}, [%[dout1]] \n" - - : [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [wh] "+r"(weights) - : [dout0] "r"(dout0), [dout1] "r"(dout1), [bias] "r"(bias) - : "memory", - "r0", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q15"); -} - -//! kernel for one out without extracting data pre -//! deal with two lines out -void compute_one_out_extract_pre(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - float* dout0, - float* dout1, - const float* weights, - const float* bias) { - asm volatile( - "mov r0, #20 \n" - "add %[wh], #4 \n" - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - - "vld1.32 {d4-d5}, [%[din0]]! \n" - "vld1.32 {d6-d7}, [%[din1]]! \n" - "vld1.32 {d8-d9}, [%[din2]]! \n" - "vld1.32 {d10-d11}, [%[din3]]! \n" - "vld1.32 {d12-d13}, [%[din4]]! \n" - "vld1.32 {d14-d15}, [%[din5]]! \n" - - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - - // weights r2 - "vmla.f32 q9, q0, q4 \n" - "vmla.f32 q10, q0, q5 \n" - - "vld1.32 {d0-d1}, [%[wh]] \n" - - // weights r3 - "vmla.f32 q9, q1, q5 \n" - "vmla.f32 q10, q1, q6 \n" - - // weights r4 - "vmla.f32 q9, q0, q6 \n" - "vmla.f32 q10, q0, q7 \n" - - // load bias - "vld1.32 {d0}, [%[bias]] \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - "vpadd.f32 d18, d18, d19 \n" - - // add bias - "vadd.f32 d18, d18, d0 \n" - - "vst1.32 {d18[0]}, [%[dout0]] \n" - "vst1.32 {d18[1]}, [%[dout1]] \n" - - : [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [wh] "+r"(weights) - : [dout0] "r"(dout0), [dout1] "r"(dout1), [bias] "r"(bias) - : "memory", - "r0", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11"); -} - -//! kernel for one out without extracting data pre -//! deal with two lines out -void compute_one_out_extract_pre_relu(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - float* dout0, - float* dout1, - const float* weights, - const float* bias) { - asm volatile( - "mov r0, #20 \n" - "add %[wh], #4 \n" - "vmov.i32 q15, #0x0 \n" - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - - "vld1.32 {d4-d5}, [%[din0]]! \n" - "vld1.32 {d6-d7}, [%[din1]]! \n" - "vld1.32 {d8-d9}, [%[din2]]! \n" - "vld1.32 {d10-d11}, [%[din3]]! \n" - "vld1.32 {d12-d13}, [%[din4]]! \n" - "vld1.32 {d14-d15}, [%[din5]]! \n" - - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - - // weights r2 - "vmla.f32 q9, q0, q4 \n" - "vmla.f32 q10, q0, q5 \n" - - "vld1.32 {d0-d1}, [%[wh]] \n" - - // weights r3 - "vmla.f32 q9, q1, q5 \n" - "vmla.f32 q10, q1, q6 \n" - - // weights r4 - "vmla.f32 q9, q0, q6 \n" - "vmla.f32 q10, q0, q7 \n" - - // load bias - "vld1.32 {d0}, [%[bias]] \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - "vpadd.f32 d18, d18, d19 \n" - - // add bias - "vadd.f32 d18, d18, d0 \n" - - // relu - "vmax.f32 d18, d18, d30 \n" - "vst1.32 {d18[0]}, [%[dout0]] \n" - "vst1.32 {d18[1]}, [%[dout1]] \n" - - : [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [wh] "+r"(weights) - : [dout0] "r"(dout0), [dout1] "r"(dout1), [bias] "r"(bias) - : "memory", - "r0", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q15"); -} - -//! kernel for one out with extracting data post -//! deal with two lines out -void compute_one_out_extract_post(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - float* dout0, - float* dout1, - const float* weights, - const float* bias) { - asm volatile( - "mov r0, #20 \n" - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - - "vld1.32 {d4-d5}, [%[din0]]! \n" - "vld1.32 {d6-d7}, [%[din1]]! \n" - "vld1.32 {d8-d9}, [%[din2]]! \n" - "vld1.32 {d10-d11}, [%[din3]]! \n" - "vld1.32 {d12-d13}, [%[din4]]! \n" - "vld1.32 {d14-d15}, [%[din5]]! \n" - - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - - // weights r2 - "vmla.f32 q9, q0, q4 \n" - "vmla.f32 q10, q0, q5 \n" - - "vld1.32 {d0-d1}, [%[wh]] \n" - - // weights r3 - "vmla.f32 q9, q1, q5 \n" - "vmla.f32 q10, q1, q6 \n" - - // weights r4 - "vmla.f32 q9, q0, q6 \n" - "vmla.f32 q10, q0, q7 \n" - - "vld1.32 {d0}, [%[bias]] \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - "vpadd.f32 d18, d18, d19 \n" - - // add bias - "vadd.f32 d18, d18, d0 \n" - - "vst1.32 {d18[0]}, [%[dout0]] \n" - "vst1.32 {d18[1]}, [%[dout1]] \n" - - : [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [wh] "+r"(weights) - : [dout0] "r"(dout0), [dout1] "r"(dout1), [bias] "r"(bias) - : "memory", - "r0", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11"); -} - -//! kernel for one out with extracting data post -//! deal with two lines out -void compute_one_out_extract_post_relu(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - float* dout0, - float* dout1, - const float* weights, - const float* bias) { - asm volatile( - "mov r0, #20 \n" - "vmov.i32 q15, #0x0 \n" - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - - "vld1.32 {d4-d5}, [%[din0]]! \n" - "vld1.32 {d6-d7}, [%[din1]]! \n" - "vld1.32 {d8-d9}, [%[din2]]! \n" - "vld1.32 {d10-d11}, [%[din3]]! \n" - "vld1.32 {d12-d13}, [%[din4]]! \n" - "vld1.32 {d14-d15}, [%[din5]]! \n" - - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - - // weights r2 - "vmla.f32 q9, q0, q4 \n" - "vmla.f32 q10, q0, q5 \n" - - "vld1.32 {d0-d1}, [%[wh]] \n" - - // weights r3 - "vmla.f32 q9, q1, q5 \n" - "vmla.f32 q10, q1, q6 \n" - - // weights r4 - "vmla.f32 q9, q0, q6 \n" - "vmla.f32 q10, q0, q7 \n" - - "vld1.32 {d0}, [%[bias]] \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - "vpadd.f32 d18, d18, d19 \n" - - // add bias - "vadd.f32 d18, d18, d0 \n" - - // relu - "vmax.f32 d18, d18, d30 \n" - - "vst1.32 {d18[0]}, [%[dout0]] \n" - "vst1.32 {d18[1]}, [%[dout1]] \n" - - : [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [wh] "+r"(weights) - : [dout0] "r"(dout0), [dout1] "r"(dout1), [bias] "r"(bias) - : "memory", - "r0", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q15"); -} - -//! kernel for two out with extracting data pre -//! deal with two lines out -void compute_two_out_extract_pre(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - float* dout0, - float* dout1, - const float* weights, - const float* bias) { - asm volatile( - "mov r0, #20 \n" - "mov r1, #0 \n" - "add %[wh], #8 \n" - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - - "vmov.32 d1[1], r1 \n" - "vmov.32 d3[1], r1 \n" - - "vld1.32 {d4-d5}, [%[din0]]! \n" - "vld1.32 {d6-d7}, [%[din1]]! \n" - "vld1.32 {d8-d9}, [%[din2]]! \n" - "vld1.32 {d10-d11}, [%[din3]]! \n" - "vld1.32 {d12-d13}, [%[din4]]! \n" - "vld1.32 {d14-d15}, [%[din5]]! \n" - - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - "vld1.32 {d24-d25}, [%[wh]], r0 \n" - "vmov.32 d25[1], r1 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - "vld1.32 {d26-d27}, [%[wh]], r0 \n" - "vmov.32 d27[1], r1 \n" - - // weights r2 - "vmla.f32 q9, q12, q4 \n" - "vmla.f32 q10, q12, q5 \n" - - "vld1.32 {d28-d29}, [%[wh]]\n" - "vmov.32 d29[1], r1 \n" - - // weights r3 - "vmla.f32 q9, q13, q5 \n" - "vmla.f32 q10, q13, q6 \n" - - // weights r4 - "vmla.f32 q9, q14, q6 \n" - "vmla.f32 q10, q14, q7 \n" - - "sub %[wh], #84 \n" - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - "vld1.32 {d24-d25}, [%[wh]], r0 \n" - - "vpadd.f32 d22, d18, d19 \n" - "vpadd.f32 d23, d20, d21 \n" - - "vld1.32 {d26-d27}, [%[wh]], r0 \n" - "vld1.32 {d28-d29}, [%[wh]]\n" - - "vpadd.f32 d22, d22, d23 \n" - - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - // weights r2 - "vmla.f32 q9, q12, q4 \n" - "vmla.f32 q10, q12, q5 \n" - - // weights r3 - "vmla.f32 q9, q13, q5 \n" - "vmla.f32 q10, q13, q6 \n" - - // weights r4 - "vmla.f32 q9, q14, q6 \n" - "vmla.f32 q10, q14, q7 \n" - "vld1.32 {d30-d31}, [%[bias]] \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - "vpadd.f32 d23, d18, d19 \n" - - // trn out neon register - "vtrn.32 d22, d23 \n" - - // add bias - "vadd.f32 q11, q11, q15 \n" - - // store result - "vst1.32 {d22}, [%[dout0]] \n" - "vst1.32 {d23}, [%[dout1]] \n" - - : [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [wh] "+r"(weights) - : [dout0] "r"(dout0), [dout1] "r"(dout1), [bias] "r"(bias) - : "memory", - "r0", - "r1", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -} - -//! kernel for two out with extracting data pre -//! deal with two lines out -void compute_two_out_extract_pre_relu(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - float* dout0, - float* dout1, - const float* weights, - const float* bias) { - asm volatile( - "mov r0, #20 \n" - "mov r1, #0 \n" - "add %[wh], #8 \n" - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - - "vmov.32 d1[1], r1 \n" - "vmov.32 d3[1], r1 \n" - - "vld1.32 {d4-d5}, [%[din0]]! \n" - "vld1.32 {d6-d7}, [%[din1]]! \n" - "vld1.32 {d8-d9}, [%[din2]]! \n" - "vld1.32 {d10-d11}, [%[din3]]! \n" - "vld1.32 {d12-d13}, [%[din4]]! \n" - "vld1.32 {d14-d15}, [%[din5]]! \n" - - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - "vld1.32 {d24-d25}, [%[wh]], r0 \n" - "vmov.32 d25[1], r1 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - "vld1.32 {d26-d27}, [%[wh]], r0 \n" - "vmov.32 d27[1], r1 \n" - - // weights r2 - "vmla.f32 q9, q12, q4 \n" - "vmla.f32 q10, q12, q5 \n" - - "vld1.32 {d28-d29}, [%[wh]]\n" - "vmov.32 d29[1], r1 \n" - - // weights r3 - "vmla.f32 q9, q13, q5 \n" - "vmla.f32 q10, q13, q6 \n" - - // weights r4 - "vmla.f32 q9, q14, q6 \n" - "vmla.f32 q10, q14, q7 \n" - - "sub %[wh], #84 \n" - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - "vld1.32 {d24-d25}, [%[wh]], r0 \n" - - "vpadd.f32 d22, d18, d19 \n" - "vpadd.f32 d23, d20, d21 \n" - - "vld1.32 {d26-d27}, [%[wh]], r0 \n" - "vld1.32 {d28-d29}, [%[wh]]\n" - - "vpadd.f32 d22, d22, d23 \n" - - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - // weights r2 - "vmla.f32 q9, q12, q4 \n" - "vmla.f32 q10, q12, q5 \n" - - // weights r3 - "vmla.f32 q9, q13, q5 \n" - "vmla.f32 q10, q13, q6 \n" - - // weights r4 - "vmla.f32 q9, q14, q6 \n" - "vmla.f32 q10, q14, q7 \n" - "vld1.32 {d30-d31}, [%[bias]] \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - "vpadd.f32 d23, d18, d19 \n" - "vmov.i32 q9, #0x0 \n" - - // trn out neon register - "vtrn.32 d22, d23 \n" - - // add bias - "vadd.f32 q11, q11, q15 \n" - - // relu - "vmax.f32 q11, q11, q9 \n" - // store result - "vst1.32 {d22}, [%[dout0]] \n" - "vst1.32 {d23}, [%[dout1]] \n" - - : [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [wh] "+r"(weights) - : [dout0] "r"(dout0), [dout1] "r"(dout1), [bias] "r"(bias) - : "memory", - "r0", - "r1", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -} - -//! kernel for two out with extracting data post -//! deal with two lines out -void compute_two_out_extract_post(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - float* dout0, - float* dout1, - const float* weights, - const float* bias) { - asm volatile( - "mov r0, #20 \n" - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - - "vld1.32 {d4-d5}, [%[din0]]! \n" - "vld1.32 {d6-d7}, [%[din1]]! \n" - "vld1.32 {d8-d9}, [%[din2]]! \n" - "vld1.32 {d10-d11}, [%[din3]]! \n" - "vld1.32 {d12-d13}, [%[din4]]! \n" - "vld1.32 {d14-d15}, [%[din5]]! \n" - - //! out zero - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - "vld1.32 {d24-d25}, [%[wh]], r0 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - "vld1.32 {d26-d27}, [%[wh]], r0 \n" - - // weights r2 - "vmla.f32 q9, q12, q4 \n" - "vmla.f32 q10, q12, q5 \n" - - "vld1.32 {d28-d29}, [%[wh]]\n" - - // weights r3 - "vmla.f32 q9, q13, q5 \n" - "vmla.f32 q10, q13, q6 \n" - - // weights r4 - "vmla.f32 q9, q14, q6 \n" - "vmla.f32 q10, q14, q7 \n" - - "vpadd.f32 d22, d18, d19 \n" - "vpadd.f32 d23, d20, d21 \n" - "vpadd.f32 d22, d22, d23 \n" - - "vmov.f32 q15, #0.0 \n" - "vext.32 q2, q2, q15, #1 \n" - "vext.32 q3, q3, q15, #1 \n" - "vext.32 q4, q4, q15, #1 \n" - "vext.32 q5, q5, q15, #1 \n" - "vext.32 q6, q6, q15, #1 \n" - "vext.32 q7, q7, q15, #1 \n" - "vext.32 q8, q8, q15, #1 \n" - - //! out one - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - // weights r2 - "vmla.f32 q9, q12, q4 \n" - "vmla.f32 q10, q12, q5 \n" - - // weights r3 - "vmla.f32 q9, q13, q5 \n" - "vmla.f32 q10, q13, q6 \n" - - // weights r4 - "vmla.f32 q9, q14, q6 \n" - "vmla.f32 q10, q14, q7 \n" - "vld1.32 {d30-d31}, [%[bias]] \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - "vpadd.f32 d23, d18, d19 \n" - - // trn out neon register - "vtrn.32 d22, d23 \n" - - // add bias - "vadd.f32 q11, q11, q15 \n" - - // store result - "vst1.32 {d22}, [%[dout0]] \n" - "vst1.32 {d23}, [%[dout1]] \n" - - : [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [wh] "+r"(weights) - : [dout0] "r"(dout0), [dout1] "r"(dout1), [bias] "r"(bias) - : "memory", - "r0", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -} - -//! kernel for two out with extracting data post -//! deal with two lines out -void compute_two_out_extract_post_relu(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - float* dout0, - float* dout1, - const float* weights, - const float* bias) { - asm volatile( - "mov r0, #20 \n" - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - - "vld1.32 {d4-d5}, [%[din0]]! \n" - "vld1.32 {d6-d7}, [%[din1]]! \n" - "vld1.32 {d8-d9}, [%[din2]]! \n" - "vld1.32 {d10-d11}, [%[din3]]! \n" - "vld1.32 {d12-d13}, [%[din4]]! \n" - "vld1.32 {d14-d15}, [%[din5]]! \n" - - //! out zero - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - "vld1.32 {d24-d25}, [%[wh]], r0 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - "vld1.32 {d26-d27}, [%[wh]], r0 \n" - - // weights r2 - "vmla.f32 q9, q12, q4 \n" - "vmla.f32 q10, q12, q5 \n" - - "vld1.32 {d28-d29}, [%[wh]]\n" - - // weights r3 - "vmla.f32 q9, q13, q5 \n" - "vmla.f32 q10, q13, q6 \n" - - // weights r4 - "vmla.f32 q9, q14, q6 \n" - "vmla.f32 q10, q14, q7 \n" - - "vpadd.f32 d22, d18, d19 \n" - "vpadd.f32 d23, d20, d21 \n" - "vpadd.f32 d22, d22, d23 \n" - - "vmov.f32 q15, #0.0 \n" - "vext.32 q2, q2, q15, #1 \n" - "vext.32 q3, q3, q15, #1 \n" - "vext.32 q4, q4, q15, #1 \n" - "vext.32 q5, q5, q15, #1 \n" - "vext.32 q6, q6, q15, #1 \n" - "vext.32 q7, q7, q15, #1 \n" - "vext.32 q8, q8, q15, #1 \n" - - //! out one - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - // weights r2 - "vmla.f32 q9, q12, q4 \n" - "vmla.f32 q10, q12, q5 \n" - - // weights r3 - "vmla.f32 q9, q13, q5 \n" - "vmla.f32 q10, q13, q6 \n" - - // weights r4 - "vmla.f32 q9, q14, q6 \n" - "vmla.f32 q10, q14, q7 \n" - "vld1.32 {d30-d31}, [%[bias]] \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - "vpadd.f32 d23, d18, d19 \n" - "vmov.i32 q9, #0x0 \n" - - // trn out neon register - "vtrn.32 d22, d23 \n" - - // add bias - "vadd.f32 q11, q11, q15 \n" - - // relu - "vmax.f32 q11, q11, q9 \n" - - // store result - "vst1.32 {d22}, [%[dout0]] \n" - "vst1.32 {d23}, [%[dout1]] \n" - - : [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [wh] "+r"(weights) - : [dout0] "r"(dout0), [dout1] "r"(dout1), [bias] "r"(bias) - : "memory", - "r0", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -} - -//! kernel for three out with extracting data pre -//! deal with two lines out -void compute_three_out_extract_pre(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - float* dout0, - float* dout1, - const float* weights, - const float* bias) { - asm volatile( - "mov r0, #20 \n" - "add %[wh], #12 \n" - "vld1.32 {d0}, [%[wh]], r0 \n" - "vld1.32 {d2}, [%[wh]], r0 \n" - - "vld1.32 {d4-d5}, [%[din0]] \n" - "vld1.32 {d6-d7}, [%[din1]] \n" - "vld1.32 {d8-d9}, [%[din2]] \n" - "vld1.32 {d10-d11}, [%[din3]] \n" - "vld1.32 {d12-d13}, [%[din4]] \n" - "vld1.32 {d14-d15}, [%[din5]] \n" - - //! out zero - // weights r0 - "vmul.f32 d18, d0, d4 \n" - "vmul.f32 d20, d0, d6 \n" - - "vld1.32 {d24}, [%[wh]], r0 \n" - - // weights r1 - "vmla.f32 d18, d2, d6 \n" - "vmla.f32 d20, d2, d8 \n" - - "vld1.32 {d26}, [%[wh]], r0 \n" - - // weights r2 - "vmla.f32 d18, d24, d8 \n" - "vmla.f32 d20, d24, d10 \n" - - "vld1.32 {d28}, [%[wh]] \n" - - // weights r3 - "vmla.f32 d18, d26, d10 \n" - "vmla.f32 d20, d26, d12 \n" - - // load bias - "vld1.32 {d30-d31}, [%[bias]] \n" - - // weights r4 - "vmla.f32 d18, d28, d12 \n" - "vmla.f32 d20, d28, d14 \n" - "vpadd.f32 d22, d18, d20 \n" - - //! out one - "mov r1, #0 \n" - "sub %[wh], #84 \n" - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - - "vmov.32 d1[1], r1 \n" - "vmov.32 d3[1], r1 \n" - - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - "vld1.32 {d24-d25}, [%[wh]], r0 \n" - "vmov.32 d25[1], r1 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - "vld1.32 {d26-d27}, [%[wh]], r0 \n" - "vmov.32 d27[1], r1 \n" - - // weights r2 - "vmla.f32 q9, q12, q4 \n" - "vmla.f32 q10, q12, q5 \n" - - "vld1.32 {d28-d29}, [%[wh]]\n" - "vmov.32 d29[1], r1 \n" - - // weights r3 - "vmla.f32 q9, q13, q5 \n" - "vmla.f32 q10, q13, q6 \n" - - // weights r4 - "vmla.f32 q9, q14, q6 \n" - "vmla.f32 q10, q14, q7 \n" - - "sub %[wh], #84 \n" - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - "vld1.32 {d24-d25}, [%[wh]], r0 \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - - "vld1.32 {d26-d27}, [%[wh]], r0 \n" - "vld1.32 {d28-d29}, [%[wh]]\n" - - "vpadd.f32 d23, d18, d19 \n" - - // trn out neon register - "vtrn.32 d22, d23 \n" - - // add bias - "vadd.f32 q11, q11, q15 \n" - - // store result - "vst1.32 {d22}, [%[dout0]]! \n" - "vst1.32 {d23}, [%[dout1]]! \n" - - //! out two - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - // weights r2 - "vmla.f32 q9, q12, q4 \n" - "vmla.f32 q10, q12, q5 \n" - - // weights r3 - "vmla.f32 q9, q13, q5 \n" - "vmla.f32 q10, q13, q6 \n" - - // weights r4 - "vmla.f32 q9, q14, q6 \n" - "vmla.f32 q10, q14, q7 \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - "vpadd.f32 d18, d18, d19 \n" - - // add bias - "vadd.f32 d18, d18, d30 \n" - - // store result - "vst1.32 {d18[0]}, [%[dout0]] \n" - "vst1.32 {d18[1]}, [%[dout1]] \n" - - : [dout0] "+r"(dout0), [dout1] "+r"(dout1), [wh] "+r"(weights) - : [din0] "r"(din0), - [din1] "r"(din1), - [din2] "r"(din2), - [din3] "r"(din3), - [din4] "r"(din4), - [din5] "r"(din5), - [bias] "r"(bias) - : "memory", - "r0", - "r1", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -} - -//! kernel for three out with extracting data pre -//! deal with two lines out -void compute_three_out_extract_pre_relu(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - float* dout0, - float* dout1, - const float* weights, - const float* bias) { - asm volatile( - "mov r0, #20 \n" - "add %[wh], #12 \n" - "vld1.32 {d0}, [%[wh]], r0 \n" - "vld1.32 {d2}, [%[wh]], r0 \n" - - "vld1.32 {d4-d5}, [%[din0]] \n" - "vld1.32 {d6-d7}, [%[din1]] \n" - "vld1.32 {d8-d9}, [%[din2]] \n" - "vld1.32 {d10-d11}, [%[din3]] \n" - "vld1.32 {d12-d13}, [%[din4]] \n" - "vld1.32 {d14-d15}, [%[din5]] \n" - - //! out zero - // weights r0 - "vmul.f32 d18, d0, d4 \n" - "vmul.f32 d20, d0, d6 \n" - - "vld1.32 {d24}, [%[wh]], r0 \n" - - // weights r1 - "vmla.f32 d18, d2, d6 \n" - "vmla.f32 d20, d2, d8 \n" - - "vld1.32 {d26}, [%[wh]], r0 \n" - - // weights r2 - "vmla.f32 d18, d24, d8 \n" - "vmla.f32 d20, d24, d10 \n" - - "vld1.32 {d28}, [%[wh]] \n" - - // weights r3 - "vmla.f32 d18, d26, d10 \n" - "vmla.f32 d20, d26, d12 \n" - - // load bias - "vld1.32 {d30-d31}, [%[bias]] \n" - - // weights r4 - "vmla.f32 d18, d28, d12 \n" - "vmla.f32 d20, d28, d14 \n" - "vpadd.f32 d22, d18, d20 \n" - - //! out one - "mov r1, #0 \n" - "sub %[wh], #84 \n" - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - - "vmov.32 d1[1], r1 \n" - "vmov.32 d3[1], r1 \n" - - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - "vld1.32 {d24-d25}, [%[wh]], r0 \n" - "vmov.32 d25[1], r1 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - "vld1.32 {d26-d27}, [%[wh]], r0 \n" - "vmov.32 d27[1], r1 \n" - - // weights r2 - "vmla.f32 q9, q12, q4 \n" - "vmla.f32 q10, q12, q5 \n" - - "vld1.32 {d28-d29}, [%[wh]]\n" - "vmov.32 d29[1], r1 \n" - - // weights r3 - "vmla.f32 q9, q13, q5 \n" - "vmla.f32 q10, q13, q6 \n" - - // weights r4 - "vmla.f32 q9, q14, q6 \n" - "vmla.f32 q10, q14, q7 \n" - - "sub %[wh], #84 \n" - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - "vld1.32 {d24-d25}, [%[wh]], r0 \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - - "vld1.32 {d26-d27}, [%[wh]], r0 \n" - "vld1.32 {d28-d29}, [%[wh]]\n" - - "vpadd.f32 d23, d18, d19 \n" - "vmov.i32 q8, #0x0 \n" - - // trn out neon register - "vtrn.32 d22, d23 \n" - - // add bias - "vadd.f32 q11, q11, q15 \n" - - // relu - "vmax.f32 q11, q11, q8 \n" - - // store result - "vst1.32 {d22}, [%[dout0]]! \n" - "vst1.32 {d23}, [%[dout1]]! \n" - - //! out two - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - // weights r2 - "vmla.f32 q9, q12, q4 \n" - "vmla.f32 q10, q12, q5 \n" - - // weights r3 - "vmla.f32 q9, q13, q5 \n" - "vmla.f32 q10, q13, q6 \n" - - // weights r4 - "vmla.f32 q9, q14, q6 \n" - "vmla.f32 q10, q14, q7 \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - "vpadd.f32 d18, d18, d19 \n" - - // add bias - "vadd.f32 d18, d18, d30 \n" - - // relu - "vmax.f32 d18, d18, d16 \n" - - // store result - "vst1.32 {d18[0]}, [%[dout0]] \n" - "vst1.32 {d18[1]}, [%[dout1]] \n" - - : [dout0] "+r"(dout0), [dout1] "+r"(dout1), [wh] "+r"(weights) - : [din0] "r"(din0), - [din1] "r"(din1), - [din2] "r"(din2), - [din3] "r"(din3), - [din4] "r"(din4), - [din5] "r"(din5), - [bias] "r"(bias) - : "memory", - "r0", - "r1", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -} - -//! kernel for three out with extracting data post -//! deal with two lines out -void compute_three_out_extract_post(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - float* dout0, - float* dout1, - const float* weights, - const float* bias) { - asm volatile( - "mov r0, #20 \n" - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - - "vld1.32 {d4-d5}, [%[din0]] \n" - "vld1.32 {d6-d7}, [%[din1]] \n" - "vld1.32 {d8-d9}, [%[din2]] \n" - "vld1.32 {d10-d11}, [%[din3]] \n" - "vld1.32 {d12-d13}, [%[din4]] \n" - "vld1.32 {d14-d15}, [%[din5]] \n" - - //! out zero && two - - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - "vmul.f32 d16, d0, d5 \n" - "vmul.f32 d17, d0, d7 \n" - - "vld1.32 {d24-d25}, [%[wh]], r0 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - "vmla.f32 d16, d2, d7 \n" - "vmla.f32 d17, d2, d9 \n" - - "vld1.32 {d26-d27}, [%[wh]], r0 \n" - - // weights r2 - "vmla.f32 q9, q12, q4 \n" - "vmla.f32 q10, q12, q5 \n" - "vmla.f32 d16, d24, d9 \n" - "vmla.f32 d17, d24, d11 \n" - - "vld1.32 {d28-d29}, [%[wh]]\n" - - // weights r3 - "vmla.f32 q9, q13, q5 \n" - "vmla.f32 q10, q13, q6 \n" - "vmla.f32 d16, d26, d11 \n" - "vmla.f32 d17, d26, d13 \n" - - // weights r4 - "vmla.f32 q9, q14, q6 \n" - "vmla.f32 q10, q14, q7 \n" - "vmla.f32 d16, d28, d13 \n" - "vmla.f32 d17, d28, d15 \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - "vpadd.f32 d16, d16, d17 \n" - "vpadd.f32 d22, d18, d19 \n" - - "vmov.f32 q15, #0.0 \n" - "vext.32 q2, q2, q15, #1 \n" - "vext.32 q3, q3, q15, #1 \n" - "vext.32 q4, q4, q15, #1 \n" - "vext.32 q5, q5, q15, #1 \n" - "vext.32 q6, q6, q15, #1 \n" - "vext.32 q7, q7, q15, #1 \n" - - //! out one - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - // weights r2 - "vmla.f32 q9, q12, q4 \n" - "vmla.f32 q10, q12, q5 \n" - - // weights r3 - "vmla.f32 q9, q13, q5 \n" - "vmla.f32 q10, q13, q6 \n" - - // weights r4 - "vmla.f32 q9, q14, q6 \n" - "vmla.f32 q10, q14, q7 \n" - - // load bias - "vld1.32 {d30-d31}, [%[bias]] \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - "vpadd.f32 d23, d18, d19 \n" - "vmov.i32 q9, #0x0 \n" - - // trn out neon register - "vtrn.32 d22, d23 \n" - - // add bias - "vadd.f32 q11, q11, q15 \n" - "vadd.f32 d16, d16, d30 \n" - - "vst1.32 {d22}, [%[dout0]]! \n" - "vst1.32 {d23}, [%[dout1]]! \n" - "vst1.32 {d16[0]}, [%[dout0]]! \n" - "vst1.32 {d16[1]}, [%[dout1]]! \n" - - : [dout0] "+r"(dout0), [dout1] "+r"(dout1), [wh] "+r"(weights) - : [din0] "r"(din0), - [din1] "r"(din1), - [din2] "r"(din2), - [din3] "r"(din3), - [din4] "r"(din4), - [din5] "r"(din5), - [bias] "r"(bias) - : "memory", - "r0", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -} - -//! kernel for three out with extracting data post -//! deal with two lines out -void compute_three_out_extract_post_relu(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - float* dout0, - float* dout1, - const float* weights, - const float* bias) { - asm volatile( - "mov r0, #20 \n" - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - - "vld1.32 {d4-d5}, [%[din0]] \n" - "vld1.32 {d6-d7}, [%[din1]] \n" - "vld1.32 {d8-d9}, [%[din2]] \n" - "vld1.32 {d10-d11}, [%[din3]] \n" - "vld1.32 {d12-d13}, [%[din4]] \n" - "vld1.32 {d14-d15}, [%[din5]] \n" - - //! out zero && two - - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - "vmul.f32 d16, d0, d5 \n" - "vmul.f32 d17, d0, d7 \n" - - "vld1.32 {d24-d25}, [%[wh]], r0 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - "vmla.f32 d16, d2, d7 \n" - "vmla.f32 d17, d2, d9 \n" - - "vld1.32 {d26-d27}, [%[wh]], r0 \n" - - // weights r2 - "vmla.f32 q9, q12, q4 \n" - "vmla.f32 q10, q12, q5 \n" - "vmla.f32 d16, d24, d9 \n" - "vmla.f32 d17, d24, d11 \n" - - "vld1.32 {d28-d29}, [%[wh]]\n" - - // weights r3 - "vmla.f32 q9, q13, q5 \n" - "vmla.f32 q10, q13, q6 \n" - "vmla.f32 d16, d26, d11 \n" - "vmla.f32 d17, d26, d13 \n" - - // weights r4 - "vmla.f32 q9, q14, q6 \n" - "vmla.f32 q10, q14, q7 \n" - "vmla.f32 d16, d28, d13 \n" - "vmla.f32 d17, d28, d15 \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - "vpadd.f32 d16, d16, d17 \n" - "vpadd.f32 d22, d18, d19 \n" - - "vmov.f32 q15, #0.0 \n" - "vext.32 q2, q2, q15, #1 \n" - "vext.32 q3, q3, q15, #1 \n" - "vext.32 q4, q4, q15, #1 \n" - "vext.32 q5, q5, q15, #1 \n" - "vext.32 q6, q6, q15, #1 \n" - "vext.32 q7, q7, q15, #1 \n" - - //! out one - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - // weights r2 - "vmla.f32 q9, q12, q4 \n" - "vmla.f32 q10, q12, q5 \n" - - // weights r3 - "vmla.f32 q9, q13, q5 \n" - "vmla.f32 q10, q13, q6 \n" - - // weights r4 - "vmla.f32 q9, q14, q6 \n" - "vmla.f32 q10, q14, q7 \n" - - // load bias - "vld1.32 {d30-d31}, [%[bias]] \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - "vpadd.f32 d23, d18, d19 \n" - "vmov.i32 q9, #0x0 \n" - - // trn out neon register - "vtrn.32 d22, d23 \n" - - // add bias - "vadd.f32 q11, q11, q15 \n" - "vadd.f32 d16, d16, d30 \n" - - // relu - "vmax.f32 q11, q11, q9 \n" - "vmax.f32 d16, d16, d18 \n" - - "vst1.32 {d22}, [%[dout0]]! \n" - "vst1.32 {d23}, [%[dout1]]! \n" - "vst1.32 {d16[0]}, [%[dout0]]! \n" - "vst1.32 {d16[1]}, [%[dout1]]! \n" - - : [dout0] "+r"(dout0), [dout1] "+r"(dout1), [wh] "+r"(weights) - : [din0] "r"(din0), - [din1] "r"(din1), - [din2] "r"(din2), - [din3] "r"(din3), - [din4] "r"(din4), - [din5] "r"(din5), - [bias] "r"(bias) - : "memory", - "r0", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -} - -//! kernel for four out with extracting data pre -//! deal with two lines out -void compute_four_out_extract_pre(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - float* dout0, - float* dout1, - const float* weights, - const float* bias) { - asm volatile( - "mov r0, #20 \n" - "add %[wh], #16 \n" - - //! out zero - // load input - "vld1.32 {d4[0]}, [%[din0]] \n" - "vld1.32 {d4[1]}, [%[din1]] \n" - "vld1.32 {d5[0]}, [%[din2]] \n" - "vld1.32 {d5[1]}, [%[din3]] \n" - "vld1.32 {d6[0]}, [%[din4]] \n" - "vld1.32 {d6[1]}, [%[din5]] \n" - - "vext.32 q4, q2, q3, #1 \n" - - // load weights - "vld1.32 d0[0], [%[wh]], r0 \n" - "vld1.32 d0[1], [%[wh]], r0 \n" - "vld1.32 d1[0], [%[wh]], r0 \n" - "vld1.32 d1[1], [%[wh]], r0 \n" - "vld1.32 d2[0], [%[wh]]\n" - - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q4 \n" - - "vld1.32 {d30-d31}, [%[bias]] \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - "vpadd.f32 d22, d18, d19 \n" - - "vmla.f32 d22, d6, d2[0] \n" - - "sub %[wh], #84 \n" - "vld1.32 {d0}, [%[wh]], r0 \n" - "vld1.32 {d2}, [%[wh]], r0 \n" - - "vld1.32 {d4-d5}, [%[din0]] \n" - "vld1.32 {d6-d7}, [%[din1]] \n" - "vld1.32 {d8-d9}, [%[din2]] \n" - "vld1.32 {d10-d11}, [%[din3]] \n" - "vld1.32 {d12-d13}, [%[din4]] \n" - "vld1.32 {d14-d15}, [%[din5]] \n" - - //! out one - // weights r0 - "vmul.f32 d18, d0, d4 \n" - "vmul.f32 d20, d0, d6 \n" - - "vld1.32 {d24}, [%[wh]], r0 \n" - - // weights r1 - "vmla.f32 d18, d2, d6 \n" - "vmla.f32 d20, d2, d8 \n" - - "vld1.32 {d26}, [%[wh]], r0 \n" - - // weights r2 - "vmla.f32 d18, d24, d8 \n" - "vmla.f32 d20, d24, d10 \n" - - "vld1.32 {d28}, [%[wh]] \n" - - // weights r3 - "vmla.f32 d18, d26, d10 \n" - "vmla.f32 d20, d26, d12 \n" - - // weights r4 - "vmla.f32 d18, d28, d12 \n" - "vmla.f32 d20, d28, d14 \n" - - "vpadd.f32 d23, d18, d20 \n" - - // trn out neon register - "vtrn.32 d22, d23 \n" - - // add bias - "vadd.f32 q11, q11, q15 \n" - - // store result - "vst1.32 {d22}, [%[dout0]]! \n" - "vst1.32 {d23}, [%[dout1]]! \n" - - //! out two - "mov r1, #0 \n" - "sub %[wh], #84 \n" - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - - "vmov.32 d1[1], r1 \n" - "vmov.32 d3[1], r1 \n" - - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - "vld1.32 {d24-d25}, [%[wh]], r0 \n" - "vmov.32 d25[1], r1 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - "vld1.32 {d26-d27}, [%[wh]], r0 \n" - "vmov.32 d27[1], r1 \n" - - // weights r2 - "vmla.f32 q9, q12, q4 \n" - "vmla.f32 q10, q12, q5 \n" - - "vld1.32 {d28-d29}, [%[wh]]\n" - "vmov.32 d29[1], r1 \n" - - // weights r3 - "vmla.f32 q9, q13, q5 \n" - "vmla.f32 q10, q13, q6 \n" - - // weights r4 - "vmla.f32 q9, q14, q6 \n" - "vmla.f32 q10, q14, q7 \n" - - "sub %[wh], #84 \n" - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - "vld1.32 {d24-d25}, [%[wh]], r0 \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - - "vld1.32 {d26-d27}, [%[wh]], r0 \n" - "vld1.32 {d28-d29}, [%[wh]]\n" - - "vpadd.f32 d22, d18, d19 \n" - - //! out three - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - // weights r2 - "vmla.f32 q9, q12, q4 \n" - "vmla.f32 q10, q12, q5 \n" - - // weights r3 - "vmla.f32 q9, q13, q5 \n" - "vmla.f32 q10, q13, q6 \n" - - // weights r4 - "vmla.f32 q9, q14, q6 \n" - "vmla.f32 q10, q14, q7 \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - "vpadd.f32 d23, d18, d19 \n" - - // trn out neon register - "vtrn.32 d22, d23 \n" - - // add bias - "vadd.f32 q11, q11, q15 \n" - - // store result - "vst1.32 {d22}, [%[dout0]] \n" - "vst1.32 {d23}, [%[dout1]] \n" - - : [dout0] "+r"(dout0), [dout1] "+r"(dout1), [wh] "+r"(weights) - : [din0] "r"(din0), - [din1] "r"(din1), - [din2] "r"(din2), - [din3] "r"(din3), - [din4] "r"(din4), - [din5] "r"(din5), - [bias] "r"(bias) - : "memory", - "r0", - "r1", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -} - -//! kernel for four out with extracting data pre -//! deal with two lines out -void compute_four_out_extract_pre_relu(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - float* dout0, - float* dout1, - const float* weights, - const float* bias) { - asm volatile( - "mov r0, #20 \n" - "add %[wh], #16 \n" - - //! out zero - // load input - "vld1.32 {d4[0]}, [%[din0]] \n" - "vld1.32 {d4[1]}, [%[din1]] \n" - "vld1.32 {d5[0]}, [%[din2]] \n" - "vld1.32 {d5[1]}, [%[din3]] \n" - "vld1.32 {d6[0]}, [%[din4]] \n" - "vld1.32 {d6[1]}, [%[din5]] \n" - - "vext.32 q4, q2, q3, #1 \n" - - // load weights - "vld1.32 d0[0], [%[wh]], r0 \n" - "vld1.32 d0[1], [%[wh]], r0 \n" - "vld1.32 d1[0], [%[wh]], r0 \n" - "vld1.32 d1[1], [%[wh]], r0 \n" - "vld1.32 d2[0], [%[wh]]\n" - - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q4 \n" - - "vld1.32 {d30-d31}, [%[bias]] \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - "vpadd.f32 d22, d18, d19 \n" - - "vmla.f32 d22, d6, d2[0] \n" - - "sub %[wh], #84 \n" - "vld1.32 {d0}, [%[wh]], r0 \n" - "vld1.32 {d2}, [%[wh]], r0 \n" - - "vld1.32 {d4-d5}, [%[din0]] \n" - "vld1.32 {d6-d7}, [%[din1]] \n" - "vld1.32 {d8-d9}, [%[din2]] \n" - "vld1.32 {d10-d11}, [%[din3]] \n" - "vld1.32 {d12-d13}, [%[din4]] \n" - "vld1.32 {d14-d15}, [%[din5]] \n" - - //! out one - // weights r0 - "vmul.f32 d18, d0, d4 \n" - "vmul.f32 d20, d0, d6 \n" - - "vld1.32 {d24}, [%[wh]], r0 \n" - - // weights r1 - "vmla.f32 d18, d2, d6 \n" - "vmla.f32 d20, d2, d8 \n" - - "vld1.32 {d26}, [%[wh]], r0 \n" - - // weights r2 - "vmla.f32 d18, d24, d8 \n" - "vmla.f32 d20, d24, d10 \n" - - "vld1.32 {d28}, [%[wh]] \n" - - // weights r3 - "vmla.f32 d18, d26, d10 \n" - "vmla.f32 d20, d26, d12 \n" - - // weights r4 - "vmla.f32 d18, d28, d12 \n" - "vmla.f32 d20, d28, d14 \n" - - "vpadd.f32 d23, d18, d20 \n" - "vmov.i32 q8, #0x0 \n" - - // trn out neon register - "vtrn.32 d22, d23 \n" - - // add bias - "vadd.f32 q11, q11, q15 \n" - - // relu - "vmax.f32 q11, q11, q8 \n" - - // store result - "vst1.32 {d22}, [%[dout0]]! \n" - "vst1.32 {d23}, [%[dout1]]! \n" - - //! out two - "mov r1, #0 \n" - "sub %[wh], #84 \n" - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - - "vmov.32 d1[1], r1 \n" - "vmov.32 d3[1], r1 \n" - - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - "vld1.32 {d24-d25}, [%[wh]], r0 \n" - "vmov.32 d25[1], r1 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - "vld1.32 {d26-d27}, [%[wh]], r0 \n" - "vmov.32 d27[1], r1 \n" - - // weights r2 - "vmla.f32 q9, q12, q4 \n" - "vmla.f32 q10, q12, q5 \n" - - "vld1.32 {d28-d29}, [%[wh]]\n" - "vmov.32 d29[1], r1 \n" - - // weights r3 - "vmla.f32 q9, q13, q5 \n" - "vmla.f32 q10, q13, q6 \n" - - // weights r4 - "vmla.f32 q9, q14, q6 \n" - "vmla.f32 q10, q14, q7 \n" - - "sub %[wh], #84 \n" - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - "vld1.32 {d24-d25}, [%[wh]], r0 \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - - "vld1.32 {d26-d27}, [%[wh]], r0 \n" - "vld1.32 {d28-d29}, [%[wh]]\n" - - "vpadd.f32 d22, d18, d19 \n" - - //! out three - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - // weights r2 - "vmla.f32 q9, q12, q4 \n" - "vmla.f32 q10, q12, q5 \n" - - // weights r3 - "vmla.f32 q9, q13, q5 \n" - "vmla.f32 q10, q13, q6 \n" - - // weights r4 - "vmla.f32 q9, q14, q6 \n" - "vmla.f32 q10, q14, q7 \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - "vpadd.f32 d23, d18, d19 \n" - - // trn out neon register - "vtrn.32 d22, d23 \n" - - // add bias - "vadd.f32 q11, q11, q15 \n" - - // relu - "vmax.f32 q11, q11, q8 \n" - - // store result - "vst1.32 {d22}, [%[dout0]] \n" - "vst1.32 {d23}, [%[dout1]] \n" - - : [dout0] "+r"(dout0), [dout1] "+r"(dout1), [wh] "+r"(weights) - : [din0] "r"(din0), - [din1] "r"(din1), - [din2] "r"(din2), - [din3] "r"(din3), - [din4] "r"(din4), - [din5] "r"(din5), - [bias] "r"(bias) - : "memory", - "r0", - "r1", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -} - -//! kernel for three out with extracting data post -//! deal with two lines out -void compute_four_out_extract_post(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - float* dout0, - float* dout1, - const float* weights, - const float* bias) { - asm volatile( - "mov r0, #20 \n" - "mov r1, #12 \n" - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - - "vld1.32 {d4-d5}, [%[din0]], r1 \n" - "vld1.32 {d6-d7}, [%[din1]], r1 \n" - "vld1.32 {d8-d9}, [%[din2]], r1 \n" - "vld1.32 {d10-d11}, [%[din3]], r1 \n" - "vld1.32 {d12-d13}, [%[din4]], r1 \n" - "vld1.32 {d14-d15}, [%[din5]], r1 \n" - - //! out zero && two - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - "vmul.f32 d16, d0, d5 \n" - "vmul.f32 d17, d0, d7 \n" - - "vld1.32 {d24-d25}, [%[wh]], r0 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - "vmla.f32 d16, d2, d7 \n" - "vmla.f32 d17, d2, d9 \n" - - "vld1.32 {d26-d27}, [%[wh]], r0 \n" - - // weights r2 - "vmla.f32 q9, q12, q4 \n" - "vmla.f32 q10, q12, q5 \n" - "vmla.f32 d16, d24, d9 \n" - "vmla.f32 d17, d24, d11 \n" - - "vld1.32 {d28-d29}, [%[wh]] \n" - - // weights r3 - "vmla.f32 q9, q13, q5 \n" - "vmla.f32 q10, q13, q6 \n" - "vmla.f32 d16, d26, d11 \n" - "vmla.f32 d17, d26, d13 \n" - - // weights r4 - "vmla.f32 q9, q14, q6 \n" - "vmla.f32 q10, q14, q7 \n" - "vmla.f32 d16, d28, d13 \n" - "vmla.f32 d17, d28, d15 \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - "vpadd.f32 d16, d16, d17 \n" - "vpadd.f32 d22, d18, d19 \n" - - //! out one - "vmov.f32 q15, #0.0 \n" - "vext.32 q2, q2, q15, #1 \n" - "vext.32 q3, q3, q15, #1 \n" - "vext.32 q4, q4, q15, #1 \n" - "vext.32 q5, q5, q15, #1 \n" - "vext.32 q6, q6, q15, #1 \n" - "vext.32 q7, q7, q15, #1 \n" - - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - // weights r2 - "vmla.f32 q9, q12, q4 \n" - "vmla.f32 q10, q12, q5 \n" - - // weights r3 - "vmla.f32 q9, q13, q5 \n" - "vmla.f32 q10, q13, q6 \n" - - // weights r4 - "vmla.f32 q9, q14, q6 \n" - "vmla.f32 q10, q14, q7 \n" - - "vld1.32 {d30-d31}, [%[bias]] \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - "vpadd.f32 d23, d18, d19 \n" - - // trn out neon register - "vtrn.32 d22, d23 \n" - - // add bias - "vadd.f32 q11, q11, q15 \n" - - // store result - "vst1.32 {d22}, [%[dout0]]! \n" - "vst1.32 {d23}, [%[dout1]]! \n" - - //! out three - "sub %[wh], #80 \n" - "vld1.32 {d4[0]}, [%[din0]] \n" - "vld1.32 {d4[1]}, [%[din1]] \n" - "vld1.32 {d5[0]}, [%[din2]] \n" - "vld1.32 {d5[1]}, [%[din3]] \n" - "vld1.32 {d6[0]}, [%[din4]] \n" - "vld1.32 {d6[1]}, [%[din5]] \n" - - "vext.32 q4, q2, q3, #1 \n" - - "vld1.32 {d0[0]}, [%[wh]], r0 \n" - "vld1.32 {d0[1]}, [%[wh]], r0 \n" - "vld1.32 {d1[0]}, [%[wh]], r0 \n" - "vld1.32 {d1[1]}, [%[wh]], r0 \n" - "vld1.32 {d2[0]}, [%[wh]] \n" - - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q4 \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d20, d20, d21 \n" - "vpadd.f32 d17, d18, d20 \n" - - "vmla.f32 d17, d6, d2[0] \n" - - // trn out neon register - "vtrn.32 d16, d17 \n" - - // add bias - "vadd.f32 q8, q8, q15 \n" - - // store result - "vst1.32 {d16}, [%[dout0]] \n" - "vst1.32 {d17}, [%[dout1]] \n" - - : [dout0] "+r"(dout0), - [dout1] "+r"(dout1), - [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [wh] "+r"(weights) - : [bias] "r"(bias) - : "memory", - "r0", - "r1", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -} - -//! kernel for three out with extracting data post -//! deal with two lines out -void compute_four_out_extract_post_relu(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - float* dout0, - float* dout1, - const float* weights, - const float* bias) { - asm volatile( - "mov r0, #20 \n" - "mov r1, #12 \n" - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - - "vld1.32 {d4-d5}, [%[din0]], r1 \n" - "vld1.32 {d6-d7}, [%[din1]], r1 \n" - "vld1.32 {d8-d9}, [%[din2]], r1 \n" - "vld1.32 {d10-d11}, [%[din3]], r1 \n" - "vld1.32 {d12-d13}, [%[din4]], r1 \n" - "vld1.32 {d14-d15}, [%[din5]], r1 \n" - - //! out zero && two - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - "vmul.f32 d16, d0, d5 \n" - "vmul.f32 d17, d0, d7 \n" - - "vld1.32 {d24-d25}, [%[wh]], r0 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - "vmla.f32 d16, d2, d7 \n" - "vmla.f32 d17, d2, d9 \n" - - "vld1.32 {d26-d27}, [%[wh]], r0 \n" - - // weights r2 - "vmla.f32 q9, q12, q4 \n" - "vmla.f32 q10, q12, q5 \n" - "vmla.f32 d16, d24, d9 \n" - "vmla.f32 d17, d24, d11 \n" - - "vld1.32 {d28-d29}, [%[wh]] \n" - - // weights r3 - "vmla.f32 q9, q13, q5 \n" - "vmla.f32 q10, q13, q6 \n" - "vmla.f32 d16, d26, d11 \n" - "vmla.f32 d17, d26, d13 \n" - - // weights r4 - "vmla.f32 q9, q14, q6 \n" - "vmla.f32 q10, q14, q7 \n" - "vmla.f32 d16, d28, d13 \n" - "vmla.f32 d17, d28, d15 \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - "vpadd.f32 d16, d16, d17 \n" - "vpadd.f32 d22, d18, d19 \n" - - //! out one - "vmov.f32 q15, #0.0 \n" - "vext.32 q2, q2, q15, #1 \n" - "vext.32 q3, q3, q15, #1 \n" - "vext.32 q4, q4, q15, #1 \n" - "vext.32 q5, q5, q15, #1 \n" - "vext.32 q6, q6, q15, #1 \n" - "vext.32 q7, q7, q15, #1 \n" - - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - // weights r2 - "vmla.f32 q9, q12, q4 \n" - "vmla.f32 q10, q12, q5 \n" - - // weights r3 - "vmla.f32 q9, q13, q5 \n" - "vmla.f32 q10, q13, q6 \n" - - // weights r4 - "vmla.f32 q9, q14, q6 \n" - "vmla.f32 q10, q14, q7 \n" - - "vld1.32 {d30-d31}, [%[bias]] \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - "vpadd.f32 d23, d18, d19 \n" - "vmov.i32 q5, #0x0 \n" - - // trn out neon register - "vtrn.32 d22, d23 \n" - - // add bias - "vadd.f32 q11, q11, q15 \n" - - // relu - "vmax.f32 q11, q11, q5 \n" - - // store result - "vst1.32 {d22}, [%[dout0]]! \n" - "vst1.32 {d23}, [%[dout1]]! \n" - - //! out three - "sub %[wh], #80 \n" - "vld1.32 {d4[0]}, [%[din0]] \n" - "vld1.32 {d4[1]}, [%[din1]] \n" - "vld1.32 {d5[0]}, [%[din2]] \n" - "vld1.32 {d5[1]}, [%[din3]] \n" - "vld1.32 {d6[0]}, [%[din4]] \n" - "vld1.32 {d6[1]}, [%[din5]] \n" - - "vext.32 q4, q2, q3, #1 \n" - - "vld1.32 {d0[0]}, [%[wh]], r0 \n" - "vld1.32 {d0[1]}, [%[wh]], r0 \n" - "vld1.32 {d1[0]}, [%[wh]], r0 \n" - "vld1.32 {d1[1]}, [%[wh]], r0 \n" - "vld1.32 {d2[0]}, [%[wh]] \n" - - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q4 \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d20, d20, d21 \n" - "vpadd.f32 d17, d18, d20 \n" - - "vmla.f32 d17, d6, d2[0] \n" - - // trn out neon register - "vtrn.32 d16, d17 \n" - - // add bias - "vadd.f32 q8, q8, q15 \n" - - // relu - "vmax.f32 q8, q8, q5 \n" - - // store result - "vst1.32 {d16}, [%[dout0]] \n" - "vst1.32 {d17}, [%[dout1]] \n" - - : [dout0] "+r"(dout0), - [dout1] "+r"(dout1), - [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [wh] "+r"(weights) - : [bias] "r"(bias) - : "memory", - "r0", - "r1", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -} - -void conv_depthwise_5x5s1_impl(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, +#define ROUNDUP(a, b) ((((a) + (b)-1) / (b)) * (b)) +#ifdef __aarch64__ +void conv_depthwise_5x5s1_fp32(float* dout, + const float* din, const float* weights, const float* bias, - int pad, bool flag_bias, bool flag_relu, + int num, + int chin, + int hin, + int win, + int hout, + int wout, + int padw, + int padh, + const operators::ConvParam& param, ARMContext* ctx) { - float* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - float* write_ptr = zero_ptr + w_in; - int pad_new = pad > 4 ? 4 : pad; - int pad_0 = pad - pad_new; - int h_out_new = h_out - 2 * pad_0; - int mid_out = w_out - 2 * pad; - int mid_cnt = mid_out >> 2; - int mid_remain = mid_out - (mid_cnt << 2); - int pad_cnt = pad_0 >> 2; - int pad_remain = pad_0 - (pad_cnt << 2); - int bias_cnt = (w_out * pad_0) >> 2; - int bias_remain = (w_out * pad_0) - (bias_cnt << 2); - int in_spatial_size = w_in * h_in; - int out_spatial_size = w_out * h_out; - int weights_saptial_size = 25; - - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * in_spatial_size * ch_in; - float* dout_batch = dout + n * out_spatial_size * ch_out; -#pragma omp parallel for - for (int c = 0; c < ch_in; ++c) { - const float* din_ch = din_batch + c * in_spatial_size; - float* dout_ch = dout_batch + c * out_spatial_size; - float bias_c = flag_bias ? bias[c] : 0.f; - float vbias[4] = {bias_c, bias_c, bias_c, bias_c}; - float32x4_t vbias_c = vdupq_n_f32(bias_c); - if (flag_bias) { - //! deal with h_out pad_0 line with bias - for (int i = 0; i < bias_cnt; ++i) { - vst1q_f32(dout_ch, vbias_c); - dout_ch += 4; - } - for (int i = 0; i < bias_remain; ++i) { - *dout_ch++ = bias_c; - } - } else { - //! deal with h_out pad_0 line without bias - for (int i = 0; i < pad_0; ++i) { - memset(dout_ch, 0x00, w_out * sizeof(float)); - dout_ch += w_out; - } - } - const float* din_list[6]; - //! set din ptr with zero buffer - for (int i = 0; i < pad_new; ++i) { - din_list[i] = zero_ptr; - } - //! set din ptr with input data - for (int i = pad_new; i < 6; ++i) { - din_list[i] = din_ch; - din_ch += w_in; - } - //! every h loop, deal with 6 line input - const float* din0 = din_list[0]; - const float* din1 = din_list[1]; - const float* din2 = din_list[2]; - const float* din3 = din_list[3]; - const float* din4 = din_list[4]; - const float* din5 = din_list[5]; - - //! every h loop, deal with 2 line output - float* dout0 = dout_ch; - float* dout1 = dout0 + w_out; - - //! load weights to neon register - const float* weights_c = weights + c * weights_saptial_size; - - //! h loop - for (int h = 0; h < h_out_new; h += 2) { - //! (h - pad_new) + 7 > h_in - 1 - if (h + 6 - pad_new > h_in) { - switch (h + 6 - pad_new - h_in) { - case 5: - din1 = zero_ptr; - case 4: - din2 = zero_ptr; - case 3: - din3 = zero_ptr; - case 2: - din4 = zero_ptr; - case 1: - din5 = zero_ptr; - default: - break; - } - } - if (h + 2 > h_out_new) { - dout1 = write_ptr; - } - const float* din_ptr0 = din0; - const float* din_ptr1 = din1; - const float* din_ptr2 = din2; - const float* din_ptr3 = din3; - const float* din_ptr4 = din4; - const float* din_ptr5 = din5; - - float* dout_ptr0 = dout0; - float* dout_ptr1 = dout1; - if (flag_bias) { - //! deal with w_out pad_0 column pre with bias - for (int i = 0; i < pad_cnt; i++) { - vst1q_f32(dout_ptr0, vbias_c); - vst1q_f32(dout_ptr1, vbias_c); - dout_ptr0 += 4; - dout_ptr1 += 4; - } - for (int i = 0; i < pad_remain; ++i) { - *dout_ptr0++ = bias_c; - *dout_ptr1++ = bias_c; - } - } else { - //! deal with w_out pad_0 column pre without bias - memset(dout_ptr0, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr1, 0x00, pad_0 * sizeof(float)); - dout_ptr0 += pad_0; - dout_ptr1 += pad_0; - } - - //! deal with w_out pad_new column pre - switch (pad_new) { - case 4: - compute_four_out_extract_pre(din_ptr0, - din_ptr1, - din_ptr2, - din_ptr3, - din_ptr4, - din_ptr5, - dout_ptr0, - dout_ptr1, - weights_c, - vbias); - dout_ptr0 += 4; - dout_ptr1 += 4; - break; - case 3: - compute_three_out_extract_pre(din_ptr0, - din_ptr1, - din_ptr2, - din_ptr3, - din_ptr4, - din_ptr5, - dout_ptr0, - dout_ptr1, - weights_c, - vbias); - dout_ptr0 += 3; - dout_ptr1 += 3; - break; - case 2: - compute_two_out_extract_pre(din_ptr0, - din_ptr1, - din_ptr2, - din_ptr3, - din_ptr4, - din_ptr5, - dout_ptr0, - dout_ptr1, - weights_c, - vbias); - dout_ptr0 += 2; - dout_ptr1 += 2; - break; - case 1: - compute_one_out_extract_pre(din_ptr0, - din_ptr1, - din_ptr2, - din_ptr3, - din_ptr4, - din_ptr5, - dout_ptr0, - dout_ptr1, - weights_c, - vbias); - dout_ptr0 += 1; - dout_ptr1 += 1; - break; - } - - //! mid loop - if (mid_cnt > 0) { - int mid_loop = mid_cnt; - const float* weights_ptr = weights_c; - asm volatile( - //! din: q7-q12 - //! dout: q13, q14 - "mov r1, #20 \n" - //! load weights - "vld1.32 {d0-d1}, [%[wh]], r1 \n" - "vld1.32 {d2-d3}, [%[wh]], r1 \n" - "vld1.32 {d4-d5}, [%[wh]], r1 \n" - "vld1.32 {d6-d7}, [%[wh]], r1 \n" - "vld1.32 {d8-d9}, [%[wh]] \n" - - "sub %[wh], #64 \n" - "vld1.32 {d10[0]}, [%[wh]], r1 \n" - "vld1.32 {d10[1]}, [%[wh]], r1 \n" - "vld1.32 {d11[0]}, [%[wh]], r1 \n" - "vld1.32 {d11[1]}, [%[wh]], r1 \n" - "vld1.32 {d12[0]}, [%[wh]] \n" - - //! load input - "mov r1, #4 \n" - "vld1.32 {d14-d15}, [%[din0]], r1 \n" - "vld1.32 {d16-d17}, [%[din1]], r1 \n" - "vld1.32 {d18-d19}, [%[din2]], r1 \n" - "vld1.32 {d20-d21}, [%[din3]], r1 \n" - "vld1.32 {d22-d23}, [%[din4]], r1 \n" - "vld1.32 {d24-d25}, [%[din5]], r1 \n" - - //! load bias - "vld1.32 {d30-d31}, [%[bias]] \n" - - "1: \n" - //! add bias to output - "vmov.32 q13, q15 \n" - "vmov.32 q14, q15 \n" - - "pld [%[din0]] \n" - "pld [%[din1]] \n" - "pld [%[din2]] \n" - "pld [%[din3]] \n" - "pld [%[din4]] \n" - "pld [%[din5]] \n" - - // weights col 0 - "vmla.f32 q13, q7, d0[0] \n" - "vmla.f32 q14, q8, d0[0] \n" - - "vmla.f32 q13, q8, d2[0] \n" - "vmla.f32 q14, q9, d2[0] \n" - - "vld1.32 {d14-d15}, [%[din0]], r1 \n" - "vld1.32 {d16-d17}, [%[din1]], r1 \n" - - "vmla.f32 q13, q9, d4[0] \n" - "vmla.f32 q14, q10, d4[0] \n" - - "vmla.f32 q13, q10, d6[0] \n" - "vmla.f32 q14, q11, d6[0] \n" - - "vld1.32 {d18-d19}, [%[din2]], r1 \n" - "vld1.32 {d20-d21}, [%[din3]], r1 \n" - - "vmla.f32 q13, q11, d8[0] \n" - "vmla.f32 q14, q12, d8[0] \n" - - "vld1.32 {d22-d23}, [%[din4]], r1 \n" - "vld1.32 {d24-d25}, [%[din5]], r1 \n" - - // weights col 1 - "vmla.f32 q13, q7, d0[1] \n" - "vmla.f32 q14, q8, d0[1] \n" - - "vmla.f32 q13, q8, d2[1] \n" - "vmla.f32 q14, q9, d2[1] \n" - - "vld1.32 {d14-d15}, [%[din0]], r1 \n" - "vld1.32 {d16-d17}, [%[din1]], r1 \n" - - "vmla.f32 q13, q9, d4[1] \n" - "vmla.f32 q14, q10, d4[1] \n" - - "vmla.f32 q13, q10, d6[1] \n" - "vmla.f32 q14, q11, d6[1] \n" - - "vld1.32 {d18-d19}, [%[din2]], r1 \n" - "vld1.32 {d20-d21}, [%[din3]], r1 \n" - - "vmla.f32 q13, q11, d8[1] \n" - "vmla.f32 q14, q12, d8[1] \n" - - "vld1.32 {d22-d23}, [%[din4]], r1 \n" - "vld1.32 {d24-d25}, [%[din5]], r1 \n" - - // weights col 2 - "vmla.f32 q13, q7, d1[0] \n" - "vmla.f32 q14, q8, d1[0] \n" - - "vmla.f32 q13, q8, d3[0] \n" - "vmla.f32 q14, q9, d3[0] \n" - - "vld1.32 {d14-d15}, [%[din0]], r1 \n" - "vld1.32 {d16-d17}, [%[din1]], r1 \n" - - "vmla.f32 q13, q9, d5[0] \n" - "vmla.f32 q14, q10, d5[0] \n" - - "vmla.f32 q13, q10, d7[0] \n" - "vmla.f32 q14, q11, d7[0] \n" - - "vld1.32 {d18-d19}, [%[din2]], r1 \n" - "vld1.32 {d20-d21}, [%[din3]], r1 \n" - - "vmla.f32 q13, q11, d9[0] \n" - "vmla.f32 q14, q12, d9[0] \n" - - "vld1.32 {d22-d23}, [%[din4]], r1 \n" - "vld1.32 {d24-d25}, [%[din5]], r1 \n" - - // weights col 3 - "vmla.f32 q13, q7, d1[1] \n" - "vmla.f32 q14, q8, d1[1] \n" - - "vmla.f32 q13, q8, d3[1] \n" - "vmla.f32 q14, q9, d3[1] \n" - - "vld1.32 {d14-d15}, [%[din0]], r1 \n" - "vld1.32 {d16-d17}, [%[din1]], r1 \n" - - "vmla.f32 q13, q9, d5[1] \n" - "vmla.f32 q14, q10, d5[1] \n" - - "vmla.f32 q13, q10, d7[1] \n" - "vmla.f32 q14, q11, d7[1] \n" - - "vld1.32 {d18-d19}, [%[din2]], r1 \n" - "vld1.32 {d20-d21}, [%[din3]], r1 \n" - - "vmla.f32 q13, q11, d9[1] \n" - "vmla.f32 q14, q12, d9[1] \n" - - "vld1.32 {d22-d23}, [%[din4]], r1 \n" - "vld1.32 {d24-d25}, [%[din5]], r1 \n" - - // weights col 4 - "vmla.f32 q13, q7, d10[0] \n" - "vmla.f32 q14, q8, d10[0] \n" - - "vmla.f32 q13, q8, d10[1] \n" - "vmla.f32 q14, q9, d10[1] \n" - - "vmla.f32 q13, q9, d11[0] \n" - "vmla.f32 q14, q10, d11[0] \n" - - "vmla.f32 q13, q10, d11[1] \n" - "vmla.f32 q14, q11, d11[1] \n" - - "vmla.f32 q13, q11, d12[0] \n" - "vmla.f32 q14, q12, d12[0] \n" - - // store reslult - "vst1.32 {d26-d27}, [%[out0]]! \n" - "vst1.32 {d28-d29}, [%[out1]]! \n" - - "subs %[cnt], #1 \n" - "bne 1b \n" - - "sub %[din0], r1 \n" - "sub %[din1], r1 \n" - "sub %[din2], r1 \n" - "sub %[din3], r1 \n" - "sub %[din4], r1 \n" - "sub %[din5], r1 \n" - - : [din0] "+r"(din_ptr0), - [din1] "+r"(din_ptr1), - [din2] "+r"(din_ptr2), - [din3] "+r"(din_ptr3), - [din4] "+r"(din_ptr4), - [din5] "+r"(din_ptr5), - [out0] "+r"(dout_ptr0), - [out1] "+r"(dout_ptr1), - [wh] "+r"(weights_ptr), - [cnt] "+r"(mid_loop) - : [bias] "r"(vbias) - : "cc", - "memory", - "r1", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - } - //! deal with mid remain - for (int i = 0; i < mid_remain; ++i) { - compute_one_out_without_extract(din_ptr0, - din_ptr1, - din_ptr2, - din_ptr3, - din_ptr4, - din_ptr5, - dout_ptr0, - dout_ptr1, - weights_c, - vbias); - din_ptr0++; - din_ptr1++; - din_ptr2++; - din_ptr3++; - din_ptr4++; - din_ptr5++; - - dout_ptr0++; - dout_ptr1++; - } - //! deal with w_out pad_new column post - switch (pad_new) { - case 4: - compute_four_out_extract_post(din_ptr0, - din_ptr1, - din_ptr2, - din_ptr3, - din_ptr4, - din_ptr5, - dout_ptr0, - dout_ptr1, - weights_c, - vbias); - dout_ptr0 += 4; - dout_ptr1 += 4; - break; - case 3: - compute_three_out_extract_post(din_ptr0, - din_ptr1, - din_ptr2, - din_ptr3, - din_ptr4, - din_ptr5, - dout_ptr0, - dout_ptr1, - weights_c, - vbias); - dout_ptr0 += 3; - dout_ptr1 += 3; - break; - case 2: - compute_two_out_extract_post(din_ptr0, - din_ptr1, - din_ptr2, - din_ptr3, - din_ptr4, - din_ptr5, - dout_ptr0, - dout_ptr1, - weights_c, - vbias); - dout_ptr0 += 2; - dout_ptr1 += 2; - break; - case 1: - compute_one_out_extract_post(din_ptr0, - din_ptr1, - din_ptr2, - din_ptr3, - din_ptr4, - din_ptr5, - dout_ptr0, - dout_ptr1, - weights_c, - vbias); - dout_ptr0 += 1; - dout_ptr1 += 1; - break; - } - - if (flag_bias) { - //! deal with w_out pad_0 column post with bias - memcpy(dout_ptr0, dout0, pad_0 * sizeof(float)); - memcpy(dout_ptr1, dout1, pad_0 * sizeof(float)); - } else { - //! deal with w_out pad_0 column post without bias - memset(dout_ptr0, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr1, 0x00, pad_0 * sizeof(float)); - } - - din0 = din2; - din1 = din3; - din2 = din4; - din3 = din5; - din4 = din3 + w_in; - din5 = din4 + w_in; - - dout0 = dout1 + w_out; - dout1 = dout0 + w_out; - } - float* dout_pad_end = dout_ch + h_out_new * w_out; - if (flag_bias) { - //! deal with h_out pad_0 line with bias - memcpy(reinterpret_cast(dout_pad_end), - dout_ch - pad_0 * w_out, - pad_0 * w_out * sizeof(float)); - } else { - //! deal with h_out pad_0 line without bias - memset(reinterpret_cast(dout_pad_end), - 0x00, - pad_0 * w_out * sizeof(float)); - } - } - } -} - -void conv_depthwise_5x5s1_relu_impl(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - int pad, - bool flag_bias, - bool flag_relu, - ARMContext* ctx) { - float* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - float* write_ptr = zero_ptr + w_in; - int pad_new = pad > 4 ? 4 : pad; - int pad_0 = pad - pad_new; - int h_out_new = h_out - 2 * pad_0; - int mid_out = w_out - 2 * pad; - int mid_cnt = mid_out >> 2; - int mid_remain = mid_out - (mid_cnt << 2); - int pad_cnt = pad_0 >> 2; - int pad_remain = pad_0 - (pad_cnt << 2); - int bias_cnt = (w_out * pad_0) >> 2; - int bias_remain = (w_out * pad_0) - (bias_cnt << 2); - int in_spatial_size = w_in * h_in; - int out_spatial_size = w_out * h_out; - int weights_saptial_size = 25; - + const int threads = ctx->threads(); + int llc_size = ctx->llc_size() / 4; + auto act_param = param.activation_param; + const int hout_c_block = 4; + const int hout_r_kernel = 2; + const int wout_block = 4; + const int wout_round = ((wout + wout_block - 1) / wout_block) * wout_block; + const int win_round = wout_round + 4; + + //! get h block + //! llc_size = threads * win_round * hout_c_block * hin_r_block * + //! sizeof(float) + //! + wout_round * hout_c_block * hout_r_block * threads * sizeof(float) + //! win_round = wout_round + 4 + //! hin_r_block = hout_r_block + 4 + int hout_r_block = (llc_size - 16 * win_round * hout_c_block * threads) / + (win_round * hout_c_block * threads * 4 + + hout_c_block * wout_round * threads * 4); + hout_r_block = hout_r_block > hout ? hout : hout_r_block; + hout_r_block = + ((hout_r_block + hout_r_kernel - 1) / hout_r_kernel) * hout_r_kernel; + hout_r_block = hout_r_block < hout_r_kernel ? hout_r_kernel : hout_r_block; + + const int hin_r_block = hout_r_block + 4; + + float* tmp_work_space = ctx->workspace_data(); + float ptr_zero[win_round]; // NOLINT + memset(ptr_zero, 0, sizeof(float) * win_round); + float ptr_write[wout_round]; // NOLINT + + int in_len = win_round * hout_c_block; + int pre_in_size = hin_r_block * in_len; + pre_in_size = ROUNDUP(pre_in_size, 4); + int pre_out_size = hout_c_block * hout_r_block * wout_round; + + float* tmp_din = tmp_work_space; + + int size_in_channel = win * hin; + int size_out_channel = wout * hout; + int w_stride = 25; // kernel_w * kernel_h; + + int ws = -padw; + int we = ws + win_round; + int w_loop = wout_round / 4; + int chout = chin; + + int out_row_stride = hout_c_block * wout_round; for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * in_spatial_size * ch_in; - float* dout_batch = dout + n * out_spatial_size * ch_out; -#pragma omp parallel for - for (int c = 0; c < ch_in; ++c) { - const float* din_ch = din_batch + c * in_spatial_size; - float* dout_ch = dout_batch + c * out_spatial_size; - float bias_c = flag_bias ? bias[c] : 0.f; - float bias_relu = bias_c > 0.f ? bias_c : 0.f; - float vbias[4] = {bias_c, bias_c, bias_c, bias_c}; - float32x4_t vbias_c = vdupq_n_f32(bias_relu); - if (flag_bias) { - //! deal with h_out pad_0 line with bias - for (int i = 0; i < bias_cnt; ++i) { - vst1q_f32(dout_ch, vbias_c); - dout_ch += 4; - } - for (int i = 0; i < bias_remain; ++i) { - *dout_ch++ = bias_relu; - } - } else { - //! deal with h_out pad_0 line without bias - for (int i = 0; i < pad_0; ++i) { - memset(dout_ch, 0x00, w_out * sizeof(float)); - dout_ch += w_out; - } - } - const float* din_list[6]; - //! set din ptr with zero buffer - for (int i = 0; i < pad_new; ++i) { - din_list[i] = zero_ptr; - } - //! set din ptr with input data - for (int i = pad_new; i < 6; ++i) { - din_list[i] = din_ch; - din_ch += w_in; - } - //! every h loop, deal with 6 line input - const float* din0 = din_list[0]; - const float* din1 = din_list[1]; - const float* din2 = din_list[2]; - const float* din3 = din_list[3]; - const float* din4 = din_list[4]; - const float* din5 = din_list[5]; - - //! every h loop, deal with 2 line output - float* dout0 = dout_ch; - float* dout1 = dout0 + w_out; - - //! load weights to neon register - const float* weights_c = weights + c * weights_saptial_size; - - //! h loop - for (int h = 0; h < h_out_new; h += 2) { - //! (h - pad_new) + 7 > h_in - 1 - if (h + 6 - pad_new > h_in) { - switch (h + 6 - pad_new - h_in) { - case 5: - din1 = zero_ptr; - case 4: - din2 = zero_ptr; - case 3: - din3 = zero_ptr; - case 2: - din4 = zero_ptr; - case 1: - din5 = zero_ptr; - default: - break; - } - } - if (h + 2 > h_out_new) { - dout1 = write_ptr; - } - const float* din_ptr0 = din0; - const float* din_ptr1 = din1; - const float* din_ptr2 = din2; - const float* din_ptr3 = din3; - const float* din_ptr4 = din4; - const float* din_ptr5 = din5; - - float* dout_ptr0 = dout0; - float* dout_ptr1 = dout1; + const float* din_batch = din + n * chin * size_in_channel; + float* dout_batch = dout + n * chout * size_out_channel; + for (int h = 0; h < hout; h += hout_r_block) { + int h_kernel = hout_r_block; + if (h + hout_r_block > hout) { + h_kernel = hout - h; + } + int hs = h - padh; + int he = hs + h_kernel + 4; + +#pragma omp parallel for num_threads(threads) + for (int c = 0; c < chout; c += hout_c_block) { +#ifdef ARM_WITH_OMP + float* pre_din = + tmp_din + omp_get_thread_num() * (pre_in_size + pre_out_size); + float* pre_out = pre_din + pre_in_size; +#else + float* pre_din = tmp_din; + float* pre_out = pre_din + pre_in_size; +#endif + prepack_input_nxwc4_dw( + din_batch, pre_din, c, hs, he, ws, we, chin, win, hin, ptr_zero); + const float* block_inr0 = pre_din; + const float* block_inr1 = block_inr0 + in_len; + const float* block_inr2 = block_inr1 + in_len; + const float* block_inr3 = block_inr2 + in_len; + const float* block_inr4 = block_inr3 + in_len; + const float* block_inr5 = block_inr4 + in_len; + + const float* weight_c = weights + c * w_stride; + float bias_local[4] = {0, 0, 0, 0}; if (flag_bias) { - //! deal with w_out pad_0 column pre with bias - for (int i = 0; i < pad_cnt; i++) { - vst1q_f32(dout_ptr0, vbias_c); - vst1q_f32(dout_ptr1, vbias_c); - dout_ptr0 += 4; - dout_ptr1 += 4; - } - for (int i = 0; i < pad_remain; ++i) { - *dout_ptr0++ = bias_relu; - *dout_ptr1++ = bias_relu; - } - } else { - //! deal with w_out pad_0 column pre without bias - memset(dout_ptr0, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr1, 0x00, pad_0 * sizeof(float)); - dout_ptr0 += pad_0; - dout_ptr1 += pad_0; - } - - //! deal with w_out pad_new column pre - switch (pad_new) { - case 4: - compute_four_out_extract_pre_relu(din_ptr0, - din_ptr1, - din_ptr2, - din_ptr3, - din_ptr4, - din_ptr5, - dout_ptr0, - dout_ptr1, - weights_c, - vbias); - dout_ptr0 += 4; - dout_ptr1 += 4; - break; - case 3: - compute_three_out_extract_pre_relu(din_ptr0, - din_ptr1, - din_ptr2, - din_ptr3, - din_ptr4, - din_ptr5, - dout_ptr0, - dout_ptr1, - weights_c, - vbias); - dout_ptr0 += 3; - dout_ptr1 += 3; - break; - case 2: - compute_two_out_extract_pre_relu(din_ptr0, - din_ptr1, - din_ptr2, - din_ptr3, - din_ptr4, - din_ptr5, - dout_ptr0, - dout_ptr1, - weights_c, - vbias); - dout_ptr0 += 2; - dout_ptr1 += 2; - break; - case 1: - compute_one_out_extract_pre_relu(din_ptr0, - din_ptr1, - din_ptr2, - din_ptr3, - din_ptr4, - din_ptr5, - dout_ptr0, - dout_ptr1, - weights_c, - vbias); - dout_ptr0 += 1; - dout_ptr1 += 1; - break; - } - - //! mid loop - if (mid_cnt > 0) { - int mid_loop = mid_cnt; - const float* weights_ptr = weights_c; + bias_local[0] = bias[c]; + bias_local[1] = bias[c + 1]; + bias_local[2] = bias[c + 2]; + bias_local[3] = bias[c + 3]; + } + for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) { + int cnt = w_loop; + const float* inr0 = block_inr0; + const float* inr1 = block_inr1; + const float* inr2 = block_inr2; + const float* inr3 = block_inr3; + const float* inr4 = block_inr4; + const float* inr5 = block_inr5; + + float* ptr_out0 = pre_out + hk * out_row_stride; + float* ptr_out1 = ptr_out0 + out_row_stride; + // clang-format off + auto wptr = weight_c; asm volatile( - //! din: q7-q12 - //! dout: q13, q14 - "mov r1, #20 \n" - "vmov.i32 q15, #0x0 \n" - //! load weights - "vld1.32 {d0-d1}, [%[wh]], r1 \n" - "vld1.32 {d2-d3}, [%[wh]], r1 \n" - "vld1.32 {d4-d5}, [%[wh]], r1 \n" - "vld1.32 {d6-d7}, [%[wh]], r1 \n" - "vld1.32 {d8-d9}, [%[wh]] \n" - - "sub %[wh], #64 \n" - "vld1.32 {d10[0]}, [%[wh]], r1 \n" - "vld1.32 {d10[1]}, [%[wh]], r1 \n" - "vld1.32 {d11[0]}, [%[wh]], r1 \n" - "vld1.32 {d11[1]}, [%[wh]], r1 \n" - "vld1.32 {d12[0]}, [%[wh]] \n" - - //! load input - "mov r1, #4 \n" - "vld1.32 {d14-d15}, [%[din0]], r1 \n" - "vld1.32 {d16-d17}, [%[din1]], r1 \n" - "vld1.32 {d18-d19}, [%[din2]], r1 \n" - "vld1.32 {d20-d21}, [%[din3]], r1 \n" - "vld1.32 {d22-d23}, [%[din4]], r1 \n" - "vld1.32 {d24-d25}, [%[din5]], r1 \n" - - "1: \n" - - //! load bias to output - "vld1.32 {d26-d27}, [%[bias]] \n" - "vld1.32 {d28-d29}, [%[bias]] \n" - - "pld [%[din0]] \n" - "pld [%[din1]] \n" - "pld [%[din2]] \n" - "pld [%[din3]] \n" - "pld [%[din4]] \n" - "pld [%[din5]] \n" - - // weights col 0 - "vmla.f32 q13, q7, d0[0] \n" - "vmla.f32 q14, q8, d0[0] \n" - - "vmla.f32 q13, q8, d2[0] \n" - "vmla.f32 q14, q9, d2[0] \n" - - "vld1.32 {d14-d15}, [%[din0]], r1 \n" - "vld1.32 {d16-d17}, [%[din1]], r1 \n" - - "vmla.f32 q13, q9, d4[0] \n" - "vmla.f32 q14, q10, d4[0] \n" - - "vmla.f32 q13, q10, d6[0] \n" - "vmla.f32 q14, q11, d6[0] \n" - - "vld1.32 {d18-d19}, [%[din2]], r1 \n" - "vld1.32 {d20-d21}, [%[din3]], r1 \n" - - "vmla.f32 q13, q11, d8[0] \n" - "vmla.f32 q14, q12, d8[0] \n" - - "vld1.32 {d22-d23}, [%[din4]], r1 \n" - "vld1.32 {d24-d25}, [%[din5]], r1 \n" - - // weights col 1 - "vmla.f32 q13, q7, d0[1] \n" - "vmla.f32 q14, q8, d0[1] \n" - - "vmla.f32 q13, q8, d2[1] \n" - "vmla.f32 q14, q9, d2[1] \n" - - "vld1.32 {d14-d15}, [%[din0]], r1 \n" - "vld1.32 {d16-d17}, [%[din1]], r1 \n" - - "vmla.f32 q13, q9, d4[1] \n" - "vmla.f32 q14, q10, d4[1] \n" - - "vmla.f32 q13, q10, d6[1] \n" - "vmla.f32 q14, q11, d6[1] \n" - - "vld1.32 {d18-d19}, [%[din2]], r1 \n" - "vld1.32 {d20-d21}, [%[din3]], r1 \n" - - "vmla.f32 q13, q11, d8[1] \n" - "vmla.f32 q14, q12, d8[1] \n" - - "vld1.32 {d22-d23}, [%[din4]], r1 \n" - "vld1.32 {d24-d25}, [%[din5]], r1 \n" - - // weights col 2 - "vmla.f32 q13, q7, d1[0] \n" - "vmla.f32 q14, q8, d1[0] \n" - - "vmla.f32 q13, q8, d3[0] \n" - "vmla.f32 q14, q9, d3[0] \n" - - "vld1.32 {d14-d15}, [%[din0]], r1 \n" - "vld1.32 {d16-d17}, [%[din1]], r1 \n" - - "vmla.f32 q13, q9, d5[0] \n" - "vmla.f32 q14, q10, d5[0] \n" - - "vmla.f32 q13, q10, d7[0] \n" - "vmla.f32 q14, q11, d7[0] \n" - - "vld1.32 {d18-d19}, [%[din2]], r1 \n" - "vld1.32 {d20-d21}, [%[din3]], r1 \n" - - "vmla.f32 q13, q11, d9[0] \n" - "vmla.f32 q14, q12, d9[0] \n" - - "vld1.32 {d22-d23}, [%[din4]], r1 \n" - "vld1.32 {d24-d25}, [%[din5]], r1 \n" - - // weights col 3 - "vmla.f32 q13, q7, d1[1] \n" - "vmla.f32 q14, q8, d1[1] \n" - - "vmla.f32 q13, q8, d3[1] \n" - "vmla.f32 q14, q9, d3[1] \n" - - "vld1.32 {d14-d15}, [%[din0]], r1 \n" - "vld1.32 {d16-d17}, [%[din1]], r1 \n" - - "vmla.f32 q13, q9, d5[1] \n" - "vmla.f32 q14, q10, d5[1] \n" - - "vmla.f32 q13, q10, d7[1] \n" - "vmla.f32 q14, q11, d7[1] \n" - - "vld1.32 {d18-d19}, [%[din2]], r1 \n" - "vld1.32 {d20-d21}, [%[din3]], r1 \n" - - "vmla.f32 q13, q11, d9[1] \n" - "vmla.f32 q14, q12, d9[1] \n" - - "vld1.32 {d22-d23}, [%[din4]], r1 \n" - "vld1.32 {d24-d25}, [%[din5]], r1 \n" - - // weights col 4 - "vmla.f32 q13, q7, d10[0] \n" - "vmla.f32 q14, q8, d10[0] \n" - - "vmla.f32 q13, q8, d10[1] \n" - "vmla.f32 q14, q9, d10[1] \n" - - "vmla.f32 q13, q9, d11[0] \n" - "vmla.f32 q14, q10, d11[0] \n" - - "vmla.f32 q13, q10, d11[1] \n" - "vmla.f32 q14, q11, d11[1] \n" - - "vmla.f32 q13, q11, d12[0] \n" - "vmla.f32 q14, q12, d12[0] \n" - - // relu - "vmax.f32 q13, q13, q15 \n" - "vmax.f32 q14, q14, q15 \n" - - // store result - "vst1.32 {d26-d27}, [%[out0]]! \n" - "vst1.32 {d28-d29}, [%[out1]]! \n" - - "subs %[cnt], #1 \n" - "bne 1b \n" - - "sub %[din0], r1 \n" - "sub %[din1], r1 \n" - "sub %[din2], r1 \n" - "sub %[din3], r1 \n" - "sub %[din4], r1 \n" - "sub %[din5], r1 \n" - - : [din0] "+r"(din_ptr0), - [din1] "+r"(din_ptr1), - [din2] "+r"(din_ptr2), - [din3] "+r"(din_ptr3), - [din4] "+r"(din_ptr4), - [din5] "+r"(din_ptr5), - [out0] "+r"(dout_ptr0), - [out1] "+r"(dout_ptr1), - [wh] "+r"(weights_ptr), - [cnt] "+r"(mid_loop) - : [bias] "r"(vbias) - : "cc", - "memory", - "r1", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - } - //! deal with mid remain - for (int i = 0; i < mid_remain; ++i) { - compute_one_out_without_extract_relu(din_ptr0, - din_ptr1, - din_ptr2, - din_ptr3, - din_ptr4, - din_ptr5, - dout_ptr0, - dout_ptr1, - weights_c, - vbias); - din_ptr0++; - din_ptr1++; - din_ptr2++; - din_ptr3++; - din_ptr4++; - din_ptr5++; - - dout_ptr0++; - dout_ptr1++; - } - //! deal with w_out pad_new column post - switch (pad_new) { - case 4: - compute_four_out_extract_post_relu(din_ptr0, - din_ptr1, - din_ptr2, - din_ptr3, - din_ptr4, - din_ptr5, - dout_ptr0, - dout_ptr1, - weights_c, - vbias); - dout_ptr0 += 4; - dout_ptr1 += 4; - break; - case 3: - compute_three_out_extract_post_relu(din_ptr0, - din_ptr1, - din_ptr2, - din_ptr3, - din_ptr4, - din_ptr5, - dout_ptr0, - dout_ptr1, - weights_c, - vbias); - dout_ptr0 += 3; - dout_ptr1 += 3; - break; - case 2: - compute_two_out_extract_post_relu(din_ptr0, - din_ptr1, - din_ptr2, - din_ptr3, - din_ptr4, - din_ptr5, - dout_ptr0, - dout_ptr1, - weights_c, - vbias); - dout_ptr0 += 2; - dout_ptr1 += 2; - break; - case 1: - compute_one_out_extract_post_relu(din_ptr0, - din_ptr1, - din_ptr2, - din_ptr3, - din_ptr4, - din_ptr5, - dout_ptr0, - dout_ptr1, - weights_c, - vbias); - dout_ptr0 += 1; - dout_ptr1 += 1; - break; - } - - if (flag_bias) { - //! deal with w_out pad_0 column post with bias - memcpy(dout_ptr0, dout0, pad_0 * sizeof(float)); - memcpy(dout_ptr1, dout1, pad_0 * sizeof(float)); - } else { - //! deal with w_out pad_0 column post without bias - memset(dout_ptr0, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr1, 0x00, pad_0 * sizeof(float)); - } - - din0 = din2; - din1 = din3; - din2 = din4; - din3 = din5; - din4 = din3 + w_in; - din5 = din4 + w_in; - - dout0 = dout1 + w_out; - dout1 = dout0 + w_out; - } - float* dout_pad_end = dout_ch + h_out_new * w_out; - if (flag_bias) { - //! deal with h_out pad_0 line with bias - memcpy(reinterpret_cast(dout_pad_end), - dout_ch - pad_0 * w_out, - pad_0 * w_out * sizeof(float)); - } else { - //! deal with h_out pad_0 line without bias - memset(reinterpret_cast(dout_pad_end), - 0x00, - pad_0 * w_out * sizeof(float)); - } - } - } -} - -void conv_depthwise_5x5s1_small_impl(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - int pad, - bool flag_bias, - bool flag_relu, - ARMContext* ctx) { - int pad_new = pad > 4 ? 4 : pad; - int pad_0 = pad - pad_new; - int h_in_new = h_in + 2 * pad_new; - int w_in_new = w_in + 2 * pad_new; - int h_out_new = h_out - 2 * pad_0; - int w_out_new = w_out - 2 * pad_0; - float zero_ptr[w_in_new + w_out]; // NOLINT - memset(zero_ptr, 0, w_in_new * sizeof(float)); - float* write_ptr = zero_ptr + w_in_new; - int pad_cnt = pad_0 >> 2; - int pad_remain = pad_0 - (pad_cnt << 2); - int bias_cnt = (w_out * pad_0) >> 2; - int bias_remain = (w_out * pad_0) - (bias_cnt << 2); - int in_spatial_size = w_in_new * h_in_new; - int out_spatial_size = w_out * h_out; - int weights_saptial_size = 25; - - float* din_new = prepad_input(din, num, ch_in, h_in, w_in, pad_new); - for (int n = 0; n < num; ++n) { - const float* din_batch = din_new + n * in_spatial_size * ch_in; - float* dout_batch = dout + n * out_spatial_size * ch_out; -#pragma omp parallel for - for (int c = 0; c < ch_in; ++c) { - const float* din_ch = din_batch + c * in_spatial_size; - float* dout_ch = dout_batch + c * out_spatial_size; - float bias_c = flag_bias ? bias[c] : 0.f; - float vbias[4] = {bias_c, bias_c, bias_c, bias_c}; - float32x4_t vbias_c = vdupq_n_f32(bias_c); - if (flag_bias) { - //! deal with h_out pad_0 line with bias - for (int i = 0; i < bias_cnt; ++i) { - vst1q_f32(dout_ch, vbias_c); - dout_ch += 4; - } - for (int i = 0; i < bias_remain; ++i) { - *dout_ch++ = bias_c; - } - } else { - //! deal with h_out pad_0 line without bias - for (int i = 0; i < pad_0; ++i) { - memset(dout_ch, 0x00, w_out * sizeof(float)); - dout_ch += w_out; - } - } - //! every h loop, deal with 6 line input - const float* din0 = din_ch; - const float* din1 = din0 + w_in_new; - const float* din2 = din1 + w_in_new; - const float* din3 = din2 + w_in_new; - const float* din4 = din3 + w_in_new; - const float* din5 = din4 + w_in_new; - //! every h loop, deal with 2 line output - float* dout0 = dout_ch; - float* dout1 = dout0 + w_out; - - const float* weights_c = weights + c * weights_saptial_size; - - //! h loop - for (int h = 0; h < h_out_new; h += 2) { - //! (h - pad_new) + 6 > h_in - 1 - if (h + 6 > h_in_new) { - switch (h + 6 - h_in_new) { - case 5: - din1 = zero_ptr; - case 4: - din2 = zero_ptr; - case 3: - din3 = zero_ptr; - case 2: - din4 = zero_ptr; - case 1: - din5 = zero_ptr; - default: - break; - } - } - if (h + 2 > h_out_new) { - dout1 = write_ptr; - } - const float* din_ptr0 = din0; - const float* din_ptr1 = din1; - const float* din_ptr2 = din2; - const float* din_ptr3 = din3; - const float* din_ptr4 = din4; - const float* din_ptr5 = din5; - - float* dout_ptr0 = dout0; - float* dout_ptr1 = dout1; - - if (flag_bias) { - //! deal with w_out pad_0 column pre with bias - for (int i = 0; i < pad_cnt; i++) { - vst1q_f32(dout_ptr0, vbias_c); - vst1q_f32(dout_ptr1, vbias_c); - dout_ptr0 += 4; - dout_ptr1 += 4; - } - for (int i = 0; i < pad_remain; ++i) { - *dout_ptr0++ = bias_c; - *dout_ptr1++ = bias_c; - } - } else { - //! deal with w_out pad_0 column pre without bias - memset(dout_ptr0, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr1, 0x00, pad_0 * sizeof(float)); - dout_ptr0 += pad_0; - dout_ptr1 += pad_0; - } - //! mid loop - for (int i = 0; i < w_out_new; ++i) { - compute_one_out_without_extract(din_ptr0, - din_ptr1, - din_ptr2, - din_ptr3, - din_ptr4, - din_ptr5, - dout_ptr0, - dout_ptr1, - weights_c, - vbias); - din_ptr0++; - din_ptr1++; - din_ptr2++; - din_ptr3++; - din_ptr4++; - din_ptr5++; - - dout_ptr0++; - dout_ptr1++; - } - if (flag_bias) { - //! deal with w_out pad_0 column post with bias - memcpy(dout_ptr0, dout0, pad_0 * sizeof(float)); - memcpy(dout_ptr1, dout1, pad_0 * sizeof(float)); - } else { - //! deal with w_out pad_0 column post without bias - memset(dout_ptr0, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr1, 0x00, pad_0 * sizeof(float)); - } - - din0 = din2; - din1 = din3; - din2 = din4; - din3 = din5; - din4 = din3 + w_in_new; - din5 = din4 + w_in_new; - - dout0 = dout1 + w_out; - dout1 = dout0 + w_out; - } - float* dout_pad_end = dout_ch + h_out_new * w_out; - if (flag_bias) { - //! deal with h_out pad_0 line with bias - memcpy(reinterpret_cast(dout_pad_end), - dout_ch - pad_0 * w_out, - pad_0 * w_out * sizeof(float)); - } else { - //! deal with h_out pad_0 line without bias - memset(reinterpret_cast(dout_pad_end), - 0x00, - pad_0 * w_out * sizeof(float)); - } - } - } - free(din_new); -} - -void conv_depthwise_5x5s1_small_relu_impl(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - int pad, - bool flag_bias, - bool flag_relu, - ARMContext* ctx) { - int pad_new = pad > 4 ? 4 : pad; - int pad_0 = pad - pad_new; - int h_in_new = h_in + 2 * pad_new; - int w_in_new = w_in + 2 * pad_new; - int h_out_new = h_out - 2 * pad_0; - int w_out_new = w_out - 2 * pad_0; - float zero_ptr[w_in_new + w_out]; // NOLINT - memset(zero_ptr, 0, w_in_new * sizeof(float)); - float* write_ptr = zero_ptr + w_in_new; - int pad_cnt = pad_0 >> 2; - int pad_remain = pad_0 - (pad_cnt << 2); - int bias_cnt = (w_out * pad_0) >> 2; - int bias_remain = (w_out * pad_0) - (bias_cnt << 2); - int in_spatial_size = w_in_new * h_in_new; - int out_spatial_size = w_out * h_out; - int weights_saptial_size = 25; - - float* din_new = prepad_input(din, num, ch_in, h_in, w_in, pad_new); - for (int n = 0; n < num; ++n) { - const float* din_batch = din_new + n * in_spatial_size * ch_in; - float* dout_batch = dout + n * out_spatial_size * ch_out; -#pragma omp parallel for - for (int c = 0; c < ch_in; ++c) { - const float* din_ch = din_batch + c * in_spatial_size; - float* dout_ch = dout_batch + c * out_spatial_size; - float bias_c = flag_bias ? bias[c] : 0.f; - float bias_relu = bias_c > 0.f ? bias_c : 0.f; - float vbias[4] = {bias_c, bias_c, bias_c, bias_c}; - float32x4_t vbias_c = vdupq_n_f32(bias_relu); - if (flag_bias) { - //! deal with h_out pad_0 line with bias - for (int i = 0; i < bias_cnt; ++i) { - vst1q_f32(dout_ch, vbias_c); - dout_ch += 4; - } - for (int i = 0; i < bias_remain; ++i) { - *dout_ch++ = bias_relu; - } - } else { - //! deal with h_out pad_0 line without bias - for (int i = 0; i < pad_0; ++i) { - memset(dout_ch, 0x00, w_out * sizeof(float)); - dout_ch += w_out; - } - } - //! every h loop, deal with 6 line input - const float* din0 = din_ch; - const float* din1 = din0 + w_in_new; - const float* din2 = din1 + w_in_new; - const float* din3 = din2 + w_in_new; - const float* din4 = din3 + w_in_new; - const float* din5 = din4 + w_in_new; - //! every h loop, deal with 2 line output - float* dout0 = dout_ch; - float* dout1 = dout0 + w_out; - - const float* weights_c = weights + c * weights_saptial_size; - - //! h loop - for (int h = 0; h < h_out_new; h += 2) { - //! (h - pad_new) + 6 > h_in - 1 - if (h + 6 > h_in_new) { - switch (h + 6 - h_in_new) { - case 5: - din1 = zero_ptr; - case 4: - din2 = zero_ptr; - case 3: - din3 = zero_ptr; - case 2: - din4 = zero_ptr; - case 1: - din5 = zero_ptr; - default: - break; - } - } - if (h + 2 > h_out_new) { - dout1 = write_ptr; - } - const float* din_ptr0 = din0; - const float* din_ptr1 = din1; - const float* din_ptr2 = din2; - const float* din_ptr3 = din3; - const float* din_ptr4 = din4; - const float* din_ptr5 = din5; - - const float* weights_ptr = weights_c; - float* dout_ptr0 = dout0; - float* dout_ptr1 = dout1; - - if (flag_bias) { - //! deal with w_out pad_0 column pre with bias - for (int i = 0; i < pad_cnt; i++) { - vst1q_f32(dout_ptr0, vbias_c); - vst1q_f32(dout_ptr1, vbias_c); - dout_ptr0 += 4; - dout_ptr1 += 4; - } - for (int i = 0; i < pad_remain; ++i) { - *dout_ptr0++ = bias_relu; - *dout_ptr1++ = bias_relu; - } - } else { - //! deal with w_out pad_0 column pre without bias - memset(dout_ptr0, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr1, 0x00, pad_0 * sizeof(float)); - dout_ptr0 += pad_0; - dout_ptr1 += pad_0; - } - //! mid loop - for (int i = 0; i < w_out_new; ++i) { - compute_one_out_without_extract_relu(din_ptr0, - din_ptr1, - din_ptr2, - din_ptr3, - din_ptr4, - din_ptr5, - dout_ptr0, - dout_ptr1, - weights_c, - vbias); - din_ptr0++; - din_ptr1++; - din_ptr2++; - din_ptr3++; - din_ptr4++; - din_ptr5++; - - dout_ptr0++; - dout_ptr1++; - } - if (flag_bias) { - //! deal with w_out pad_0 column post with bias - memcpy(dout_ptr0, dout0, pad_0 * sizeof(float)); - memcpy(dout_ptr1, dout1, pad_0 * sizeof(float)); - } else { - //! deal with w_out pad_0 column post without bias - memset(dout_ptr0, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr1, 0x00, pad_0 * sizeof(float)); - } - - din0 = din2; - din1 = din3; - din2 = din4; - din3 = din5; - din4 = din3 + w_in_new; - din5 = din4 + w_in_new; - - dout0 = dout1 + w_out; - dout1 = dout0 + w_out; - } - float* dout_pad_end = dout_ch + h_out_new * w_out; - if (flag_bias) { - //! deal with h_out pad_0 line with bias - memcpy(reinterpret_cast(dout_pad_end), - dout_ch - pad_0 * w_out, - pad_0 * w_out * sizeof(float)); - } else { - //! deal with h_out pad_0 line without bias - memset(reinterpret_cast(dout_pad_end), - 0x00, - pad_0 * w_out * sizeof(float)); + "ldr q24, [%[bias]] \n" /* load bias to out00 */ + "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[wc]], #64 \n" /* load w0-w3 */ + "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[inr0]], #64 \n" /* load inr0, 0-3 */ + "1:\n" + "ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[inr1]], #64 \n" /* load inr1, 0-3 */ + "mov v25.16b, v24.16b \n" /* mov bias to out01 */ + "mov v26.16b, v24.16b \n" /* mov bias to out02 */ + "mov v27.16b, v24.16b \n" /* mov bias to out03 */ + "mov v28.16b, v24.16b \n" /* mov bias to out10 */ + "mov v29.16b, v24.16b \n" /* mov bias to out11 */ + "mov v30.16b, v24.16b \n" /* mov bias to out12 */ + "mov v31.16b, v24.16b \n" /* mov bias to out13 */ + // out row0 + "fmla v24.4s, v8.4s, v0.4s \n" /* out00 = w0 * inr00 */ + "fmla v25.4s, v9.4s, v0.4s \n" /* out01 = w0 * inr01 */ + "ldp q12, q13, [%[inr0]] \n" /* load inr0, 4-5 */ + "fmla v26.4s, v10.4s, v0.4s \n" /* out02 = w0 * inr02 */ + "fmla v27.4s, v11.4s, v0.4s \n" /* out03 = w0 * inr03 */ + "fmla v28.4s, v16.4s, v0.4s \n" /* out10 = w0 * inr10 */ + "fmla v29.4s, v17.4s, v0.4s \n" /* out11 = w0 * inr11 */ + "ldp q20, q21, [%[inr1]] \n" /* load inr1, 4-5 */ + "fmla v30.4s, v18.4s, v0.4s \n" /* out12 = w0 * inr12 */ + "fmla v31.4s, v19.4s, v0.4s \n" /* out13 = w0 * inr13 */ + "fmla v24.4s, v9.4s, v1.4s \n" /* out00 = w1 * inr01 */ + "fmla v25.4s, v10.4s, v1.4s \n" /* out01 = w1 * inr02 */ + "fmla v26.4s, v11.4s, v1.4s \n" /* out02 = w1 * inr03 */ + "fmla v27.4s, v12.4s, v1.4s \n" /* out03 = w1 * inr04 */ + "ldp q14, q15, [%[inr0], #32] \n" /* load inr0, 6-7 */ + "fmla v28.4s, v17.4s, v1.4s \n" /* out10 = w1 * inr11 */ + "fmla v29.4s, v18.4s, v1.4s \n" /* out11 = w1 * inr12 */ + "fmla v30.4s, v19.4s, v1.4s \n" /* out12 = w1 * inr13 */ + "fmla v31.4s, v20.4s, v1.4s \n" /* out13 = w1 * inr14 */ + "fmla v24.4s, v10.4s, v2.4s \n" /* out00 = w2 * inr02 */ + "fmla v25.4s, v11.4s, v2.4s \n" /* out01 = w2 * inr03 */ + "fmla v26.4s, v12.4s, v2.4s \n" /* out02 = w2 * inr04 */ + "fmla v27.4s, v13.4s, v2.4s \n" /* out03 = w2 * inr05 */ + "ldp q22, q23, [%[inr1], #32] \n" /* load inr1, 6-7 */ + "fmla v28.4s, v18.4s, v2.4s \n" /* out10 = w2 * inr12 */ + "fmla v29.4s, v19.4s, v2.4s \n" /* out11 = w2 * inr13 */ + "fmla v30.4s, v20.4s, v2.4s \n" /* out12 = w2 * inr14 */ + "fmla v31.4s, v21.4s, v2.4s \n" /* out13 = w2 * inr15 */ + "ldp q4, q5, [%[wc]], #32 \n" /* load w4-w5 */ + "fmla v24.4s, v11.4s, v3.4s \n" /* out00 = w3 * inr03 */ + "fmla v25.4s, v12.4s, v3.4s \n" /* out01 = w3 * inr04 */ + "fmla v26.4s, v13.4s, v3.4s \n" /* out02 = w3 * inr05 */ + "fmla v27.4s, v14.4s, v3.4s \n" /* out03 = w3 * inr06 */ + "ldp q6, q7, [%[wc]], #32 \n" /* load w6-w7 */ + "fmla v28.4s, v19.4s, v3.4s \n" /* out10 = w3 * inr13 */ + "fmla v29.4s, v20.4s, v3.4s \n" /* out11 = w3 * inr14 */ + "fmla v30.4s, v21.4s, v3.4s \n" /* out12 = w3 * inr15 */ + "fmla v31.4s, v22.4s, v3.4s \n" /* out13 = w3 * inr16 */ + "fmla v24.4s, v12.4s, v4.4s \n" /* out00 = w4 * inr04 */ + "fmla v25.4s, v13.4s, v4.4s \n" /* out01 = w4 * inr05 */ + "fmla v26.4s, v14.4s, v4.4s \n" /* out02 = w4 * inr06 */ + "fmla v27.4s, v15.4s, v4.4s \n" /* out03 = w4 * inr07 */ + "ldp q8, q9, [%[inr2]], #32 \n" /* load inr2, 0-1 */ + "fmla v28.4s, v20.4s, v4.4s \n" /* out10 = w4 * inr14 */ + "fmla v29.4s, v21.4s, v4.4s \n" /* out11 = w4 * inr15 */ + "fmla v30.4s, v22.4s, v4.4s \n" /* out12 = w4 * inr16 */ + "fmla v31.4s, v23.4s, v4.4s \n" /* out13 = w4 * inr17 */ + "ldp q10, q11, [%[inr2]], #32\n" /* load inr2, 2-3 */ + // out row1 + "fmla v24.4s, v16.4s, v5.4s \n" /* out00 = w5 * inr10 */ + "fmla v25.4s, v17.4s, v5.4s \n" /* out01 = w5 * inr11 */ + "fmla v26.4s, v18.4s, v5.4s \n" /* out02 = w5 * inr12 */ + "fmla v27.4s, v19.4s, v5.4s \n" /* out03 = w5 * inr13 */ + "ldp q12, q13, [%[inr2]] \n" /* load inr2, 4-5 */ + "fmla v28.4s, v8.4s, v5.4s \n" /* out10 = w5 * inr20 */ + "fmla v29.4s, v9.4s, v5.4s \n" /* out11 = w5 * inr21 */ + "fmla v30.4s, v10.4s, v5.4s \n" /* out12 = w5 * inr22 */ + "fmla v31.4s, v11.4s, v5.4s \n" /* out13 = w5 * inr23 */ + "fmla v24.4s, v17.4s, v6.4s \n" /* out00 = w6 * inr11 */ + "fmla v25.4s, v18.4s, v6.4s \n" /* out01 = w6 * inr12 */ + "fmla v26.4s, v19.4s, v6.4s \n" /* out02 = w6 * inr13 */ + "fmla v27.4s, v20.4s, v6.4s \n" /* out03 = w6 * inr14 */ + "ldp q14, q15, [%[inr2], #32]\n" /* load inr2, 6-7 */ + "fmla v28.4s, v9.4s, v6.4s \n" /* out10 = w6 * inr21 */ + "fmla v29.4s, v10.4s, v6.4s \n" /* out11 = w6 * inr22 */ + "fmla v30.4s, v11.4s, v6.4s \n" /* out12 = w6 * inr23 */ + "fmla v31.4s, v12.4s, v6.4s \n" /* out13 = w6 * inr24 */ + "fmla v24.4s, v18.4s, v7.4s \n" /* out00 = w7 * inr12 */ + "fmla v25.4s, v19.4s, v7.4s \n" /* out01 = w7 * inr13 */ + "fmla v26.4s, v20.4s, v7.4s \n" /* out02 = w7 * inr14 */ + "fmla v27.4s, v21.4s, v7.4s \n" /* out03 = w7 * inr15 */ + "ldp q0, q1, [%[wc]], #32 \n" /* load w8-w9 */ + "fmla v28.4s, v10.4s, v7.4s \n" /* out10 = w7 * inr22 */ + "fmla v29.4s, v11.4s, v7.4s \n" /* out11 = w7 * inr23 */ + "fmla v30.4s, v12.4s, v7.4s \n" /* out12 = w7 * inr24 */ + "fmla v31.4s, v13.4s, v7.4s \n" /* out13 = w7 * inr25 */ + "fmla v24.4s, v19.4s, v0.4s \n" /* out00 = w8 * inr13 */ + "fmla v25.4s, v20.4s, v0.4s \n" /* out01 = w8 * inr14 */ + "fmla v26.4s, v21.4s, v0.4s \n" /* out02 = w8 * inr15 */ + "fmla v27.4s, v22.4s, v0.4s \n" /* out03 = w8 * inr16 */ + "ldp q2, q3, [%[wc]], #32 \n" /* load w10-w11 */ + "fmla v28.4s, v11.4s, v0.4s \n" /* out10 = w8 * inr23 */ + "fmla v29.4s, v12.4s, v0.4s \n" /* out11 = w8 * inr24 */ + "fmla v30.4s, v13.4s, v0.4s \n" /* out12 = w8 * inr25 */ + "fmla v31.4s, v14.4s, v0.4s \n" /* out13 = w8 * inr26 */ + "ldp q16, q17, [%[inr3]], #32\n" /* load inr3, 0-1 */ + "fmla v24.4s, v20.4s, v1.4s \n" /* out00 = w9 * inr14 */ + "fmla v25.4s, v21.4s, v1.4s \n" /* out01 = w9 * inr15 */ + "fmla v26.4s, v22.4s, v1.4s \n" /* out02 = w9 * inr16 */ + "fmla v27.4s, v23.4s, v1.4s \n" /* out03 = w9 * inr17 */ + "ldp q18, q19, [%[inr3]], #32\n" /* load inr3, 2-3 */ + "fmla v28.4s, v12.4s, v1.4s \n" /* out10 = w9 * inr24 */ + "fmla v29.4s, v13.4s, v1.4s \n" /* out11 = w9 * inr25 */ + "fmla v30.4s, v14.4s, v1.4s \n" /* out12 = w9 * inr26 */ + "fmla v31.4s, v15.4s, v1.4s \n" /* out13 = w9 * inr27 */ + // out row2 + "fmla v24.4s, v8.4s, v2.4s \n" /* out00 = w10 * inr20 */ + "fmla v25.4s, v9.4s, v2.4s \n" /* out01 = w10 * inr21 */ + "fmla v26.4s, v10.4s, v2.4s \n" /* out02 = w10 * inr22 */ + "fmla v27.4s, v11.4s, v2.4s \n" /* out03 = w10 * inr23 */ + "ldp q4, q5, [%[wc]], #32 \n" /* load w12-w13 */ + "fmla v28.4s, v16.4s, v2.4s \n" /* out10 = w10 * inr30 */ + "fmla v29.4s, v17.4s, v2.4s \n" /* out11 = w10 * inr31 */ + "fmla v30.4s, v18.4s, v2.4s \n" /* out12 = w10 * inr32 */ + "fmla v31.4s, v19.4s, v2.4s \n" /* out13 = w10 * inr33 */ + "ldp q20, q21, [%[inr3]] \n" /* load inr3, 4-5 */ + "fmla v24.4s, v9.4s, v3.4s \n" /* out00 = w11 * inr21 */ + "fmla v25.4s, v10.4s, v3.4s \n" /* out01 = w11 * inr22 */ + "fmla v26.4s, v11.4s, v3.4s \n" /* out02 = w11 * inr23 */ + "fmla v27.4s, v12.4s, v3.4s \n" /* out03 = w11 * inr24 */ + "ldp q22, q23, [%[inr3], #32]\n" /* load inr3, 6-7 */ + "fmla v28.4s, v17.4s, v3.4s \n" /* out10 = w11 * inr31 */ + "fmla v29.4s, v18.4s, v3.4s \n" /* out11 = w11 * inr32 */ + "fmla v30.4s, v19.4s, v3.4s \n" /* out12 = w11 * inr33 */ + "fmla v31.4s, v20.4s, v3.4s \n" /* out13 = w11 * inr34 */ + "fmla v24.4s, v10.4s, v4.4s \n" /* out00 = w12 * inr22 */ + "fmla v25.4s, v11.4s, v4.4s \n" /* out01 = w12 * inr23 */ + "fmla v26.4s, v12.4s, v4.4s \n" /* out02 = w12 * inr24 */ + "fmla v27.4s, v13.4s, v4.4s \n" /* out03 = w12 * inr25 */ + "ldp q6, q7, [%[wc]], #32 \n" /* load w14-w15 */ + "fmla v28.4s, v18.4s, v4.4s \n" /* out10 = w12 * inr32 */ + "fmla v29.4s, v19.4s, v4.4s \n" /* out11 = w12 * inr33 */ + "fmla v30.4s, v20.4s, v4.4s \n" /* out12 = w12 * inr34 */ + "fmla v31.4s, v21.4s, v4.4s \n" /* out13 = w12 * inr35 */ + "fmla v24.4s, v11.4s, v5.4s \n" /* out00 = w13 * inr23 */ + "fmla v25.4s, v12.4s, v5.4s \n" /* out01 = w13 * inr24 */ + "fmla v26.4s, v13.4s, v5.4s \n" /* out02 = w13 * inr25 */ + "fmla v27.4s, v14.4s, v5.4s \n" /* out03 = w13 * inr26 */ + "ldp q8, q9, [%[inr4]], #32 \n" /* load inr4, 0-1 */ + "fmla v28.4s, v19.4s, v5.4s \n" /* out10 = w13 * inr33 */ + "fmla v29.4s, v20.4s, v5.4s \n" /* out11 = w13 * inr34 */ + "fmla v30.4s, v21.4s, v5.4s \n" /* out12 = w13 * inr35 */ + "fmla v31.4s, v22.4s, v5.4s \n" /* out13 = w13 * inr36 */ + "fmla v24.4s, v12.4s, v6.4s \n" /* out00 = w14 * inr24 */ + "fmla v25.4s, v13.4s, v6.4s \n" /* out01 = w14 * inr25 */ + "fmla v26.4s, v14.4s, v6.4s \n" /* out02 = w14 * inr26 */ + "fmla v27.4s, v15.4s, v6.4s \n" /* out03 = w14 * inr27 */ + "ldp q10, q11, [%[inr4]], #32\n" /* load inr4, 2-3 */ + "fmla v28.4s, v20.4s, v6.4s \n" /* out10 = w14 * inr34 */ + "fmla v29.4s, v21.4s, v6.4s \n" /* out11 = w14 * inr35 */ + "fmla v30.4s, v22.4s, v6.4s \n" /* out12 = w14 * inr36 */ + "fmla v31.4s, v23.4s, v6.4s \n" /* out13 = w14 * inr37 */ + "ldp q0, q1, [%[wc]], #32 \n" /* load w16-w17 */ + // out row3 + "fmla v24.4s, v16.4s, v7.4s \n" /* out00 = w15 * inr30 */ + "fmla v25.4s, v17.4s, v7.4s \n" /* out01 = w15 * inr31 */ + "fmla v26.4s, v18.4s, v7.4s \n" /* out02 = w15 * inr32 */ + "fmla v27.4s, v19.4s, v7.4s \n" /* out03 = w15 * inr33 */ + "ldp q12, q13, [%[inr4]] \n" /* load inr4, 4-5 */ + "fmla v28.4s, v8.4s, v7.4s \n" /* out10 = w15 * inr40 */ + "fmla v29.4s, v9.4s, v7.4s \n" /* out11 = w15 * inr41 */ + "fmla v30.4s, v10.4s, v7.4s \n" /* out12 = w15 * inr42 */ + "fmla v31.4s, v11.4s, v7.4s \n" /* out13 = w15 * inr42 */ + "ldp q2, q3, [%[wc]], #32 \n" /* load w18-w19 */ + "fmla v24.4s, v17.4s, v0.4s \n" /* out00 = w16 * inr31 */ + "fmla v25.4s, v18.4s, v0.4s \n" /* out01 = w16 * inr32 */ + "fmla v26.4s, v19.4s, v0.4s \n" /* out02 = w16 * inr33 */ + "fmla v27.4s, v20.4s, v0.4s \n" /* out03 = w16 * inr34 */ + "ldp q14, q15, [%[inr4], #32]\n" /* load inr4, 6-7 */ + "fmla v28.4s, v9.4s, v0.4s \n" /* out10 = w16 * inr41 */ + "fmla v29.4s, v10.4s, v0.4s \n" /* out11 = w16 * inr42 */ + "fmla v30.4s, v11.4s, v0.4s \n" /* out12 = w16 * inr43 */ + "fmla v31.4s, v12.4s, v0.4s \n" /* out13 = w16 * inr44 */ + "fmla v24.4s, v18.4s, v1.4s \n" /* out00 = w17 * inr32 */ + "fmla v25.4s, v19.4s, v1.4s \n" /* out01 = w17 * inr33 */ + "fmla v26.4s, v20.4s, v1.4s \n" /* out02 = w17 * inr34 */ + "fmla v27.4s, v21.4s, v1.4s \n" /* out03 = w17 * inr35 */ + "ldp q4, q5, [%[wc]], #32 \n" /* load w20-w21 */ + "fmla v28.4s, v10.4s, v1.4s \n" /* out10 = w17 * inr42 */ + "fmla v29.4s, v11.4s, v1.4s \n" /* out11 = w17 * inr43 */ + "fmla v30.4s, v12.4s, v1.4s \n" /* out12 = w17 * inr44 */ + "fmla v31.4s, v13.4s, v1.4s \n" /* out13 = w17 * inr45 */ + "fmla v24.4s, v19.4s, v2.4s \n" /* out00 = w18 * inr33 */ + "fmla v25.4s, v20.4s, v2.4s \n" /* out01 = w18 * inr34 */ + "fmla v26.4s, v21.4s, v2.4s \n" /* out02 = w18 * inr35 */ + "fmla v27.4s, v22.4s, v2.4s \n" /* out03 = w18 * inr36 */ + "ldp q16, q17, [%[inr5]], #32\n" /* load inr5, 0-1 */ + "fmla v28.4s, v11.4s, v2.4s \n" /* out10 = w18 * inr43 */ + "fmla v29.4s, v12.4s, v2.4s \n" /* out11 = w18 * inr44 */ + "fmla v30.4s, v13.4s, v2.4s \n" /* out12 = w18 * inr45 */ + "fmla v31.4s, v14.4s, v2.4s \n" /* out13 = w18 * inr46 */ + "fmla v24.4s, v20.4s, v3.4s \n" /* out00 = w19 * inr34 */ + "fmla v25.4s, v21.4s, v3.4s \n" /* out01 = w19 * inr35 */ + "fmla v26.4s, v22.4s, v3.4s \n" /* out02 = w19 * inr36 */ + "fmla v27.4s, v23.4s, v3.4s \n" /* out03 = w19 * inr37 */ + "ldp q18, q19, [%[inr5]], #32\n" /* load inr5, 2-3 */ + "fmla v28.4s, v12.4s, v3.4s \n" /* out10 = w19 * inr44 */ + "fmla v29.4s, v13.4s, v3.4s \n" /* out11 = w19 * inr45 */ + "fmla v30.4s, v14.4s, v3.4s \n" /* out12 = w19 * inr46 */ + "fmla v31.4s, v15.4s, v3.4s \n" /* out13 = w19 * inr47 */ + // out row4 + "fmla v24.4s, v8.4s, v4.4s \n" /* out00 = w20 * inr40 */ + "fmla v25.4s, v9.4s, v4.4s \n" /* out01 = w20 * inr41 */ + "fmla v26.4s, v10.4s, v4.4s \n" /* out02 = w20 * inr42 */ + "fmla v27.4s, v11.4s, v4.4s \n" /* out03 = w20 * inr43 */ + "ldp q20, q21, [%[inr5]] \n" /* load inr5, 4-5 */ + "fmla v28.4s, v16.4s, v4.4s \n" /* out10 = w20 * inr50 */ + "fmla v29.4s, v17.4s, v4.4s \n" /* out11 = w20 * inr51 */ + "fmla v30.4s, v18.4s, v4.4s \n" /* out12 = w20 * inr52 */ + "fmla v31.4s, v19.4s, v4.4s \n" /* out13 = w20 * inr53 */ + "ldp q6, q7, [%[wc]], #32 \n" /* load w22-w23 */ + "fmla v24.4s, v9.4s, v5.4s \n" /* out00 = w21 * inr41 */ + "fmla v25.4s, v10.4s, v5.4s \n" /* out01 = w21 * inr42 */ + "fmla v26.4s, v11.4s, v5.4s \n" /* out02 = w21 * inr43 */ + "fmla v27.4s, v12.4s, v5.4s \n" /* out03 = w21 * inr44 */ + "ldp q22, q23, [%[inr5], #32]\n" /* load inr5, 6-7 */ + "fmla v28.4s, v17.4s, v5.4s \n" /* out10 = w21 * inr51 */ + "fmla v29.4s, v18.4s, v5.4s \n" /* out11 = w21 * inr52 */ + "fmla v30.4s, v19.4s, v5.4s \n" /* out12 = w21 * inr53 */ + "fmla v31.4s, v20.4s, v5.4s \n" /* out13 = w21 * inr54 */ + "ldp q8, q9, [%[inr0]], #32 \n" /* load inr0, 0-1 */ + "fmla v24.4s, v10.4s, v6.4s \n" /* out00 = w22 * inr42 */ + "fmla v25.4s, v11.4s, v6.4s \n" /* out01 = w22 * inr43 */ + "fmla v26.4s, v12.4s, v6.4s \n" /* out02 = w22 * inr44 */ + "fmla v27.4s, v13.4s, v6.4s \n" /* out03 = w22 * inr45 */ + "ldp q4, q5, [%[wc]], #-384 \n" /* load w24 */ + "fmla v28.4s, v18.4s, v6.4s \n" /* out10 = w22 * inr52 */ + "fmla v29.4s, v19.4s, v6.4s \n" /* out11 = w22 * inr53 */ + "fmla v30.4s, v20.4s, v6.4s \n" /* out12 = w22 * inr54 */ + "fmla v31.4s, v21.4s, v6.4s \n" /* out13 = w22 * inr55 */ + "ldp q0, q1, [%[wc]], #32 \n" /* load w0-w1 */ + "fmla v24.4s, v11.4s, v7.4s \n" /* out00 = w23 * inr43 */ + "fmla v25.4s, v12.4s, v7.4s \n" /* out01 = w23 * inr44 */ + "fmla v26.4s, v13.4s, v7.4s \n" /* out02 = w23 * inr45 */ + "fmla v27.4s, v14.4s, v7.4s \n" /* out03 = w23 * inr46 */ + "ldp q2, q3, [%[wc]], #32 \n" /* load w1-w2 */ + "fmla v28.4s, v19.4s, v7.4s \n" /* out10 = w23 * inr53 */ + "fmla v29.4s, v20.4s, v7.4s \n" /* out11 = w23 * inr54 */ + "fmla v30.4s, v21.4s, v7.4s \n" /* out12 = w23 * inr55 */ + "fmla v31.4s, v22.4s, v7.4s \n" /* out13 = w23 * inr56 */ + "ldp q10, q11, [%[inr0]], #32\n" /* load inr0, 2-3 */ + "fmla v24.4s, v12.4s, v4.4s \n" /* out00 = w24 * inr44 */ + "fmla v25.4s, v13.4s, v4.4s \n" /* out01 = w24 * inr45 */ + "fmla v26.4s, v14.4s, v4.4s \n" /* out02 = w24 * inr46 */ + "fmla v27.4s, v15.4s, v4.4s \n" /* out03 = w24 * inr47 */ + "stp q24, q25, [%[out0]], #32\n" /* store outr0, 0-1 */ + "fmla v28.4s, v20.4s, v4.4s \n" /* out10 = w24 * inr54 */ + "fmla v29.4s, v21.4s, v4.4s \n" /* out11 = w24 * inr55 */ + "stp q26, q27, [%[out0]], #32\n" /* store outr0, 2-3 */ + "fmla v30.4s, v22.4s, v4.4s \n" /* out12 = w24 * inr56 */ + "fmla v31.4s, v23.4s, v4.4s \n" /* out13 = w24 * inr57 */ + "ldr q24, [%[bias]] \n" /* load bias to out00 */ + "subs %w[cnt], %w[cnt], #1\n" /* cnt = cnt - 1 */ + "stp q28, q29, [%[out1]], #32\n" /* store outr1, 0-1 */ + "stp q30, q31, [%[out1]], #32\n" /* store outr1, 2-3 */ + "bne 1b\n" + : [cnt] "+r"(cnt), + [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [inr3] "+r"(inr3), + [inr4] "+r"(inr4), + [inr5] "+r"(inr5), + [wc] "+r"(wptr), + [out0] "+r"(ptr_out0), + [out1] "+r"(ptr_out1) + : [bias] "r"(bias_local) + : "cc","memory", + "v0","v1","v2","v3","v4","v5","v6","v7", + "v8","v9","v10","v11","v12","v13", + "v14","v15","v16","v17","v18","v19", + "v20","v21","v22","v23","v24","v25", + "v26","v27","v28","v29","v30","v31" + ); + // clang-format on + block_inr0 = block_inr2; + block_inr1 = block_inr3; + block_inr2 = block_inr4; + block_inr3 = block_inr5; + block_inr4 = block_inr3 + in_len; + block_inr5 = block_inr4 + in_len; + } + write_to_output_c4_fp32(pre_out, + dout_batch, + c, + c + hout_c_block, + h, + h + h_kernel, + 0, + wout_round, + chout, + hout, + wout, + flag_relu, + ptr_write, + &act_param); } } } - free(din_new); } -#endif // __aarch64__ - -void conv_depthwise_5x5s1_fp32(const float* din, - float* dout, - int num, - int chout, - int hout, - int wout, - int chin, - int hin, - int win, +#else // __aarch64__ +void conv_depthwise_5x5s1_fp32(float* dout, + const float* din, const float* weights, const float* bias, - int pad, bool flag_bias, bool flag_relu, + int num, + int chin, + int hin, + int win, + int hout, + int wout, + int padw, + int padh, + const operators::ConvParam& param, ARMContext* ctx) { - if (win < 4) { - if (flag_relu) { - conv_depthwise_5x5s1_small_relu_impl(din, - dout, - num, - chout, - hout, - wout, - chin, - hin, - win, - weights, - bias, - pad, - flag_bias, - flag_relu, - ctx); - } else { - conv_depthwise_5x5s1_small_impl(din, - dout, - num, - chout, - hout, - wout, - chin, - hin, - win, - weights, - bias, - pad, - flag_bias, - flag_relu, - ctx); - } - } else { - if (flag_relu) { - conv_depthwise_5x5s1_relu_impl(din, - dout, - num, - chout, - hout, - wout, - chin, - hin, - win, - weights, - bias, - pad, - flag_bias, - flag_relu, - ctx); - } else { - conv_depthwise_5x5s1_impl(din, - dout, - num, + const int threads = ctx->threads(); + int llc_size = ctx->llc_size() / 4; + auto act_param = param.activation_param; + const int hout_c_block = 4; + const int hout_r_kernel = 1; + const int wout_block = 4; + const int wout_round = ((wout + wout_block - 1) / wout_block) * wout_block; + const int win_round = wout_round + 4; + + //! get h block + //! llc_size = threads * win_round * hout_c_block * hin_r_block * + //! sizeof(float) + //! + wout_round * hout_c_block * hout_r_block * threads * sizeof(float) + //! win_round = wout_round + 4 + //! hin_r_block = hout_r_block + 4 + int hout_r_block = (llc_size - 16 * win_round * hout_c_block * threads) / + (win_round * hout_c_block * threads * 4 + + hout_c_block * wout_round * threads * 4); + hout_r_block = hout_r_block > hout ? hout : hout_r_block; + hout_r_block = + ((hout_r_block + hout_r_kernel - 1) / hout_r_kernel) * hout_r_kernel; + hout_r_block = hout_r_block < hout_r_kernel ? hout_r_kernel : hout_r_block; + + const int hin_r_block = hout_r_block + 4; + + float* tmp_work_space = ctx->workspace_data(); + float ptr_zero[win_round]; // NOLINT + memset(ptr_zero, 0, sizeof(float) * win_round); + float ptr_write[wout_round]; // NOLINT + + int in_len = win_round * hout_c_block; + int pre_in_size = hin_r_block * in_len; + pre_in_size = ROUNDUP(pre_in_size, 4); + int pre_out_size = hout_c_block * hout_r_block * wout_round; + + float* tmp_din = tmp_work_space; + + int size_in_channel = win * hin; + int size_out_channel = wout * hout; + int w_stride = 25; // kernel_w * kernel_h; + + int ws = -padw; + int we = ws + win_round; + int w_loop = wout_round / 4; + int chout = chin; + + int out_row_stride = hout_c_block * wout_round; + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * chin * size_in_channel; + float* dout_batch = dout + n * chout * size_out_channel; + for (int h = 0; h < hout; h += hout_r_block) { + int h_kernel = hout_r_block; + if (h + hout_r_block > hout) { + h_kernel = hout - h; + } + int hs = h - padh; + int he = hs + h_kernel + 4; + +#pragma omp parallel for num_threads(threads) + for (int c = 0; c < chout; c += hout_c_block) { +#ifdef ARM_WITH_OMP + float* pre_din = + tmp_din + omp_get_thread_num() * (pre_in_size + pre_out_size); + float* pre_out = pre_din + pre_in_size; +#else + float* pre_din = tmp_din; + float* pre_out = pre_din + pre_in_size; +#endif + prepack_input_nxwc4_dw( + din_batch, pre_din, c, hs, he, ws, we, chin, win, hin, ptr_zero); + const float* block_inr0 = pre_din; + const float* block_inr1 = block_inr0 + in_len; + const float* block_inr2 = block_inr1 + in_len; + const float* block_inr3 = block_inr2 + in_len; + const float* block_inr4 = block_inr3 + in_len; + + const float* weight_c = weights + c * w_stride; + float bias_local[4] = {0, 0, 0, 0}; + if (flag_bias) { + bias_local[0] = bias[c]; + bias_local[1] = bias[c + 1]; + bias_local[2] = bias[c + 2]; + bias_local[3] = bias[c + 3]; + } + for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) { + int cnt = w_loop; + const float* inr0 = block_inr0; + const float* inr1 = block_inr1; + const float* inr2 = block_inr2; + const float* inr3 = block_inr3; + const float* inr4 = block_inr4; + + float* ptr_out0 = pre_out + hk * out_row_stride; + // clang-format off + auto wptr = weight_c; + asm volatile( + "vld1.32 {d24-d25}, [%[bias]] \n" /* load bias to out00 */ + "vld1.32 {d0-d3}, [%[wc]]! \n" /* load w0-w1 */ + "vld1.32 {d4-d7}, [%[wc]]! \n" /* load w2-w3 */ + "vld1.32 {d8-d11}, [%[inr0]]! \n" /* load inr0, 0-1 */ + "vld1.32 {d12-d15}, [%[inr0]]! \n" /* load inr0, 2-3 */ + "1:\n" + "vld1.32 {d16-d19}, [%[inr0]]! \n" /* load inr0, 4-5 */ + "vmov.u32 q13, q12 \n" /* mov bias to out01 */ + "vmov.u32 q14, q12 \n" /* mov bias to out02 */ + "vmov.u32 q15, q12 \n" /* mov bias to out03 */ + // out row0 + "vmla.f32 q12, q4, q0 \n" /* out00 = w0 * inr00 */ + "vmla.f32 q13, q5, q0 \n" /* out01 = w0 * inr01 */ + "vmla.f32 q14, q6, q0 \n" /* out02 = w0 * inr02 */ + "vmla.f32 q15, q7, q0 \n" /* out03 = w0 * inr03 */ + "vld1.32 {d20-d23}, [%[inr0]]! \n" /* load inr0, 6-7 */ + "sub %[inr0], %[inr0], #64 \n" /* inr0 -= 64 */ + "vmla.f32 q12, q5, q1 \n" /* out00 = w1 * inr01 */ + "vmla.f32 q13, q6, q1 \n" /* out01 = w1 * inr02 */ + "vmla.f32 q14, q7, q1 \n" /* out02 = w1 * inr03 */ + "vmla.f32 q15, q8, q1 \n" /* out03 = w1 * inr04 */ + "vld1.32 {d8-d11}, [%[inr1]]!\n" /* load inr1, 0-1 */ + "vmla.f32 q12, q6, q2 \n" /* out00 = w2 * inr02 */ + "vmla.f32 q13, q7, q2 \n" /* out01 = w2 * inr03 */ + "vmla.f32 q14, q8, q2 \n" /* out02 = w2 * inr04 */ + "vmla.f32 q15, q9, q2 \n" /* out03 = w2 * inr05 */ + "vld1.32 {d0-d3}, [%[wc]]! \n" /* load w4-w5 */ + "vmla.f32 q12, q7, q3 \n" /* out00 = w3 * inr03 */ + "vmla.f32 q13, q8, q3 \n" /* out01 = w3 * inr04 */ + "vmla.f32 q14, q9, q3 \n" /* out02 = w3 * inr05 */ + "vmla.f32 q15, q10, q3 \n" /* out03 = w3 * inr06 */ + "vld1.32 {d12-d15}, [%[inr1]]!\n" /* load inr1, 2-3 */ + "vmla.f32 q12, q8, q0 \n" /* out00 = w4 * inr04 */ + "vmla.f32 q13, q9, q0 \n" /* out01 = w4 * inr05 */ + "vmla.f32 q14, q10, q0 \n" /* out02 = w4 * inr06 */ + "vmla.f32 q15, q11, q0 \n" /* out03 = w4 * inr07 */ + "vld1.32 {d4-d7}, [%[wc]]! \n" /* load w6-w7 */ + // out row1 + "vmla.f32 q12, q4, q1 \n" /* out00 = w5 * inr10 */ + "vmla.f32 q13, q5, q1 \n" /* out01 = w5 * inr11 */ + "vmla.f32 q14, q6, q1 \n" /* out02 = w5 * inr12 */ + "vmla.f32 q15, q7, q1 \n" /* out03 = w5 * inr13 */ + "vld1.32 {d16-d19}, [%[inr1]]!\n" /* load inr1, 4-5 */ + "vmla.f32 q12, q5, q2 \n" /* out00 = w6 * inr11 */ + "vmla.f32 q13, q6, q2 \n" /* out01 = w6 * inr12 */ + "vmla.f32 q14, q7, q2 \n" /* out02 = w6 * inr13 */ + "vmla.f32 q15, q8, q2 \n" /* out03 = w6 * inr14 */ + "vld1.32 {d0-d3}, [%[wc]]! \n" /* load w8-w9 */ + "vmla.f32 q12, q6, q3 \n" /* out00 = w7 * inr12 */ + "vmla.f32 q13, q7, q3 \n" /* out01 = w7 * inr13 */ + "vld1.32 {d20-d23}, [%[inr1]]!\n" /* load inr1, 6-7 */ + "vmla.f32 q14, q8, q3 \n" /* out02 = w7 * inr14 */ + "vmla.f32 q15, q9, q3 \n" /* out03 = w7 * inr15 */ + "sub %[inr1], %[inr1], #64 \n" /* inr1 -= 64 */ + "vmla.f32 q12, q7, q0 \n" /* out00 = w8 * inr13 */ + "vmla.f32 q13, q8, q0 \n" /* out01 = w8 * inr14 */ + "vld1.32 {d8-d11}, [%[inr2]]!\n" /* load inr2, 0-1 */ + "vmla.f32 q14, q9, q0 \n" /* out02 = w8 * inr15 */ + "vmla.f32 q15, q10, q0 \n" /* out03 = w8 * inr16 */ + "vld1.32 {d4-d7}, [%[wc]]! \n" /* load w10-w11 */ + "vmla.f32 q12, q8, q1 \n" /* out00 = w9 * inr14 */ + "vmla.f32 q13, q9, q1 \n" /* out01 = w9 * inr15 */ + "vld1.32 {d12-d15}, [%[inr2]]!\n" /* load inr2, 2-3 */ + "vmla.f32 q14, q10, q1 \n" /* out02 = w9 * inr16 */ + "vmla.f32 q15, q11, q1 \n" /* out03 = w9 * inr17 */ + // out row3 + "vmla.f32 q12, q4, q2 \n" /* out00 = w10 * inr20 */ + "vmla.f32 q13, q5, q2 \n" /* out01 = w10 * inr21 */ + "vld1.32 {d16-d19}, [%[inr2]]!\n" /* load inr2, 4-5 */ + "vmla.f32 q14, q6, q2 \n" /* out02 = w10 * inr22 */ + "vmla.f32 q15, q7, q2 \n" /* out03 = w10 * inr23 */ + "vld1.32 {d0-d3}, [%[wc]]! \n" /* load w12-w13 */ + "vmla.f32 q12, q5, q3 \n" /* out00 = w11 * inr21 */ + "vmla.f32 q13, q6, q3 \n" /* out01 = w11 * inr22 */ + "vld1.32 {d20-d23}, [%[inr2]]!\n" /* load inr2, 6-7 */ + "vmla.f32 q14, q7, q3 \n" /* out02 = w11 * inr23 */ + "vmla.f32 q15, q8, q3 \n" /* out03 = w11 * inr24 */ + "vld1.32 {d4-d7}, [%[wc]]! \n" /* load w14-w15 */ + "sub %[inr2], %[inr2], #64 \n" /* inr2 -= 64 */ + "vmla.f32 q12, q6, q0 \n" /* out00 = w12 * inr22 */ + "vmla.f32 q13, q7, q0 \n" /* out01 = w12 * inr23 */ + "vmla.f32 q14, q8, q0 \n" /* out02 = w12 * inr24 */ + "vmla.f32 q15, q9, q0 \n" /* out03 = w12 * inr25 */ + "vld1.32 {d8-d11}, [%[inr3]]!\n" /* load inr3, 0-1 */ + "vmla.f32 q12, q7, q1 \n" /* out00 = w13 * inr23 */ + "vmla.f32 q13, q8, q1 \n" /* out01 = w13 * inr24 */ + "vmla.f32 q14, q9, q1 \n" /* out02 = w13 * inr25 */ + "vmla.f32 q15, q10, q1 \n" /* out03 = w13 * inr26 */ + "vld1.32 {d0-d3}, [%[wc]]! \n" /* load w16-w17 */ + "vmla.f32 q12, q8, q2 \n" /* out00 = w14 * inr24 */ + "vmla.f32 q13, q9, q2 \n" /* out01 = w14 * inr25 */ + "vld1.32 {d12-d15}, [%[inr3]]!\n" /* load inr3, 2-3 */ + "vmla.f32 q14, q10, q2 \n" /* out02 = w14 * inr26 */ + "vmla.f32 q15, q11, q2 \n" /* out03 = w14 * inr27 */ + // out row3 + "vmla.f32 q12, q4, q3 \n" /* out00 = w15 * inr30 */ + "vmla.f32 q13, q5, q3 \n" /* out01 = w15 * inr31 */ + "vld1.32 {d16-d19}, [%[inr3]]!\n" /* load inr3, 4-5 */ + "vmla.f32 q14, q6, q3 \n" /* out02 = w15 * inr32 */ + "vmla.f32 q15, q7, q3 \n" /* out03 = w15 * inr33 */ + "vld1.32 {d4-d7}, [%[wc]]! \n" /* load w18-w19 */ + "vmla.f32 q12, q5, q0 \n" /* out00 = w16 * inr31 */ + "vmla.f32 q13, q6, q0 \n" /* out01 = w16 * inr32 */ + "vld1.32 {d20-d23}, [%[inr3]]!\n" /* load inr3, 6-7 */ + "vmla.f32 q14, q7, q0 \n" /* out02 = w16 * inr33 */ + "vmla.f32 q15, q8, q0 \n" /* out03 = w16 * inr34 */ + "sub %[inr3], %[inr3], #64 \n" /* inr3 -= 64 */ + "vmla.f32 q12, q6, q1 \n" /* out00 = w17 * inr32 */ + "vmla.f32 q13, q7, q1 \n" /* out01 = w17 * inr33 */ + "vmla.f32 q14, q8, q1 \n" /* out02 = w17 * inr34 */ + "vmla.f32 q15, q9, q1 \n" /* out03 = w17 * inr35 */ + "vld1.32 {d0-d3}, [%[wc]]! \n" /* load w20-w21 */ + "vmla.f32 q12, q7, q2 \n" /* out00 = w18 * inr33 */ + "vmla.f32 q13, q8, q2 \n" /* out01 = w18 * inr34 */ + "vmla.f32 q14, q9, q2 \n" /* out02 = w18 * inr35 */ + "vmla.f32 q15, q10, q2 \n" /* out03 = w18 * inr36 */ + "vld1.32 {d8-d11}, [%[inr4]]!\n" /* load inr4, 0-1 */ + "vmla.f32 q12, q8, q3 \n" /* out00 = w19 * inr34 */ + "vmla.f32 q13, q9, q3 \n" /* out01 = w19 * inr35 */ + "vld1.32 {d12-d15}, [%[inr4]]!\n" /* load inr4, 2-3 */ + "vmla.f32 q14, q10, q3 \n" /* out02 = w19 * inr36 */ + "vmla.f32 q15, q11, q3 \n" /* out03 = w19 * inr37 */ + // out row4 + "vmla.f32 q12, q4, q0 \n" /* out00 = w20 * inr40 */ + "vmla.f32 q13, q5, q0 \n" /* out01 = w20 * inr41 */ + "vld1.32 {d16-d19}, [%[inr4]]!\n" /* load inr4, 4-5 */ + "vmla.f32 q14, q6, q0 \n" /* out02 = w20 * inr42 */ + "vmla.f32 q15, q7, q0 \n" /* out03 = w20 * inr43 */ + "vld1.32 {d4-d7}, [%[wc]]! \n" /* load w22-w23 */ + "vmla.f32 q12, q5, q1 \n" /* out00 = w21 * inr41 */ + "vmla.f32 q13, q6, q1 \n" /* out01 = w21 * inr42 */ + "vmla.f32 q14, q7, q1 \n" /* out02 = w21 * inr43 */ + "vmla.f32 q15, q8, q1 \n" /* out03 = w21 * inr44 */ + "vld1.32 {d20-d23}, [%[inr4]]!\n" /* load inr4, 6-7 */ + "vmla.f32 q12, q6, q2 \n" /* out00 = w22 * inr42 */ + "vmla.f32 q13, q7, q2 \n" /* out01 = w22 * inr43 */ + "vmla.f32 q14, q8, q2 \n" /* out02 = w22 * inr44 */ + "vmla.f32 q15, q9, q2 \n" /* out03 = w22 * inr45 */ + "vld1.32 {d4-d5}, [%[wc]] \n" /* load w24 */ + "sub %[inr4], %[inr4], #64 \n" /* inr4 -= 64 */ + "vmla.f32 q12, q7, q3 \n" /* out00 = w23 * inr43 */ + "vmla.f32 q13, q8, q3 \n" /* out01 = w23 * inr44 */ + "vld1.32 {d8-d11}, [%[inr0]]!\n" /* load inr0, 0-1 */ + "sub %[wc], %[wc], #384 \n" /* wptr = wptr - 384 */ + "vmla.f32 q14, q9, q3 \n" /* out02 = w23 * inr45 */ + "vmla.f32 q15, q10, q3 \n" /* out03 = w23 * inr46 */ + "vld1.32 {d0-d3}, [%[wc]]! \n" /* load w0-w1 */ + "vmla.f32 q12, q8, q2 \n" /* out00 = w24 * inr44 */ + "vmla.f32 q13, q9, q2 \n" /* out01 = w24 * inr45 */ + "vld1.32 {d12-d15}, [%[inr0]]!\n" /* load inr0, 2-3 */ + "vmla.f32 q14, q10, q2 \n" /* out02 = w24 * inr46 */ + "vmla.f32 q15, q11, q2 \n" /* out03 = w24 * inr47 */ + "vst1.32 {d24-d27}, [%[out0]]!\n" /* store out00, out01 */ + "vld1.32 {d4-d7}, [%[wc]]! \n" /* load w2-w3 */ + "subs %[cnt], %[cnt], #1 \n" /* cnt = cnt - 1 */ + "vst1.32 {d28-d31}, [%[out0]]!\n" /* store out02, out03 */ + "vld1.32 {d24-d25}, [%[bias]] \n" /* load bias to out00 */ + "bne 1b\n" + : [cnt] "+r"(cnt), + [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [inr3] "+r"(inr3), + [inr4] "+r"(inr4), + [wc] "+r"(wptr), + [out0] "+r"(ptr_out0) + : [bias] "r"(bias_local) + : "cc","memory", + "q0", "q1", "q2", "q3", "q4", "q5", + "q6", "q7", "q8", "q9", "q10", "q11", + "q12", "q13", "q14", "q15" + ); + // clang-format on + block_inr0 = block_inr1; + block_inr1 = block_inr2; + block_inr2 = block_inr3; + block_inr3 = block_inr4; + block_inr4 = block_inr3 + in_len; + } + write_to_output_c4_fp32(pre_out, + dout_batch, + c, + c + hout_c_block, + h, + h + h_kernel, + 0, + wout_round, chout, hout, wout, - chin, - hin, - win, - weights, - bias, - pad, - flag_bias, flag_relu, - ctx); + ptr_write, + &act_param); + } } } } - +#endif // __aarch64__ } // namespace math } // namespace arm } // namespace lite diff --git a/lite/backends/arm/math/conv5x5s1_depthwise_int8.cc b/lite/backends/arm/math/conv5x5s1_depthwise_int8.cc index 802082048c86beeeecfe64a0de09880b1b9b0137..ed3dad300804dc90fac874999ac5d0a420cff4a4 100644 --- a/lite/backends/arm/math/conv5x5s1_depthwise_int8.cc +++ b/lite/backends/arm/math/conv5x5s1_depthwise_int8.cc @@ -709,7 +709,6 @@ void conv_depthwise_5x5s1_int8(Dtype* dout, "q15"); #endif // clang-format on - int32_t* ptr_tmp = ptr_out0 - w_loop * 32; block_inr0 = block_inr1; block_inr1 = block_inr2; block_inr2 = block_inr3; diff --git a/lite/backends/arm/math/conv5x5s2_depthwise_fp32.cc b/lite/backends/arm/math/conv5x5s2_depthwise_fp32.cc index dced24db72f71630c0cb9d7ff4275f740a2b69a4..a72b7553e0c8fddcb9028b0e6125281a07e65387 100644 --- a/lite/backends/arm/math/conv5x5s2_depthwise_fp32.cc +++ b/lite/backends/arm/math/conv5x5s2_depthwise_fp32.cc @@ -13,3732 +13,932 @@ // limitations under the License. #include +#include "lite/backends/arm/math/conv_block_utils.h" #include "lite/backends/arm/math/conv_depthwise.h" +#include "lite/core/context.h" +#include "lite/operators/op_params.h" +#ifdef ARM_WITH_OMP +#include +#endif namespace paddle { namespace lite { namespace arm { namespace math { +#ifdef __aarch64__ +#define COMPUTE \ + "ldp q0, q1, [%[inr0]], #32\n" /* load r0, 0-1 */ \ + "and v19.16b, %[vbias].16b, %[vbias].16b\n" \ + "ldp q2, q3, [%[inr0]], #32\n" /* load r0, 2-3 */ \ + "and v20.16b, %[vbias].16b, %[vbias].16b\n" \ + "ldp q4, q5, [%[inr0]], #32\n" /* load r0, 4-5 */ \ + "and v21.16b, %[vbias].16b, %[vbias].16b\n" \ + "ldp q6, q7, [%[inr0]], #32\n" /* load r0, 6-7 */ \ + "and v22.16b, %[vbias].16b, %[vbias].16b\n" \ + "ldp q8, q9, [%[inr0]], #32\n" /* load r0, 8-9 */ \ + "fmla v19.4s , %[w0].4s, v0.4s\n" /* outr0 = w0 * r0, 0*/ \ + "fmla v20.4s , %[w0].4s, v2.4s\n" /* outr1 = w0 * r0, 2*/ \ + "fmla v21.4s , %[w0].4s, v4.4s\n" /* outr2 = w0 * r0, 4*/ \ + "fmla v22.4s , %[w0].4s, v6.4s\n" /* outr3 = w0 * r0, 6*/ \ + "ldr q10, [%[inr0]] \n" /* load r0, 10 */ \ + "fmla v19.4s , %[w1].4s, v1.4s\n" /* outr0 = w1 * r0, 1*/ \ + "fmla v20.4s , %[w1].4s, v3.4s\n" /* outr1 = w1 * r0, 3*/ \ + "fmla v21.4s , %[w1].4s, v5.4s\n" /* outr2 = w1 * r0, 5*/ \ + "fmla v22.4s , %[w1].4s, v7.4s\n" /* outr3 = w1 * r0, 7*/ \ + "sub %[inr0], %[inr0], #32\n" /* inr0 -= 32 */ \ + "ldp q0, q1, [%[inr1]], #32\n" /* load r1, 0-1 */ \ + "fmla v19.4s , %[w2].4s, v2.4s\n" /* outr0 = w0 * r0, 2*/ \ + "fmla v20.4s , %[w2].4s, v4.4s\n" /* outr1 = w0 * r0, 4*/ \ + "fmla v21.4s , %[w2].4s, v6.4s\n" /* outr2 = w0 * r0, 6*/ \ + "fmla v22.4s , %[w2].4s, v8.4s\n" /* outr3 = w0 * r0, 8*/ \ + "ldp q14, q15, [%[wc0]], #32\n" /* load w0-1, to q14-15*/ \ + "fmla v19.4s , %[w3].4s, v3.4s\n" /* outr0 = w3 * r1, 0*/ \ + "fmla v20.4s , %[w3].4s, v5.4s\n" /* outr1 = w3 * r1, 2*/ \ + "fmla v21.4s , %[w3].4s, v7.4s\n" /* outr2 = w3 * r1, 4*/ \ + "fmla v22.4s , %[w3].4s, v9.4s\n" /* outr3 = w3 * r1, 6*/ \ + "ldp q16, q17, [%[wc0]], #32\n" /* load w2-3, to q16-17*/ \ + "ldp q2, q3, [%[inr1]], #32\n" /* load r1, 2-3 */ \ + "fmla v19.4s , %[w4].4s, v4.4s\n" /* outr0 = w3 * r1, 0*/ \ + "fmla v20.4s , %[w4].4s, v6.4s\n" /* outr1 = w3 * r1, 2*/ \ + "fmla v21.4s , %[w4].4s, v8.4s\n" /* outr2 = w3 * r1, 4*/ \ + "fmla v22.4s , %[w4].4s, v10.4s\n" /* outr3 = w3 * r1, 6*/ \ + "ldp q4, q5, [%[inr1]], #32\n" /* load r1, 4-5 */ \ + "ldr q18, [%[wc0]], #16\n" /* load w4, to q18*/ \ + "ldp q6, q7, [%[inr1]], #32\n" /* load r0, 6-7 */ \ + "fmla v19.4s , v14.4s, v0.4s\n" /* outr0 = w0 * r0, 0*/ \ + "fmla v20.4s , v14.4s, v2.4s\n" /* outr1 = w0 * r0, 2*/ \ + "fmla v21.4s , v14.4s, v4.4s\n" /* outr2 = w0 * r0, 4*/ \ + "fmla v22.4s , v14.4s, v6.4s\n" /* outr3 = w0 * r0, 6*/ \ + "ldp q8, q9, [%[inr1]], #32\n" /* load r0, 8-9 */ \ + "fmla v19.4s , v15.4s, v1.4s\n" /* outr0 = w1 * r0, 1*/ \ + "fmla v20.4s , v15.4s, v3.4s\n" /* outr1 = w1 * r0, 3*/ \ + "fmla v21.4s , v15.4s, v5.4s\n" /* outr2 = w1 * r0, 5*/ \ + "fmla v22.4s , v15.4s, v7.4s\n" /* outr3 = w1 * r0, 7*/ \ + "ldr q10, [%[inr1]] \n" /* load r0, 10 */ \ + "fmla v19.4s , v16.4s, v2.4s\n" /* outr0 = w0 * r0, 2*/ \ + "fmla v20.4s , v16.4s, v4.4s\n" /* outr1 = w0 * r0, 4*/ \ + "fmla v21.4s , v16.4s, v6.4s\n" /* outr2 = w0 * r0, 6*/ \ + "fmla v22.4s , v16.4s, v8.4s\n" /* outr3 = w0 * r0, 8*/ \ + "sub %[inr1], %[inr1], #32\n" /* inr1 -= 32 */ \ + "ldp q0, q1, [%[inr2]], #32\n" /* load r1, 0-1 */ \ + "ldp q14, q15, [%[wc0]], #32\n" /* load w0-1, to q14-15*/ \ + "fmla v19.4s , v17.4s, v3.4s\n" /* outr0 = w3 * r1, 0*/ \ + "fmla v20.4s , v17.4s, v5.4s\n" /* outr1 = w3 * r1, 2*/ \ + "fmla v21.4s , v17.4s, v7.4s\n" /* outr2 = w3 * r1, 4*/ \ + "fmla v22.4s , v17.4s, v9.4s\n" /* outr3 = w3 * r1, 6*/ \ + "ldp q16, q17, [%[wc0]], #32\n" /* load w2-3, to q16-17*/ \ + "ldp q2, q3, [%[inr2]], #32\n" /* load r1, 2-3 */ \ + "fmla v19.4s , v18.4s, v4.4s\n" /* outr0 = w3 * r1, 0*/ \ + "fmla v20.4s , v18.4s, v6.4s\n" /* outr1 = w3 * r1, 2*/ \ + "fmla v21.4s , v18.4s, v8.4s\n" /* outr2 = w3 * r1, 4*/ \ + "fmla v22.4s , v18.4s, v10.4s\n" /* outr3 = w3 * r1, 6*/ \ + "ldp q4, q5, [%[inr2]], #32\n" /* load r1, 4-5 */ \ + "ldr q18, [%[wc0]], #16\n" /* load w4, to q18*/ \ + "ldp q6, q7, [%[inr2]], #32\n" /* load r0, 6-7 */ \ + "fmla v19.4s , v14.4s, v0.4s\n" /* outr0 = w0 * r0, 0*/ \ + "fmla v20.4s , v14.4s, v2.4s\n" /* outr1 = w0 * r0, 2*/ \ + "fmla v21.4s , v14.4s, v4.4s\n" /* outr2 = w0 * r0, 4*/ \ + "fmla v22.4s , v14.4s, v6.4s\n" /* outr3 = w0 * r0, 6*/ \ + "ldp q8, q9, [%[inr2]], #32\n" /* load r0, 8-9 */ \ + "fmla v19.4s , v15.4s, v1.4s\n" /* outr0 = w1 * r0, 1*/ \ + "fmla v20.4s , v15.4s, v3.4s\n" /* outr1 = w1 * r0, 3*/ \ + "fmla v21.4s , v15.4s, v5.4s\n" /* outr2 = w1 * r0, 5*/ \ + "fmla v22.4s , v15.4s, v7.4s\n" /* outr3 = w1 * r0, 7*/ \ + "ldr q10, [%[inr2]] \n" /* load r0, 10 */ \ + "fmla v19.4s , v16.4s, v2.4s\n" /* outr0 = w0 * r0, 2*/ \ + "fmla v20.4s , v16.4s, v4.4s\n" /* outr1 = w0 * r0, 4*/ \ + "fmla v21.4s , v16.4s, v6.4s\n" /* outr2 = w0 * r0, 6*/ \ + "fmla v22.4s , v16.4s, v8.4s\n" /* outr3 = w0 * r0, 8*/ \ + "sub %[inr2], %[inr2], #32\n" /* inr0 -= 32 */ \ + "ldp q0, q1, [%[inr3]], #32\n" /* load r1, 0-1 */ \ + "ldp q14, q15, [%[wc0]], #32\n" /* load w0-1, to q14-15*/ \ + "fmla v19.4s , v17.4s, v3.4s\n" /* outr0 = w3 * r1, 0*/ \ + "fmla v20.4s , v17.4s, v5.4s\n" /* outr1 = w3 * r1, 2*/ \ + "fmla v21.4s , v17.4s, v7.4s\n" /* outr2 = w3 * r1, 4*/ \ + "fmla v22.4s , v17.4s, v9.4s\n" /* outr3 = w3 * r1, 6*/ \ + "ldp q16, q17, [%[wc0]], #32\n" /* load w2-3, to q16-17*/ \ + "ldp q2, q3, [%[inr3]], #32\n" /* load r1, 2-3 */ \ + "fmla v19.4s , v18.4s, v4.4s\n" /* outr0 = w3 * r1, 0*/ \ + "fmla v20.4s , v18.4s, v6.4s\n" /* outr1 = w3 * r1, 2*/ \ + "fmla v21.4s , v18.4s, v8.4s\n" /* outr2 = w3 * r1, 4*/ \ + "fmla v22.4s , v18.4s, v10.4s\n" /* outr3 = w3 * r1, 6*/ \ + "ldp q4, q5, [%[inr3]], #32\n" /* load r1, 4-5 */ \ + "ldr q18, [%[wc0]], #16\n" /* load w4, to q18*/ \ + "ldp q6, q7, [%[inr3]], #32\n" /* load r0, 6-7 */ \ + "fmla v19.4s , v14.4s, v0.4s\n" /* outr0 = w0 * r0, 0*/ \ + "fmla v20.4s , v14.4s, v2.4s\n" /* outr1 = w0 * r0, 2*/ \ + "fmla v21.4s , v14.4s, v4.4s\n" /* outr2 = w0 * r0, 4*/ \ + "fmla v22.4s , v14.4s, v6.4s\n" /* outr3 = w0 * r0, 6*/ \ + "ldp q8, q9, [%[inr3]], #32\n" /* load r0, 8-9 */ \ + "fmla v19.4s , v15.4s, v1.4s\n" /* outr0 = w1 * r0, 1*/ \ + "fmla v20.4s , v15.4s, v3.4s\n" /* outr1 = w1 * r0, 3*/ \ + "fmla v21.4s , v15.4s, v5.4s\n" /* outr2 = w1 * r0, 5*/ \ + "fmla v22.4s , v15.4s, v7.4s\n" /* outr3 = w1 * r0, 7*/ \ + "ldr q10, [%[inr3]] \n" /* load r0, 10 */ \ + "fmla v19.4s , v16.4s, v2.4s\n" /* outr0 = w0 * r0, 2*/ \ + "fmla v20.4s , v16.4s, v4.4s\n" /* outr1 = w0 * r0, 4*/ \ + "fmla v21.4s , v16.4s, v6.4s\n" /* outr2 = w0 * r0, 6*/ \ + "fmla v22.4s , v16.4s, v8.4s\n" /* outr3 = w0 * r0, 8*/ \ + "sub %[inr3], %[inr3], #32\n" /* inr0 -= 32 */ \ + "ldp q0, q1, [%[inr4]], #32\n" /* load r1, 0-1 */ \ + "ldp q14, q15, [%[wc0]], #32\n" /* load w0-1, to q14-15*/ \ + "fmla v19.4s , v17.4s, v3.4s\n" /* outr0 = w3 * r1, 0*/ \ + "fmla v20.4s , v17.4s, v5.4s\n" /* outr1 = w3 * r1, 2*/ \ + "fmla v21.4s , v17.4s, v7.4s\n" /* outr2 = w3 * r1, 4*/ \ + "fmla v22.4s , v17.4s, v9.4s\n" /* outr3 = w3 * r1, 6*/ \ + "ldp q16, q17, [%[wc0]], #32\n" /* load w2-3, to q16-17*/ \ + "ldp q2, q3, [%[inr4]], #32\n" /* load r1, 2-3 */ \ + "fmla v19.4s , v18.4s, v4.4s\n" /* outr0 = w3 * r1, 0*/ \ + "fmla v20.4s , v18.4s, v6.4s\n" /* outr1 = w3 * r1, 2*/ \ + "fmla v21.4s , v18.4s, v8.4s\n" /* outr2 = w3 * r1, 4*/ \ + "fmla v22.4s , v18.4s, v10.4s\n" /* outr3 = w3 * r1, 6*/ \ + "ldp q4, q5, [%[inr4]], #32\n" /* load r1, 4-5 */ \ + "ldr q18, [%[wc0]], #16\n" /* load w4, to q18*/ \ + "ldp q6, q7, [%[inr4]], #32\n" /* load r0, 6-7 */ \ + "fmla v19.4s , v14.4s, v0.4s\n" /* outr0 = w0 * r0, 0*/ \ + "fmla v20.4s , v14.4s, v2.4s\n" /* outr1 = w0 * r0, 2*/ \ + "fmla v21.4s , v14.4s, v4.4s\n" /* outr2 = w0 * r0, 4*/ \ + "fmla v22.4s , v14.4s, v6.4s\n" /* outr3 = w0 * r0, 6*/ \ + "ldp q8, q9, [%[inr4]], #32\n" /* load r0, 8-9 */ \ + "fmla v19.4s , v15.4s, v1.4s\n" /* outr0 = w1 * r0, 1*/ \ + "fmla v20.4s , v15.4s, v3.4s\n" /* outr1 = w1 * r0, 3*/ \ + "fmla v21.4s , v15.4s, v5.4s\n" /* outr2 = w1 * r0, 5*/ \ + "fmla v22.4s , v15.4s, v7.4s\n" /* outr3 = w1 * r0, 7*/ \ + "ldr q10, [%[inr4]] \n" /* load r0, 10 */ \ + "fmla v19.4s , v16.4s, v2.4s\n" /* outr0 = w0 * r0, 2*/ \ + "fmla v20.4s , v16.4s, v4.4s\n" /* outr1 = w0 * r0, 4*/ \ + "fmla v21.4s , v16.4s, v6.4s\n" /* outr2 = w0 * r0, 6*/ \ + "fmla v22.4s , v16.4s, v8.4s\n" /* outr3 = w0 * r0, 8*/ \ + "sub %[inr4], %[inr4], #32\n" /* inr0 -= 32 */ \ + "fmla v19.4s , v17.4s, v3.4s\n" /* outr0 = w3 * r1, 0*/ \ + "fmla v20.4s , v17.4s, v5.4s\n" /* outr1 = w3 * r1, 2*/ \ + "fmla v21.4s , v17.4s, v7.4s\n" /* outr2 = w3 * r1, 4*/ \ + "fmla v22.4s , v17.4s, v9.4s\n" /* outr3 = w3 * r1, 6*/ \ + "fmla v19.4s , v18.4s, v4.4s\n" /* outr0 = w3 * r1, 0*/ \ + "fmla v20.4s , v18.4s, v6.4s\n" /* outr1 = w3 * r1, 2*/ \ + "fmla v21.4s , v18.4s, v8.4s\n" /* outr2 = w3 * r1, 4*/ \ + "fmla v22.4s , v18.4s, v10.4s\n" /* outr3 = w3 * r1, 6*/ \ + "sub %[wc0], %[wc0], #320\n" /* weight -= 320 */ \ + "trn1 v0.4s, v19.4s, v20.4s\n" /* r0: a0a1c0c1*/ \ + "trn2 v1.4s, v19.4s, v20.4s\n" /* r0: b0b1d0d1*/ \ + "trn1 v2.4s, v21.4s, v22.4s\n" /* r0: a2a3c2c3*/ \ + "trn2 v3.4s, v21.4s, v22.4s\n" /* r0: b2b3d2d3*/ \ + "trn1 v19.2d, v0.2d, v2.2d\n" /* r0: a0a1a2a3*/ \ + "trn2 v21.2d, v0.2d, v2.2d\n" /* r0: c0c1c2c3*/ \ + "trn1 v20.2d, v1.2d, v3.2d\n" /* r0: b0b1b2b3*/ \ + "trn2 v22.2d, v1.2d, v3.2d\n" /* r0: d0d1d2d3*/ +#define RELU /* relu */ \ + "movi v0.4s, #0\n" /* for relu */ \ + "fmax v19.4s, v19.4s, v0.4s\n" \ + "fmax v20.4s, v20.4s, v0.4s\n" \ + "fmax v21.4s, v21.4s, v0.4s\n" \ + "fmax v22.4s, v22.4s, v0.4s\n" +#define RELU6 /* relu6 */ \ + "fmin v19.4s, v19.4s, %[vsix].4s\n" \ + "fmin v20.4s, v20.4s, %[vsix].4s\n" \ + "fmin v21.4s, v21.4s, %[vsix].4s\n" \ + "fmin v22.4s, v22.4s, %[vsix].4s\n" +#define LEAKY_RELU /* LeakyRelu */ \ + "movi v0.4s, #0\n" /* for relu */ \ + "fcmge v1.4s, v19.4s, v0.4s \n" /* vcgeq_f32 */ \ + "fmul v2.4s, v19.4s, %[vscale].4s \n" /* mul */ \ + "fcmge v3.4s, v20.4s, v0.4s \n" /* vcgeq_f32 */ \ + "fmul v4.4s, v20.4s, %[vscale].4s \n" /* mul */ \ + "fcmge v5.4s, v21.4s, v0.4s \n" /* vcgeq_f32 */ \ + "fmul v6.4s, v21.4s, %[vscale].4s \n" /* mul */ \ + "fcmge v7.4s, v22.4s, v0.4s \n" /* vcgeq_f32 */ \ + "fmul v8.4s, v22.4s, %[vscale].4s \n" /* mul */ \ + "bif v19.16b, v2.16b, v1.16b \n" /* choose*/ \ + "bif v20.16b, v4.16b, v3.16b \n" /* choose*/ \ + "bif v21.16b, v6.16b, v5.16b \n" /* choose*/ \ + "bif v22.16b, v8.16b, v7.16b \n" /* choose*/ +#define STORE /* save result */ \ + "str q19, [%[outc0]], #16\n" \ + "str q20, [%[outc1]], #16\n" \ + "str q21, [%[outc2]], #16\n" \ + "str q22, [%[outc3]], #16\n" -void conv_depthwise_5x5s2p2(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - bool flag_bias, - bool flag_relu, - ARMContext* ctx); - -void conv_depthwise_5x5s2p2_relu(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - bool flag_bias, - bool flag_relu, - ARMContext* ctx); - -void conv_depthwise_5x5s2p2_s(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - bool flag_bias, - bool flag_relu, - ARMContext* ctx); - -void conv_depthwise_5x5s2p2_relu_s(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - bool flag_bias, - bool flag_relu, - ARMContext* ctx); - -void conv_depthwise_5x5s2_fp32(const float* din, - float* dout, - int num, - int chout, - int hout, - int wout, - int chin, - int hin, +#else +#define COMPUTE \ + /* fill with bias */ \ + "vld1.32 {d12-d13}, [%[bias]]\n" /* load bias */ /* load weights */ \ + "vld1.32 {d14-d17}, [%[wc0]]!\n" /* load w0-1, to q7-8 */ \ + "vld1.32 {d0-d3}, [%[r0]]!\n" /* load input r0, 0,1*/ \ + "vand.i32 q12, q6, q6\n" \ + "vld1.32 {d4-d7}, [%[r0]]!\n" /* load input r0, 2,3*/ \ + "vand.i32 q13, q6, q6\n" \ + "vld1.32 {d8-d11}, [%[r0]]!\n" /* load input r0, 4,5*/ \ + "vand.i32 q14, q6, q6\n" \ + "vand.i32 q15, q6, q6\n" \ + "vld1.32 {d12-d13}, [%[r0]]!\n" /* load input r0, 6*/ \ + "vmla.f32 q12, q7, q0 @ w0 * inr0\n" \ + "vmla.f32 q13, q7, q2 @ w0 * inr2\n" \ + "vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w2-3, to q9-q10 */ \ + "vmla.f32 q14, q7, q4 @ w0 * inr4\n" \ + "vmla.f32 q15, q7, q6 @ w0 * inr6\n" \ + "vmla.f32 q12, q8, q1 @ w1 * inr1\n" \ + "vmla.f32 q13, q8, q3 @ w1 * inr3\n" \ + "vmla.f32 q14, q8, q5 @ w1 * inr5\n" \ + "vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w4, to q11 */ \ + "vmla.f32 q12, q9, q2 @ w2 * inr2\n" \ + "vmla.f32 q13, q9, q4 @ w2 * inr6\n" \ + "vmla.f32 q14, q9, q6 @ w2 * inr4\n" \ + "vld1.32 {d0-d3}, [%[r0]]! \n" /* load r0, 7-8 */ \ + "vmla.f32 q12, q10, q3 @ w3 * inr3\n" \ + "vmla.f32 q13, q10, q5 @ w3 * inr5\n" \ + "vmla.f32 q14, q10, q0 @ w3 * inr7\n" \ + "vmla.f32 q15, q8, q0 @ w1 * inr7\n" \ + "vld1.32 {d4-d7}, [%[r0]] \n" /* load r0, 9-10 */ \ + "vmla.f32 q12, q11, q4 @ w4 * inr4\n" \ + "vmla.f32 q13, q11, q6 @ w4 * inr6\n" \ + "vmla.f32 q14, q11, q1 @ w4 * inr8\n" \ + "vmla.f32 q15, q9, q1 @ w2 * inr8\n" \ + "vld1.32 {d0-d3}, [%[r1]]! @ load r1, 0, 1\n" \ + "vld1.32 {d14-d17}, [%[wc0]]!\n" /* load w0-1, to q7-8 */ \ + "vmla.f32 q15, q10, q2 @ w3 * inr9\n" \ + "vld1.32 {d4-d5}, [%[r1]]! @ load r1, 2\n" \ + "sub %[r0], %[r0], #16 @ r0 - 16 to nextline address\n" \ + "vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w2-3, to q9-10 */ \ + "vmla.f32 q12, q7, q0 @ w0 * inr0\n" \ + "vmla.f32 q13, q7, q2 @ w0 * inr2\n" \ + "vmla.f32 q15, q11, q3 @ w4 * inr10\n" \ + "vld1.32 {d6-d9}, [%[r1]]! @ load r1, 3, 4\n" \ + "vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w4, to q11 */ \ + "vld1.32 {d10-d13}, [%[r1]]! @ load r1, 5, 6\n" \ + "vmla.f32 q14, q7, q4 @ w0 * inr0\n" \ + "vmla.f32 q15, q7, q6 @ w0 * inr2\n" \ + "vmla.f32 q12, q8, q1 @ w1 * inr1\n" \ + "vmla.f32 q13, q8, q3 @ w1 * inr3\n" \ + "vld1.32 {d0-d3}, [%[r1]]! @ load r1, 7, 8\n" \ + "vmla.f32 q14, q8, q5 @ w1 * inr5\n" \ + "vmla.f32 q15, q8, q0 @ w1 * inr7\n" \ + "vmla.f32 q12, q9, q2 @ w2 * inr2\n" \ + "vmla.f32 q13, q9, q4 @ w2 * inr4\n" \ + "vmla.f32 q14, q9, q6 @ w2 * inr6\n" \ + "vmla.f32 q15, q9, q1 @ w2 * inr8\n" \ + "vmla.f32 q12, q10, q3 @ w3 * inr3\n" \ + "vld1.32 {d4-d7}, [%[r1]] @ load r1, 9, 10\n" \ + "vmla.f32 q13, q10, q5 @ w3 * inr5\n" \ + "vmla.f32 q14, q10, q0 @ w3 * inr7\n" \ + "vmla.f32 q15, q10, q2 @ w3 * inr9\n" \ + "vld1.32 {d14-d17}, [%[wc0]]!\n" /* load w0-1, to q7-8 */ \ + "vmla.f32 q12, q11, q4 @ w4 * inr4\n" \ + "vmla.f32 q13, q11, q6 @ w4 * inr6\n" \ + "vmla.f32 q14, q11, q1 @ w4 * inr8\n" \ + "vmla.f32 q15, q11, q3 @ w4 * inr10\n" \ + "vld1.32 {d0-d3}, [%[r2]]! @ load r2, 0, 1\n" \ + "vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w2-3, to q9-10 */ \ + "sub %[r1], %[r1], #16 @ r1 - 16 to nextline address\n" \ + "vld1.32 {d4-d7}, [%[r2]]! @ load r2, 2, 3\n" \ + "vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w4 to q11 */ \ + "vmla.f32 q12, q7, q0 @ w0 * inr0\n" \ + "vmla.f32 q13, q7, q2 @ w0 * inr2\n" \ + "vld1.32 {d8-d11}, [%[r2]]! @ load r2, 4, 5\n" \ + "vmla.f32 q12, q8, q1 @ w1 * inr1\n" \ + "vmla.f32 q13, q8, q3 @ w1 * inr3\n" \ + "vld1.32 {d12-d13}, [%[r2]]! @ load r2, 6 \n" \ + "vmla.f32 q14, q7, q4 @ w0 * inr4\n" \ + "vmla.f32 q15, q7, q6 @ w0 * inr6\n" \ + "vld1.32 {d0-d3}, [%[r2]]! @ load r2, 7, 8\n" \ + "vmla.f32 q12, q9, q2 @ w2 * inr2\n" \ + "vmla.f32 q13, q9, q4 @ w2 * inr4\n" \ + "vmla.f32 q14, q8, q5 @ w1 * inr5\n" \ + "vmla.f32 q15, q8, q0 @ w1 * inr7\n" \ + "vmla.f32 q12, q10, q3 @ w3 * inr3\n" \ + "vmla.f32 q13, q10, q5 @ w3 * inr5\n" \ + "vmla.f32 q14, q9, q6 @ w2 * inr6\n" \ + "vmla.f32 q15, q9, q1 @ w2 * inr8\n" \ + "vld1.32 {d4-d7}, [%[r2]] @ load r2, 9, 10\n" \ + "vmla.f32 q12, q11, q4 @ w4 * inr4\n" \ + "vmla.f32 q13, q11, q6 @ w4 * inr6\n" \ + "vmla.f32 q14, q10, q0 @ w3 * inr7\n" \ + "vmla.f32 q15, q10, q2 @ w3 * inr9\n" \ + "vld1.32 {d14-d17}, [%[wc0]]!\n" /* load w0-1, to q7-8 */ \ + "sub %[r2], %[r2], #16 @ r1 - 16 to nextline address\n" \ + "vmla.f32 q14, q11, q1 @ w4 * inr8\n" \ + "vld1.32 {d0-d3}, [%[r3]]! @ load r3, 0, 1\n" \ + "vmla.f32 q15, q11, q3 @ w4 * inr10\n" \ + "vld1.32 {d4-d7}, [%[r3]]! @ load r3, 2, 3\n" \ + "vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w2-3, to q9-10 */ \ + "vmla.f32 q12, q7, q0 @ w0 * inr0\n" \ + "vmla.f32 q13, q7, q2 @ w0 * inr2\n" \ + "vld1.32 {d8-d11}, [%[r3]]! @ load r3, 4, 5\n" \ + "vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w4 to q11 */ \ + "vld1.32 {d12-d13}, [%[r3]]! @ load r3, 6, \n" \ + "vmla.f32 q12, q8, q1 @ w1 * inr1\n" \ + "vmla.f32 q13, q8, q3 @ w1 * inr3\n" \ + "vmla.f32 q14, q7, q4 @ w0 * inr4\n" \ + "vmla.f32 q15, q7, q6 @ w0 * inr6\n" \ + "vld1.32 {d0-d3}, [%[r3]]! @ load r3, 7, 8\n" \ + "vmla.f32 q12, q9, q2 @ w2 * inr2\n" \ + "vmla.f32 q13, q9, q4 @ w2 * inr4\n" \ + "vmla.f32 q14, q8, q5 @ w1 * inr5\n" \ + "vmla.f32 q15, q8, q0 @ w1 * inr7\n" \ + "vmla.f32 q12, q10, q3 @ w3 * inr3\n" \ + "vld1.32 {d4-d7}, [%[r3]] @ load r3, 9, 10\n" \ + "vmla.f32 q13, q10, q5 @ w3 * inr5\n" \ + "vmla.f32 q14, q9, q6 @ w2 * inr6\n" \ + "vmla.f32 q15, q9, q1 @ w2 * inr8\n" \ + "vmla.f32 q12, q11, q4 @ w4 * inr4\n" \ + "vmla.f32 q13, q11, q6 @ w4 * inr6\n" \ + "vmla.f32 q14, q10, q0 @ w3 * inr7\n" \ + "vmla.f32 q15, q10, q2 @ w3 * inr9\n" \ + "vld1.32 {d14-d17}, [%[wc0]]!\n" /* load w0-1, to q7-8 */ \ + "sub %[r3], %[r3], #16 @ r1 - 16 to nextline address\n" \ + "vmla.f32 q14, q11, q1 @ w4 * inr8\n" \ + "vld1.32 {d0-d3}, [%[r4]]! @ load r4, 0, 1\n" \ + "vmla.f32 q15, q11, q3 @ w4 * inr10\n" \ + "vld1.32 {d4-d7}, [%[r4]]! @ load r4, 2, 3\n" \ + "vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w2-3, to q9-10 */ \ + "vmla.f32 q12, q7, q0 @ w0 * inr0\n" \ + "vmla.f32 q13, q7, q2 @ w0 * inr2\n" \ + "vld1.32 {d8-d11}, [%[r4]]! @ load r3, 4, 5\n" \ + "vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w4 to q11 */ \ + "vld1.32 {d12-d13}, [%[r4]]! @ load r3, 6, \n" \ + "vmla.f32 q12, q8, q1 @ w1 * inr1\n" \ + "vmla.f32 q13, q8, q3 @ w1 * inr3\n" \ + "vmla.f32 q14, q7, q4 @ w0 * inr4\n" \ + "vmla.f32 q15, q7, q6 @ w0 * inr6\n" \ + "vld1.32 {d0-d3}, [%[r4]]! @ load r3, 7, 8\n" \ + "vmla.f32 q12, q9, q2 @ w2 * inr2\n" \ + "vmla.f32 q13, q9, q4 @ w2 * inr4\n" \ + "vmla.f32 q14, q8, q5 @ w1 * inr5\n" \ + "vmla.f32 q15, q8, q0 @ w1 * inr7\n" \ + "vmla.f32 q12, q10, q3 @ w3 * inr3\n" \ + "vld1.32 {d4-d7}, [%[r4]] @ load r3, 9, 10\n" \ + "vmla.f32 q13, q10, q5 @ w3 * inr5\n" \ + "vmla.f32 q14, q9, q6 @ w2 * inr6\n" \ + "vmla.f32 q15, q9, q1 @ w2 * inr8\n" \ + "vmla.f32 q12, q11, q4 @ w4 * inr4\n" \ + "vmla.f32 q13, q11, q6 @ w4 * inr6\n" \ + "vmla.f32 q14, q10, q0 @ w3 * inr7\n" \ + "vmla.f32 q15, q10, q2 @ w3 * inr9\n" \ + "sub %[wc0], %[wc0], #400 @ wc0 - 400 to start address\n" \ + "sub %[r4], %[r4], #16 @ r1 - 16 to nextline address\n" \ + "vmla.f32 q14, q11, q1 @ w4 * inr8\n" \ + "vmla.f32 q15, q11, q3 @ w4 * inr10\n" \ + "vtrn.32 q12, q13\n" /* a0a1c0c1, b0b1d0d1*/ \ + "vtrn.32 q14, q15\n" /* a2a3c2c3, b2b3d2d3*/ \ + "vswp d25, d28\n" /* a0a1a2a3, c0c1c2c3*/ \ + "vswp d27, d30\n" /* b0b1b2b3, d0d1d2d3*/ + +#define RELU /* relu */ \ + "vmov.u32 q0, #0\n" \ + "vld1.32 {d2-d3}, [%[six_ptr]]\n" \ + "vmax.f32 q12, q12, q0\n" \ + "vmax.f32 q13, q13, q0\n" \ + "vmax.f32 q14, q14, q0\n" \ + "vmax.f32 q15, q15, q0\n" +#define RELU6 /* relu6 */ \ + "vmin.f32 q12, q12, q1\n" \ + "vmin.f32 q13, q13, q1\n" \ + "vmin.f32 q14, q14, q1\n" \ + "vmin.f32 q15, q15, q1\n" +#define LEAKY_RELU /* LeakyRelu */ \ + "vmov.u32 q0, #0\n" \ + "vld1.32 {d2-d3}, [%[scale_ptr]]\n" \ + "vcge.f32 q2, q12, q0 @ q0 > 0 \n" \ + "vcge.f32 q4, q13, q0 @ q0 > 0 \n" \ + "vcge.f32 q6, q14, q0 @ q0 > 0 \n" \ + "vcge.f32 q8, q15, q0 @ q0 > 0 \n" \ + "vmul.f32 q3, q12, q1 @ mul \n" \ + "vmul.f32 q5, q13, q1 @ mul \n" \ + "vmul.f32 q7, q14, q1 @ mul \n" \ + "vmul.f32 q9, q15, q1 @ mul \n" \ + "vbif q12, q3, q2 @ choose \n" \ + "vbif q13, q5, q4 @ choose \n" \ + "vbif q14, q7, q6 @ choose \n" \ + "vbif q15, q9, q8 @ choose \n" +#define STORE /* save result */ \ + "vst1.32 {d24-d25}, [%[outc0]]!\n" /* save outc0*/ \ + "vst1.32 {d26-d27}, [%[outc1]]!\n" /* save outc1*/ \ + "vst1.32 {d28-d29}, [%[outc2]]!\n" /* save outc2*/ \ + "vst1.32 {d30-d31}, [%[outc3]]!\n" /* save outc3*/ + +#endif + +void act_switch_5x5s2(const float* inr0, + const float* inr1, + const float* inr2, + const float* inr3, + const float* inr4, + float* outc0, + float* outc1, + float* outc2, + float* outc3, + float32x4_t w0, + float32x4_t w1, + float32x4_t w2, + float32x4_t w3, + float32x4_t w4, + float32x4_t vbias, + const float* weight_c, + float* bias_local, + const operators::ActivationParam act_param) { + bool has_active = act_param.has_active; + if (has_active) { + float tmp = act_param.Relu_clipped_coef; + float ss = act_param.Leaky_relu_alpha; +#ifdef __aarch64__ + float32x4_t vsix = vdupq_n_f32(tmp); + float32x4_t vscale = vdupq_n_f32(ss); +#else + float vsix[4] = {tmp, tmp, tmp, tmp}; + float vscale[4] = {ss, ss, ss, ss}; +#endif + switch (act_param.active_type) { + case lite_api::ActivationType::kRelu: +#ifdef __aarch64__ + asm volatile(COMPUTE RELU STORE + : [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [inr3] "+r"(inr3), + [inr4] "+r"(inr4), + [wc0] "+r"(weight_c), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [vbias] "w"(vbias) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22"); +#else + asm volatile(COMPUTE RELU STORE + : [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [r3] "+r"(inr3), + [r4] "+r"(inr4), + [wc0] "+r"(weight_c), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3) + : [bias] "r"(bias_local), [six_ptr] "r"(vsix) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + break; + case lite_api::ActivationType::kRelu6: +#ifdef __aarch64__ + asm volatile(COMPUTE RELU RELU6 STORE + : [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [inr3] "+r"(inr3), + [inr4] "+r"(inr4), + [wc0] "+r"(weight_c), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [vbias] "w"(vbias), + [vsix] "w"(vsix) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22"); +#else + asm volatile(COMPUTE RELU RELU6 STORE + : [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [r3] "+r"(inr3), + [r4] "+r"(inr4), + [wc0] "+r"(weight_c), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3) + : [bias] "r"(bias_local), [six_ptr] "r"(vsix) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + break; + case lite_api::ActivationType::kLeakyRelu: +#ifdef __aarch64__ + asm volatile(COMPUTE LEAKY_RELU STORE + : [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [inr3] "+r"(inr3), + [inr4] "+r"(inr4), + [wc0] "+r"(weight_c), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [vbias] "w"(vbias), + [vscale] "w"(vscale) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22"); +#else + asm volatile(COMPUTE LEAKY_RELU STORE + : [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [r3] "+r"(inr3), + [r4] "+r"(inr4), + [wc0] "+r"(weight_c), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3) + : [bias] "r"(bias_local), [scale_ptr] "r"(vscale) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param.active_type) + << " fuse not support"; + } + } else { +#ifdef __aarch64__ + asm volatile(COMPUTE STORE + : [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [inr3] "+r"(inr3), + [inr4] "+r"(inr4), + [wc0] "+r"(weight_c), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [vbias] "w"(vbias) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22"); +#else + asm volatile(COMPUTE STORE + : [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [r3] "+r"(inr3), + [r4] "+r"(inr4), + [wc0] "+r"(weight_c), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3) + : [bias] "r"(bias_local) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + } +} +void conv_depthwise_5x5s2_fp32(const float* i_data, + float* o_data, + int bs, + int oc, + int oh, + int ow, + int ic, + int ih, int win, const float* weights, const float* bias, - int pad, - bool flag_bias, - bool flag_relu, + const operators::ConvParam& param, + const operators::ActivationParam act_param, ARMContext* ctx) { - if (pad == 2) { - if (win >= 9) { - if (flag_relu) { - conv_depthwise_5x5s2p2_relu(din, - dout, - num, - chout, - hout, - wout, - chin, - hin, - win, - weights, - bias, - flag_bias, - flag_relu, - ctx); - } else { - conv_depthwise_5x5s2p2(din, - dout, - num, - chout, - hout, - wout, - chin, - hin, - win, - weights, - bias, - flag_bias, - flag_relu, - ctx); - } - } else { - if (flag_relu) { - conv_depthwise_5x5s2p2_relu_s(din, - dout, - num, - chout, - hout, - wout, - chin, - hin, - win, - weights, - bias, - flag_bias, - flag_relu, - ctx); - } else { - conv_depthwise_5x5s2p2_s(din, - dout, - num, - chout, - hout, - wout, - chin, - hin, - win, - weights, - bias, - flag_bias, - flag_relu, - ctx); + auto paddings = *param.paddings; + int threads = ctx->threads(); + const int pad_h = paddings[0]; + const int pad_w = paddings[2]; + const int out_c_block = 4; + const int out_h_kernel = 1; + const int out_w_kernel = 4; + const int win_ext = ow * 2 + 3; + const int ow_round = ROUNDUP(ow, 4); + const int win_round = ROUNDUP(win_ext, 4); + const int hin_round = oh * 2 + 3; + const int prein_size = win_round * hin_round * out_c_block; + auto workspace_size = threads * prein_size + win_round + ow_round; + ctx->ExtendWorkspace(sizeof(float) * workspace_size); + + bool flag_bias = param.bias != nullptr; + + /// get workspace + auto ptr_zero = ctx->workspace_data(); + memset(ptr_zero, 0, sizeof(float) * win_round); + float* ptr_write = ptr_zero + win_round; + + int size_in_channel = win * ih; + int size_out_channel = ow * oh; + + int ws = -pad_w; + int we = ws + win_round; + int hs = -pad_h; + int he = hs + hin_round; + int w_loop = ow_round / 4; + auto remain = w_loop * 4 - ow; + bool flag_remain = remain > 0; + remain = 4 - remain; + remain = remain > 0 ? remain : 0; + int row_len = win_round * out_c_block; + + float32x4_t vzero = vdupq_n_f32(0.f); + + for (int n = 0; n < bs; ++n) { + const float* din_batch = i_data + n * ic * size_in_channel; + float* dout_batch = o_data + n * oc * size_out_channel; +#pragma omp parallel for num_threads(threads) + for (int c = 0; c < oc; c += out_c_block) { +#ifdef ARM_WITH_OMP + float* pre_din = ptr_write + ow_round + omp_get_thread_num() * prein_size; +#else + float* pre_din = ptr_write + ow_round; +#endif + /// const array size + prepack_input_nxwc4_dw( + din_batch, pre_din, c, hs, he, ws, we, ic, win, ih, ptr_zero); + const float* weight_c = weights + c * 25; // kernel_w * kernel_h + float* dout_c00 = dout_batch + c * size_out_channel; + float bias_local[4] = {0, 0, 0, 0}; + + if (flag_bias) { + bias_local[0] = bias[c]; + bias_local[1] = bias[c + 1]; + bias_local[2] = bias[c + 2]; + bias_local[3] = bias[c + 3]; } - } - } -} - #ifdef __aarch64__ - -//! larger depthwise, win >= 9; -void conv_depthwise_5x5s2p2(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - bool flag_bias, - bool flag_relu, - ARMContext* ctx) { - CHECK_GE(w_in, 9) << "only support win >= 9"; - int w_out_round = (w_out + 3) / 4 * 4; - int cnt = (w_out_round - 4) / 4; - int mid_cnt = cnt - 1; - int right_start = cnt * 2 * 4 - 2; - int mask_cnt = 12 - (w_in - right_start); - int mask[12]; - memset(mask, 0xff, 12 * sizeof(int)); - for (int i = 0; i < mask_cnt; ++i) { - mask[11 - i] = 0; - } - float* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - float* write_ptr = zero_ptr + w_in; - int in_spatial_size = w_in * h_in; - int out_spatial_size = w_out * h_out; - int weights_saptial_size = 25; - - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * in_spatial_size * ch_in; - float* dout_batch = dout + n * out_spatial_size * ch_out; -#pragma omp parallel for - for (int c = 0; c < ch_in; ++c) { - const float* din_ch = din_batch + c * in_spatial_size; - float* dout_ch = dout_batch + c * out_spatial_size; - const float* din0 = zero_ptr; - const float* din1 = zero_ptr; - const float* din2 = din_ch; - const float* din3 = din2 + w_in; - const float* din4 = din3 + w_in; - const float* din5 = din4 + w_in; - const float* din6 = din5 + w_in; - - float out_buf0[4]; - float out_buf1[4]; - float* dout0 = dout_ch; - float* dout1 = dout0 + w_out; - - const float* weights_c = weights + c * weights_saptial_size; - for (int h = 0; h < h_out; h += 2) { - //! (h * 2 - 2) + 6 > h_in - 1 - if (h * 2 + 5 > h_in) { - switch (h * 2 + 5 - h_in) { - case 6: - din1 = zero_ptr; - case 5: - din2 = zero_ptr; - case 4: - din3 = zero_ptr; - case 3: - din4 = zero_ptr; - case 2: - din5 = zero_ptr; - case 1: - din6 = zero_ptr; - default: - break; - } - } - if (h + 2 > h_out) { - switch (h + 2 - h_out) { - case 1: - dout1 = write_ptr; - default: - break; - } - } - const float* din_ptr0 = din0; - const float* din_ptr1 = din1; - const float* din_ptr2 = din2; - const float* din_ptr3 = din3; - const float* din_ptr4 = din4; - const float* din_ptr5 = din5; - const float* din_ptr6 = din6; - - const float* weights_ptr = weights_c; - float* dout_ptr0 = dout0; - float* dout_ptr1 = dout1; - - float bias_c = 0.f; - if (flag_bias) { - bias_c = bias[c]; - } - float vbias[4] = {bias_c, bias_c, bias_c, bias_c}; - int* mask_ptr = mask; - int loop = mid_cnt; - const int s_8 = 8; - const int s_16 = 16; - - //! in r0, r1/r4, r2/r5, r3/r6: x 0 2 4 -- v8 v13 v18 v23 - //! in r0, r1/r4, r2/r5, r3/r6: x 1 3 5 -- v9 v14 v19 v24 - //! in r0, r1/r4, r2/r5, r3/r6: 0 2 4 6 -- v6 v11 v16 v21 - //! in r0, r1/r4, r2/r5, r3/r6: 1 3 5 7 -- v7 v12 v17 v22 - //! in r0, r1/r4, r2/r5, r3/r6: 2 4 6 8 -- v10 v15 v20 v25 - //! out r0, r1 -- v26, v27 - asm volatile( - "movi v31.4s, #0x0\n" - "prfm pldl1keep, [%[din_ptr0]] \n" - "prfm pldl1keep, [%[din_ptr1]] \n" - "prfm pldl1keep, [%[din_ptr2]] \n" - "prfm pldl1keep, [%[din_ptr3]] \n" - "prfm pldl1keep, [%[din_ptr4]] \n" - "prfm pldl1keep, [%[din_ptr5]] \n" - "prfm pldl1keep, [%[din_ptr6]] \n" - "prfm pldl1keep, [%[weights]] \n" - "prfm pldl1keep, [%[mask]] \n" - // left - "ld2 {v6.4s, v7.4s}, [%[din_ptr0]], #32 \n" // r0 v6: 0 - // 2 4 6, - // v7: 1 3 - // 5 7 - "ext v8.16b, v31.16b, v6.16b, #12 \n" // r0 v8: x - // 0 2 4 - "ld2 {v11.4s, v12.4s}, [%[din_ptr1]], #32 \n" // r1 v11: - // 0 2 4 6, - // v12: 1 3 - // 5 7 - "ext v9.16b, v31.16b, v7.16b, #12 \n" // r0 v9: x - // 1 3 5 - "ld1 {v0.4s, v1.4s}, [%[weights]], #32 \n" // load - // weights - // 0-7 - "ext v10.16b, v6.16b, v31.16b, #4 \n" - "ld1 {v10.s}[3], [%[din_ptr0]] \n" // r0 v10: - // 2 4 6 8 - "sub %[din_ptr0], %[din_ptr0], #8 \n" - "ext v13.16b, v31.16b, v11.16b, #12 \n" // r1 v13: - // x 0 2 4 - "ld2 {v16.4s, v17.4s}, [%[din_ptr2]], #32 \n" // r2 v16: - // 0 2 4 6, - // v17: 1 3 - // 5 7 - "ext v14.16b, v31.16b, v12.16b, #12 \n" // r1 v14: - // x 1 3 5 - "ld1 {v2.4s, v3.4s}, [%[weights]], #32 \n" // load - // weights - // 8-15 - "ext v15.16b, v11.16b, v31.16b, #4 \n" - "ld1 {v15.s}[3], [%[din_ptr1]] \n" // r1 v15: - // 2 4 6 - "sub %[din_ptr1], %[din_ptr1], #8 \n" - "ext v18.16b, v31.16b, v16.16b, #12 \n" // r2 v18: - // x 0 2 4 - "ld1 {v4.4s, v5.4s}, [%[weights]], #32 \n" // load - // weights - // 16-23 - "ext v19.16b, v31.16b, v17.16b, #12 \n" // r2 v19: - // x 1 3 5 - "ld2 {v21.4s, v22.4s}, [%[din_ptr3]], #32 \n" // r3 v21: - // 0 2 4 6, - // v22: 1 3 - // 5 7 - "ext v20.16b, v16.16b, v31.16b, #4 \n" - "ld1 {v20.s}[3], [%[din_ptr2]] \n" // r2 v20: - // 2 4 6 8 - "sub %[din_ptr2], %[din_ptr2], #8 \n" - "ext v23.16b, v31.16b, v21.16b, #12 \n" // r3 v23: - // x 0 2 4 - "ld1 {v30.4s}, [%[weights]] \n" // load - // weights - // 24 - "ext v24.16b, v31.16b, v22.16b, #12 \n" // r3 v24: - // x 1 3 5 - "ld1 {v26.4s}, [%[vbias]] \n" // load - // bias to - // out_r0 - "ext v25.16b, v21.16b, v31.16b, #4 \n" - "ld1 {v25.s}[3], [%[din_ptr3]] \n" // r2 v25: - // 2 4 6 8 - "sub %[din_ptr3], %[din_ptr3], #8 \n" - "mov v27.16b, v26.16b \n" // load - // bias to - // out_r1 - "mov v28.16b, v31.16b \n" // load - // zero to - // out_r0 - "mov v29.16b, v31.16b \n" // load - // zero to - // out_r1 - - "fmla v26.4s, v8.4s, v0.s[0] \n" // out r0: - // w0 - "fmla v28.4s, v9.4s, v0.s[1] \n" // out r0: - // w1 - "fmla v26.4s, v6.4s, v0.s[2] \n" // out r0: - // w2 - "fmla v28.4s, v7.4s, v0.s[3] \n" // out r0: - // w3 - - "ld2 {v8.4s, v9.4s}, [%[din_ptr0]], %[s_8] \n" // next r0 - // v8: 0 2 - // 4 6, v9: - // 1 3 5 7 - - "fmla v26.4s, v10.4s, v1.s[0] \n" // out r0: - // w4 - "fmla v28.4s, v13.4s, v1.s[1] \n" // out r0: - // w5 - "fmla v26.4s, v14.4s, v1.s[2] \n" // out r0: - // w6 - "fmla v28.4s, v11.4s, v1.s[3] \n" // out r0: - // w7 - - "ld2 {v6.4s, v7.4s}, [%[din_ptr0]], %[s_8] \n" // next r0 - // v6: 2 4 - // 6 8, v7: - // 3 5 7 9 - - "fmla v26.4s, v12.4s, v2.s[0] \n" // out r0: - // w8 - "fmla v28.4s, v15.4s, v2.s[1] \n" // out r0: - // w9 - "fmla v26.4s, v18.4s, v2.s[2] \n" // out r0: - // w10 - "fmla v28.4s, v19.4s, v2.s[3] \n" // out r0: - // w11 - - "ld2 {v10.4s, v11.4s}, [%[din_ptr0]], %[s_16] \n" // next r0 - // v10: 4 6 - // 8 10, - // v11: - // trash - // register - - "fmla v26.4s, v16.4s, v3.s[0] \n" // out r0: - // w12 - "fmla v28.4s, v17.4s, v3.s[1] \n" // out r0: - // w13 - "fmla v26.4s, v20.4s, v3.s[2] \n" // out r0: - // w14 - "fmla v28.4s, v23.4s, v3.s[3] \n" // out r0: - // w15 - "prfm pldl1keep, [%[din_ptr0]] \n" - - "ld2 {v11.4s, v12.4s}, [%[din_ptr4]], #32 \n" // r4 v11: - // 0 2 4 6, - // v12: 1 3 - // 5 7 - - "fmla v26.4s, v24.4s, v4.s[0] \n" // out r0: - // w16 - "fmla v28.4s, v21.4s, v4.s[1] \n" // out r0: - // w17 - - "ext v13.16b, v31.16b, v11.16b, #12 \n" // r4 v13: - // x 0 2 4 - "ext v14.16b, v31.16b, v12.16b, #12 \n" // r4 v14: - // x 1 3 5 - "ext v15.16b, v11.16b, v31.16b, #4 \n" - - "fmla v26.4s, v22.4s, v4.s[2] \n" // out r0: - // w18 - "fmla v28.4s, v25.4s, v4.s[3] \n" // out r0: - // w19 - - "ld1 {v15.s}[3], [%[din_ptr4]] \n" // r4 v15: - // 2 4 6 - - "fmla v27.4s, v18.4s, v0.s[0] \n" // out r1: - // w0 - "fmla v29.4s, v19.4s, v0.s[1] \n" // out r1: - // w1 - - "sub %[din_ptr4], %[din_ptr4], #8 \n" - - "fmla v27.4s, v16.4s, v0.s[2] \n" // out r1: - // w2 - "fmla v29.4s, v17.4s, v0.s[3] \n" // out r1: - // w3 - "fmla v27.4s, v20.4s, v1.s[0] \n" // out r1: - // w4 - "fmla v29.4s, v23.4s, v1.s[1] \n" // out r1: - // w5 - - "ld2 {v16.4s, v17.4s}, [%[din_ptr5]], #32 \n" // r5 v16: - // 0 2 4 6, - // v17: 1 3 - // 5 7 - - "fmla v27.4s, v24.4s, v1.s[2] \n" // out r1: - // w6 - "fmla v29.4s, v21.4s, v1.s[3] \n" // out r1: - // w7 - - "ext v18.16b, v31.16b, v16.16b, #12 \n" // r5 v18: - // x 0 2 4 - "ext v19.16b, v31.16b, v17.16b, #12 \n" // r5 v19: - // x 1 3 5 - "ext v20.16b, v16.16b, v31.16b, #4 \n" - - "fmla v27.4s, v22.4s, v2.s[0] \n" // out r1: - // w8 - "fmla v29.4s, v25.4s, v2.s[1] \n" // out r1: - // w9 - - "ld1 {v20.s}[3], [%[din_ptr5]] \n" // r5 v20: - // 2 4 6 - "ld2 {v21.4s, v22.4s}, [%[din_ptr6]], #32 \n" // r6 v21: - // 0 2 4 6, - // v22: 1 3 - // 5 7 - - "ext v23.16b, v31.16b, v21.16b, #12 \n" // r6 v23: - // x 0 2 4 - "ext v24.16b, v31.16b, v22.16b, #12 \n" // r6 v24: - // x 1 3 5 - "ext v25.16b, v21.16b, v31.16b, #4 \n" - "sub %[din_ptr5], %[din_ptr5], #8 \n" - - "fmla v26.4s, v11.4s, v5.s[2] \n" // out r0: - // w22 - "fmla v28.4s, v12.4s, v5.s[3] \n" // out r0: - // w23 - - "ld1 {v25.s}[3], [%[din_ptr6]] \n" // r6 v25: - // 2 4 6 - - "fmla v26.4s, v13.4s, v5.s[0] \n" // out r0: - // w20 - "fmla v28.4s, v14.4s, v5.s[1] \n" // out r0: - // w21 - - "sub %[din_ptr6], %[din_ptr6], #8 \n" - - "fmla v26.4s, v15.4s, v30.s[0] \n" // out r0: - // w24 - "fmla v27.4s, v13.4s, v2.s[2] \n" // out r1: - // w10 - - "fadd v26.4s, v26.4s, v28.4s \n" - "fmla v29.4s, v14.4s, v2.s[3] \n" // out r1: - // w11 - - "ld2 {v13.4s, v14.4s}, [%[din_ptr1]], %[s_8] \n" // next r1 - // v13: 0 2 - // 4 6, - // v14: 1 3 - // 5 7 - "fmla v27.4s, v11.4s, v3.s[0] \n" // out r1: - // w12 - "fmla v29.4s, v12.4s, v3.s[1] \n" // out r1: - // w13 - - "st1 {v26.4s}, [%[dout_ptr0]], %[s_16] \n" // store - // output - // r0 - "ld2 {v11.4s, v12.4s}, [%[din_ptr1]], %[s_8] \n" // next r1 - // v11: 2 4 - // 6 8, - // v12: 3 5 - // 7 9 - - "fmla v27.4s, v15.4s, v3.s[2] \n" // out r1: - // w14 - "fmla v29.4s, v16.4s, v4.s[1] \n" // out r1: - // w17 - "fmla v27.4s, v18.4s, v3.s[3] \n" // out r1: - // w15 - "fmla v29.4s, v19.4s, v4.s[0] \n" // out r1: - // w16 - - "ld2 {v15.4s, v16.4s}, [%[din_ptr1]], %[s_16] \n" // next r1 - // v15: 4 6 - // 8 10, - // v16: - // trash - // register - - "fmla v27.4s, v17.4s, v4.s[2] \n" // out r1: - // w18 - "fmla v29.4s, v20.4s, v4.s[3] \n" // out r1: - // w19 - - "ld2 {v18.4s, v19.4s}, [%[din_ptr2]], %[s_8] \n" // next r2 - // v18: 0 2 - // 4 6, - // v19: 1 3 - // 5 7 - "ld2 {v16.4s, v17.4s}, [%[din_ptr2]], %[s_8] \n" // next r2 - // v16: 2 4 - // 6 8, - // v11: 3 5 - // 7 9 - - "fmla v27.4s, v23.4s, v5.s[0] \n" // out r1: - // w20 - "fmla v29.4s, v21.4s, v5.s[2] \n" // out r1: - // w22 - "fmla v27.4s, v24.4s, v5.s[1] \n" // out r1: - // w21 - "fmla v29.4s, v22.4s, v5.s[3] \n" // out r1: - // w23 - - "ld2 {v20.4s, v21.4s}, [%[din_ptr2]], %[s_16] \n" // next r2 - // v20: 4 6 - // 8 10, - // v21: - // trash - // register - "ld2 {v23.4s, v24.4s}, [%[din_ptr3]], %[s_8] \n" // next r3 - // v23: 0 2 - // 4 6, - // v24: 1 3 - // 5 7 - - "fmla v27.4s, v25.4s, v30.s[0] \n" // out r1: - // w24 - - "ld2 {v21.4s, v22.4s}, [%[din_ptr3]], %[s_8] \n" // next r3 - // v21: 2 4 - // 6 8, - // v22: 3 5 - // 7 9 - "ld2 {v25.4s, v26.4s}, [%[din_ptr3]], %[s_16] \n" // next r3 - // v25: 4 6 - // 8 10, - // v26: - // trash - // register - - "fadd v27.4s, v27.4s, v29.4s \n" - "cmp %w[mid_cnt], #1 \n" - - "prfm pldl1keep, [%[din_ptr1]] \n" - "prfm pldl1keep, [%[din_ptr2]] \n" - "prfm pldl1keep, [%[din_ptr3]] \n" - - "st1 {v27.4s}, [%[dout_ptr1]], #16 \n" - "blt 2f \n" - - // mid loop - "1: \n" - "ld1 {v26.4s}, [%[vbias]] \n" - "mov v27.16b, v26.16b \n" - "mov v28.16b, v31.16b \n" - "mov v29.16b, v31.16b \n" - - // out_r0 r0-r3 - "fmla v26.4s, v8.4s, v0.s[0] \n" - "fmla v28.4s, v9.4s, v0.s[1] \n" - "fmla v26.4s, v6.4s, v0.s[2] \n" - "fmla v28.4s, v7.4s, v0.s[3] \n" - - "ld2 {v8.4s, v9.4s}, [%[din_ptr0]], %[s_8] \n" - - "fmla v26.4s, v10.4s, v1.s[0] \n" - "fmla v28.4s, v11.4s, v1.s[3] \n" - - "ld2 {v6.4s, v7.4s}, [%[din_ptr0]], %[s_8] \n" - - "fmla v26.4s, v14.4s, v1.s[2] \n" - "fmla v28.4s, v13.4s, v1.s[1] \n" - - "ld2 {v10.4s, v11.4s}, [%[din_ptr0]], %[s_16] \n" - "prfm pldl1keep, [%[din_ptr0]] \n" - - "fmla v26.4s, v12.4s, v2.s[0] \n" - "fmla v28.4s, v15.4s, v2.s[1] \n" - - "ld2 {v13.4s, v14.4s}, [%[din_ptr4]], %[s_8] \n" - - "fmla v26.4s, v16.4s, v3.s[0] \n" - "fmla v27.4s, v16.4s, v0.s[2] \n" - - "ld2 {v11.4s, v12.4s}, [%[din_ptr4]], %[s_8] \n" - - "fmla v28.4s, v19.4s, v2.s[3] \n" - "fmla v29.4s, v19.4s, v0.s[1] \n" - - "ld2 {v15.4s, v16.4s}, [%[din_ptr4]], %[s_16] \n" - "prfm pldl1keep, [%[din_ptr4]] \n" - - "fmla v26.4s, v18.4s, v2.s[2] \n" - "fmla v27.4s, v18.4s, v0.s[0] \n" - - "fmla v28.4s, v17.4s, v3.s[1] \n" - "fmla v29.4s, v17.4s, v0.s[3] \n" - - "ld2 {v18.4s, v19.4s}, [%[din_ptr5]], %[s_8] \n" - - "fmla v26.4s, v20.4s, v3.s[2] \n" - "fmla v27.4s, v20.4s, v1.s[0] \n" - - "ld2 {v16.4s, v17.4s}, [%[din_ptr5]], %[s_8] \n" - - "fmla v29.4s, v21.4s, v1.s[3] \n" - "fmla v28.4s, v21.4s, v4.s[1] \n" - "fmla v28.4s, v23.4s, v3.s[3] \n" - "fmla v29.4s, v23.4s, v1.s[1] \n" - - "ld2 {v20.4s, v21.4s}, [%[din_ptr5]], %[s_16] \n" - "prfm pldl1keep, [%[din_ptr5]] \n" - - "fmla v26.4s, v24.4s, v4.s[0] \n" - "fmla v27.4s, v24.4s, v1.s[2] \n" - - "ld2 {v23.4s, v24.4s}, [%[din_ptr6]], %[s_8] \n" - - "fmla v27.4s, v22.4s, v2.s[0] \n" - "fmla v26.4s, v22.4s, v4.s[2] \n" - - "fmla v28.4s, v25.4s, v4.s[3] \n" - "fmla v29.4s, v25.4s, v2.s[1] \n" - - "ld2 {v21.4s, v22.4s}, [%[din_ptr6]], %[s_8] \n" - "fadd v28.4s, v26.4s, v28.4s \n" - - "ld2 {v25.4s, v26.4s}, [%[din_ptr6]], %[s_16] \n" - "mov v26.16b, v31.16b \n" - "prfm pldl1keep, [%[din_ptr6]] \n" - - "fmla v26.4s, v13.4s, v5.s[0] \n" - "fmla v28.4s, v14.4s, v5.s[1] \n" - "fmla v27.4s, v13.4s, v2.s[2] \n" - "fmla v29.4s, v14.4s, v2.s[3] \n" - - "ld2 {v13.4s, v14.4s}, [%[din_ptr1]], %[s_8] \n" - - "fmla v26.4s, v11.4s, v5.s[2] \n" - "fmla v28.4s, v12.4s, v5.s[3] \n" - "fmla v27.4s, v11.4s, v3.s[0] \n" - "fmla v29.4s, v12.4s, v3.s[1] \n" - - "ld2 {v11.4s, v12.4s}, [%[din_ptr1]], %[s_8] \n" - - "fmla v26.4s, v15.4s, v30.s[0] \n" - "fmla v27.4s, v15.4s, v3.s[2] \n" - "fmla v29.4s, v16.4s, v4.s[1] \n" - "fmla v27.4s, v17.4s, v4.s[2] \n" - - "ld2 {v15.4s, v16.4s}, [%[din_ptr1]], %[s_16] \n" - "prfm pldl1keep, [%[din_ptr1]] \n" - - "fmla v29.4s, v18.4s, v3.s[3] \n" - "fmla v27.4s, v19.4s, v4.s[0] \n" - - "ld2 {v18.4s, v19.4s}, [%[din_ptr2]], %[s_8] \n" - - "fmla v29.4s, v20.4s, v4.s[3] \n" - - "ld2 {v16.4s, v17.4s}, [%[din_ptr2]], %[s_8] \n" - - "fmla v27.4s, v23.4s, v5.s[0] \n" - "fmla v27.4s, v21.4s, v5.s[2] \n" - - "ld2 {v20.4s, v21.4s}, [%[din_ptr2]], %[s_16] \n" - - "fmla v29.4s, v24.4s, v5.s[1] \n" - - "ld2 {v23.4s, v24.4s}, [%[din_ptr3]], %[s_8] \n" - "prfm pldl1keep, [%[din_ptr2]] \n" - - "fmla v29.4s, v22.4s, v5.s[3] \n" - - "ld2 {v21.4s, v22.4s}, [%[din_ptr3]], %[s_8] \n" - - "fmla v27.4s, v25.4s, v30.s[0] \n" - - "fadd v26.4s, v26.4s, v28.4s \n" - - "prfm pldl1keep, [%[din_ptr3]] \n" - - "fadd v27.4s, v27.4s, v29.4s \n" - - "st1 {v26.4s}, [%[dout_ptr0]], #16 \n" - "st1 {v27.4s}, [%[dout_ptr1]], #16 \n" - - "ld2 {v25.4s, v26.4s}, [%[din_ptr3]], %[s_16] \n" - "subs %w[mid_cnt], %w[mid_cnt], #1 \n" - "bne 1b \n" - - "2: \n" - "ld2 {v26.4s, v27.4s}, [%[mask]], %[s_8] \n" - "ld2 {v28.4s, v29.4s}, [%[mask]], %[s_8] \n" - "bif v8.16b, v31.16b, v26.16b \n" - "bif v9.16b, v31.16b, v27.16b \n" - "bif v6.16b, v31.16b, v28.16b \n" - "bif v7.16b, v31.16b, v29.16b \n" - - "bif v13.16b, v31.16b, v26.16b \n" - "bif v14.16b, v31.16b, v27.16b \n" - "bif v11.16b, v31.16b, v28.16b \n" - "bif v12.16b, v31.16b, v29.16b \n" - - "bif v18.16b, v31.16b, v26.16b \n" - "bif v19.16b, v31.16b, v27.16b \n" - "bif v16.16b, v31.16b, v28.16b \n" - "bif v17.16b, v31.16b, v29.16b \n" - - "bif v23.16b, v31.16b, v26.16b \n" - "bif v24.16b, v31.16b, v27.16b \n" - "bif v21.16b, v31.16b, v28.16b \n" - "bif v22.16b, v31.16b, v29.16b \n" - - "ld2 {v28.4s, v29.4s}, [%[mask]] \n" - "ld1 {v26.4s}, [%[vbias]] \n" - "mov v29.16b, v31.16b \n" - - "bif v10.16b, v31.16b, v28.16b \n" - "bif v15.16b, v31.16b, v28.16b \n" - - "mov v27.16b, v26.16b \n" - - "bif v20.16b, v31.16b, v28.16b \n" - "bif v25.16b, v31.16b, v28.16b \n" - "mov v28.16b, v31.16b \n" - - "fmla v26.4s, v8.4s, v0.s[0] \n" - "fmla v28.4s, v9.4s, v0.s[1] \n" - "fmla v26.4s, v6.4s, v0.s[2] \n" - "fmla v28.4s, v7.4s, v0.s[3] \n" - - "fmla v26.4s, v10.4s, v1.s[0] \n" - "fmla v28.4s, v13.4s, v1.s[1] \n" - "fmla v26.4s, v14.4s, v1.s[2] \n" - "fmla v28.4s, v11.4s, v1.s[3] \n" - - "sub %[mask], %[mask], #16 \n" - "ld2 {v6.4s, v7.4s}, [%[mask]], %[s_8] \n" - "ld2 {v8.4s, v9.4s}, [%[mask]], %[s_8] \n" - "ld2 {v10.4s, v11.4s}, [%[mask]] \n" - - "fmla v26.4s, v12.4s, v2.s[0] \n" - "fmla v28.4s, v15.4s, v2.s[1] \n" - - "ld2 {v13.4s, v14.4s}, [%[din_ptr4]], %[s_8] \n" - - "fmla v26.4s, v16.4s, v3.s[0] \n" - "fmla v28.4s, v17.4s, v3.s[1] \n" - - "ld2 {v11.4s, v12.4s}, [%[din_ptr4]], %[s_8] \n" - - "fmla v27.4s, v16.4s, v0.s[2] \n" - "fmla v29.4s, v17.4s, v0.s[3] \n" - - "ld2 {v15.4s, v16.4s}, [%[din_ptr4]] \n" - - "fmla v26.4s, v18.4s, v2.s[2] \n" - "fmla v28.4s, v19.4s, v2.s[3] \n" - "fmla v27.4s, v18.4s, v0.s[0] \n" - "fmla v29.4s, v19.4s, v0.s[1] \n" - - "bif v13.16b, v31.16b, v6.16b \n" - "bif v14.16b, v31.16b, v7.16b \n" - "bif v11.16b, v31.16b, v8.16b \n" - "bif v12.16b, v31.16b, v9.16b \n" - "bif v15.16b, v31.16b, v10.16b \n" - - "ld2 {v18.4s, v19.4s}, [%[din_ptr5]], %[s_8] \n" - - "fmla v26.4s, v20.4s, v3.s[2] \n" - "fmla v27.4s, v20.4s, v1.s[0] \n" - - "ld2 {v16.4s, v17.4s}, [%[din_ptr5]], %[s_8] \n" - - "fmla v29.4s, v21.4s, v1.s[3] \n" - "fmla v28.4s, v21.4s, v4.s[1] \n" - - "ld2 {v20.4s, v21.4s}, [%[din_ptr5]] \n" - - "fmla v28.4s, v23.4s, v3.s[3] \n" - "fmla v29.4s, v23.4s, v1.s[1] \n" - "fmla v27.4s, v24.4s, v1.s[2] \n" - "fmla v26.4s, v24.4s, v4.s[0] \n" - - "bif v18.16b, v31.16b, v6.16b \n" - "bif v19.16b, v31.16b, v7.16b \n" - "bif v16.16b, v31.16b, v8.16b \n" - "bif v17.16b, v31.16b, v9.16b \n" - "bif v20.16b, v31.16b, v10.16b \n" - - "ld2 {v23.4s, v24.4s}, [%[din_ptr6]], %[s_8] \n" - - "fmla v27.4s, v22.4s, v2.s[0] \n" - "fmla v26.4s, v22.4s, v4.s[2] \n" - - "ld2 {v21.4s, v22.4s}, [%[din_ptr6]], %[s_8] \n" - - "fmla v28.4s, v25.4s, v4.s[3] \n" - "fmla v29.4s, v25.4s, v2.s[1] \n" - "fadd v28.4s, v28.4s, v26.4s \n" - - "ld2 {v25.4s, v26.4s}, [%[din_ptr6]] \n" - "mov v26.16b, v31.16b \n" - - "bif v23.16b, v31.16b, v6.16b \n" - "bif v24.16b, v31.16b, v7.16b \n" - "bif v21.16b, v31.16b, v8.16b \n" - "bif v22.16b, v31.16b, v9.16b \n" - "bif v25.16b, v31.16b, v10.16b \n" - - "fmla v26.4s, v13.4s, v5.s[0] \n" - "fmla v28.4s, v14.4s, v5.s[1] \n" - "fmla v26.4s, v11.4s, v5.s[2] \n" - "fmla v28.4s, v12.4s, v5.s[3] \n" - "fmla v26.4s, v15.4s, v30.s[0] \n" - - "fmla v27.4s, v13.4s, v2.s[2] \n" - "fmla v29.4s, v14.4s, v2.s[3] \n" - "fmla v27.4s, v11.4s, v3.s[0] \n" - "fmla v29.4s, v12.4s, v3.s[1] \n" - - "fadd v26.4s, v26.4s, v28.4s \n" - "fmla v27.4s, v15.4s, v3.s[2] \n" - "fmla v29.4s, v18.4s, v3.s[3] \n" - "fmla v27.4s, v19.4s, v4.s[0] \n" - "fmla v29.4s, v16.4s, v4.s[1] \n" - - "st1 {v26.4s}, [%[out_buf0]] \n" - "fmla v27.4s, v17.4s, v4.s[2] \n" - "fmla v29.4s, v20.4s, v4.s[3] \n" - "fmla v27.4s, v23.4s, v5.s[0] \n" - "fmla v29.4s, v24.4s, v5.s[1] \n" - - "fmla v27.4s, v21.4s, v5.s[2] \n" - "fmla v29.4s, v22.4s, v5.s[3] \n" - "fmla v27.4s, v25.4s, v30.s[0] \n" - "fadd v27.4s, v27.4s, v29.4s \n" - - "st1 {v27.4s}, [%[out_buf1]] \n" - - : [dout_ptr0] "+r"(dout_ptr0), - [dout_ptr1] "+r"(dout_ptr1), - [mid_cnt] "+r"(loop), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [din_ptr4] "+r"(din_ptr4), - [din_ptr5] "+r"(din_ptr5), - [din_ptr6] "+r"(din_ptr6), - [mask] "+r"(mask_ptr), - [weights] "+r"(weights_ptr) - : [vbias] "r"(vbias), - [out_buf0] "r"(out_buf0), - [out_buf1] "r"(out_buf1), - [s_8] "r"(s_8), - [s_16] "r"(s_16) - : "memory", - "cc", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25", - "v26", - "v27", - "v28", - "v29", - "v30", - "v31"); - - int remain_cnt = w_out - (mid_cnt + 1) * 4; - for (int i = 0; i < remain_cnt; ++i) { - dout_ptr0[i] = out_buf0[i]; - dout_ptr1[i] = out_buf1[i]; - } - din0 = din4; - din1 = din5; - din2 = din6; - din3 = din6 + w_in; - din4 = din3 + w_in; - din5 = din4 + w_in; - din6 = din5 + w_in; - dout0 = dout1 + w_out; - dout1 = dout0 + w_out; + float32x4_t w0 = vld1q_f32(weight_c); // w0, v23 + float32x4_t w1 = vld1q_f32(weight_c + 4); // w1, v24 + float32x4_t w2 = vld1q_f32(weight_c + 8); // w2, v25 + float32x4_t w3 = vld1q_f32(weight_c + 12); // w3, v26 + float32x4_t w4 = vld1q_f32(weight_c + 16); // w4, v27 + float32x4_t vbias = vdupq_n_f32(0.f); + if (flag_bias) { + vbias = vld1q_f32(&bias[c]); // v28 } - } - } -} - -//! larger depthwise, win >= 9; -void conv_depthwise_5x5s2p2_relu(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - bool flag_bias, - bool flag_relu, - ARMContext* ctx) { - CHECK_GE(w_in, 9) << "only support win >= 9"; - int w_out_round = (w_out + 3) / 4 * 4; - int cnt = (w_out_round - 4) / 4; - int mid_cnt = cnt - 1; - int right_start = cnt * 2 * 4 - 2; - int mask_cnt = 12 - (w_in - right_start); - int mask[12]; - memset(mask, 0xff, 12 * sizeof(int)); - for (int i = 0; i < mask_cnt; ++i) { - mask[11 - i] = 0; - } - float* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - float* write_ptr = zero_ptr + w_in; - int in_spatial_size = w_in * h_in; - int out_spatial_size = w_out * h_out; - int weights_saptial_size = 25; - - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * in_spatial_size * ch_in; - float* dout_batch = dout + n * out_spatial_size * ch_out; - -#pragma omp parallel for - for (int c = 0; c < ch_in; ++c) { - const float* din_ch = din_batch + c * in_spatial_size; - float* dout_ch = dout_batch + c * out_spatial_size; - const float* din0 = zero_ptr; - const float* din1 = zero_ptr; - const float* din2 = din_ch; - const float* din3 = din2 + w_in; - const float* din4 = din3 + w_in; - const float* din5 = din4 + w_in; - const float* din6 = din5 + w_in; - - float out_buf0[4]; - float out_buf1[4]; - float* dout0 = dout_ch; - float* dout1 = dout0 + w_out; - - const float* weights_c = weights + c * weights_saptial_size; - for (int h = 0; h < h_out; h += 2) { - //! (h * 2 - 2) + 6 > h_in - 1 - if (h * 2 + 5 > h_in) { - switch (h * 2 + 5 - h_in) { - case 6: - din1 = zero_ptr; - case 5: - din2 = zero_ptr; - case 4: - din3 = zero_ptr; + weight_c += 20; +#endif + for (int h = 0; h < oh; h += out_h_kernel) { + float* outc0 = dout_c00 + h * ow; + float* outc1 = outc0 + size_out_channel; + float* outc2 = outc1 + size_out_channel; + float* outc3 = outc2 + size_out_channel; + const float* inr0 = pre_din + h * 2 * row_len; + const float* inr1 = inr0 + row_len; + const float* inr2 = inr1 + row_len; + const float* inr3 = inr2 + row_len; + const float* inr4 = inr3 + row_len; + + if (c + out_c_block > oc) { + switch (c + out_c_block - oc) { case 3: - din4 = zero_ptr; + outc1 = ptr_write; case 2: - din5 = zero_ptr; + outc2 = ptr_write; case 1: - din6 = zero_ptr; + outc3 = ptr_write; default: break; } } - if (h + 2 > h_out) { - switch (h + 2 - h_out) { - case 1: - dout1 = write_ptr; - default: - break; + auto c0 = outc0; + auto c1 = outc1; + auto c2 = outc2; + auto c3 = outc3; + float pre_out[16]; + for (int w = 0; w < w_loop; ++w) { + bool flag_mask = (w == w_loop - 1) && flag_remain; + if (flag_mask) { + c0 = outc0; + c1 = outc1; + c2 = outc2; + c3 = outc3; + outc0 = pre_out; + outc1 = pre_out + 4; + outc2 = pre_out + 8; + outc3 = pre_out + 12; } - } - const float* din_ptr0 = din0; - const float* din_ptr1 = din1; - const float* din_ptr2 = din2; - const float* din_ptr3 = din3; - const float* din_ptr4 = din4; - const float* din_ptr5 = din5; - const float* din_ptr6 = din6; - - const float* weights_ptr = weights_c; - float* dout_ptr0 = dout0; - float* dout_ptr1 = dout1; - - float bias_c = 0.f; - if (flag_bias) { - bias_c = bias[c]; - } - float vbias[4] = {bias_c, bias_c, bias_c, bias_c}; - int* mask_ptr = mask; - int loop = mid_cnt; - const int s_8 = 8; - const int s_16 = 16; - - //! in r0, r1/r4, r2/r5, r3/r6: x 0 2 4 -- v8 v13 v18 v23 - //! in r0, r1/r4, r2/r5, r3/r6: x 1 3 5 -- v9 v14 v19 v24 - //! in r0, r1/r4, r2/r5, r3/r6: 0 2 4 6 -- v6 v11 v16 v21 - //! in r0, r1/r4, r2/r5, r3/r6: 1 3 5 7 -- v7 v12 v17 v22 - //! in r0, r1/r4, r2/r5, r3/r6: 2 4 6 8 -- v10 v15 v20 v25 - //! out r0, r1 -- v26, v27 - asm volatile( - "movi v31.4s, #0x0\n" - "prfm pldl1keep, [%[din_ptr0]] \n" - "prfm pldl1keep, [%[din_ptr1]] \n" - "prfm pldl1keep, [%[din_ptr2]] \n" - "prfm pldl1keep, [%[din_ptr3]] \n" - "prfm pldl1keep, [%[din_ptr4]] \n" - "prfm pldl1keep, [%[din_ptr5]] \n" - "prfm pldl1keep, [%[din_ptr6]] \n" - "prfm pldl1keep, [%[weights]] \n" - "prfm pldl1keep, [%[mask]] \n" - // left - "ld2 {v6.4s, v7.4s}, [%[din_ptr0]], #32 \n" // r0 v6: 0 - // 2 4 6, - // v7: 1 3 - // 5 7 - "ext v8.16b, v31.16b, v6.16b, #12 \n" // r0 v8: x - // 0 2 4 - "ld2 {v11.4s, v12.4s}, [%[din_ptr1]], #32 \n" // r1 v11: - // 0 2 4 6, - // v12: 1 3 - // 5 7 - "ext v9.16b, v31.16b, v7.16b, #12 \n" // r0 v9: x - // 1 3 5 - "ld1 {v0.4s, v1.4s}, [%[weights]], #32 \n" // load - // weights - // 0-7 - "ext v10.16b, v6.16b, v31.16b, #4 \n" - "ld1 {v10.s}[3], [%[din_ptr0]] \n" // r0 v10: - // 2 4 6 8 - "sub %[din_ptr0], %[din_ptr0], #8 \n" - "ext v13.16b, v31.16b, v11.16b, #12 \n" // r1 v13: - // x 0 2 4 - "ld2 {v16.4s, v17.4s}, [%[din_ptr2]], #32 \n" // r2 v16: - // 0 2 4 6, - // v17: 1 3 - // 5 7 - "ext v14.16b, v31.16b, v12.16b, #12 \n" // r1 v14: - // x 1 3 5 - "ld1 {v2.4s, v3.4s}, [%[weights]], #32 \n" // load - // weights - // 8-15 - "ext v15.16b, v11.16b, v31.16b, #4 \n" - "ld1 {v15.s}[3], [%[din_ptr1]] \n" // r1 v15: - // 2 4 6 - "sub %[din_ptr1], %[din_ptr1], #8 \n" - "ext v18.16b, v31.16b, v16.16b, #12 \n" // r2 v18: - // x 0 2 4 - "ld1 {v4.4s, v5.4s}, [%[weights]], #32 \n" // load - // weights - // 16-23 - "ext v19.16b, v31.16b, v17.16b, #12 \n" // r2 v19: - // x 1 3 5 - "ld2 {v21.4s, v22.4s}, [%[din_ptr3]], #32 \n" // r3 v21: - // 0 2 4 6, - // v22: 1 3 - // 5 7 - "ext v20.16b, v16.16b, v31.16b, #4 \n" - "ld1 {v20.s}[3], [%[din_ptr2]] \n" // r2 v20: - // 2 4 6 8 - "sub %[din_ptr2], %[din_ptr2], #8 \n" - "ext v23.16b, v31.16b, v21.16b, #12 \n" // r3 v23: - // x 0 2 4 - "ld1 {v30.4s}, [%[weights]] \n" // load - // weights - // 24 - "ext v24.16b, v31.16b, v22.16b, #12 \n" // r3 v24: - // x 1 3 5 - "ld1 {v26.4s}, [%[vbias]] \n" // load - // bias to - // out_r0 - "ext v25.16b, v21.16b, v31.16b, #4 \n" - "ld1 {v25.s}[3], [%[din_ptr3]] \n" // r2 v25: - // 2 4 6 8 - "sub %[din_ptr3], %[din_ptr3], #8 \n" - "mov v27.16b, v26.16b \n" // load - // bias to - // out_r1 - "mov v28.16b, v31.16b \n" // load - // zero to - // out_r0 - "mov v29.16b, v31.16b \n" // load - // zero to - // out_r1 - - "fmla v26.4s, v8.4s, v0.s[0] \n" // out r0: - // w0 - "fmla v28.4s, v9.4s, v0.s[1] \n" // out r0: - // w1 - "fmla v26.4s, v6.4s, v0.s[2] \n" // out r0: - // w2 - "fmla v28.4s, v7.4s, v0.s[3] \n" // out r0: - // w3 - - "ld2 {v8.4s, v9.4s}, [%[din_ptr0]], %[s_8] \n" // next r0 - // v8: 0 2 - // 4 6, v9: - // 1 3 5 7 - - "fmla v26.4s, v10.4s, v1.s[0] \n" // out r0: - // w4 - "fmla v28.4s, v13.4s, v1.s[1] \n" // out r0: - // w5 - "fmla v26.4s, v14.4s, v1.s[2] \n" // out r0: - // w6 - "fmla v28.4s, v11.4s, v1.s[3] \n" // out r0: - // w7 - - "ld2 {v6.4s, v7.4s}, [%[din_ptr0]], %[s_8] \n" // next r0 - // v6: 2 4 - // 6 8, v7: - // 3 5 7 9 - - "fmla v26.4s, v12.4s, v2.s[0] \n" // out r0: - // w8 - "fmla v28.4s, v15.4s, v2.s[1] \n" // out r0: - // w9 - "fmla v26.4s, v18.4s, v2.s[2] \n" // out r0: - // w10 - "fmla v28.4s, v19.4s, v2.s[3] \n" // out r0: - // w11 - - "ld2 {v10.4s, v11.4s}, [%[din_ptr0]], %[s_16] \n" // next r0 - // v10: 4 6 - // 8 10, - // v11: - // trash - // register - - "fmla v26.4s, v16.4s, v3.s[0] \n" // out r0: - // w12 - "fmla v28.4s, v17.4s, v3.s[1] \n" // out r0: - // w13 - "fmla v26.4s, v20.4s, v3.s[2] \n" // out r0: - // w14 - "fmla v28.4s, v23.4s, v3.s[3] \n" // out r0: - // w15 - "prfm pldl1keep, [%[din_ptr0]] \n" - - "ld2 {v11.4s, v12.4s}, [%[din_ptr4]], #32 \n" // r4 v11: - // 0 2 4 6, - // v12: 1 3 - // 5 7 - - "fmla v26.4s, v24.4s, v4.s[0] \n" // out r0: - // w16 - "fmla v28.4s, v21.4s, v4.s[1] \n" // out r0: - // w17 - - "ext v13.16b, v31.16b, v11.16b, #12 \n" // r4 v13: - // x 0 2 4 - "ext v14.16b, v31.16b, v12.16b, #12 \n" // r4 v14: - // x 1 3 5 - "ext v15.16b, v11.16b, v31.16b, #4 \n" - - "fmla v26.4s, v22.4s, v4.s[2] \n" // out r0: - // w18 - "fmla v28.4s, v25.4s, v4.s[3] \n" // out r0: - // w19 - - "ld1 {v15.s}[3], [%[din_ptr4]] \n" // r4 v15: - // 2 4 6 - - "fmla v27.4s, v18.4s, v0.s[0] \n" // out r1: - // w0 - "fmla v29.4s, v19.4s, v0.s[1] \n" // out r1: - // w1 - - "sub %[din_ptr4], %[din_ptr4], #8 \n" - - "fmla v27.4s, v16.4s, v0.s[2] \n" // out r1: - // w2 - "fmla v29.4s, v17.4s, v0.s[3] \n" // out r1: - // w3 - "fmla v27.4s, v20.4s, v1.s[0] \n" // out r1: - // w4 - "fmla v29.4s, v23.4s, v1.s[1] \n" // out r1: - // w5 - - "ld2 {v16.4s, v17.4s}, [%[din_ptr5]], #32 \n" // r5 v16: - // 0 2 4 6, - // v17: 1 3 - // 5 7 - - "fmla v27.4s, v24.4s, v1.s[2] \n" // out r1: - // w6 - "fmla v29.4s, v21.4s, v1.s[3] \n" // out r1: - // w7 - - "ext v18.16b, v31.16b, v16.16b, #12 \n" // r5 v18: - // x 0 2 4 - "ext v19.16b, v31.16b, v17.16b, #12 \n" // r5 v19: - // x 1 3 5 - "ext v20.16b, v16.16b, v31.16b, #4 \n" - - "fmla v27.4s, v22.4s, v2.s[0] \n" // out r1: - // w8 - "fmla v29.4s, v25.4s, v2.s[1] \n" // out r1: - // w9 - - "ld1 {v20.s}[3], [%[din_ptr5]] \n" // r5 v20: - // 2 4 6 - "ld2 {v21.4s, v22.4s}, [%[din_ptr6]], #32 \n" // r6 v21: - // 0 2 4 6, - // v22: 1 3 - // 5 7 - - "ext v23.16b, v31.16b, v21.16b, #12 \n" // r6 v23: - // x 0 2 4 - "ext v24.16b, v31.16b, v22.16b, #12 \n" // r6 v24: - // x 1 3 5 - "ext v25.16b, v21.16b, v31.16b, #4 \n" - "sub %[din_ptr5], %[din_ptr5], #8 \n" - - "fmla v26.4s, v11.4s, v5.s[2] \n" // out r0: - // w22 - "fmla v28.4s, v12.4s, v5.s[3] \n" // out r0: - // w23 - - "ld1 {v25.s}[3], [%[din_ptr6]] \n" // r6 v25: - // 2 4 6 - - "fmla v26.4s, v13.4s, v5.s[0] \n" // out r0: - // w20 - "fmla v28.4s, v14.4s, v5.s[1] \n" // out r0: - // w21 - - "sub %[din_ptr6], %[din_ptr6], #8 \n" - - "fmla v26.4s, v15.4s, v30.s[0] \n" // out r0: - // w24 - "fmla v27.4s, v13.4s, v2.s[2] \n" // out r1: - // w10 - - "fadd v26.4s, v26.4s, v28.4s \n" - "fmla v29.4s, v14.4s, v2.s[3] \n" // out r1: - // w11 - "fmax v26.4s, v26.4s, v31.4s \n" - - "ld2 {v13.4s, v14.4s}, [%[din_ptr1]], %[s_8] \n" // next r1 - // v13: 0 2 - // 4 6, - // v14: 1 3 - // 5 7 - "fmla v27.4s, v11.4s, v3.s[0] \n" // out r1: - // w12 - "fmla v29.4s, v12.4s, v3.s[1] \n" // out r1: - // w13 - - "st1 {v26.4s}, [%[dout_ptr0]], %[s_16] \n" // store - // output - // r0 - "ld2 {v11.4s, v12.4s}, [%[din_ptr1]], %[s_8] \n" // next r1 - // v11: 2 4 - // 6 8, - // v12: 3 5 - // 7 9 - - "fmla v27.4s, v15.4s, v3.s[2] \n" // out r1: - // w14 - "fmla v29.4s, v16.4s, v4.s[1] \n" // out r1: - // w17 - "fmla v27.4s, v18.4s, v3.s[3] \n" // out r1: - // w15 - "fmla v29.4s, v19.4s, v4.s[0] \n" // out r1: - // w16 - - "ld2 {v15.4s, v16.4s}, [%[din_ptr1]], %[s_16] \n" // next r1 - // v15: 4 6 - // 8 10, - // v16: - // trash - // register - - "fmla v27.4s, v17.4s, v4.s[2] \n" // out r1: - // w18 - "fmla v29.4s, v20.4s, v4.s[3] \n" // out r1: - // w19 - - "ld2 {v18.4s, v19.4s}, [%[din_ptr2]], %[s_8] \n" // next r2 - // v18: 0 2 - // 4 6, - // v19: 1 3 - // 5 7 - "ld2 {v16.4s, v17.4s}, [%[din_ptr2]], %[s_8] \n" // next r2 - // v16: 2 4 - // 6 8, - // v11: 3 5 - // 7 9 - - "fmla v27.4s, v23.4s, v5.s[0] \n" // out r1: - // w20 - "fmla v29.4s, v21.4s, v5.s[2] \n" // out r1: - // w22 - "fmla v27.4s, v24.4s, v5.s[1] \n" // out r1: - // w21 - "fmla v29.4s, v22.4s, v5.s[3] \n" // out r1: - // w23 - - "ld2 {v20.4s, v21.4s}, [%[din_ptr2]], %[s_16] \n" // next r2 - // v20: 4 6 - // 8 10, - // v21: - // trash - // register - "ld2 {v23.4s, v24.4s}, [%[din_ptr3]], %[s_8] \n" // next r3 - // v23: 0 2 - // 4 6, - // v24: 1 3 - // 5 7 - - "fmla v27.4s, v25.4s, v30.s[0] \n" // out r1: - // w24 - - "ld2 {v21.4s, v22.4s}, [%[din_ptr3]], %[s_8] \n" // next r3 - // v21: 2 4 - // 6 8, - // v22: 3 5 - // 7 9 - "ld2 {v25.4s, v26.4s}, [%[din_ptr3]], %[s_16] \n" // next r3 - // v25: 4 6 - // 8 10, - // v26: - // trash - // register - - "fadd v27.4s, v27.4s, v29.4s \n" - "fmax v27.4s, v27.4s, v31.4s \n" - "cmp %w[mid_cnt], #1 \n" - "prfm pldl1keep, [%[din_ptr1]] \n" - "prfm pldl1keep, [%[din_ptr2]] \n" - "prfm pldl1keep, [%[din_ptr3]] \n" - "st1 {v27.4s}, [%[dout_ptr1]], #16 \n" - "blt 2f \n" - - // mid loop - "1: \n" - "ld1 {v26.4s}, [%[vbias]] \n" - "mov v27.16b, v26.16b \n" - "mov v28.16b, v31.16b \n" - "mov v29.16b, v31.16b \n" - - // out_r0 r0-r3 - "fmla v26.4s, v8.4s, v0.s[0] \n" - "fmla v28.4s, v9.4s, v0.s[1] \n" - "fmla v26.4s, v6.4s, v0.s[2] \n" - "fmla v28.4s, v7.4s, v0.s[3] \n" - - "ld2 {v8.4s, v9.4s}, [%[din_ptr0]], %[s_8] \n" - - "fmla v26.4s, v10.4s, v1.s[0] \n" - "fmla v28.4s, v11.4s, v1.s[3] \n" - - "ld2 {v6.4s, v7.4s}, [%[din_ptr0]], %[s_8] \n" - - "fmla v26.4s, v14.4s, v1.s[2] \n" - "fmla v28.4s, v13.4s, v1.s[1] \n" - - "ld2 {v10.4s, v11.4s}, [%[din_ptr0]], %[s_16] \n" - "prfm pldl1keep, [%[din_ptr0]] \n" - - "fmla v26.4s, v12.4s, v2.s[0] \n" - "fmla v28.4s, v15.4s, v2.s[1] \n" - - "ld2 {v13.4s, v14.4s}, [%[din_ptr4]], %[s_8] \n" - - "fmla v26.4s, v16.4s, v3.s[0] \n" - "fmla v27.4s, v16.4s, v0.s[2] \n" - - "ld2 {v11.4s, v12.4s}, [%[din_ptr4]], %[s_8] \n" - - "fmla v28.4s, v19.4s, v2.s[3] \n" - "fmla v29.4s, v19.4s, v0.s[1] \n" - - "ld2 {v15.4s, v16.4s}, [%[din_ptr4]], %[s_16] \n" - "prfm pldl1keep, [%[din_ptr4]] \n" - - "fmla v26.4s, v18.4s, v2.s[2] \n" - "fmla v27.4s, v18.4s, v0.s[0] \n" - - "fmla v28.4s, v17.4s, v3.s[1] \n" - "fmla v29.4s, v17.4s, v0.s[3] \n" - - "ld2 {v18.4s, v19.4s}, [%[din_ptr5]], %[s_8] \n" - - "fmla v26.4s, v20.4s, v3.s[2] \n" - "fmla v27.4s, v20.4s, v1.s[0] \n" - - "ld2 {v16.4s, v17.4s}, [%[din_ptr5]], %[s_8] \n" - - "fmla v29.4s, v21.4s, v1.s[3] \n" - "fmla v28.4s, v21.4s, v4.s[1] \n" - "fmla v28.4s, v23.4s, v3.s[3] \n" - "fmla v29.4s, v23.4s, v1.s[1] \n" - - "ld2 {v20.4s, v21.4s}, [%[din_ptr5]], %[s_16] \n" - "prfm pldl1keep, [%[din_ptr5]] \n" - - "fmla v26.4s, v24.4s, v4.s[0] \n" - "fmla v27.4s, v24.4s, v1.s[2] \n" - - "ld2 {v23.4s, v24.4s}, [%[din_ptr6]], %[s_8] \n" - - "fmla v27.4s, v22.4s, v2.s[0] \n" - "fmla v26.4s, v22.4s, v4.s[2] \n" - - "fmla v28.4s, v25.4s, v4.s[3] \n" - "fmla v29.4s, v25.4s, v2.s[1] \n" - - "ld2 {v21.4s, v22.4s}, [%[din_ptr6]], %[s_8] \n" - "fadd v28.4s, v26.4s, v28.4s \n" - - "ld2 {v25.4s, v26.4s}, [%[din_ptr6]], %[s_16] \n" - "mov v26.16b, v31.16b \n" - "prfm pldl1keep, [%[din_ptr6]] \n" - - "fmla v26.4s, v13.4s, v5.s[0] \n" - "fmla v28.4s, v14.4s, v5.s[1] \n" - "fmla v27.4s, v13.4s, v2.s[2] \n" - "fmla v29.4s, v14.4s, v2.s[3] \n" - - "ld2 {v13.4s, v14.4s}, [%[din_ptr1]], %[s_8] \n" - - "fmla v26.4s, v11.4s, v5.s[2] \n" - "fmla v28.4s, v12.4s, v5.s[3] \n" - "fmla v27.4s, v11.4s, v3.s[0] \n" - "fmla v29.4s, v12.4s, v3.s[1] \n" - - "ld2 {v11.4s, v12.4s}, [%[din_ptr1]], %[s_8] \n" - - "fmla v26.4s, v15.4s, v30.s[0] \n" - "fmla v27.4s, v15.4s, v3.s[2] \n" - "fmla v29.4s, v16.4s, v4.s[1] \n" - "fmla v27.4s, v17.4s, v4.s[2] \n" - - "ld2 {v15.4s, v16.4s}, [%[din_ptr1]], %[s_16] \n" - "prfm pldl1keep, [%[din_ptr1]] \n" - - "fmla v29.4s, v18.4s, v3.s[3] \n" - "fmla v27.4s, v19.4s, v4.s[0] \n" - - "ld2 {v18.4s, v19.4s}, [%[din_ptr2]], %[s_8] \n" - - "fmla v29.4s, v20.4s, v4.s[3] \n" - - "ld2 {v16.4s, v17.4s}, [%[din_ptr2]], %[s_8] \n" - - "fmla v27.4s, v23.4s, v5.s[0] \n" - "fmla v27.4s, v21.4s, v5.s[2] \n" - - "ld2 {v20.4s, v21.4s}, [%[din_ptr2]], %[s_16] \n" - - "fmla v29.4s, v24.4s, v5.s[1] \n" - - "ld2 {v23.4s, v24.4s}, [%[din_ptr3]], %[s_8] \n" - "prfm pldl1keep, [%[din_ptr2]] \n" - - "fmla v29.4s, v22.4s, v5.s[3] \n" - - "ld2 {v21.4s, v22.4s}, [%[din_ptr3]], %[s_8] \n" - - "fmla v27.4s, v25.4s, v30.s[0] \n" - - "fadd v26.4s, v26.4s, v28.4s \n" - "fadd v27.4s, v27.4s, v29.4s \n" - "fmax v26.4s, v26.4s, v31.4s \n" - "fmax v27.4s, v27.4s, v31.4s \n" - - "prfm pldl1keep, [%[din_ptr3]] \n" - "st1 {v26.4s}, [%[dout_ptr0]], #16 \n" - "st1 {v27.4s}, [%[dout_ptr1]], #16 \n" - - "ld2 {v25.4s, v26.4s}, [%[din_ptr3]], %[s_16] \n" - "subs %w[mid_cnt], %w[mid_cnt], #1 \n" - "bne 1b \n" - - "2: \n" - "ld2 {v26.4s, v27.4s}, [%[mask]], %[s_8] \n" - "ld2 {v28.4s, v29.4s}, [%[mask]], %[s_8] \n" - "bif v8.16b, v31.16b, v26.16b \n" - "bif v9.16b, v31.16b, v27.16b \n" - "bif v6.16b, v31.16b, v28.16b \n" - "bif v7.16b, v31.16b, v29.16b \n" - - "bif v13.16b, v31.16b, v26.16b \n" - "bif v14.16b, v31.16b, v27.16b \n" - "bif v11.16b, v31.16b, v28.16b \n" - "bif v12.16b, v31.16b, v29.16b \n" - - "bif v18.16b, v31.16b, v26.16b \n" - "bif v19.16b, v31.16b, v27.16b \n" - "bif v16.16b, v31.16b, v28.16b \n" - "bif v17.16b, v31.16b, v29.16b \n" - - "bif v23.16b, v31.16b, v26.16b \n" - "bif v24.16b, v31.16b, v27.16b \n" - "bif v21.16b, v31.16b, v28.16b \n" - "bif v22.16b, v31.16b, v29.16b \n" - - "ld2 {v28.4s, v29.4s}, [%[mask]] \n" - "ld1 {v26.4s}, [%[vbias]] \n" - "mov v29.16b, v31.16b \n" - - "bif v10.16b, v31.16b, v28.16b \n" - "bif v15.16b, v31.16b, v28.16b \n" - - "mov v27.16b, v26.16b \n" - - "bif v20.16b, v31.16b, v28.16b \n" - "bif v25.16b, v31.16b, v28.16b \n" - "mov v28.16b, v31.16b \n" - - "fmla v26.4s, v8.4s, v0.s[0] \n" - "fmla v28.4s, v9.4s, v0.s[1] \n" - "fmla v26.4s, v6.4s, v0.s[2] \n" - "fmla v28.4s, v7.4s, v0.s[3] \n" - - "fmla v26.4s, v10.4s, v1.s[0] \n" - "fmla v28.4s, v13.4s, v1.s[1] \n" - "fmla v26.4s, v14.4s, v1.s[2] \n" - "fmla v28.4s, v11.4s, v1.s[3] \n" - - "sub %[mask], %[mask], #16 \n" - "ld2 {v6.4s, v7.4s}, [%[mask]], %[s_8] \n" - "ld2 {v8.4s, v9.4s}, [%[mask]], %[s_8] \n" - "ld2 {v10.4s, v11.4s}, [%[mask]] \n" - - "fmla v26.4s, v12.4s, v2.s[0] \n" - "fmla v28.4s, v15.4s, v2.s[1] \n" - - "ld2 {v13.4s, v14.4s}, [%[din_ptr4]], %[s_8] \n" - - "fmla v26.4s, v16.4s, v3.s[0] \n" - "fmla v28.4s, v17.4s, v3.s[1] \n" - - "ld2 {v11.4s, v12.4s}, [%[din_ptr4]], %[s_8] \n" - - "fmla v27.4s, v16.4s, v0.s[2] \n" - "fmla v29.4s, v17.4s, v0.s[3] \n" - - "ld2 {v15.4s, v16.4s}, [%[din_ptr4]] \n" - - "fmla v26.4s, v18.4s, v2.s[2] \n" - "fmla v28.4s, v19.4s, v2.s[3] \n" - "fmla v27.4s, v18.4s, v0.s[0] \n" - "fmla v29.4s, v19.4s, v0.s[1] \n" - - "bif v13.16b, v31.16b, v6.16b \n" - "bif v14.16b, v31.16b, v7.16b \n" - "bif v11.16b, v31.16b, v8.16b \n" - "bif v12.16b, v31.16b, v9.16b \n" - "bif v15.16b, v31.16b, v10.16b \n" - - "ld2 {v18.4s, v19.4s}, [%[din_ptr5]], %[s_8] \n" - - "fmla v26.4s, v20.4s, v3.s[2] \n" - "fmla v27.4s, v20.4s, v1.s[0] \n" - - "ld2 {v16.4s, v17.4s}, [%[din_ptr5]], %[s_8] \n" - - "fmla v29.4s, v21.4s, v1.s[3] \n" - "fmla v28.4s, v21.4s, v4.s[1] \n" - - "ld2 {v20.4s, v21.4s}, [%[din_ptr5]] \n" - - "fmla v28.4s, v23.4s, v3.s[3] \n" - "fmla v29.4s, v23.4s, v1.s[1] \n" - "fmla v27.4s, v24.4s, v1.s[2] \n" - "fmla v26.4s, v24.4s, v4.s[0] \n" - - "bif v18.16b, v31.16b, v6.16b \n" - "bif v19.16b, v31.16b, v7.16b \n" - "bif v16.16b, v31.16b, v8.16b \n" - "bif v17.16b, v31.16b, v9.16b \n" - "bif v20.16b, v31.16b, v10.16b \n" - - "ld2 {v23.4s, v24.4s}, [%[din_ptr6]], %[s_8] \n" - - "fmla v27.4s, v22.4s, v2.s[0] \n" - "fmla v26.4s, v22.4s, v4.s[2] \n" - - "ld2 {v21.4s, v22.4s}, [%[din_ptr6]], %[s_8] \n" - - "fmla v28.4s, v25.4s, v4.s[3] \n" - "fmla v29.4s, v25.4s, v2.s[1] \n" - "fadd v28.4s, v28.4s, v26.4s \n" - - "ld2 {v25.4s, v26.4s}, [%[din_ptr6]] \n" - "mov v26.16b, v31.16b \n" - - "bif v23.16b, v31.16b, v6.16b \n" - "bif v24.16b, v31.16b, v7.16b \n" - "bif v21.16b, v31.16b, v8.16b \n" - "bif v22.16b, v31.16b, v9.16b \n" - "bif v25.16b, v31.16b, v10.16b \n" - - "fmla v26.4s, v13.4s, v5.s[0] \n" - "fmla v28.4s, v14.4s, v5.s[1] \n" - "fmla v26.4s, v11.4s, v5.s[2] \n" - "fmla v28.4s, v12.4s, v5.s[3] \n" - "fmla v26.4s, v15.4s, v30.s[0] \n" - - "fmla v27.4s, v13.4s, v2.s[2] \n" - "fmla v29.4s, v14.4s, v2.s[3] \n" - "fmla v27.4s, v11.4s, v3.s[0] \n" - "fmla v29.4s, v12.4s, v3.s[1] \n" - - "fadd v26.4s, v26.4s, v28.4s \n" - "fmla v27.4s, v15.4s, v3.s[2] \n" - "fmla v29.4s, v18.4s, v3.s[3] \n" - "fmla v27.4s, v19.4s, v4.s[0] \n" - "fmla v29.4s, v16.4s, v4.s[1] \n" - - "fmax v26.4s, v26.4s, v31.4s \n" - "fmla v27.4s, v17.4s, v4.s[2] \n" - "fmla v29.4s, v20.4s, v4.s[3] \n" - "fmla v27.4s, v23.4s, v5.s[0] \n" - "fmla v29.4s, v24.4s, v5.s[1] \n" - - "st1 {v26.4s}, [%[out_buf0]] \n" - "fmla v27.4s, v21.4s, v5.s[2] \n" - "fmla v29.4s, v22.4s, v5.s[3] \n" - "fmla v27.4s, v25.4s, v30.s[0] \n" - "fadd v27.4s, v27.4s, v29.4s \n" - - "fmax v27.4s, v27.4s, v31.4s \n" - "st1 {v27.4s}, [%[out_buf1]] \n" - - : [dout_ptr0] "+r"(dout_ptr0), - [dout_ptr1] "+r"(dout_ptr1), - [mid_cnt] "+r"(loop), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [din_ptr4] "+r"(din_ptr4), - [din_ptr5] "+r"(din_ptr5), - [din_ptr6] "+r"(din_ptr6), - [mask] "+r"(mask_ptr), - [weights] "+r"(weights_ptr) - : [vbias] "r"(vbias), - [out_buf0] "r"(out_buf0), - [out_buf1] "r"(out_buf1), - [s_8] "r"(s_8), - [s_16] "r"(s_16) - : "memory", - "cc", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25", - "v26", - "v27", - "v28", - "v29", - "v30", - "v31"); - - int remain_cnt = w_out - (mid_cnt + 1) * 4; - for (int i = 0; i < remain_cnt; ++i) { - dout_ptr0[i] = out_buf0[i]; - dout_ptr1[i] = out_buf1[i]; - } - din0 = din4; - din1 = din5; - din2 = din6; - din3 = din6 + w_in; - din4 = din3 + w_in; - din5 = din4 + w_in; - din6 = din5 + w_in; - dout0 = dout1 + w_out; - dout1 = dout0 + w_out; - } - } - } -} - -//! small depthwise, win < 9; -void conv_depthwise_5x5s2p2_s(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - bool flag_bias, - bool flag_relu, - ARMContext* ctx) { - CHECK_LT(w_in, 9) << "only support win < 9"; - int w_out_round = (w_out + 3) / 4 * 4; - int mask_cnt = 12 - w_in - 2; - int mask[12]; - memset(mask, 0xff, 12 * sizeof(int)); - for (int i = 0; i < mask_cnt; ++i) { - mask[11 - i] = 0; - } - float* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - int in_spatial_size = w_in * h_in; - int out_spatial_size = w_out * h_out; - int weights_saptial_size = 25; - - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * in_spatial_size * ch_in; - float* dout_batch = dout + n * out_spatial_size * ch_out; -#pragma omp parallel for - for (int c = 0; c < ch_in; ++c) { - const float* din_ch = din_batch + c * in_spatial_size; - float* dout_ch = dout_batch + c * out_spatial_size; - const float* din0 = zero_ptr; - const float* din1 = zero_ptr; - const float* din2 = din_ch; - const float* din3 = din2 + w_in; - const float* din4 = din3 + w_in; - - float out_buf0[4]; - float out_buf1[4]; - float* dout0 = dout_ch; - float* dout1 = dout0 + w_out; - - const float* weights_c = weights + c * weights_saptial_size; - for (int h = 0; h < h_out; h += 1) { - //! (h * 2 - 2) + 4 > h_in - 1 - if (h * 2 + 3 > h_in) { - switch (h * 2 + 3 - h_in) { - case 4: - din1 = zero_ptr; - case 3: - din2 = zero_ptr; - case 2: - din3 = zero_ptr; - case 1: - din4 = zero_ptr; - default: - break; - } - } - - const float* din_ptr0 = din0; - const float* din_ptr1 = din1; - const float* din_ptr2 = din2; - const float* din_ptr3 = din3; - const float* din_ptr4 = din4; - - const float* weights_ptr = weights_c; - float* dout_ptr0 = dout0; - - float bias_c = 0.f; - if (flag_bias) { - bias_c = bias[c]; - } - float vbias[4] = {bias_c, bias_c, bias_c, bias_c}; - int* mask_ptr = mask; - const int s_8 = 8; - //! in r0/r4, r1, r2, r3: x 0 2 4 -- v8 v13 v18 v23 v28 - //! in r0/r4, r1, r2, r3: x 1 3 5 -- v9 v14 v19 v24 v29 - //! in r0/r4, r1, r2, r3: 0 2 4 6 -- v6 v11 v16 v21 v26 - //! in r0/r4, r1, r2, r3: 1 3 5 7 -- v7 v12 v17 v22 v27 - //! in r0/r4, r1, r2, r3: 2 4 6 8 -- v10 v15 v20 v25 v30 - //! out r0 -- v4 - asm volatile( - "movi v31.4s, #0x0\n" - "prfm pldl1keep, [%[din_ptr0]] \n" - "prfm pldl1keep, [%[din_ptr1]] \n" - "prfm pldl1keep, [%[din_ptr2]] \n" - "prfm pldl1keep, [%[din_ptr3]] \n" - "prfm pldl1keep, [%[din_ptr4]] \n" - "prfm pldl1keep, [%[weights]] \n" - "prfm pldl1keep, [%[mask]] \n" - - //! load mask - "ld2 {v0.4s, v1.4s}, [%[mask]], %[s_8] \n" - "ld2 {v2.4s, v3.4s}, [%[mask]], %[s_8] \n" - "ld2 {v4.4s, v5.4s}, [%[mask]] \n" - - //! load and extract input - "ld2 {v6.4s, v7.4s}, [%[din_ptr0]], #32 \n" - "ld2 {v11.4s, v12.4s}, [%[din_ptr1]], #32 \n" - "ld2 {v16.4s, v17.4s}, [%[din_ptr2]], #32 \n" - "ld2 {v21.4s, v22.4s}, [%[din_ptr3]], #32 \n" - "ld2 {v26.4s, v27.4s}, [%[din_ptr4]], #32 \n" - - "ext v8.16b, v31.16b, v6.16b, #12 \n" - "ext v9.16b, v31.16b, v7.16b, #12 \n" - "ext v13.16b, v31.16b, v11.16b, #12 \n" - "ext v14.16b, v31.16b, v12.16b, #12 \n" - - "ext v18.16b, v31.16b, v16.16b, #12 \n" - "ext v19.16b, v31.16b, v17.16b, #12 \n" - "ext v23.16b, v31.16b, v21.16b, #12 \n" - "ext v24.16b, v31.16b, v22.16b, #12 \n" - "ext v28.16b, v31.16b, v26.16b, #12 \n" - "ext v29.16b, v31.16b, v27.16b, #12 \n" - - "ext v10.16b, v6.16b, v31.16b, #4 \n" - "ext v15.16b, v11.16b, v31.16b, #4 \n" - "ext v20.16b, v16.16b, v31.16b, #4 \n" - "ext v25.16b, v21.16b, v31.16b, #4 \n" - "ext v30.16b, v26.16b, v31.16b, #4 \n" - - "bif v8.16b, v31.16b, v0.16b \n" - "bif v9.16b, v31.16b, v1.16b \n" - "bif v6.16b, v31.16b, v2.16b \n" - "bif v7.16b, v31.16b, v3.16b \n" - - "bif v13.16b, v31.16b, v0.16b \n" - "bif v14.16b, v31.16b, v1.16b \n" - "bif v11.16b, v31.16b, v2.16b \n" - "bif v12.16b, v31.16b, v3.16b \n" - - "bif v18.16b, v31.16b, v0.16b \n" - "bif v19.16b, v31.16b, v1.16b \n" - "bif v16.16b, v31.16b, v2.16b \n" - "bif v17.16b, v31.16b, v3.16b \n" - - "ld1 {v10.s}[3], [%[din_ptr0]] \n" - "ld1 {v15.s}[3], [%[din_ptr1]] \n" - "ld1 {v20.s}[3], [%[din_ptr2]] \n" - "ld1 {v25.s}[3], [%[din_ptr3]] \n" - "ld1 {v30.s}[3], [%[din_ptr4]] \n" - - "bif v23.16b, v31.16b, v0.16b \n" - "bif v24.16b, v31.16b, v1.16b \n" - "bif v21.16b, v31.16b, v2.16b \n" - "bif v22.16b, v31.16b, v3.16b \n" - - "bif v28.16b, v31.16b, v0.16b \n" - "bif v29.16b, v31.16b, v1.16b \n" - "bif v26.16b, v31.16b, v2.16b \n" - "bif v27.16b, v31.16b, v3.16b \n" - - "bif v10.16b, v31.16b, v4.16b \n" - "bif v15.16b, v31.16b, v4.16b \n" - "bif v20.16b, v31.16b, v4.16b \n" - "bif v25.16b, v31.16b, v4.16b \n" - "bif v30.16b, v31.16b, v4.16b \n" - - "ld1 {v4.4s}, [%[vbias]] \n" - "mov v5.16b, v31.16b \n" - - "ld1 {v0.4s, v1.4s}, [%[weights]], #32 \n" // load weights 0-7 - "ld1 {v2.4s, v3.4s}, [%[weights]], #32 \n" // load weights 8-15 - - //! compute - "fmla v4.4s, v8.4s, v0.s[0] \n" // out r0: w0 - "fmla v5.4s, v9.4s, v0.s[1] \n" // out r0: w1 - "fmla v4.4s, v6.4s, v0.s[2] \n" // out r0: w2 - "fmla v5.4s, v7.4s, v0.s[3] \n" // out r0: w3 - - "fmla v4.4s, v10.4s, v1.s[0] \n" // out r0: w4 - "fmla v5.4s, v13.4s, v1.s[1] \n" // out r0: w5 - "fmla v4.4s, v14.4s, v1.s[2] \n" // out r0: w6 - "fmla v5.4s, v11.4s, v1.s[3] \n" // out r0: w7 - - "ld1 {v6.4s, v7.4s}, [%[weights]], #32 \n" // load weights 16-23 - "ld1 {v8.s}[0], [%[weights]] \n" // load weights 24 - - "fmla v4.4s, v12.4s, v2.s[0] \n" // out r0: w8 - "fmla v5.4s, v15.4s, v2.s[1] \n" // out r0: w9 - "fmla v4.4s, v18.4s, v2.s[2] \n" // out r0: w10 - "fmla v5.4s, v19.4s, v2.s[3] \n" // out r0: w11 - - "fmla v4.4s, v16.4s, v3.s[0] \n" // out r0: w12 - "fmla v5.4s, v17.4s, v3.s[1] \n" // out r0: w13 - "fmla v4.4s, v20.4s, v3.s[2] \n" // out r0: w14 - "fmla v5.4s, v23.4s, v3.s[3] \n" // out r0: w15 - - "fmla v4.4s, v24.4s, v6.s[0] \n" // out r0: w16 - "fmla v5.4s, v21.4s, v6.s[1] \n" // out r0: w17 - "fmla v4.4s, v22.4s, v6.s[2] \n" // out r0: w18 - "fmla v5.4s, v25.4s, v6.s[3] \n" // out r0: w19 - - "fmla v4.4s, v28.4s, v7.s[0] \n" // out r0: w20 - "fmla v5.4s, v29.4s, v7.s[1] \n" // out r0: w21 - "fmla v4.4s, v26.4s, v7.s[2] \n" // out r0: w22 - "fmla v5.4s, v27.4s, v7.s[3] \n" // out r0: w23 - "fmla v4.4s, v30.4s, v8.s[0] \n" // out r0: w24 - - "fadd v4.4s, v4.4s, v5.4s \n" // add out to v4 - "st1 {v4.4s}, [%[out_buf0]] \n" - - : [dout_ptr0] "+r"(dout_ptr0), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [din_ptr4] "+r"(din_ptr4), - [mask] "+r"(mask_ptr), - [weights] "+r"(weights_ptr) - : [vbias] "r"(vbias), - [out_buf0] "r"(out_buf0), - [out_buf1] "r"(out_buf1), - [s_8] "r"(s_8) - : "memory", - "cc", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25", - "v26", - "v27", - "v28", - "v29", - "v30", - "v31"); - for (int i = 0; i < w_out; ++i) { - dout_ptr0[i] = out_buf0[i]; - } - din0 = din2; - din1 = din3; - din2 = din4; - din3 = din2 + w_in; - din4 = din3 + w_in; - dout0 += w_out; - } - } - } -} - -//! small depthwise, win < 9; -void conv_depthwise_5x5s2p2_relu_s(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - bool flag_bias, - bool flag_relu, - ARMContext* ctx) { - CHECK_LT(w_in, 9) << "only support win < 9"; - int w_out_round = (w_out + 3) / 4 * 4; - int mask_cnt = 12 - w_in - 2; - int mask[12]; - memset(mask, 0xff, 12 * sizeof(int)); - for (int i = 0; i < mask_cnt; ++i) { - mask[11 - i] = 0; - } - float* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - int in_spatial_size = w_in * h_in; - int out_spatial_size = w_out * h_out; - int weights_saptial_size = 25; - - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * in_spatial_size * ch_in; - float* dout_batch = dout + n * out_spatial_size * ch_out; -#pragma omp parallel for - for (int c = 0; c < ch_in; ++c) { - const float* din_ch = din_batch + c * in_spatial_size; - float* dout_ch = dout_batch + c * out_spatial_size; - const float* din0 = zero_ptr; - const float* din1 = zero_ptr; - const float* din2 = din_ch; - const float* din3 = din2 + w_in; - const float* din4 = din3 + w_in; - - float out_buf0[4]; - float out_buf1[4]; - float* dout0 = dout_ch; - float* dout1 = dout0 + w_out; - - const float* weights_c = weights + c * weights_saptial_size; - for (int h = 0; h < h_out; h += 1) { - //! (h * 2 - 2) + 4 > h_in - 1 - if (h * 2 + 3 > h_in) { - switch (h * 2 + 3 - h_in) { - case 4: - din1 = zero_ptr; - case 3: - din2 = zero_ptr; - case 2: - din3 = zero_ptr; - case 1: - din4 = zero_ptr; - default: - break; - } - } - const float* din_ptr0 = din0; - const float* din_ptr1 = din1; - const float* din_ptr2 = din2; - const float* din_ptr3 = din3; - const float* din_ptr4 = din4; - - const float* weights_ptr = weights_c; - float* dout_ptr0 = dout0; - - float bias_c = 0.f; - if (flag_bias) { - bias_c = bias[c]; - } - float vbias[4] = {bias_c, bias_c, bias_c, bias_c}; - int* mask_ptr = mask; - const int s_8 = 8; - //! in r0/r4, r1, r2, r3: x 0 2 4 -- v8 v13 v18 v23 v28 - //! in r0/r4, r1, r2, r3: x 1 3 5 -- v9 v14 v19 v24 v29 - //! in r0/r4, r1, r2, r3: 0 2 4 6 -- v6 v11 v16 v21 v26 - //! in r0/r4, r1, r2, r3: 1 3 5 7 -- v7 v12 v17 v22 v27 - //! in r0/r4, r1, r2, r3: 2 4 6 8 -- v10 v15 v20 v25 v30 - //! out r0 -- v4 - asm volatile( - "movi v31.4s, #0x0\n" - "prfm pldl1keep, [%[din_ptr0]] \n" - "prfm pldl1keep, [%[din_ptr1]] \n" - "prfm pldl1keep, [%[din_ptr2]] \n" - "prfm pldl1keep, [%[din_ptr3]] \n" - "prfm pldl1keep, [%[din_ptr4]] \n" - "prfm pldl1keep, [%[weights]] \n" - "prfm pldl1keep, [%[mask]] \n" - - //! load mask - "ld2 {v0.4s, v1.4s}, [%[mask]], %[s_8] \n" - "ld2 {v2.4s, v3.4s}, [%[mask]], %[s_8] \n" - "ld2 {v4.4s, v5.4s}, [%[mask]] \n" - - //! load and extract input - "ld2 {v6.4s, v7.4s}, [%[din_ptr0]], #32 \n" - "ld2 {v11.4s, v12.4s}, [%[din_ptr1]], #32 \n" - "ld2 {v16.4s, v17.4s}, [%[din_ptr2]], #32 \n" - "ld2 {v21.4s, v22.4s}, [%[din_ptr3]], #32 \n" - "ld2 {v26.4s, v27.4s}, [%[din_ptr4]], #32 \n" - - "ext v8.16b, v31.16b, v6.16b, #12 \n" - "ext v9.16b, v31.16b, v7.16b, #12 \n" - "ext v13.16b, v31.16b, v11.16b, #12 \n" - "ext v14.16b, v31.16b, v12.16b, #12 \n" - - "ext v18.16b, v31.16b, v16.16b, #12 \n" - "ext v19.16b, v31.16b, v17.16b, #12 \n" - "ext v23.16b, v31.16b, v21.16b, #12 \n" - "ext v24.16b, v31.16b, v22.16b, #12 \n" - "ext v28.16b, v31.16b, v26.16b, #12 \n" - "ext v29.16b, v31.16b, v27.16b, #12 \n" - - "ext v10.16b, v6.16b, v31.16b, #4 \n" - "ext v15.16b, v11.16b, v31.16b, #4 \n" - "ext v20.16b, v16.16b, v31.16b, #4 \n" - "ext v25.16b, v21.16b, v31.16b, #4 \n" - "ext v30.16b, v26.16b, v31.16b, #4 \n" - - "bif v8.16b, v31.16b, v0.16b \n" - "bif v9.16b, v31.16b, v1.16b \n" - "bif v6.16b, v31.16b, v2.16b \n" - "bif v7.16b, v31.16b, v3.16b \n" - - "bif v13.16b, v31.16b, v0.16b \n" - "bif v14.16b, v31.16b, v1.16b \n" - "bif v11.16b, v31.16b, v2.16b \n" - "bif v12.16b, v31.16b, v3.16b \n" - - "bif v18.16b, v31.16b, v0.16b \n" - "bif v19.16b, v31.16b, v1.16b \n" - "bif v16.16b, v31.16b, v2.16b \n" - "bif v17.16b, v31.16b, v3.16b \n" - - "ld1 {v10.s}[3], [%[din_ptr0]] \n" - "ld1 {v15.s}[3], [%[din_ptr1]] \n" - "ld1 {v20.s}[3], [%[din_ptr2]] \n" - "ld1 {v25.s}[3], [%[din_ptr3]] \n" - "ld1 {v30.s}[3], [%[din_ptr4]] \n" - - "bif v23.16b, v31.16b, v0.16b \n" - "bif v24.16b, v31.16b, v1.16b \n" - "bif v21.16b, v31.16b, v2.16b \n" - "bif v22.16b, v31.16b, v3.16b \n" - - "bif v28.16b, v31.16b, v0.16b \n" - "bif v29.16b, v31.16b, v1.16b \n" - "bif v26.16b, v31.16b, v2.16b \n" - "bif v27.16b, v31.16b, v3.16b \n" - - "bif v10.16b, v31.16b, v4.16b \n" - "bif v15.16b, v31.16b, v4.16b \n" - "bif v20.16b, v31.16b, v4.16b \n" - "bif v25.16b, v31.16b, v4.16b \n" - "bif v30.16b, v31.16b, v4.16b \n" - - "ld1 {v4.4s}, [%[vbias]] \n" - "mov v5.16b, v31.16b \n" - - "ld1 {v0.4s, v1.4s}, [%[weights]], #32 \n" // load weights 0-7 - "ld1 {v2.4s, v3.4s}, [%[weights]], #32 \n" // load weights 8-15 - - //! compute - "fmla v4.4s, v8.4s, v0.s[0] \n" // out r0: w0 - "fmla v5.4s, v9.4s, v0.s[1] \n" // out r0: w1 - "fmla v4.4s, v6.4s, v0.s[2] \n" // out r0: w2 - "fmla v5.4s, v7.4s, v0.s[3] \n" // out r0: w3 - - "fmla v4.4s, v10.4s, v1.s[0] \n" // out r0: w4 - "fmla v5.4s, v13.4s, v1.s[1] \n" // out r0: w5 - "fmla v4.4s, v14.4s, v1.s[2] \n" // out r0: w6 - "fmla v5.4s, v11.4s, v1.s[3] \n" // out r0: w7 - - "ld1 {v6.4s, v7.4s}, [%[weights]], #32 \n" // load weights 16-23 - "ld1 {v8.s}[0], [%[weights]] \n" // load weights 24 - - "fmla v4.4s, v12.4s, v2.s[0] \n" // out r0: w8 - "fmla v5.4s, v15.4s, v2.s[1] \n" // out r0: w9 - "fmla v4.4s, v18.4s, v2.s[2] \n" // out r0: w10 - "fmla v5.4s, v19.4s, v2.s[3] \n" // out r0: w11 - - "fmla v4.4s, v16.4s, v3.s[0] \n" // out r0: w12 - "fmla v5.4s, v17.4s, v3.s[1] \n" // out r0: w13 - "fmla v4.4s, v20.4s, v3.s[2] \n" // out r0: w14 - "fmla v5.4s, v23.4s, v3.s[3] \n" // out r0: w15 - - "fmla v4.4s, v24.4s, v6.s[0] \n" // out r0: w16 - "fmla v5.4s, v21.4s, v6.s[1] \n" // out r0: w17 - "fmla v4.4s, v22.4s, v6.s[2] \n" // out r0: w18 - "fmla v5.4s, v25.4s, v6.s[3] \n" // out r0: w19 - - "fmla v4.4s, v28.4s, v7.s[0] \n" // out r0: w20 - "fmla v5.4s, v29.4s, v7.s[1] \n" // out r0: w21 - "fmla v4.4s, v26.4s, v7.s[2] \n" // out r0: w22 - "fmla v5.4s, v27.4s, v7.s[3] \n" // out r0: w23 - "fmla v4.4s, v30.4s, v8.s[0] \n" // out r0: w24 - - "fadd v4.4s, v4.4s, v5.4s \n" // add out to v4 - "fmax v4.4s, v4.4s, v31.4s \n" - "st1 {v4.4s}, [%[out_buf0]] \n" - - : [dout_ptr0] "+r"(dout_ptr0), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [din_ptr4] "+r"(din_ptr4), - [mask] "+r"(mask_ptr), - [weights] "+r"(weights_ptr) - : [vbias] "r"(vbias), - [out_buf0] "r"(out_buf0), - [out_buf1] "r"(out_buf1), - [s_8] "r"(s_8) - : "memory", - "cc", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25", - "v26", - "v27", - "v28", - "v29", - "v30", - "v31"); - for (int i = 0; i < w_out; ++i) { - dout_ptr0[i] = out_buf0[i]; - } - din0 = din2; - din1 = din3; - din2 = din4; - din3 = din2 + w_in; - din4 = din3 + w_in; - dout0 += w_out; - } - } - } -} - +#ifdef __aarch64__ + act_switch_5x5s2(inr0, + inr1, + inr2, + inr3, + inr4, + outc0, + outc1, + outc2, + outc3, + w0, + w1, + w2, + w3, + w4, + vbias, + weight_c, + bias_local, + act_param); #else - -//! larger depthwise, win >= 9; -void conv_depthwise_5x5s2p2(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - bool flag_bias, - bool flag_relu, - ARMContext* ctx) { - // printf("invoke 5x5s2p2 armv7\n"); - CHECK_GE(w_in, 9) << "only support win >= 9"; - int w_out_round = (w_out + 3) / 4 * 4; - int cnt = (w_out_round - 4) / 4; - int mid_cnt = cnt - 1; - int right_start = cnt * 2 * 4 - 2; - int mask_cnt = 12 - (w_in - right_start); - int mask[12]; - memset(mask, 0xff, 12 * sizeof(int)); - for (int i = 0; i < mask_cnt; ++i) { - mask[11 - i] = 0; - } - float* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - int in_spatial_size = w_in * h_in; - int out_spatial_size = w_out * h_out; - int weights_saptial_size = 25; - - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * in_spatial_size * ch_in; - float* dout_batch = dout + n * out_spatial_size * ch_out; -#pragma omp parallel for - for (int c = 0; c < ch_in; ++c) { - const float* din_ch = din_batch + c * in_spatial_size; - float* dout_ch = dout_batch + c * out_spatial_size; - const float* din0 = zero_ptr; - const float* din1 = zero_ptr; - const float* din2 = din_ch; - const float* din3 = din2 + w_in; - const float* din4 = din3 + w_in; - - float out_buf0[4]; - float* dout0 = dout_ch; - - const float* weights_c = weights + c * weights_saptial_size; - float32x4_t w0 = vld1q_f32(weights_c); - float32x4_t w1 = vld1q_f32(weights_c + 4); - float32x4_t w2 = vld1q_f32(weights_c + 8); - float32x4_t w3 = vld1q_f32(weights_c + 12); - float32x4_t w4 = vld1q_f32(weights_c + 16); - float32x4_t w5 = vld1q_f32(weights_c + 20); - for (int h = 0; h < h_out; h += 1) { - //! (h * 2 - 2) + 4 > h_in - 1 - if (h * 2 + 3 > h_in) { - switch (h * 2 + 3 - h_in) { - case 4: - din1 = zero_ptr; - case 3: - din2 = zero_ptr; - case 2: - din3 = zero_ptr; - case 1: - din4 = zero_ptr; - default: - break; - } - } - const float* din_ptr0 = din0; - const float* din_ptr1 = din1; - const float* din_ptr2 = din2; - const float* din_ptr3 = din3; - const float* din_ptr4 = din4; - - const float* weights_ptr = weights_c + 24; - float* dout_ptr0 = dout0; - - float bias_c = 0.f; - if (flag_bias) { - bias_c = bias[c]; - } - float vbias[4] = {bias_c, bias_c, bias_c, bias_c}; - int* mask_ptr = mask; - int loop = mid_cnt; - const int s_8 = 8; - const int s_16 = 16; - - asm volatile( - "vmov.i32 q15, #0x0 \n" - "pld [%[din_ptr0]] \n" - "pld [%[din_ptr1]] \n" - "pld [%[din_ptr2]] \n" - "pld [%[din_ptr3]] \n" - "pld [%[din_ptr4]] \n" - "pld [%[mask]] \n" - - // left - "vld2.32 {d16-d19}, [%[din_ptr0]]! \n" - "vld1.32 {d26-d29}, [%[vbias]] \n" - "vext.32 q6, q15, q8, #3 \n" - "vext.32 q7, q15, q9, #3 \n" - "vext.32 q10, q8, q15, #1 \n" - "vmov.32 q14, q15 \n" - - // r0 - "vmla.f32 q13, q8, %f[w0][0] \n" - "vmla.f32 q14, q9, %f[w0][1] \n" - - "vld1.32 {d21[1]}, [%[din_ptr0]] \n" - "vld2.32 {d16-d19}, [%[din_ptr1]]! \n" - "sub %[din_ptr0], #8 \n" - - "vmla.f32 q13, q6, %e[w0][0] \n" - "vmla.f32 q14, q7, %e[w0][1] \n" - "vmla.f32 q13, q10, %e[w1][0] \n" - - "vext.32 q6, q15, q8, #3 \n" - "vext.32 q7, q15, q9, #3 \n" - "vext.32 q10, q8, q15, #1 \n" - - // r1 - "vmla.f32 q13, q8, %f[w1][1] \n" - "vmla.f32 q14, q9, %e[w2][0] \n" - - "vld1.32 {d21[1]}, [%[din_ptr1]] \n" - "vld2.32 {d16-d19}, [%[din_ptr2]]! \n" - "sub %[din_ptr1], #8 \n" - - "vmla.f32 q13, q6, %e[w1][1] \n" - "vmla.f32 q14, q7, %f[w1][0] \n" - "vmla.f32 q13, q10, %e[w2][1] \n" - - "vext.32 q6, q15, q8, #3 \n" - "vext.32 q7, q15, q9, #3 \n" - "vext.32 q10, q8, q15, #1 \n" - - // r2 - "vmla.f32 q13, q8, %e[w3][0] \n" - "vmla.f32 q14, q9, %e[w3][1] \n" - - "vld1.32 {d21[1]}, [%[din_ptr2]] \n" - "vld2.32 {d16-d19}, [%[din_ptr3]]! \n" - "sub %[din_ptr2], #8 \n" - - "vmla.f32 q13, q6, %f[w2][0] \n" - "vmla.f32 q14, q7, %f[w2][1] \n" - "vmla.f32 q13, q10, %f[w3][0] \n" - - "vext.32 q6, q15, q8, #3 \n" - "vext.32 q7, q15, q9, #3 \n" - "vext.32 q10, q8, q15, #1 \n" - - // r3 - "vmla.f32 q13, q8, %e[w4][1] \n" - "vmla.f32 q14, q9, %f[w4][0] \n" - - "vld1.32 {d21[1]}, [%[din_ptr3]] \n" - "vld2.32 {d16-d19}, [%[din_ptr4]]! \n" - "sub %[din_ptr3], #8 \n" - - "vmla.f32 q13, q6, %f[w3][1] \n" - "vmla.f32 q14, q7, %e[w4][0] \n" - "vmla.f32 q13, q10, %f[w4][1] \n" - - "vext.32 q6, q15, q8, #3 \n" - "vext.32 q7, q15, q9, #3 \n" - "vext.32 q10, q8, q15, #1 \n" - - // r4 - "vmla.f32 q13, q6, %e[w5][0] \n" - "vmla.f32 q14, q7, %e[w5][1] \n" - - "vld1.32 {d21[1]}, [%[din_ptr4]] \n" - "vld2.32 {d12-d15}, [%[din_ptr0]], %[s_8] \n" - "sub %[din_ptr4], #8 \n" - - "vmla.f32 q13, q8, %f[w5][0] \n" - "vmla.f32 q14, q9, %f[w5][1] \n" - - "vld2.32 {d16-d19}, [%[din_ptr0]], %[s_8] \n" - - "vmov.32 q12, %q[w0] \n" - "vld1.32 {%e[w0][0]}, [%[weights]] \n" - "vmla.f32 q13, q10, %e[w0][0] \n" - "vadd.f32 q13, q13, q14 \n" - "vmov.32 %q[w0], q12 \n" - "cmp %[mid_cnt], #1 \n" - "vld2.32 {d20-d23}, [%[din_ptr0]], %[s_16] \n" - "vst1.32 {d26-d27}, [%[dout_ptr0]]! \n" - "pld [%[din_ptr0]] \n" - "blt 2f \n" - - // mid - "1: \n" - "vld1.32 {d26-d27}, [%[vbias]] \n" - "vmov.32 q14, q15 \n" - - // r0 - "vmla.f32 q13, q6, %e[w0][0] \n" - "vmla.f32 q14, q7, %e[w0][1] \n" - - "vld2.32 {d12-d15}, [%[din_ptr1]], %[s_8] \n" - - "vmla.f32 q13, q8, %f[w0][0] \n" - "vmla.f32 q14, q9, %f[w0][1] \n" - - "vld2.32 {d16-d19}, [%[din_ptr1]], %[s_8] \n" - - "vmla.f32 q13, q10, %e[w1][0] \n" - - "vld2.32 {d20-d23}, [%[din_ptr1]], %[s_16] \n" - - // r1 - "vmla.f32 q13, q6, %e[w1][1] \n" - "vmla.f32 q14, q7, %f[w1][0] \n" - "pld [%[din_ptr1]] \n" - - "vld2.32 {d12-d15}, [%[din_ptr2]], %[s_8] \n" - - "vmla.f32 q13, q8, %f[w1][1] \n" - "vmla.f32 q14, q9, %e[w2][0] \n" - - "vld2.32 {d16-d19}, [%[din_ptr2]], %[s_8] \n" - - "vmla.f32 q13, q10, %e[w2][1] \n" - - "vld2.32 {d20-d23}, [%[din_ptr2]], %[s_16] \n" - - // r2 - "vmla.f32 q13, q6, %f[w2][0] \n" - "vmla.f32 q14, q7, %f[w2][1] \n" - "pld [%[din_ptr2]] \n" - - "vld2.32 {d12-d15}, [%[din_ptr3]], %[s_8] \n" - - "vmla.f32 q13, q8, %e[w3][0] \n" - "vmla.f32 q14, q9, %e[w3][1] \n" - - "vld2.32 {d16-d19}, [%[din_ptr3]], %[s_8] \n" - - "vmla.f32 q13, q10, %f[w3][0] \n" - - "vld2.32 {d20-d23}, [%[din_ptr3]], %[s_16] \n" - - // r3 - "vmla.f32 q13, q6, %f[w3][1] \n" - "vmla.f32 q14, q7, %e[w4][0] \n" - "pld [%[din_ptr3]] \n" - - "vld2.32 {d12-d15}, [%[din_ptr4]], %[s_8] \n" - - "vmla.f32 q13, q8, %e[w4][1] \n" - "vmla.f32 q14, q9, %f[w4][0] \n" - - "vld2.32 {d16-d19}, [%[din_ptr4]], %[s_8] \n" - - "vmla.f32 q13, q10, %f[w4][1] \n" - - "vld2.32 {d20-d23}, [%[din_ptr4]], %[s_16] \n" - - // r4 - "vmla.f32 q13, q6, %e[w5][0] \n" - "vmla.f32 q14, q7, %e[w5][1] \n" - "pld [%[din_ptr4]] \n" - - "vld2.32 {d12-d15}, [%[din_ptr0]], %[s_8] \n" - "vld1.32 {%e[w0][0]}, [%[weights]] \n" - - "vmla.f32 q13, q8, %f[w5][0] \n" - "vmla.f32 q14, q9, %f[w5][1] \n" - - "vld2.32 {d16-d19}, [%[din_ptr0]], %[s_8] \n" - - "vmla.f32 q13, q10, %e[w0][0] \n" - - "vld2.32 {d20-d23}, [%[din_ptr0]], %[s_16] \n" - - "vmov.32 %q[w0], q12 \n" - "vadd.f32 q13, q13, q14 \n" - "subs %[mid_cnt], #1 \n" - "vst1.32 {d26-d27}, [%[dout_ptr0]]! \n" - "bne 1b \n" - - "2: \n" - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vld1.32 {d26-d27}, [%[vbias]] \n" - "vmov.32 q14, q15 \n" - - // r0 - "vbif.32 q6, q15, q11 \n" - "vbif.32 q7, q15, q12 \n" - "vmla.f32 q13, q6, %e[w0][0] \n" - "vmla.f32 q14, q7, %e[w0][1] \n" - - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vld2.32 {d12-d15}, [%[din_ptr1]], %[s_8] \n" - "vbif.32 q8, q15, q11 \n" - "vbif.32 q9, q15, q12 \n" - "vmla.f32 q13, q8, %f[w0][0] \n" - "vmla.f32 q14, q9, %f[w0][1] \n" - - "vld2.32 {d22-d25}, [%[mask]] \n" - "sub %[mask], #16 \n" - "vld2.32 {d16-d19}, [%[din_ptr1]], %[s_8] \n" - "vbif.32 q10, q15, q11 \n" - "vmla.f32 q13, q10, %e[w1][0] \n" - - // r1 - "vld2.32 {d20-d23}, [%[din_ptr1]] \n" - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vbif.32 q6, q15, q11 \n" - "vbif.32 q7, q15, q12 \n" - "vmla.f32 q13, q6, %e[w1][1] \n" - "vmla.f32 q14, q7, %f[w1][0] \n" - - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vld2.32 {d12-d15}, [%[din_ptr2]], %[s_8] \n" - "vbif.32 q8, q15, q11 \n" - "vbif.32 q9, q15, q12 \n" - "vmla.f32 q13, q8, %f[w1][1] \n" - "vmla.f32 q14, q9, %e[w2][0] \n" - - "vld2.32 {d22-d25}, [%[mask]] \n" - "sub %[mask], #16 \n" - "vld2.32 {d16-d19}, [%[din_ptr2]], %[s_8] \n" - "vbif.32 q10, q15, q11 \n" - "vmla.f32 q13, q10, %e[w2][1] \n" - - // r2 - "vld2.32 {d20-d23}, [%[din_ptr2]] \n" - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vbif.32 q6, q15, q11 \n" - "vbif.32 q7, q15, q12 \n" - "vmla.f32 q13, q6, %f[w2][0] \n" - "vmla.f32 q14, q7, %f[w2][1] \n" - - "vld2.32 {d12-d15}, [%[din_ptr3]], %[s_8] \n" - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vbif.32 q8, q15, q11 \n" - "vbif.32 q9, q15, q12 \n" - "vmla.f32 q13, q8, %e[w3][0] \n" - "vmla.f32 q14, q9, %e[w3][1] \n" - - "vld2.32 {d22-d25}, [%[mask]] \n" - "sub %[mask], #16 \n" - "vld2.32 {d16-d19}, [%[din_ptr3]], %[s_8] \n" - "vbif.32 q10, q15, q11 \n" - "vmla.f32 q13, q10, %f[w3][0] \n" - - // r3 - "vld2.32 {d20-d23}, [%[din_ptr3]] \n" - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vbif.32 q6, q15, q11 \n" - "vbif.32 q7, q15, q12 \n" - "vmla.f32 q13, q6, %f[w3][1] \n" - "vmla.f32 q14, q7, %e[w4][0] \n" - - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vld2.32 {d12-d15}, [%[din_ptr4]], %[s_8] \n" - "vbif.32 q8, q15, q11 \n" - "vbif.32 q9, q15, q12 \n" - "vmla.f32 q13, q8, %e[w4][1] \n" - "vmla.f32 q14, q9, %f[w4][0] \n" - - "vld2.32 {d22-d25}, [%[mask]] \n" - "sub %[mask], #16 \n" - "vld2.32 {d16-d19}, [%[din_ptr4]], %[s_8] \n" - "vbif.32 q10, q15, q11 \n" - "vmla.f32 q13, q10, %f[w4][1] \n" - - // r4 - "vld2.32 {d20-d23}, [%[din_ptr4]] \n" - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vbif.32 q6, q15, q11 \n" - "vbif.32 q7, q15, q12 \n" - "vmla.f32 q13, q6, %e[w5][0] \n" - "vmla.f32 q14, q7, %e[w5][1] \n" - - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vld1.32 {d12[0]}, [%[weights]] \n" - "vbif.32 q8, q15, q11 \n" - "vbif.32 q9, q15, q12 \n" - "vmla.f32 q13, q8, %f[w5][0] \n" - "vmla.f32 q14, q9, %f[w5][1] \n" - - "vld2.32 {d22-d25}, [%[mask]] \n" - "vbif.32 q10, q15, q11 \n" - "vmla.f32 q13, q10, d12[0] \n" - - "vadd.f32 q13, q13, q14 \n" - "vst1.32 {d26-d27}, [%[out_buf0]] \n" - - : [dout_ptr0] "+r"(dout_ptr0), - [mid_cnt] "+r"(loop), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [din_ptr4] "+r"(din_ptr4), - [mask] "+r"(mask_ptr), - [weights] "+r"(weights_ptr) - : [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [w5] "w"(w5), - [vbias] "r"(vbias), - [out_buf0] "r"(out_buf0), - [s_8] "r"(s_8), - [s_16] "r"(s_16) - : "memory", - "cc", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - - int remain_cnt = w_out - (mid_cnt + 1) * 4; - for (int i = 0; i < remain_cnt; ++i) { - dout_ptr0[i] = out_buf0[i]; - } - - din0 = din2; - din1 = din3; - din2 = din4; - din3 = din2 + w_in; - din4 = din3 + w_in; - dout0 += w_out; - } - } - } -} - -//! larger depthwise, win >= 9; -void conv_depthwise_5x5s2p2_relu(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - bool flag_bias, - bool flag_relu, - ARMContext* ctx) { - // printf("invoke 5x5s2p2 armv7\n"); - CHECK_GE(w_in, 9) << "only support win >= 9"; - int w_out_round = (w_out + 3) / 4 * 4; - int cnt = (w_out_round - 4) / 4; - int mid_cnt = cnt - 1; - int right_start = cnt * 2 * 4 - 2; - int mask_cnt = 12 - (w_in - right_start); - int mask[12]; - memset(mask, 0xff, 12 * sizeof(int)); - for (int i = 0; i < mask_cnt; ++i) { - mask[11 - i] = 0; - } - float* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - int in_spatial_size = w_in * h_in; - int out_spatial_size = w_out * h_out; - int weights_saptial_size = 25; - - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * in_spatial_size * ch_in; - float* dout_batch = dout + n * out_spatial_size * ch_out; -#pragma omp parallel for - for (int c = 0; c < ch_in; ++c) { - const float* din_ch = din_batch + c * in_spatial_size; - float* dout_ch = dout_batch + c * out_spatial_size; - const float* din0 = zero_ptr; - const float* din1 = zero_ptr; - const float* din2 = din_ch; - const float* din3 = din2 + w_in; - const float* din4 = din3 + w_in; - - float out_buf0[4]; - float* dout0 = dout_ch; - - const float* weights_c = weights + c * weights_saptial_size; - float32x4_t w0 = vld1q_f32(weights_c); - float32x4_t w1 = vld1q_f32(weights_c + 4); - float32x4_t w2 = vld1q_f32(weights_c + 8); - float32x4_t w3 = vld1q_f32(weights_c + 12); - float32x4_t w4 = vld1q_f32(weights_c + 16); - float32x4_t w5 = vld1q_f32(weights_c + 20); - for (int h = 0; h < h_out; h += 1) { - //! (h * 2 - 2) + 4 > h_in - 1 - if (h * 2 + 3 > h_in) { - switch (h * 2 + 3 - h_in) { - case 4: - din1 = zero_ptr; - case 3: - din2 = zero_ptr; - case 2: - din3 = zero_ptr; - case 1: - din4 = zero_ptr; - default: - break; - } - } - const float* din_ptr0 = din0; - const float* din_ptr1 = din1; - const float* din_ptr2 = din2; - const float* din_ptr3 = din3; - const float* din_ptr4 = din4; - - const float* weights_ptr = weights_c + 24; - float* dout_ptr0 = dout0; - - float bias_c = 0.f; - if (flag_bias) { - bias_c = bias[c]; - } - float vbias[4] = {bias_c, bias_c, bias_c, bias_c}; - int* mask_ptr = mask; - int loop = mid_cnt; - const int s_8 = 8; - const int s_16 = 16; - - asm volatile( - "vmov.i32 q15, #0x0 \n" - "pld [%[din_ptr0]] \n" - "pld [%[din_ptr1]] \n" - "pld [%[din_ptr2]] \n" - "pld [%[din_ptr3]] \n" - "pld [%[din_ptr4]] \n" - "pld [%[mask]] \n" - - // left - "vld2.32 {d16-d19}, [%[din_ptr0]]! \n" - "vld1.32 {d26-d29}, [%[vbias]] \n" - "vext.32 q6, q15, q8, #3 \n" - "vext.32 q7, q15, q9, #3 \n" - "vext.32 q10, q8, q15, #1 \n" - "vmov.32 q14, q15 \n" - - // r0 - "vmla.f32 q13, q8, %f[w0][0] \n" - "vmla.f32 q14, q9, %f[w0][1] \n" - - "vld1.32 {d21[1]}, [%[din_ptr0]] \n" - "vld2.32 {d16-d19}, [%[din_ptr1]]! \n" - "sub %[din_ptr0], #8 \n" - - "vmla.f32 q13, q6, %e[w0][0] \n" - "vmla.f32 q14, q7, %e[w0][1] \n" - "vmla.f32 q13, q10, %e[w1][0] \n" - - "vext.32 q6, q15, q8, #3 \n" - "vext.32 q7, q15, q9, #3 \n" - "vext.32 q10, q8, q15, #1 \n" - - // r1 - "vmla.f32 q13, q8, %f[w1][1] \n" - "vmla.f32 q14, q9, %e[w2][0] \n" - - "vld1.32 {d21[1]}, [%[din_ptr1]] \n" - "vld2.32 {d16-d19}, [%[din_ptr2]]! \n" - "sub %[din_ptr1], #8 \n" - - "vmla.f32 q13, q6, %e[w1][1] \n" - "vmla.f32 q14, q7, %f[w1][0] \n" - "vmla.f32 q13, q10, %e[w2][1] \n" - - "vext.32 q6, q15, q8, #3 \n" - "vext.32 q7, q15, q9, #3 \n" - "vext.32 q10, q8, q15, #1 \n" - - // r2 - "vmla.f32 q13, q8, %e[w3][0] \n" - "vmla.f32 q14, q9, %e[w3][1] \n" - - "vld1.32 {d21[1]}, [%[din_ptr2]] \n" - "vld2.32 {d16-d19}, [%[din_ptr3]]! \n" - "sub %[din_ptr2], #8 \n" - - "vmla.f32 q13, q6, %f[w2][0] \n" - "vmla.f32 q14, q7, %f[w2][1] \n" - "vmla.f32 q13, q10, %f[w3][0] \n" - - "vext.32 q6, q15, q8, #3 \n" - "vext.32 q7, q15, q9, #3 \n" - "vext.32 q10, q8, q15, #1 \n" - - // r3 - "vmla.f32 q13, q8, %e[w4][1] \n" - "vmla.f32 q14, q9, %f[w4][0] \n" - - "vld1.32 {d21[1]}, [%[din_ptr3]] \n" - "vld2.32 {d16-d19}, [%[din_ptr4]]! \n" - "sub %[din_ptr3], #8 \n" - - "vmla.f32 q13, q6, %f[w3][1] \n" - "vmla.f32 q14, q7, %e[w4][0] \n" - "vmla.f32 q13, q10, %f[w4][1] \n" - - "vext.32 q6, q15, q8, #3 \n" - "vext.32 q7, q15, q9, #3 \n" - "vext.32 q10, q8, q15, #1 \n" - - // r4 - "vmla.f32 q13, q6, %e[w5][0] \n" - "vmla.f32 q14, q7, %e[w5][1] \n" - - "vld1.32 {d21[1]}, [%[din_ptr4]] \n" - "vld2.32 {d12-d15}, [%[din_ptr0]], %[s_8] \n" - "sub %[din_ptr4], #8 \n" - - "vmla.f32 q13, q8, %f[w5][0] \n" - "vmla.f32 q14, q9, %f[w5][1] \n" - - "vld2.32 {d16-d19}, [%[din_ptr0]], %[s_8] \n" - - "vmov.32 q12, %q[w0] \n" - "vld1.32 {%e[w0][0]}, [%[weights]] \n" - "vmla.f32 q13, q10, %e[w0][0] \n" - "vadd.f32 q13, q13, q14 \n" - "vmov.f32 %q[w0], q12 \n" - "vmax.f32 q13, q13, q15 \n" - "cmp %[mid_cnt], #1 \n" - "vld2.32 {d20-d23}, [%[din_ptr0]], %[s_16] \n" - "vst1.32 {d26-d27}, [%[dout_ptr0]]! \n" - "pld [%[din_ptr0]] \n" - "blt 2f \n" - - // mid - "1: \n" - "vld1.32 {d26-d27}, [%[vbias]] \n" - "vmov.32 q14, q15 \n" - - // r0 - "vmla.f32 q13, q6, %e[w0][0] \n" - "vmla.f32 q14, q7, %e[w0][1] \n" - - "vld2.32 {d12-d15}, [%[din_ptr1]], %[s_8] \n" - - "vmla.f32 q13, q8, %f[w0][0] \n" - "vmla.f32 q14, q9, %f[w0][1] \n" - - "vld2.32 {d16-d19}, [%[din_ptr1]], %[s_8] \n" - - "vmla.f32 q13, q10, %e[w1][0] \n" - - "vld2.32 {d20-d23}, [%[din_ptr1]], %[s_16] \n" - - // r1 - "vmla.f32 q13, q6, %e[w1][1] \n" - "vmla.f32 q14, q7, %f[w1][0] \n" - "pld [%[din_ptr1]] \n" - - "vld2.32 {d12-d15}, [%[din_ptr2]], %[s_8] \n" - - "vmla.f32 q13, q8, %f[w1][1] \n" - "vmla.f32 q14, q9, %e[w2][0] \n" - - "vld2.32 {d16-d19}, [%[din_ptr2]], %[s_8] \n" - - "vmla.f32 q13, q10, %e[w2][1] \n" - - "vld2.32 {d20-d23}, [%[din_ptr2]], %[s_16] \n" - - // r2 - "vmla.f32 q13, q6, %f[w2][0] \n" - "vmla.f32 q14, q7, %f[w2][1] \n" - "pld [%[din_ptr2]] \n" - - "vld2.32 {d12-d15}, [%[din_ptr3]], %[s_8] \n" - - "vmla.f32 q13, q8, %e[w3][0] \n" - "vmla.f32 q14, q9, %e[w3][1] \n" - - "vld2.32 {d16-d19}, [%[din_ptr3]], %[s_8] \n" - - "vmla.f32 q13, q10, %f[w3][0] \n" - - "vld2.32 {d20-d23}, [%[din_ptr3]], %[s_16] \n" - - // r3 - "vmla.f32 q13, q6, %f[w3][1] \n" - "vmla.f32 q14, q7, %e[w4][0] \n" - "pld [%[din_ptr3]] \n" - - "vld2.32 {d12-d15}, [%[din_ptr4]], %[s_8] \n" - - "vmla.f32 q13, q8, %e[w4][1] \n" - "vmla.f32 q14, q9, %f[w4][0] \n" - - "vld2.32 {d16-d19}, [%[din_ptr4]], %[s_8] \n" - - "vmla.f32 q13, q10, %f[w4][1] \n" - - "vld2.32 {d20-d23}, [%[din_ptr4]], %[s_16] \n" - - // r4 - "vmla.f32 q13, q6, %e[w5][0] \n" - "vmla.f32 q14, q7, %e[w5][1] \n" - "pld [%[din_ptr4]] \n" - - "vld2.32 {d12-d15}, [%[din_ptr0]], %[s_8] \n" - "vld1.32 {%e[w0][0]}, [%[weights]] \n" - - "vmla.f32 q13, q8, %f[w5][0] \n" - "vmla.f32 q14, q9, %f[w5][1] \n" - - "vld2.32 {d16-d19}, [%[din_ptr0]], %[s_8] \n" - - "vmla.f32 q13, q10, %e[w0][0] \n" - - "vld2.32 {d20-d23}, [%[din_ptr0]], %[s_16] \n" - - "vmov.32 %q[w0], q12 \n" - "vadd.f32 q13, q13, q14 \n" - "vmax.f32 q13, q13, q15 \n" - "subs %[mid_cnt], #1 \n" - "vst1.32 {d26-d27}, [%[dout_ptr0]]! \n" - "bne 1b \n" - - "2: \n" - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vld1.32 {d26-d27}, [%[vbias]] \n" - "vmov.32 q14, q15 \n" - - // r0 - "vbif.32 q6, q15, q11 \n" - "vbif.32 q7, q15, q12 \n" - "vmla.f32 q13, q6, %e[w0][0] \n" - "vmla.f32 q14, q7, %e[w0][1] \n" - - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vld2.32 {d12-d15}, [%[din_ptr1]], %[s_8] \n" - "vbif.32 q8, q15, q11 \n" - "vbif.32 q9, q15, q12 \n" - "vmla.f32 q13, q8, %f[w0][0] \n" - "vmla.f32 q14, q9, %f[w0][1] \n" - - "vld2.32 {d22-d25}, [%[mask]] \n" - "sub %[mask], #16 \n" - "vld2.32 {d16-d19}, [%[din_ptr1]], %[s_8] \n" - "vbif.32 q10, q15, q11 \n" - "vmla.f32 q13, q10, %e[w1][0] \n" - - // r1 - "vld2.32 {d20-d23}, [%[din_ptr1]] \n" - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vbif.32 q6, q15, q11 \n" - "vbif.32 q7, q15, q12 \n" - "vmla.f32 q13, q6, %e[w1][1] \n" - "vmla.f32 q14, q7, %f[w1][0] \n" - - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vld2.32 {d12-d15}, [%[din_ptr2]], %[s_8] \n" - "vbif.32 q8, q15, q11 \n" - "vbif.32 q9, q15, q12 \n" - "vmla.f32 q13, q8, %f[w1][1] \n" - "vmla.f32 q14, q9, %e[w2][0] \n" - - "vld2.32 {d22-d25}, [%[mask]] \n" - "sub %[mask], #16 \n" - "vld2.32 {d16-d19}, [%[din_ptr2]], %[s_8] \n" - "vbif.32 q10, q15, q11 \n" - "vmla.f32 q13, q10, %e[w2][1] \n" - - // r2 - "vld2.32 {d20-d23}, [%[din_ptr2]] \n" - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vbif.32 q6, q15, q11 \n" - "vbif.32 q7, q15, q12 \n" - "vmla.f32 q13, q6, %f[w2][0] \n" - "vmla.f32 q14, q7, %f[w2][1] \n" - - "vld2.32 {d12-d15}, [%[din_ptr3]], %[s_8] \n" - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vbif.32 q8, q15, q11 \n" - "vbif.32 q9, q15, q12 \n" - "vmla.f32 q13, q8, %e[w3][0] \n" - "vmla.f32 q14, q9, %e[w3][1] \n" - - "vld2.32 {d22-d25}, [%[mask]] \n" - "sub %[mask], #16 \n" - "vld2.32 {d16-d19}, [%[din_ptr3]], %[s_8] \n" - "vbif.32 q10, q15, q11 \n" - "vmla.f32 q13, q10, %f[w3][0] \n" - - // r3 - "vld2.32 {d20-d23}, [%[din_ptr3]] \n" - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vbif.32 q6, q15, q11 \n" - "vbif.32 q7, q15, q12 \n" - "vmla.f32 q13, q6, %f[w3][1] \n" - "vmla.f32 q14, q7, %e[w4][0] \n" - - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vld2.32 {d12-d15}, [%[din_ptr4]], %[s_8] \n" - "vbif.32 q8, q15, q11 \n" - "vbif.32 q9, q15, q12 \n" - "vmla.f32 q13, q8, %e[w4][1] \n" - "vmla.f32 q14, q9, %f[w4][0] \n" - - "vld2.32 {d22-d25}, [%[mask]] \n" - "sub %[mask], #16 \n" - "vld2.32 {d16-d19}, [%[din_ptr4]], %[s_8] \n" - "vbif.32 q10, q15, q11 \n" - "vmla.f32 q13, q10, %f[w4][1] \n" - - // r4 - "vld2.32 {d20-d23}, [%[din_ptr4]] \n" - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vbif.32 q6, q15, q11 \n" - "vbif.32 q7, q15, q12 \n" - "vmla.f32 q13, q6, %e[w5][0] \n" - "vmla.f32 q14, q7, %e[w5][1] \n" - - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vld1.32 {d12[0]}, [%[weights]] \n" - "vbif.32 q8, q15, q11 \n" - "vbif.32 q9, q15, q12 \n" - "vmla.f32 q13, q8, %f[w5][0] \n" - "vmla.f32 q14, q9, %f[w5][1] \n" - - "vld2.32 {d22-d25}, [%[mask]] \n" - "vbif.32 q10, q15, q11 \n" - "vmla.f32 q13, q10, d12[0] \n" - - "vadd.f32 q13, q13, q14 \n" - "vmax.f32 q13, q13, q15 \n" - "vst1.32 {d26-d27}, [%[out_buf0]] \n" - - : [dout_ptr0] "+r"(dout_ptr0), - [mid_cnt] "+r"(loop), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [din_ptr4] "+r"(din_ptr4), - [mask] "+r"(mask_ptr), - [weights] "+r"(weights_ptr) - : [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [w5] "w"(w5), - [vbias] "r"(vbias), - [out_buf0] "r"(out_buf0), - [s_8] "r"(s_8), - [s_16] "r"(s_16) - : "memory", - "cc", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - - int remain_cnt = w_out - (mid_cnt + 1) * 4; - for (int i = 0; i < remain_cnt; ++i) { - dout_ptr0[i] = out_buf0[i]; - } - - din0 = din2; - din1 = din3; - din2 = din4; - din3 = din2 + w_in; - din4 = din3 + w_in; - dout0 += w_out; - } - } - } -} - -//! small depthwise, win < 9; -void conv_depthwise_5x5s2p2_s(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - bool flag_bias, - bool flag_relu, - ARMContext* ctx) { - CHECK_LT(w_in, 9) << "only support win < 9"; - int w_out_round = (w_out + 3) / 4 * 4; - int mask_cnt = 12 - w_in - 2; - int mask[12]; - memset(mask, 0xff, 12 * sizeof(int)); - for (int i = 0; i < mask_cnt; ++i) { - mask[11 - i] = 0; - } - float* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - int in_spatial_size = w_in * h_in; - int out_spatial_size = w_out * h_out; - int weights_saptial_size = 25; - - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * in_spatial_size * ch_in; - float* dout_batch = dout + n * out_spatial_size * ch_out; -#pragma omp parallel for - for (int c = 0; c < ch_in; ++c) { - const float* din_ch = din_batch + c * in_spatial_size; - float* dout_ch = dout_batch + c * out_spatial_size; - const float* din0 = zero_ptr; - const float* din1 = zero_ptr; - const float* din2 = din_ch; - const float* din3 = din2 + w_in; - const float* din4 = din3 + w_in; - - float out_buf0[4]; - float out_buf1[4]; - float* dout0 = dout_ch; - float* dout1 = dout0 + w_out; - - const float* weights_c = weights + c * weights_saptial_size; - float32x4_t w0 = vld1q_f32(weights_c); - float32x4_t w1 = vld1q_f32(weights_c + 4); - float32x4_t w2 = vld1q_f32(weights_c + 8); - float32x4_t w3 = vld1q_f32(weights_c + 12); - float32x4_t w4 = vld1q_f32(weights_c + 16); - float32x4_t w5 = vld1q_f32(weights_c + 20); - for (int h = 0; h < h_out; h += 1) { - //! (h * 2 - 2) + 4 > h_in - 1 - if (h * 2 + 3 > h_in) { - switch (h * 2 + 3 - h_in) { - case 4: - din1 = zero_ptr; - case 3: - din2 = zero_ptr; - case 2: - din3 = zero_ptr; - case 1: - din4 = zero_ptr; - default: - break; - } - } - const float* din_ptr0 = din0; - const float* din_ptr1 = din1; - const float* din_ptr2 = din2; - const float* din_ptr3 = din3; - const float* din_ptr4 = din4; - - const float* weights_ptr = weights_c + 24; - float* dout_ptr0 = dout0; - - float bias_c = 0.f; - if (flag_bias) { - bias_c = bias[c]; - } - float vbias[4] = {bias_c, bias_c, bias_c, bias_c}; - int* mask_ptr = mask; - const int s_8 = 8; - - asm volatile( - "vmov.i32 q15, #0x0 \n" - "pld [%[din_ptr0]] \n" - "pld [%[din_ptr1]] \n" - "pld [%[din_ptr2]] \n" - "pld [%[din_ptr3]] \n" - "pld [%[din_ptr4]] \n" - "vld1.32 {d26-d27}, [%[vbias]] \n" - "vmov.32 q14, q15 \n" - "vld2.32 {d16-d19}, [%[din_ptr0]]! \n" - - // r0 - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vext.32 q6, q15, q8, #3 \n" - "vext.32 q7, q15, q9, #3 \n" - "vext.32 q10, q8, q15, #1 \n" - "vld1.32 {d21[1]}, [%[din_ptr0]] \n" - - "vbif.32 q6, q15, q11 \n" - "vbif.32 q7, q15, q12 \n" - "vmla.f32 q13, q6, %e[w0][0] \n" - "vmla.f32 q14, q7, %e[w0][1] \n" - - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vbif.32 q8, q15, q11 \n" - "vbif.32 q9, q15, q12 \n" - "vmla.f32 q13, q8, %f[w0][0] \n" - "vmla.f32 q14, q9, %f[w0][1] \n" - - "vld2.32 {d22-d25}, [%[mask]] \n" - "vld2.32 {d16-d19}, [%[din_ptr1]]! \n" - "sub %[mask], #16 \n" - "vbif.32 q10, q15, q11 \n" - "vmla.f32 q13, q10, %e[w1][0] \n" - - // r1 - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vext.32 q6, q15, q8, #3 \n" - "vext.32 q7, q15, q9, #3 \n" - "vext.32 q10, q8, q15, #1 \n" - "vld1.32 {d21[1]}, [%[din_ptr1]] \n" - - "vbif.32 q6, q15, q11 \n" - "vbif.32 q7, q15, q12 \n" - "vmla.f32 q14, q6, %e[w1][1] \n" - "vmla.f32 q13, q7, %f[w1][0] \n" - - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vbif.32 q8, q15, q11 \n" - "vbif.32 q9, q15, q12 \n" - "vmla.f32 q14, q8, %f[w1][1] \n" - "vmla.f32 q13, q9, %e[w2][0] \n" - - "vld2.32 {d22-d25}, [%[mask]] \n" - "vld2.32 {d16-d19}, [%[din_ptr2]]! \n" - "sub %[mask], #16 \n" - "vbif.32 q10, q15, q11 \n" - "vmla.f32 q14, q10, %e[w2][1] \n" - - // r2 - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vext.32 q6, q15, q8, #3 \n" - "vext.32 q7, q15, q9, #3 \n" - "vext.32 q10, q8, q15, #1 \n" - "vld1.32 {d21[1]}, [%[din_ptr2]] \n" - - "vbif.32 q6, q15, q11 \n" - "vbif.32 q7, q15, q12 \n" - "vmla.f32 q13, q6, %f[w2][0] \n" - "vmla.f32 q14, q7, %f[w2][1] \n" - - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vbif.32 q8, q15, q11 \n" - "vbif.32 q9, q15, q12 \n" - "vmla.f32 q13, q8, %e[w3][0] \n" - "vmla.f32 q14, q9, %e[w3][1] \n" - - "vld2.32 {d22-d25}, [%[mask]] \n" - "vld2.32 {d16-d19}, [%[din_ptr3]]! \n" - "sub %[mask], #16 \n" - "vbif.32 q10, q15, q11 \n" - "vmla.f32 q13, q10, %f[w3][0] \n" - - // r3 - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vext.32 q6, q15, q8, #3 \n" - "vext.32 q7, q15, q9, #3 \n" - "vext.32 q10, q8, q15, #1 \n" - "vld1.32 {d21[1]}, [%[din_ptr3]] \n" - - "vbif.32 q6, q15, q11 \n" - "vbif.32 q7, q15, q12 \n" - "vmla.f32 q14, q6, %f[w3][1] \n" - "vmla.f32 q13, q7, %e[w4][0] \n" - - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vbif.32 q8, q15, q11 \n" - "vbif.32 q9, q15, q12 \n" - "vmla.f32 q14, q8, %e[w4][1] \n" - "vmla.f32 q13, q9, %f[w4][0] \n" - - "vld2.32 {d22-d25}, [%[mask]] \n" - "vld2.32 {d16-d19}, [%[din_ptr4]]! \n" - "sub %[mask], #16 \n" - "vbif.32 q10, q15, q11 \n" - "vmla.f32 q14, q10, %f[w4][1] \n" - - // r4 - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vext.32 q6, q15, q8, #3 \n" - "vext.32 q7, q15, q9, #3 \n" - "vext.32 q10, q8, q15, #1 \n" - "vld1.32 {d21[1]}, [%[din_ptr4]] \n" - - "vbif.32 q6, q15, q11 \n" - "vbif.32 q7, q15, q12 \n" - "vmla.f32 q13, q6, %e[w5][0] \n" - "vmla.f32 q14, q7, %e[w5][1] \n" - - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vld1.32 {d12[0]}, [%[weights]] \n" - "vbif.32 q8, q15, q11 \n" - "vbif.32 q9, q15, q12 \n" - "vmla.f32 q13, q8, %f[w5][0] \n" - "vmla.f32 q14, q9, %f[w5][1] \n" - - "vld2.32 {d22-d25}, [%[mask]] \n" - "vbif.32 q10, q15, q11 \n" - "vmla.f32 q13, q10, d12[0] \n" - - "vadd.f32 q13, q13, q14 \n" - "vst1.32 {d26-d27}, [%[out_buf0]] \n" - - : [dout_ptr0] "+r"(dout_ptr0), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [din_ptr4] "+r"(din_ptr4), - [mask] "+r"(mask_ptr), - [weights] "+r"(weights_ptr) - : [vbias] "r"(vbias), - [out_buf0] "r"(out_buf0), - [s_8] "r"(s_8), - [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [w5] "w"(w5) - : "memory", - "cc", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - for (int i = 0; i < w_out; ++i) { - dout_ptr0[i] = out_buf0[i]; - } - din0 = din2; - din1 = din3; - din2 = din4; - din3 = din2 + w_in; - din4 = din3 + w_in; - dout0 += w_out; - } - } - } -} - -//! small depthwise, win < 9; -void conv_depthwise_5x5s2p2_relu_s(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - bool flag_bias, - bool flag_relu, - ARMContext* ctx) { - CHECK_LT(w_in, 9) << "only support win < 9\n"; - int w_out_round = (w_out + 3) / 4 * 4; - int mask_cnt = 12 - w_in - 2; - int mask[12]; - memset(mask, 0xff, 12 * sizeof(int)); - for (int i = 0; i < mask_cnt; ++i) { - mask[11 - i] = 0; - } - float* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - int in_spatial_size = w_in * h_in; - int out_spatial_size = w_out * h_out; - int weights_saptial_size = 25; - - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * in_spatial_size * ch_in; - float* dout_batch = dout + n * out_spatial_size * ch_out; -#pragma omp parallel for - for (int c = 0; c < ch_in; ++c) { - const float* din_ch = din_batch + c * in_spatial_size; - float* dout_ch = dout_batch + c * out_spatial_size; - const float* din0 = zero_ptr; - const float* din1 = zero_ptr; - const float* din2 = din_ch; - const float* din3 = din2 + w_in; - const float* din4 = din3 + w_in; - - float out_buf0[4]; - float out_buf1[4]; - float* dout0 = dout_ch; - float* dout1 = dout0 + w_out; - - const float* weights_c = weights + c * weights_saptial_size; - float32x4_t w0 = vld1q_f32(weights_c); - float32x4_t w1 = vld1q_f32(weights_c + 4); - float32x4_t w2 = vld1q_f32(weights_c + 8); - float32x4_t w3 = vld1q_f32(weights_c + 12); - float32x4_t w4 = vld1q_f32(weights_c + 16); - float32x4_t w5 = vld1q_f32(weights_c + 20); - for (int h = 0; h < h_out; h += 1) { - //! (h * 2 - 2) + 4 > h_in - 1 - if (h * 2 + 3 > h_in) { - switch (h * 2 + 3 - h_in) { - case 4: - din1 = zero_ptr; - case 3: - din2 = zero_ptr; - case 2: - din3 = zero_ptr; - case 1: - din4 = zero_ptr; - default: - break; + act_switch_5x5s2(inr0, + inr1, + inr2, + inr3, + inr4, + outc0, + outc1, + outc2, + outc3, + vzero, + vzero, + vzero, + vzero, + vzero, + vzero, + weight_c, + bias_local, + act_param); +#endif + if (flag_mask) { + for (int i = 0; i < remain; ++i) { + c0[i] = pre_out[i]; + c1[i] = pre_out[i + 4]; + c2[i] = pre_out[i + 8]; + c3[i] = pre_out[i + 12]; + } } + inr0 += 32; + inr1 += 32; + inr2 += 32; + inr3 += 32; + inr4 += 32; + outc0 += 4; + outc1 += 4; + outc2 += 4; + outc3 += 4; } - const float* din_ptr0 = din0; - const float* din_ptr1 = din1; - const float* din_ptr2 = din2; - const float* din_ptr3 = din3; - const float* din_ptr4 = din4; - - const float* weights_ptr = weights_c + 24; - float* dout_ptr0 = dout0; - - float bias_c = 0.f; - if (flag_bias) { - bias_c = bias[c]; - } - float vbias[4] = {bias_c, bias_c, bias_c, bias_c}; - int* mask_ptr = mask; - const int s_8 = 8; - - asm volatile( - "vmov.i32 q15, #0x0 \n" - "pld [%[din_ptr0]] \n" - "pld [%[din_ptr1]] \n" - "pld [%[din_ptr2]] \n" - "pld [%[din_ptr3]] \n" - "pld [%[din_ptr4]] \n" - "vld1.32 {d26-d27}, [%[vbias]] \n" - "vmov.32 q14, q15 \n" - "vld2.32 {d16-d19}, [%[din_ptr0]]! \n" - - // r0 - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vext.32 q6, q15, q8, #3 \n" - "vext.32 q7, q15, q9, #3 \n" - "vext.32 q10, q8, q15, #1 \n" - "vld1.32 {d21[1]}, [%[din_ptr0]] \n" - - "vbif.32 q6, q15, q11 \n" - "vbif.32 q7, q15, q12 \n" - "vmla.f32 q13, q6, %e[w0][0] \n" - "vmla.f32 q14, q7, %e[w0][1] \n" - - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vbif.32 q8, q15, q11 \n" - "vbif.32 q9, q15, q12 \n" - "vmla.f32 q13, q8, %f[w0][0] \n" - "vmla.f32 q14, q9, %f[w0][1] \n" - - "vld2.32 {d22-d25}, [%[mask]] \n" - "vld2.32 {d16-d19}, [%[din_ptr1]]! \n" - "sub %[mask], #16 \n" - "vbif.32 q10, q15, q11 \n" - "vmla.f32 q13, q10, %e[w1][0] \n" - - // r1 - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vext.32 q6, q15, q8, #3 \n" - "vext.32 q7, q15, q9, #3 \n" - "vext.32 q10, q8, q15, #1 \n" - "vld1.32 {d21[1]}, [%[din_ptr1]] \n" - - "vbif.32 q6, q15, q11 \n" - "vbif.32 q7, q15, q12 \n" - "vmla.f32 q14, q6, %e[w1][1] \n" - "vmla.f32 q13, q7, %f[w1][0] \n" - - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vbif.32 q8, q15, q11 \n" - "vbif.32 q9, q15, q12 \n" - "vmla.f32 q14, q8, %f[w1][1] \n" - "vmla.f32 q13, q9, %e[w2][0] \n" - - "vld2.32 {d22-d25}, [%[mask]] \n" - "vld2.32 {d16-d19}, [%[din_ptr2]]! \n" - "sub %[mask], #16 \n" - "vbif.32 q10, q15, q11 \n" - "vmla.f32 q14, q10, %e[w2][1] \n" - - // r2 - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vext.32 q6, q15, q8, #3 \n" - "vext.32 q7, q15, q9, #3 \n" - "vext.32 q10, q8, q15, #1 \n" - "vld1.32 {d21[1]}, [%[din_ptr2]] \n" - - "vbif.32 q6, q15, q11 \n" - "vbif.32 q7, q15, q12 \n" - "vmla.f32 q13, q6, %f[w2][0] \n" - "vmla.f32 q14, q7, %f[w2][1] \n" - - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vbif.32 q8, q15, q11 \n" - "vbif.32 q9, q15, q12 \n" - "vmla.f32 q13, q8, %e[w3][0] \n" - "vmla.f32 q14, q9, %e[w3][1] \n" - - "vld2.32 {d22-d25}, [%[mask]] \n" - "vld2.32 {d16-d19}, [%[din_ptr3]]! \n" - "sub %[mask], #16 \n" - "vbif.32 q10, q15, q11 \n" - "vmla.f32 q13, q10, %f[w3][0] \n" - - // r3 - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vext.32 q6, q15, q8, #3 \n" - "vext.32 q7, q15, q9, #3 \n" - "vext.32 q10, q8, q15, #1 \n" - "vld1.32 {d21[1]}, [%[din_ptr3]] \n" - - "vbif.32 q6, q15, q11 \n" - "vbif.32 q7, q15, q12 \n" - "vmla.f32 q14, q6, %f[w3][1] \n" - "vmla.f32 q13, q7, %e[w4][0] \n" - - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vbif.32 q8, q15, q11 \n" - "vbif.32 q9, q15, q12 \n" - "vmla.f32 q14, q8, %e[w4][1] \n" - "vmla.f32 q13, q9, %f[w4][0] \n" - - "vld2.32 {d22-d25}, [%[mask]] \n" - "vld2.32 {d16-d19}, [%[din_ptr4]]! \n" - "sub %[mask], #16 \n" - "vbif.32 q10, q15, q11 \n" - "vmla.f32 q14, q10, %f[w4][1] \n" - - // r4 - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vext.32 q6, q15, q8, #3 \n" - "vext.32 q7, q15, q9, #3 \n" - "vext.32 q10, q8, q15, #1 \n" - "vld1.32 {d21[1]}, [%[din_ptr4]] \n" - - "vbif.32 q6, q15, q11 \n" - "vbif.32 q7, q15, q12 \n" - "vmla.f32 q13, q6, %e[w5][0] \n" - "vmla.f32 q14, q7, %e[w5][1] \n" - - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vld1.32 {d12[0]}, [%[weights]] \n" - "vbif.32 q8, q15, q11 \n" - "vbif.32 q9, q15, q12 \n" - "vmla.f32 q13, q8, %f[w5][0] \n" - "vmla.f32 q14, q9, %f[w5][1] \n" - - "vld2.32 {d22-d25}, [%[mask]] \n" - "vbif.32 q10, q15, q11 \n" - "vmla.f32 q13, q10, d12[0] \n" - - "vadd.f32 q13, q13, q14 \n" - "vmax.f32 q13, q13, q15 \n" - "vst1.32 {d26-d27}, [%[out_buf0]] \n" - - : [dout_ptr0] "+r"(dout_ptr0), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [din_ptr4] "+r"(din_ptr4), - [mask] "+r"(mask_ptr), - [weights] "+r"(weights_ptr) - : [vbias] "r"(vbias), - [out_buf0] "r"(out_buf0), - [s_8] "r"(s_8), - [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [w5] "w"(w5) - : "memory", - "cc", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - for (int i = 0; i < w_out; ++i) { - dout_ptr0[i] = out_buf0[i]; - } - din0 = din2; - din1 = din3; - din2 = din4; - din3 = din2 + w_in; - din4 = din3 + w_in; - dout0 += w_out; } } } } -#endif // __aarch64__ } // namespace math } // namespace arm diff --git a/lite/backends/arm/math/conv5x5s2_depthwise_int8.cc b/lite/backends/arm/math/conv5x5s2_depthwise_int8.cc new file mode 100644 index 0000000000000000000000000000000000000000..c778896550de73f888979c8337731a0b9967b5dd --- /dev/null +++ b/lite/backends/arm/math/conv5x5s2_depthwise_int8.cc @@ -0,0 +1,795 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/backends/arm/math/conv_block_utils.h" +#include "lite/backends/arm/math/conv_depthwise.h" +#include "lite/backends/arm/math/conv_impl.h" +#include "lite/core/context.h" +#include "lite/operators/op_params.h" +#ifdef ARM_WITH_OMP +#include +#endif + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +#define ROUNDUP(a, b) ((((a) + (b)-1) / (b)) * (b)) + +template +void conv_depthwise_5x5s2_int8(Dtype* dout, + const int8_t* din, + const int8_t* weights, + const float* scale, + const float* bias, + bool flag_bias, + bool flag_relu, + int num, + int chin, + int hin, + int win, + int hout, + int wout, + int padw, + int padh, + ARMContext* ctx) { + const int threads = ctx->threads(); + int llc_size = ctx->llc_size() / 4; + + const int hout_c_block = 8; + const int hout_r_kernel = 1; + const int wout_block = 4; + const int wout_round = ((wout + wout_block - 1) / wout_block) * wout_block; + const int win_round = wout_round * 2 + 3; + + //! get h block + //! llc_size = threads * win_round * hout_c_block * hin_r_block * + //! sizeof(int8_t) + //! + wout_round * hout_c_block * hout_r_block * threads * sizeof(int32_t) + //! win_round = wout_round * 2 + 3 + //! hin_r_block = hout_r_block * 2 + 3 + int hout_r_block = (llc_size - 3 * win_round * hout_c_block * threads) / + (2 * win_round * hout_c_block * threads + + hout_c_block * wout_round * threads * 4); + hout_r_block = hout_r_block > hout ? hout : hout_r_block; + hout_r_block = + ((hout_r_block + hout_r_kernel - 1) / hout_r_kernel) * hout_r_kernel; + hout_r_block = hout_r_block < hout_r_kernel ? hout_r_kernel : hout_r_block; + + const int hin_r_block = hout_r_block * 2 + 3; + + auto tmp_work_space = ctx->workspace_data(); + int8_t ptr_zero[win_round]; // NOLINT + memset(ptr_zero, 0, sizeof(int8_t) * win_round); + Dtype ptr_write[wout_round]; // NOLINT + + int in_len = win_round * hout_c_block; + int pre_in_size = hin_r_block * in_len; + pre_in_size = ROUNDUP(pre_in_size, 4); + int pre_out_size = hout_c_block * hout_r_block * wout_round; + + int8_t* tmp_din = tmp_work_space; + + int size_in_channel = win * hin; + int size_out_channel = wout * hout; + int w_stride = 25; // kernel_w * kernel_h; + + int ws = -padw; + int we = ws + win_round; + int w_loop = wout_round / 4; + int chout = chin; + + int out_row_stride = hout_c_block * wout_round; + for (int n = 0; n < num; ++n) { + const int8_t* din_batch = din + n * chin * size_in_channel; + int8_t* dout_batch = reinterpret_cast(dout) + + n * chout * size_out_channel * sizeof(Dtype); + for (int h = 0; h < hout; h += hout_r_block) { + int h_kernel = hout_r_block; + if (h + hout_r_block > hout) { + h_kernel = hout - h; + } + int hs = h - padh; + int he = hs + h_kernel * 2 + 3; + +#pragma omp parallel for num_threads(threads) + for (int c = 0; c < chout; c += hout_c_block) { +#ifdef ARM_WITH_OMP + int8_t* pre_din = + tmp_din + omp_get_thread_num() * (pre_in_size + pre_out_size * 4); + int32_t* pre_out = reinterpret_cast(pre_din + pre_in_size); +#else + int32_t* pre_out = reinterpret_cast(tmp_din + pre_in_size); + auto pre_din = tmp_din; +#endif + prepack_input_nxwc8_int8_dw( + din_batch, pre_din, c, hs, he, ws, we, chin, win, hin); + + const int8_t* block_inr0 = pre_din; + const int8_t* block_inr1 = block_inr0 + in_len; + const int8_t* block_inr2 = block_inr1 + in_len; + const int8_t* block_inr3 = block_inr2 + in_len; + const int8_t* block_inr4 = block_inr3 + in_len; + + const int8_t* weight_c = weights + c * w_stride; + float bias_local[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + if (flag_bias) { + bias_local[0] = bias[c]; + bias_local[1] = bias[c + 1]; + bias_local[2] = bias[c + 2]; + bias_local[3] = bias[c + 3]; + bias_local[4] = bias[c + 4]; + bias_local[5] = bias[c + 5]; + bias_local[6] = bias[c + 6]; + bias_local[7] = bias[c + 7]; + } + for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) { + int cnt = w_loop; + const int8_t* inr0 = block_inr0; + const int8_t* inr1 = block_inr1; + const int8_t* inr2 = block_inr2; + const int8_t* inr3 = block_inr3; + const int8_t* inr4 = block_inr4; + + int32_t* ptr_out0 = pre_out + hk * out_row_stride; +// clang-format off +#ifdef __aarch64__ + auto wptr = weight_c; + asm volatile( + "ld1 {v0.8b, v1.8b, v2.8b, v3.8b}, [%[r0]], #32\n" /* load r0 0-3 */ + "ld1 {v4.8b, v5.8b, v6.8b, v7.8b}, [%[r0]], #32\n" /* load r0 4-7 */ + "ld1 {v12.8b, v13.8b, v14.8b, v15.8b}, [%[wc]], #32\n" /* load wc 0-3 */ + "1:\n" + /* in r0 */ + "smull v20.8h, v0.8b, v12.8b\n" /* w0, int16, out0 */ + "smull v21.8h, v2.8b, v12.8b\n" /* w0, int16, out1 */ + "smull v22.8h, v4.8b, v12.8b\n" /* w0, int16, out2 */ + "smull v23.8h, v6.8b, v12.8b\n" /* w0, int16, out3 */ + "ld1 {v8.8b, v9.8b, v10.8b, v11.8b}, [%[r0]]\n" /* load r0 8-11 */ + "smlal v20.8h, v1.8b, v13.8b\n" /* w1, int16, out0 */ + "smlal v21.8h, v3.8b, v13.8b\n" /* w1, int16, out1 */ + "smlal v22.8h, v5.8b, v13.8b\n" /* w1, int16, out2 */ + "smlal v23.8h, v7.8b, v13.8b\n" /* w1, int16, out3 */ + "sxtl v24.4s, v20.4h\n" /* mov to out0 low */ + "sxtl2 v25.4s, v20.8h\n" /* mov to out0 hig */ + "sxtl v26.4s, v21.4h\n" /* mov to out1 low */ + "sxtl2 v27.4s, v21.8h\n" /* mov to out1 hig */ + "sxtl v28.4s, v22.4h\n" /* mov to out2 low */ + "sxtl2 v29.4s, v22.8h\n" /* mov to out2 hig */ + "sxtl v30.4s, v23.4h\n" /* mov to out3 low */ + "sxtl2 v31.4s, v23.8h\n" /* mov to out3 hig */ + "ld1 {v16.8b, v17.8b, v18.8b, v19.8b}, [%[wc]], #32\n" /* load wc 4-7 */ + + "smull v20.8h, v2.8b, v14.8b\n" /* w2, int16, out0 */ + "smull v21.8h, v4.8b, v14.8b\n" /* w2, int16, out1 */ + "smull v22.8h, v6.8b, v14.8b\n" /* w2, int16, out2 */ + "smull v23.8h, v8.8b, v14.8b\n" /* w2, int16, out3 */ + "smlal v20.8h, v3.8b, v15.8b\n" /* w3, int16, out0 */ + "smlal v21.8h, v5.8b, v15.8b\n" /* w3, int16, out1 */ + "smlal v22.8h, v7.8b, v15.8b\n" /* w3, int16, out2 */ + "smlal v23.8h, v9.8b, v15.8b\n" /* w3, int16, out3 */ + "saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */ + "saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */ + "saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */ + "saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */ + "ld1 {v0.8b, v1.8b, v2.8b, v3.8b}, [%[r1]], #32\n" /* load r1 0-3 */ + "saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */ + "saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */ + "saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */ + "saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */ + + "smull v20.8h, v4.8b, v16.8b\n" /* w4, int16, out0 */ + "smull v21.8h, v6.8b, v16.8b\n" /* w4, int16, out1 */ + "smull v22.8h, v8.8b, v16.8b\n" /* w4, int16, out2 */ + "smull v23.8h, v10.8b, v16.8b\n" /* w4, int16, out3 */ + "ld1 {v4.8b, v5.8b, v6.8b, v7.8b}, [%[r1]], #32\n" /* load r1 4-7 */ + /* in r1 */ + "smlal v20.8h, v0.8b, v17.8b\n" /* w5, int16, out0 */ + "smlal v21.8h, v2.8b, v17.8b\n" /* w5, int16, out1 */ + "smlal v22.8h, v4.8b, v17.8b\n" /* w5, int16, out2 */ + "smlal v23.8h, v6.8b, v17.8b\n" /* w5, int16, out3 */ + "ld1 {v8.8b, v9.8b, v10.8b, v11.8b}, [%[r1]]\n" /* load r1 8-11 */ + "saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */ + "saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */ + "saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */ + "saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */ + "saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */ + "saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */ + "saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */ + "saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */ + + "smull v20.8h, v1.8b, v18.8b\n" /* w6, int16, out0 */ + "smull v21.8h, v3.8b, v18.8b\n" /* w6, int16, out1 */ + "smull v22.8h, v5.8b, v18.8b\n" /* w6, int16, out2 */ + "smull v23.8h, v7.8b, v18.8b\n" /* w6, int16, out3 */ + "ld1 {v12.8b, v13.8b, v14.8b, v15.8b}, [%[wc]], #32\n" /* load wc 8-11 */ + "smlal v20.8h, v2.8b, v19.8b\n" /* w7, int16, out0 */ + "smlal v21.8h, v4.8b, v19.8b\n" /* w7, int16, out1 */ + "smlal v22.8h, v6.8b, v19.8b\n" /* w7, int16, out2 */ + "smlal v23.8h, v8.8b, v19.8b\n" /* w7, int16, out3 */ + "saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */ + "saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */ + "saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */ + "saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */ + "ld1 {v16.8b, v17.8b, v18.8b, v19.8b}, [%[wc]], #32\n" /* load wc 12-15 */ + "saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */ + "saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */ + "saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */ + "saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */ + + "smull v20.8h, v3.8b, v12.8b\n" /* w8, int16, out0 */ + "smull v21.8h, v5.8b, v12.8b\n" /* w8, int16, out1 */ + "smull v22.8h, v7.8b, v12.8b\n" /* w8, int16, out2 */ + "smull v23.8h, v9.8b, v12.8b\n" /* w8, int16, out3 */ + "ld1 {v0.8b, v1.8b, v2.8b, v3.8b}, [%[r2]], #32\n" /* load r2 0-3 */ + "smlal v20.8h, v4.8b, v13.8b\n" /* w9, int16, out0 */ + "smlal v21.8h, v6.8b, v13.8b\n" /* w9, int16, out1 */ + "smlal v22.8h, v8.8b, v13.8b\n" /* w9, int16, out2 */ + "smlal v23.8h, v10.8b, v13.8b\n" /* w9, int16, out3 */ + "saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */ + "saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */ + "saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */ + "saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */ + "ld1 {v4.8b, v5.8b, v6.8b, v7.8b}, [%[r2]], #32\n" /* load r2 4-7 */ + "saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */ + "saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */ + "saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */ + "saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */ + + /* in r2 */ + "smull v20.8h, v0.8b, v14.8b\n" /* w10, int16, out0 */ + "smull v21.8h, v2.8b, v14.8b\n" /* w10, int16, out1 */ + "smull v22.8h, v4.8b, v14.8b\n" /* w10, int16, out2 */ + "smull v23.8h, v6.8b, v14.8b\n" /* w10, int16, out3 */ + "ld1 {v8.8b, v9.8b, v10.8b, v11.8b}, [%[r2]]\n" /* load r2 8-11 */ + "smlal v20.8h, v1.8b, v15.8b\n" /* w11, int16, out0 */ + "smlal v21.8h, v3.8b, v15.8b\n" /* w11, int16, out1 */ + "smlal v22.8h, v5.8b, v15.8b\n" /* w11, int16, out2 */ + "smlal v23.8h, v7.8b, v15.8b\n" /* w11, int16, out3 */ + + "saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */ + "saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */ + "saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */ + "saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */ + "ld1 {v12.8b, v13.8b, v14.8b, v15.8b}, [%[wc]], #32\n" /* load wc 16-19 */ + "saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */ + "saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */ + "saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */ + "saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */ + + "smull v20.8h, v2.8b, v16.8b\n" /* w12, int16, out0 */ + "smull v21.8h, v4.8b, v16.8b\n" /* w12, int16, out1 */ + "smull v22.8h, v6.8b, v16.8b\n" /* w12, int16, out2 */ + "smull v23.8h, v8.8b, v16.8b\n" /* w12, int16, out3 */ + "smlal v20.8h, v3.8b, v17.8b\n" /* w13, int16, out0 */ + "smlal v21.8h, v5.8b, v17.8b\n" /* w13, int16, out1 */ + "smlal v22.8h, v7.8b, v17.8b\n" /* w13, int16, out2 */ + "smlal v23.8h, v9.8b, v17.8b\n" /* w13, int16, out3 */ + "ld1 {v0.8b, v1.8b, v2.8b, v3.8b}, [%[r3]], #32\n" /* load r3 0-3 */ + "saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */ + "saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */ + "saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */ + "saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */ + "saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */ + "saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */ + "saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */ + "saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */ + "smull v20.8h, v4.8b, v18.8b\n" /* w14, int16, out0 */ + "smull v21.8h, v6.8b, v18.8b\n" /* w14, int16, out1 */ + "ld1 {v4.8b, v5.8b, v6.8b, v7.8b}, [%[r3]], #32\n" /* load r3 4-7 */ + "smull v22.8h, v8.8b, v18.8b\n" /* w14, int16, out2 */ + "smull v23.8h, v10.8b, v18.8b\n" /* w14, int16, out3 */ + /* in r3 */ + "smlal v20.8h, v0.8b, v19.8b\n" /* w15, int16, out0 */ + "smlal v21.8h, v2.8b, v19.8b\n" /* w15, int16, out1 */ + "smlal v22.8h, v4.8b, v19.8b\n" /* w15, int16, out2 */ + "smlal v23.8h, v6.8b, v19.8b\n" /* w15, int16, out3 */ + "saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */ + "saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */ + "saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */ + "saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */ + "ld1 {v8.8b, v9.8b, v10.8b, v11.8b}, [%[r3]]\n" /* load r3 8-11 */ + "saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */ + "saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */ + "saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */ + "saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */ + + "smull v20.8h, v1.8b, v12.8b\n" /* w16, int16, out0 */ + "smull v21.8h, v3.8b, v12.8b\n" /* w16, int16, out1 */ + "smull v22.8h, v5.8b, v12.8b\n" /* w16, int16, out2 */ + "smull v23.8h, v7.8b, v12.8b\n" /* w16, int16, out3 */ + "ld1 {v16.8b, v17.8b, v18.8b, v19.8b}, [%[wc]], #32\n" /* load wc 20-23 */ + "smlal v20.8h, v2.8b, v13.8b\n" /* w17, int16, out0 */ + "smlal v21.8h, v4.8b, v13.8b\n" /* w17, int16, out1 */ + "smlal v22.8h, v6.8b, v13.8b\n" /* w17, int16, out2 */ + "smlal v23.8h, v8.8b, v13.8b\n" /* w17, int16, out3 */ + "saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */ + "saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */ + "saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */ + "saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */ + "saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */ + "saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */ + "saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */ + "saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */ + + "smull v20.8h, v3.8b, v14.8b\n" /* w18, int16, out0 */ + "smull v21.8h, v5.8b, v14.8b\n" /* w18, int16, out1 */ + "smull v22.8h, v7.8b, v14.8b\n" /* w18, int16, out2 */ + "smull v23.8h, v9.8b, v14.8b\n" /* w18, int16, out3 */ + "ld1 {v0.8b, v1.8b, v2.8b, v3.8b}, [%[r4]], #32\n" /* load r4 0-3 */ + "smlal v20.8h, v4.8b, v15.8b\n" /* w19, int16, out0 */ + "smlal v21.8h, v6.8b, v15.8b\n" /* w19, int16, out1 */ + "smlal v22.8h, v8.8b, v15.8b\n" /* w19, int16, out2 */ + "smlal v23.8h, v10.8b, v15.8b\n" /* w19, int16, out3 */ + "saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */ + "saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */ + "saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */ + "saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */ + "ld1 {v4.8b, v5.8b, v6.8b, v7.8b}, [%[r4]], #32\n" /* load r4 4-7 */ + "saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */ + "saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */ + "saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */ + "saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */ + + /* in r4 */ + "smull v20.8h, v0.8b, v16.8b\n" /* w20, int16, out0 */ + "smull v21.8h, v2.8b, v16.8b\n" /* w20, int16, out1 */ + "smull v22.8h, v4.8b, v16.8b\n" /* w20, int16, out2 */ + "smull v23.8h, v6.8b, v16.8b\n" /* w20, int16, out3 */ + "ld1 {v8.8b, v9.8b, v10.8b, v11.8b}, [%[r4]]\n" /* load r4 8-11 */ + "smlal v20.8h, v1.8b, v17.8b\n" /* w21, int16, out0 */ + "smlal v21.8h, v3.8b, v17.8b\n" /* w21, int16, out1 */ + "smlal v22.8h, v5.8b, v17.8b\n" /* w21, int16, out2 */ + "smlal v23.8h, v7.8b, v17.8b\n" /* w21, int16, out3 */ + "saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */ + "saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */ + "saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */ + "saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */ + "saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */ + "saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */ + "saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */ + "saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */ + "ld1 {v16.8b}, [%[wc]], #8\n" /* load wc 24 */ + "smull v20.8h, v2.8b, v18.8b\n" /* w22, int16, out0 */ + "smull v21.8h, v4.8b, v18.8b\n" /* w22, int16, out1 */ + "smull v22.8h, v6.8b, v18.8b\n" /* w22, int16, out2 */ + "smull v23.8h, v8.8b, v18.8b\n" /* w22, int16, out3 */ + "sub %[wc], %[wc], #200 \n" + "smlal v20.8h, v3.8b, v19.8b\n" /* w23, int16, out0 */ + "smlal v21.8h, v5.8b, v19.8b\n" /* w23, int16, out1 */ + "smlal v22.8h, v7.8b, v19.8b\n" /* w23, int16, out2 */ + "smlal v23.8h, v9.8b, v19.8b\n" /* w23, int16, out3 */ + "ld1 {v0.8b, v1.8b, v2.8b, v3.8b}, [%[r0]], #32\n" /* load r0 0-3 */ + "saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */ + "saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */ + "saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */ + "saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */ + "saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */ + "saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */ + "ld1 {v12.8b, v13.8b, v14.8b, v15.8b}, [%[wc]], #32\n" /* load wc 0-3 */ + "saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */ + "saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */ + + "smull v20.8h, v4.8b, v16.8b\n" /* w24, int16, out0 */ + "smull v21.8h, v6.8b, v16.8b\n" /* w24, int16, out1 */ + "smull v22.8h, v8.8b, v16.8b\n" /* w24, int16, out2 */ + "smull v23.8h, v10.8b, v16.8b\n" /* w24, int16, out3 */ + "ld1 {v4.8b, v5.8b, v6.8b, v7.8b}, [%[r0]], #32\n" /* load r0 4-7 */ + "saddw v24.4s, v24.4s, v20.4h\n" /* add to out0 low */ + "saddw2 v25.4s, v25.4s, v20.8h\n" /* add to out0 hig */ + "saddw v26.4s, v26.4s, v21.4h\n" /* add to out1 low */ + "saddw2 v27.4s, v27.4s, v21.8h\n" /* add to out1 hig */ + "stp q24, q25, [%[ptr_out0]], #32\n" + "saddw v28.4s, v28.4s, v22.4h\n" /* add to out2 low */ + "saddw2 v29.4s, v29.4s, v22.8h\n" /* add to out2 hig */ + "stp q26, q27, [%[ptr_out0]], #32\n" + "saddw v30.4s, v30.4s, v23.4h\n" /* add to out3 low */ + "saddw2 v31.4s, v31.4s, v23.8h\n" /* add to out3 hig */ + "subs %w[cnt], %w[cnt], #1\n" + "stp q28, q29, [%[ptr_out0]], #32\n" + "stp q30, q31, [%[ptr_out0]], #32\n" + "bne 1b\n" + : [cnt] "+r"(cnt), + [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [r3] "+r"(inr3), + [r4] "+r"(inr4), + [wc] "+r"(wptr), + [ptr_out0] "+r"(ptr_out0) + : + : "cc","memory", + "v0","v1","v2","v3","v4","v5","v6","v7", + "v8","v9","v10","v11","v12","v13", + "v14","v15","v16","v17","v18","v19", + "v20","v21","v22","v23","v24","v25", + "v26","v27","v28","v29","v30","v31" + ); +#else + auto wptr = weight_c; + asm volatile( + "vld1.32 {d0-d3}, [%[r0]]!\n" /* load r0, 0-3 */ + "vld1.32 {d4-d5}, [%[r0]]!\n" /* load r0, 4-5 */ + "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w0-w1 */ + "1:\n" + /* inr0 */ + "vmull.s8 q4, d0, d6\n" /* int16, out0 */ + "vmull.s8 q5, d2, d6\n" /* int16, out1 */ + "vmull.s8 q6, d4, d6\n" /* int16, out2 */ + "vmlal.s8 q4, d1, d7\n" /* int16, out0 */ + "vld1.32 {d0-d1}, [%[r0]]!\n" /* load r0, 6-7 */ + "vmlal.s8 q5, d3, d7\n" /* int16, out1 */ + "vmlal.s8 q6, d5, d7\n" /* int16, out2 */ + "vmovl.s16 q8, d8\n" /* mov to out0 low */ + "vmull.s8 q7, d0, d6\n" /* int16, out3 */ + "vmovl.s16 q9, d9\n" /* mov to out0 hig */ + "vmovl.s16 q10, d10\n" /* mov to out1 low */ + "vmovl.s16 q11, d11\n" /* mov to out1 hig */ + "vmlal.s8 q7, d1, d7\n" /* int16, out3 */ + "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w2-w3 */ + "vmovl.s16 q12, d12\n" /* mov to out2 low */ + "vmovl.s16 q13, d13\n" /* mov to out2 hig */ + "vmovl.s16 q14, d14\n" /* mov to out3 low */ + "vmovl.s16 q15, d15\n" /* mov to out3 hig */ + + "vmull.s8 q4, d2, d6\n" /* w2, int16, out0 */ + "vmull.s8 q5, d4, d6\n" /* w2, int16, out1 */ + "vmull.s8 q6, d0, d6\n" /* w2, int16, out2 */ + "vmlal.s8 q4, d3, d7\n" /* w3, int16, out0 */ + "vld1.32 {d2-d3}, [%[r0]]!\n" /* load r0, 8-9 */ + "vmlal.s8 q5, d5, d7\n" /* w3, int16, out1 */ + "vmlal.s8 q6, d1, d7\n" /* w3, int16, out2 */ + "vaddw.s16 q8, q8, d8\n" /* add to out0 low */ + "vmull.s8 q7, d2, d6\n" /* w2, int16, out3 */ + "vaddw.s16 q9, q9, d9\n" /* add to out0 hig */ + "vaddw.s16 q10, q10, d10\n" /* add to out1 low */ + "vaddw.s16 q11, q11, d11\n" /* add to out1 hig */ + "vmlal.s8 q7, d3, d7\n" /* w3, int16, out3 */ + "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w4-w5 */ + "vld1.32 {d5}, [%[r0]]\n" /* load r0, 10 */ + "sub %[r0], %[r0], #16\n" /* r0 = r0 - 16 */ + "vaddw.s16 q12, q12, d12\n" /* add to out2 low */ + "vaddw.s16 q13, q13, d13\n" /* add to out2 hig */ + "vaddw.s16 q14, q14, d14\n" /* add to out3 low */ + "vaddw.s16 q15, q15, d15\n" /* add to out3 hig */ + + "vmull.s8 q4, d4, d6\n" /* w4, int16, out0 */ + "vmull.s8 q5, d0, d6\n" /* w4, int16, out1 */ + "vmull.s8 q6, d2, d6\n" /* w4, int16, out2 */ + "vmull.s8 q7, d5, d6\n" /* w4, int16, out3 */ + "vld1.32 {d0-d3}, [%[r1]]!\n" /* load r1, 0-3 */ + "vld1.32 {d4-d5}, [%[r1]]!\n" /* load r1, 4-5 */ + /* inr1 */ + "vmlal.s8 q4, d0, d7\n" /* w5, int16, out0 */ + "vmlal.s8 q5, d2, d7\n" /* w5, int16, out1 */ + "vmlal.s8 q6, d4, d7\n" /* w5, int16, out2 */ + "vld1.32 {d0}, [%[r1]]!\n" /* load r1, 6 */ + "vaddw.s16 q8, q8, d8\n" /* add to out0 low */ + "vaddw.s16 q9, q9, d9\n" /* add to out0 hig */ + "vaddw.s16 q10, q10, d10\n" /* add to out1 low */ + "vaddw.s16 q11, q11, d11\n" /* add to out1 hig */ + "vmlal.s8 q7, d0, d7\n" /* w5, int16, out3 */ + "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w6-w7 */ + "vaddw.s16 q12, q12, d12\n" /* add to out2 low */ + "vaddw.s16 q13, q13, d13\n" /* add to out2 hig */ + "vaddw.s16 q14, q14, d14\n" /* add to out3 low */ + "vaddw.s16 q15, q15, d15\n" /* add to out3 hig */ + + "vmull.s8 q4, d1, d6\n" /* w6, int16, out0 */ + "vld1.32 {d1}, [%[r1]]!\n" /* load r1, 7 */ + "vmull.s8 q5, d3, d6\n" /* w6, int16, out1 */ + "vmull.s8 q6, d5, d6\n" /* w6, int16, out2 */ + "vmlal.s8 q4, d2, d7\n" /* w7, int16, out0 */ + "vmlal.s8 q5, d4, d7\n" /* w7, int16, out1 */ + "vmlal.s8 q6, d0, d7\n" /* w7, int16, out2 */ + "vmull.s8 q7, d1, d6\n" /* w6, int16, out3 */ + "vld1.32 {d2}, [%[r1]]!\n" /* load r1, 8 */ + "vaddw.s16 q8, q8, d8\n" /* add to out0 low */ + "vaddw.s16 q9, q9, d9\n" /* add to out0 hig */ + "vaddw.s16 q10, q10, d10\n" /* add to out1 low */ + "vaddw.s16 q11, q11, d11\n" /* add to out1 hig */ + "vmlal.s8 q7, d2, d7\n" /* w7, int16, out3 */ + "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w8-w9 */ + "vaddw.s16 q12, q12, d12\n" /* add to out2 low */ + "vaddw.s16 q13, q13, d13\n" /* add to out2 hig */ + "vaddw.s16 q14, q14, d14\n" /* add to out3 low */ + "vaddw.s16 q15, q15, d15\n" /* add to out3 hig */ + + "vmull.s8 q4, d3, d6\n" /* w8, int16, out0 */ + "vld1.32 {d3}, [%[r1]]!\n" /* load r1, 9 */ + "vmull.s8 q5, d5, d6\n" /* w8, int16, out1 */ + "vmull.s8 q6, d1, d6\n" /* w8, int16, out2 */ + "vld1.32 {d5}, [%[r1]]\n" /* load r1, 10 */ + "vmlal.s8 q4, d4, d7\n" /* w9, int16, out0 */ + "vmlal.s8 q5, d0, d7\n" /* w9, int16, out1 */ + "vmlal.s8 q6, d2, d7\n" /* w9, int16, out2 */ + "vmull.s8 q7, d3, d6\n" /* w8, int16, out3 */ + "vaddw.s16 q8, q8, d8\n" /* add to out0 low */ + "vaddw.s16 q9, q9, d9\n" /* add to out0 hig */ + "vaddw.s16 q10, q10, d10\n" /* add to out1 low */ + "vaddw.s16 q11, q11, d11\n" /* add to out1 hig */ + "vmlal.s8 q7, d5, d7\n" /* w9, int16, out3 */ + "sub %[r1], %[r1], #16\n" /* r1 = r1 - 16 */ + "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w10-w11 */ + "vaddw.s16 q12, q12, d12\n" /* add to out2 low */ + "vaddw.s16 q13, q13, d13\n" /* add to out2 hig */ + "vaddw.s16 q14, q14, d14\n" /* add to out3 low */ + "vaddw.s16 q15, q15, d15\n" /* add to out3 hig */ + "vld1.32 {d0-d3}, [%[r2]]!\n" /* load r2, 0-3 */ + "vld1.32 {d4-d5}, [%[r2]]!\n" /* load r2, 4-5 */ + + /* inr2 */ + "vmull.s8 q4, d0, d6\n" /* w10, int16, out0 */ + "vmull.s8 q5, d2, d6\n" /* w10, int16, out1 */ + "vmull.s8 q6, d4, d6\n" /* w10, int16, out2 */ + "vmlal.s8 q4, d1, d7\n" /* w11, int16, out0 */ + "vld1.32 {d0-d1}, [%[r2]]!\n" /* load r2, 6-7 */ + "vmlal.s8 q5, d3, d7\n" /* w11, int16, out1 */ + "vmlal.s8 q6, d5, d7\n" /* w11, int16, out2 */ + "vaddw.s16 q8, q8, d8\n" /* add to out0 low */ + "vmull.s8 q7, d0, d6\n" /* w10, int16, out3 */ + "vaddw.s16 q9, q9, d9\n" /* add to out0 hig */ + "vaddw.s16 q10, q10, d10\n" /* add to out1 low */ + "vaddw.s16 q11, q11, d11\n" /* add to out1 hig */ + "vmlal.s8 q7, d1, d7\n" /* w11, int16, out3 */ + "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w12-w13 */ + "vaddw.s16 q12, q12, d12\n" /* add to out2 low */ + "vaddw.s16 q13, q13, d13\n" /* add to out2 hig */ + "vaddw.s16 q14, q14, d14\n" /* add to out3 low */ + "vaddw.s16 q15, q15, d15\n" /* add to out3 hig */ + + "vmull.s8 q4, d2, d6\n" /* w12, int16, out0 */ + "vmull.s8 q5, d4, d6\n" /* w12, int16, out1 */ + "vmull.s8 q6, d0, d6\n" /* w12, int16, out2 */ + "vmlal.s8 q4, d3, d7\n" /* w13, int16, out0 */ + "vld1.32 {d2-d3}, [%[r2]]!\n" /* load r2, 8-9 */ + "vmlal.s8 q5, d5, d7\n" /* w13, int16, out1 */ + "vmlal.s8 q6, d1, d7\n" /* w13, int16, out2 */ + "vaddw.s16 q8, q8, d8\n" /* add to out0 low */ + "vmull.s8 q7, d2, d6\n" /* w12, int16, out3 */ + "vaddw.s16 q9, q9, d9\n" /* add to out0 hig */ + "vaddw.s16 q10, q10, d10\n" /* add to out1 low */ + "vaddw.s16 q11, q11, d11\n" /* add to out1 hig */ + "vmlal.s8 q7, d3, d7\n" /* w13, int16, out3 */ + "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w14-w15 */ + "vld1.32 {d5}, [%[r2]]\n" /* load r2, 10 */ + "sub %[r2], %[r2], #16\n" /* r2 = r2 - 16 */ + "vaddw.s16 q12, q12, d12\n" /* add to out2 low */ + "vaddw.s16 q13, q13, d13\n" /* add to out2 hig */ + "vaddw.s16 q14, q14, d14\n" /* add to out3 low */ + "vaddw.s16 q15, q15, d15\n" /* add to out3 hig */ + + "vmull.s8 q4, d4, d6\n" /* w14, int16, out0 */ + "vmull.s8 q5, d0, d6\n" /* w14, int16, out1 */ + "vmull.s8 q6, d2, d6\n" /* w14, int16, out2 */ + "vmull.s8 q7, d5, d6\n" /* w14, int16, out3 */ + "vld1.32 {d0-d3}, [%[r3]]!\n" /* load r3, 0-3 */ + "vld1.32 {d4-d5}, [%[r3]]!\n" /* load r3, 4-5 */ + /* inr3 */ + "vmlal.s8 q4, d0, d7\n" /* w15, int16, out0 */ + "vmlal.s8 q5, d2, d7\n" /* w15, int16, out1 */ + "vmlal.s8 q6, d4, d7\n" /* w15, int16, out2 */ + "vld1.32 {d0}, [%[r3]]!\n" /* load r3, 6 */ + "vaddw.s16 q8, q8, d8\n" /* add to out0 low */ + "vaddw.s16 q9, q9, d9\n" /* add to out0 hig */ + "vaddw.s16 q10, q10, d10\n" /* add to out1 low */ + "vaddw.s16 q11, q11, d11\n" /* add to out1 hig */ + "vmlal.s8 q7, d0, d7\n" /* w15, int16, out3 */ + "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w16-w17 */ + "vaddw.s16 q12, q12, d12\n" /* add to out2 low */ + "vaddw.s16 q13, q13, d13\n" /* add to out2 hig */ + "vaddw.s16 q14, q14, d14\n" /* add to out3 low */ + "vaddw.s16 q15, q15, d15\n" /* add to out3 hig */ + + "vmull.s8 q4, d1, d6\n" /* w16, int16, out0 */ + "vld1.32 {d1}, [%[r3]]!\n" /* load r3, 7 */ + "vmull.s8 q5, d3, d6\n" /* w16, int16, out1 */ + "vmull.s8 q6, d5, d6\n" /* w16, int16, out2 */ + "vmlal.s8 q4, d2, d7\n" /* w17, int16, out0 */ + "vmlal.s8 q5, d4, d7\n" /* w17, int16, out1 */ + "vmlal.s8 q6, d0, d7\n" /* w17, int16, out2 */ + "vmull.s8 q7, d1, d6\n" /* w16, int16, out3 */ + "vld1.32 {d2}, [%[r3]]!\n" /* load r3, 8 */ + "vaddw.s16 q8, q8, d8\n" /* add to out0 low */ + "vaddw.s16 q9, q9, d9\n" /* add to out0 hig */ + "vaddw.s16 q10, q10, d10\n" /* add to out1 low */ + "vaddw.s16 q11, q11, d11\n" /* add to out1 hig */ + "vmlal.s8 q7, d2, d7\n" /* w17, int16, out3 */ + "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w18-w19 */ + "vaddw.s16 q12, q12, d12\n" /* add to out2 low */ + "vaddw.s16 q13, q13, d13\n" /* add to out2 hig */ + "vaddw.s16 q14, q14, d14\n" /* add to out3 low */ + "vaddw.s16 q15, q15, d15\n" /* add to out3 hig */ + + "vmull.s8 q4, d3, d6\n" /* w18, int16, out0 */ + "vld1.32 {d3}, [%[r3]]!\n" /* load r3, 9 */ + "vmull.s8 q5, d5, d6\n" /* w18, int16, out1 */ + "vmull.s8 q6, d1, d6\n" /* w18, int16, out2 */ + "vld1.32 {d5}, [%[r3]]\n" /* load r3, 10 */ + "vmlal.s8 q4, d4, d7\n" /* w19, int16, out0 */ + "vmlal.s8 q5, d0, d7\n" /* w19, int16, out1 */ + "vmlal.s8 q6, d2, d7\n" /* w19, int16, out2 */ + "vmull.s8 q7, d3, d6\n" /* w18, int16, out3 */ + "vaddw.s16 q8, q8, d8\n" /* add to out0 low */ + "vaddw.s16 q9, q9, d9\n" /* add to out0 hig */ + "vaddw.s16 q10, q10, d10\n" /* add to out1 low */ + "vaddw.s16 q11, q11, d11\n" /* add to out1 hig */ + "vmlal.s8 q7, d5, d7\n" /* w19, int16, out3 */ + "sub %[r3], %[r3], #16\n" /* r3 = r3 - 16 */ + "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w20-w21 */ + "vaddw.s16 q12, q12, d12\n" /* add to out2 low */ + "vaddw.s16 q13, q13, d13\n" /* add to out2 hig */ + "vaddw.s16 q14, q14, d14\n" /* add to out3 low */ + "vaddw.s16 q15, q15, d15\n" /* add to out3 hig */ + "vld1.32 {d0-d3}, [%[r4]]!\n" /* load r4, 0-3 */ + "vld1.32 {d4-d5}, [%[r4]]!\n" /* load r4, 4-5 */ + + /* inr4 */ + "vmull.s8 q4, d0, d6\n" /* w20, int16, out0 */ + "vmull.s8 q5, d2, d6\n" /* w20, int16, out1 */ + "vmull.s8 q6, d4, d6\n" /* w20, int16, out2 */ + "vmlal.s8 q4, d1, d7\n" /* w21, int16, out0 */ + "vld1.32 {d0-d1}, [%[r4]]!\n" /* load r4, 6-7 */ + "vmlal.s8 q5, d3, d7\n" /* w21, int16, out1 */ + "vmlal.s8 q6, d5, d7\n" /* w21, int16, out2 */ + "vaddw.s16 q8, q8, d8\n" /* add to out0 low */ + "vmull.s8 q7, d0, d6\n" /* w20, int16, out3 */ + "vaddw.s16 q9, q9, d9\n" /* add to out0 hig */ + "vaddw.s16 q10, q10, d10\n" /* add to out1 low */ + "vaddw.s16 q11, q11, d11\n" /* add to out1 hig */ + "vmlal.s8 q7, d1, d7\n" /* w21, int16, out3 */ + "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w22-w23 */ + "vaddw.s16 q12, q12, d12\n" /* add to out2 low */ + "vaddw.s16 q13, q13, d13\n" /* add to out2 hig */ + "vaddw.s16 q14, q14, d14\n" /* add to out3 low */ + "vaddw.s16 q15, q15, d15\n" /* add to out3 hig */ + + "vmull.s8 q4, d2, d6\n" /* w22, int16, out0 */ + "vmull.s8 q5, d4, d6\n" /* w22, int16, out1 */ + "vmull.s8 q6, d0, d6\n" /* w22, int16, out2 */ + "vmlal.s8 q4, d3, d7\n" /* w23, int16, out0 */ + "vld1.32 {d2-d3}, [%[r4]]!\n" /* load r4, 7-8 */ + "vmlal.s8 q5, d5, d7\n" /* w23, int16, out1 */ + "vmlal.s8 q6, d1, d7\n" /* w23, int16, out2 */ + "vaddw.s16 q8, q8, d8\n" /* add to out0 low */ + "vmull.s8 q7, d2, d6\n" /* w22, int16, out3 */ + "vaddw.s16 q9, q9, d9\n" /* add to out0 hig */ + "vaddw.s16 q10, q10, d10\n" /* add to out1 low */ + "vaddw.s16 q11, q11, d11\n" /* add to out1 hig */ + "vmlal.s8 q7, d3, d7\n" /* w23, int16, out3 */ + "vld1.32 {d6}, [%[wptr]]!\n" /* load w24 */ + "vld1.32 {d5}, [%[r4]]\n" /* load r4, 10 */ + "sub %[r4], %[r4], #16\n" /* r4 = r4 - 16 */ + "vaddw.s16 q12, q12, d12\n" /* add to out2 low */ + "vaddw.s16 q13, q13, d13\n" /* add to out2 hig */ + "vaddw.s16 q14, q14, d14\n" /* add to out3 low */ + "vaddw.s16 q15, q15, d15\n" /* add to out3 hig */ + "sub %[wptr], %[wptr], #200 \n" /* wptr = wptr - 200 */ + + "vmull.s8 q4, d4, d6\n" /* w22, int16, out0 */ + "vmull.s8 q5, d0, d6\n" /* w22, int16, out1 */ + "vmull.s8 q6, d2, d6\n" /* w22, int16, out2 */ + "vmull.s8 q7, d5, d6\n" /* w22, int16, out3 */ + "vld1.32 {d0-d3}, [%[r0]]!\n" /* load r0, 0-3 */ + "vld1.32 {d6-d7}, [%[wptr]]!\n" /* load w0-w1 */ + "vaddw.s16 q8, q8, d8\n" /* add to out0 low */ + "vaddw.s16 q9, q9, d9\n" /* add to out0 hig */ + "vld1.32 {d4-d5}, [%[r0]]!\n" /* load r0, 0-3 */ + "vaddw.s16 q10, q10, d10\n" /* add to out1 low */ + "vaddw.s16 q11, q11, d11\n" /* add to out1 hig */ + "vst1.32 {d16-d19}, [%[ptr_out0]]!\n"/* store out0 */ + "vaddw.s16 q12, q12, d12\n" /* add to out2 low */ + "vaddw.s16 q13, q13, d13\n" /* add to out2 hig */ + "vst1.32 {d20-d23}, [%[ptr_out0]]!\n"/*store out1 */ + "vaddw.s16 q14, q14, d14\n" /* add to out3 low */ + "vaddw.s16 q15, q15, d15\n" /* add to out3 hig */ + "subs %[cnt], #1\n" /* cnt = cnt - 1 */ + "vst1.32 {d24-d27}, [%[ptr_out0]]!\n"/* store out2 */ + "vst1.32 {d28-d31}, [%[ptr_out0]]!\n"/* store out3 */ + "bne 1b\n" /* branch main loop */ + : [cnt] "+r"(cnt), + [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [r3] "+r"(inr3), + [r4] "+r"(inr4), + [ptr_out0] "+r"(ptr_out0), + [wptr] "+r"(wptr) + : + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + // clang-format on + block_inr0 = block_inr2; + block_inr1 = block_inr3; + block_inr2 = block_inr4; + block_inr3 = block_inr2 + in_len; + block_inr4 = block_inr3 + in_len; + } + write_int32_nchwc8_to_nchw(pre_out, + reinterpret_cast(dout_batch), + c, + c + hout_c_block, + h, + h + h_kernel, + 0, + wout_round, + chout, + hout, + wout, + flag_relu, + bias_local, + flag_bias, + ptr_write, + scale + c); + } + } + } +} + +template void conv_depthwise_5x5s2_int8(int8_t* dout, + const int8_t* din, + const int8_t* weights, + const float* scale, + const float* bias, + bool flag_bias, + bool flag_relu, + int num, + int chin, + int hin, + int win, + int hout, + int wout, + int padw, + int padh, + ARMContext* ctx); + +template void conv_depthwise_5x5s2_int8(float* dout, + const int8_t* din, + const int8_t* weights, + const float* scale, + const float* bias, + bool flag_bias, + bool flag_relu, + int num, + int chin, + int hin, + int win, + int hout, + int wout, + int padw, + int padh, + ARMContext* ctx); +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/backends/arm/math/conv_block_utils.h b/lite/backends/arm/math/conv_block_utils.h index b2d16d18d2300ea51de8c8e9f25664ffdf4aebc7..85404d6a6e2e6246677857be8231e15afa86210d 100644 --- a/lite/backends/arm/math/conv_block_utils.h +++ b/lite/backends/arm/math/conv_block_utils.h @@ -20,6 +20,7 @@ #include "lite/backends/arm/math/sgemm.h" #include "lite/backends/arm/math/type_trans.h" #include "lite/core/target_wrapper.h" +#include "lite/operators/op_params.h" #include "lite/utils/cp_logging.h" namespace paddle { @@ -28,6 +29,7 @@ namespace arm { namespace math { #define LITEMAX(a, b) ((a) > (b) ? (a) : (b)) +#define LITEMIN(a, b) ((a) < (b) ? (a) : (b)) #define ROUNDUP(a, b) ((((a) + (b)-1) / (b)) * (b)) template @@ -254,6 +256,7 @@ inline void prepack_input_nxwc4_dw(const float* din, LOG(FATAL) << "prepack_dw_input, valid height must > zero"; } float32x4_t vzero = vdupq_n_f32(0.f); + auto out_data = dout; int size_w = we - ws; int w0 = ws < 0 ? 0 : ws; @@ -269,6 +272,7 @@ inline void prepack_input_nxwc4_dw(const float* din, bool flag_ext_l = left_remain > 0; int left_sl = 4 - left_remain; + int left_valid_sl = left_sl > width ? width : left_sl; uint32x4_t vmask_padl; bool flag_mask_l = false; if (flag_ext_l) { @@ -290,6 +294,7 @@ inline void prepack_input_nxwc4_dw(const float* din, } int size_c = width * height; for (int h = hs; h < he; ++h) { + dout = out_data + (h - hs) * 4 * size_w; auto ptr_c0 = din + cs * size_c + h * width; auto ptr_c1 = ptr_c0 + size_c; auto ptr_c2 = ptr_c1 + size_c; @@ -351,10 +356,10 @@ inline void prepack_input_nxwc4_dw(const float* din, } transpose_4x4(vc0, vc1, vc2, vc3, dout); dout += 16; - ptr_c0 += left_sl; - ptr_c1 += left_sl; - ptr_c2 += left_sl; - ptr_c3 += left_sl; + ptr_c0 += left_valid_sl; + ptr_c1 += left_valid_sl; + ptr_c2 += left_valid_sl; + ptr_c3 += left_valid_sl; } /// valid for (int i = 0; i < cnt_valid; ++i) { @@ -586,7 +591,238 @@ inline void prepack_input_nxwc8_int8_dw(const int8_t* din, } } } - +// clang-format off +#ifdef __aarch64__ +#define NCHWC1_TRANS_FP32_COMPUTE \ + "ldr q0, [%[ptr_din]], #16 \n" /* load data, c0r0, c1r0, c0r1*/ \ + "ldr q1, [%[ptr_din]], #16 \n" /* load data, c0r0, c1r0, c0r1*/ \ + "ldr q2, [%[ptr_din]], #16 \n" /* load data, c0r0, c1r0, c0r1*/ \ + "ldr q3, [%[ptr_din]], #16 \n" /* load data, c0r0, c1r0, c0r1*/ \ + "movi v20.4s, #0 \n" /* for relu */ \ + "1: \n" /* main loop*/ + +#define NCHWC1_TRANS_FP32_RELU \ + "fmax v0.4s, v0.4s, v20.4s \n" /*relu*/ \ + "fmax v1.4s, v1.4s, v20.4s \n" /*relu*/ \ + "fmax v2.4s, v2.4s, v20.4s \n" /*relu*/ \ + "fmax v3.4s, v3.4s, v20.4s \n" /*relu*/ + +#define NCHWC1_TRANS_FP32_RELU6 \ + "fmin v0.4s, v0.4s, %[six].4s \n" /* relu6 */ \ + "fmin v1.4s, v1.4s, %[six].4s \n" /* relu6 */ \ + "fmin v2.4s, v2.4s, %[six].4s \n" /* relu6 */ \ + "fmin v3.4s, v3.4s, %[six].4s \n" /* relu6 */ + +#define NCHWC1_TRANS_FP32_LEAKY_RELU \ + "fcmge v4.4s, v0.4s, v20.4s \n" /* vcgeq_f32 */ \ + "fcmge v5.4s, v1.4s, v20.4s \n" /* vcgeq_f32 */ \ + "fcmge v6.4s, v2.4s, v20.4s \n" /* vcgeq_f32 */ \ + "fcmge v7.4s, v3.4s, v20.4s \n" /* vcgeq_f32 */ \ + "fmul v8.4s, v0.4s, %[scale].4s \n" /* mul */ \ + "fmul v9.4s, v1.4s, %[scale].4s \n" /* mul */ \ + "fmul v10.4s, v2.4s, %[scale].4s \n" /* mul */ \ + "fmul v11.4s, v3.4s, %[scale].4s \n" /* mul */ \ + "bif v0.16b, v8.16b, v4.16b \n" /* choose*/ \ + "bif v1.16b, v9.16b, v5.16b \n" /* choose*/ \ + "bif v2.16b, v10.16b, v6.16b \n" /* choose*/ \ + "bif v3.16b, v11.16b, v7.16b \n" /* choose*/ + +#define NCHWC1_TRANS_FP32_STORE \ + "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ \ + \ + "str q0, [%[doutc0r0]], #16 \n" /* store c0r0*/ \ + "str q1, [%[doutc0r0]], #16 \n" /* store c0r0*/ \ + "ldr q0, [%[ptr_din]], #16 \n" /* load data, c0r0, c1r0, c0r1*/ \ + "ldr q1, [%[ptr_din]], #16 \n" /* load data, c0r0, c1r0, c0r1*/ \ + "str q2, [%[doutc0r0]], #16 \n" /* store c0r0*/ \ + "str q3, [%[doutc0r0]], #16 \n" /* store c2r0*/ \ + "ldr q2, [%[ptr_din]], #16 \n" /* load data, c0r0, c1r0, c0r1*/ \ + "ldr q3, [%[ptr_din]], #16 \n" /* load data, c0r0, c1r0, c0r1*/ \ + \ + "bne 1b \n" /* jump to main loop*/ +#else +#define NCHWC1_TRANS_FP32_COMPUTE \ + "vld1.32 {d0-d3}, [%[ptr_din]]! @ load data, c0r0 \n" \ + "vld1.32 {d4-d7}, [%[ptr_din]]! @ load data, c0r0 \n" \ + "vmov.u32 q15, #0 @ dump zero\n" \ + "1: @ main loop\n" + +#define NCHWC1_TRANS_FP32_RELU \ + "vmax.f32 q0, q0, q15 @ relu\n" \ + "vmax.f32 q1, q1, q15 @ relu\n" \ + "vmax.f32 q2, q2, q15 @ relu\n" \ + "vmax.f32 q3, q3, q15 @ relu\n" + +#define NCHWC1_TRANS_FP32_RELU6 \ + "vmin.f32 q0, q0, %q[six] @ relu6 \n" \ + "vmin.f32 q1, q1, %q[six] @ relu6 \n" \ + "vmin.f32 q2, q2, %q[six] @ relu6 \n" \ + "vmin.f32 q3, q3, %q[six] @ relu6 \n" + +#define NCHWC1_TRANS_FP32_LEAKY_RELU \ + "vcge.f32 q5, q0, q15 @ q0 > 0 \n" \ + "vcge.f32 q6, q1, q15 @ q0 > 0 \n" \ + "vcge.f32 q7, q2, q15 @ q0 > 0 \n" \ + "vcge.f32 q8, q3, q15 @ q0 > 0 \n" \ + "vmul.f32 q9, q0, %q[scale] \n" \ + "vmul.f32 q10, q1, %q[scale] \n" \ + "vmul.f32 q11, q2, %q[scale] \n" \ + "vmul.f32 q12, q3, %q[scale] \n" \ + "vbif q0, q9, q5 @ choose \n" \ + "vbif q1, q10, q6 @ choose \n" \ + "vbif q2, q11, q7 @ choose \n" \ + "vbif q3, q12, q8 @ choose \n" + +#define NCHWC1_TRANS_FP32_STORE \ + "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result \n" \ + "vst1.32 {d2-d3}, [%[doutc0r0]]! @ store result, \n" \ + "subs %[cnt], %[cnt], #1 @ loop count - 1\n" \ + \ + "vld1.32 {d0-d3}, [%[ptr_din]]! @ load data \n" \ + "vst1.32 {d4-d5}, [%[doutc0r0]]! @ store result \n" \ + "vst1.32 {d6-d7}, [%[doutc0r0]]! @ store result, \n" \ + \ + "vld1.32 {d4-d7}, [%[ptr_din]]! @ load data \n" \ + \ + "bne 1b @ jump to main loop\n" +#endif +// clang-format on +inline void act_switch_c1_fp32(const float* din_ptr, + float* doutc0_ptr, + int cnt_loop, + const operators::ActivationParam* act_param) { + if (act_param != nullptr && act_param->has_active) { + float32x4_t six = vdupq_n_f32(act_param->Relu_clipped_coef); + float32x4_t scale = vdupq_n_f32(act_param->Leaky_relu_alpha); + switch (act_param->active_type) { + case lite_api::ActivationType::kRelu: +#ifdef __aarch64__ + asm volatile(NCHWC1_TRANS_FP32_COMPUTE NCHWC1_TRANS_FP32_RELU + NCHWC1_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_ptr) + : + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v20"); +#else + asm volatile(NCHWC1_TRANS_FP32_COMPUTE NCHWC1_TRANS_FP32_RELU + NCHWC1_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [ptr_din] "+r"(din_ptr), + [cnt] "+r"(cnt_loop) + : + : "q0", "q1", "q2", "q3", "q15"); +#endif + break; + case lite_api::ActivationType::kRelu6: +/* 0 <= din <= 6 */ +#ifdef __aarch64__ + asm volatile(NCHWC1_TRANS_FP32_COMPUTE NCHWC1_TRANS_FP32_RELU + NCHWC1_TRANS_FP32_RELU6 NCHWC1_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_ptr) + : [six] "w"(six) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v20"); +#else + asm volatile(NCHWC1_TRANS_FP32_COMPUTE NCHWC1_TRANS_FP32_RELU + NCHWC1_TRANS_FP32_RELU6 NCHWC1_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [ptr_din] "+r"(din_ptr), + [cnt] "+r"(cnt_loop) + : [six] "w"(six) + : "q0", "q1", "q2", "q3", "q15"); +#endif + break; + case lite_api::ActivationType::kLeakyRelu: +/*din = din >= 0 ? din : din * scale*/ +#ifdef __aarch64__ + asm volatile(NCHWC1_TRANS_FP32_COMPUTE NCHWC1_TRANS_FP32_LEAKY_RELU + NCHWC1_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_ptr) + : [scale] "w"(scale) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v20"); +#else + asm volatile(NCHWC1_TRANS_FP32_COMPUTE NCHWC1_TRANS_FP32_LEAKY_RELU + NCHWC1_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [ptr_din] "+r"(din_ptr), + [cnt] "+r"(cnt_loop) + : [scale] "w"(scale) + : "q0", + "q1", + "q2", + "q3", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q15"); +#endif + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param->active_type) + << " fuse not support"; + } + } else { +#ifdef __aarch64__ + asm volatile(NCHWC1_TRANS_FP32_COMPUTE NCHWC1_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_ptr) + : + : "v0", "v1", "v2", "v3", "v20"); +#else + asm volatile(NCHWC1_TRANS_FP32_COMPUTE NCHWC1_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [ptr_din] "+r"(din_ptr), + [cnt] "+r"(cnt_loop) + : + : "q0", "q1", "q2", "q3", "q15"); +#endif + } +} /*wirte result in outputs * input din: [n, c, h, w], output dout: [n, c, h, w] */ @@ -602,13 +838,14 @@ inline bool write_to_output_c1_fp32(const float* din, int height, int width, bool flag_relu, - float* trash_ptr) { + float* trash_ptr, + operators::ActivationParam* act_param) { if (cs > channel) { return true; } const int c1 = 1; - const int w4 = 4; + const int w4 = 16; int size_c_out = width * height; @@ -620,98 +857,53 @@ inline bool write_to_output_c1_fp32(const float* din, int w_round = we - ws; int cnt = (width - ws) / w4; - + int remain = (width - ws) % w4; for (int i = 0; i < size_h; i++) { int size_w = i * width; float* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; const float* din_hei_ptr = ptr_din + i * w_round * c1; if (cnt > 0) { int cnt_loop = cnt; - if (flag_relu) { -#ifdef __aarch64__ - asm volatile( - "ldr q0, [%[ptr_din]], #16 \n" /* load data, c0r0, c0r1, c0r2, - c0r3 */ - "movi v20.4s, #0 \n" /* for relu */ - "1: \n" /* main loop*/ - "fmax v1.4s, v0.4s, v20.4s \n" /*relu*/ - "ldr q0, [%[ptr_din]], #16 \n" /* load data, c0r0, c0r1, c0r2, - c0r3 */ - "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ - "str q1, [%[doutc0r0]], #16 \n" /* store c0r0*/ - "bne 1b \n" /* jump to main loop*/ - : [doutc0r0] "+r"(doutc0_ptr), - [cnt] "+r"(cnt_loop), - [ptr_din] "+r"(din_hei_ptr) - : - : "v0", "v1", "v20"); -#else - asm volatile( - "vld1.32 {d0-d1}, [%[ptr_din]]! @ load data, c0r0, " - "c1r0, c0r1, c1r1, , c0r2, c1r2, c0r3, c1r3\n" - "vmov.u32 q15, #0 @ dump zero\n" - "1: @ main loop\n" - - "vmax.f32 q1, q0, q15 @ relu\n" - "vld1.32 {d0-d1}, [%[ptr_din]]! @ load data \n" - - "vst1.32 {d2-d3}, [%[doutc0r0]]! @ store result, add " - "pointer\n" - - "subs %[cnt], %[cnt], #1 @ loop count - 1\n" - - "bne 1b @ jump to main loop\n" - - : [doutc0r0] "+r"(doutc0_ptr), - [ptr_din] "+r"(din_hei_ptr), - [cnt] "+r"(cnt_loop) - : - : "q0", "q1", "q15"); -#endif - } else { -#ifdef __aarch64__ - asm volatile( - "ldr q0, [%[ptr_din]], #16 \n" /* load data, c0r0, c0r1, c0r2, - c0r3 */ - "1: \n" /* main loop*/ - "str q0, [%[doutc0r0]], #16 \n" /* store c2r0*/ - "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ - "ldr q0, [%[ptr_din]], #16 \n" /* load data, c0r0, c0r1, c0r2, - c0r3 */ - "bne 1b \n" /* jump to main loop*/ - - : [doutc0r0] "+r"(doutc0_ptr), - [cnt] "+r"(cnt_loop), - [ptr_din] "+r"(din_hei_ptr) - : - : "v0"); -#else - asm volatile( - "vld1.32 {d0-d1}, [%[ptr_din]]! @ load data, c0r0, " - "c0r1, c0r2, c0r3\n" - "1: @ main loop\n" - "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add " - "pointer\n" - "subs %[cnt], %[cnt], #1 @ loop count - 1\n" - "vld1.32 {d0-d1}, [%[ptr_din]]! @ load data \n" - "bne 1b @ jump to main loop\n" - - : [doutc0r0] "+r"(doutc0_ptr), - [ptr_din] "+r"(din_hei_ptr), - [cnt] "+r"(cnt_loop) - : - : "q0"); -#endif - } + act_switch_c1_fp32(din_hei_ptr, doutc0_ptr, cnt_loop, act_param); } - if (we > width) { + if (remain > 0) { int offset = i * w_round * c1 + c1 * w4 * cnt; din_hei_ptr = ptr_din + offset; - int j = we - w4; - if (flag_relu) { - for (; j < width; ++j) { - *(doutc0_ptr++) = LITEMAX(din_hei_ptr[0], 0.f); - din_hei_ptr++; + doutc0_ptr += w4 * cnt; + int j = w4 * cnt; + if (act_param != nullptr && act_param->has_active) { + float six = act_param->Relu_clipped_coef; + float scale = act_param->Leaky_relu_alpha; + switch (act_param->active_type) { + case lite_api::ActivationType::kRelu: + for (; j < width; ++j) { + *(doutc0_ptr++) = LITEMAX(din_hei_ptr[0], 0.f); + din_hei_ptr++; + } + break; + case lite_api::ActivationType::kRelu6: + /* 0 <= din <= 6 */ + for (; j < width; ++j) { + float tmp = LITEMAX(din_hei_ptr[0], 0.f); + *(doutc0_ptr++) = LITEMIN(tmp, six); + din_hei_ptr++; + } + break; + case lite_api::ActivationType::kLeakyRelu: + /*din = din >= 0 ? din : din * scale*/ + for (; j < width; ++j) { + if (din_hei_ptr[0] >= 0) { + *(doutc0_ptr++) = din_hei_ptr[0]; + } else { + *(doutc0_ptr++) = din_hei_ptr[0] * scale; + } + din_hei_ptr++; + } + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param->active_type) + << " fuse not support"; } } else { for (; j < width; ++j) { @@ -722,7 +914,224 @@ inline bool write_to_output_c1_fp32(const float* din, } return true; } - +// clang-format off +#ifdef __aarch64__ +#define NCHWC2_TRANS_FP32_COMPUTE \ + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load data, c0r0, c1r0, c0r1*/ \ + "movi v20.4s, #0 \n" /* for relu */ \ + "1: \n" /* main loop*/ \ + "trn1 v2.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ \ + "trn2 v3.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ \ + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load data, c0r0, c1r0, c0r1*/ \ + "trn1 v4.2d, v2.2d, v3.2d \n" /* trans q8, q10*/ \ + "trn2 v5.2d, v2.2d, v3.2d \n" /* trans q8, q10*/ + +#define NCHWC2_TRANS_FP32_RELU \ + "fmax v2.4s, v4.4s, v20.4s \n" /*relu*/ \ + "fmax v3.4s, v5.4s, v20.4s \n" /*relu*/ + +#define NCHWC2_TRANS_FP32_RELU6 \ + "fmin v2.4s, v2.4s, %[six].4s \n" /* relu6 */ \ + "fmin v3.4s, v3.4s, %[six].4s \n" /* relu6 */ + +#define NCHWC2_TRANS_FP32_LEAKY_RELU \ + "fcmge v6.4s, v2.4s, v20.4s \n" /* vcgeq_f32 */ \ + "fcmge v7.4s, v3.4s, v20.4s \n" /* vcgeq_f32 */ \ + "fmul v4.4s, v2.4s, %[scale].4s \n" /* mul */ \ + "fmul v5.4s, v3.4s, %[scale].4s \n" /* mul */ \ + "bif v2.16b, v4.16b, v6.16b \n" /* choose*/ \ + "bif v3.16b, v5.16b, v7.16b \n" /* choose*/ + +#define NCHWC2_TRANS_FP32_STORE \ + "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ \ + \ + "str q2, [%[doutc0r0]], #16 \n" /* store c0r0*/ \ + "str q3, [%[doutc1r0]], #16 \n" /* store c2r0*/ \ + \ + "bne 1b \n" /* jump to main loop*/ +#else +#define NCHWC2_TRANS_FP32_COMPUTE \ + "vld1.32 {d0-d3}, [%[ptr_din]]! @ load data, c0r0, c1r0 \n" \ + "vmov.u32 q15, #0 @ dump zero\n" \ + "1: @ main loop\n" \ + "vtrn.32 d0, d1 @ trans data:c0r0, c0r1, " \ + "c1r0, c1r1 \n" \ + "vtrn.32 d2, d3 @ trans data:c0r2, c0r3, " \ + "c1r2, c1r3 \n" \ + \ + "vswp d1, d2 @ swap data\n" + +#define NCHWC2_TRANS_FP32_RELU \ + "vmax.f32 q0, q0, q15 @ relu\n" \ + "vmax.f32 q1, q1, q15 @ relu\n" + +#define NCHWC2_TRANS_FP32_RELU6 \ + "vmin.f32 q0, q0, %q[six] @ relu6 \n" \ + "vmin.f32 q1, q1, %q[six] @ relu6 \n" + +#define NCHWC2_TRANS_FP32_LEAKY_RELU \ + "vcge.f32 q5, q0, q15 @ q0 > 0 \n" \ + "vcge.f32 q6, q1, q15 @ q0 > 0 \n" \ + "vmul.f32 q9, q0, %q[scale] \n" \ + "vmul.f32 q10, q1, %q[scale] \n" \ + "vbif q0, q9, q5 @ choose \n" \ + "vbif q1, q10, q6 @ choose \n" + +#define NCHWC2_TRANS_FP32_STORE \ + "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add pointer\n" \ + "vst1.32 {d2-d3}, [%[doutc1r0]]! @ store result, add pointer\n" \ + \ + "subs %[cnt], %[cnt], #1 @ loop count - 1\n" \ + \ + "vld1.32 {d0-d3}, [%[ptr_din]]! @ load data \n" \ + \ + "bne 1b @ jump to main loop\n" +#endif +// clang-format on +inline void act_switch_c2_fp32(const float* din_ptr, + float* doutc0_ptr, + float* doutc1_ptr, + int cnt_loop, + const operators::ActivationParam* act_param) { + if (act_param != nullptr && act_param->has_active) { + float32x4_t six = vdupq_n_f32(act_param->Relu_clipped_coef); + float32x4_t scale = vdupq_n_f32(act_param->Leaky_relu_alpha); + switch (act_param->active_type) { + case lite_api::ActivationType::kRelu: +#ifdef __aarch64__ + asm volatile(NCHWC2_TRANS_FP32_COMPUTE NCHWC2_TRANS_FP32_RELU + NCHWC2_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_ptr) + : + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v20"); +#else + asm volatile(NCHWC2_TRANS_FP32_COMPUTE NCHWC2_TRANS_FP32_RELU + NCHWC2_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [ptr_din] "+r"(din_ptr), + [cnt] "+r"(cnt_loop) + : + : "q0", "q1", "q2", "q3", "q15"); +#endif + break; + case lite_api::ActivationType::kRelu6: +/* 0 <= din <= 6 */ +#ifdef __aarch64__ + asm volatile(NCHWC2_TRANS_FP32_COMPUTE NCHWC2_TRANS_FP32_RELU + NCHWC2_TRANS_FP32_RELU6 NCHWC2_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_ptr) + : [six] "w"(six) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v20"); +#else + asm volatile(NCHWC2_TRANS_FP32_COMPUTE NCHWC2_TRANS_FP32_RELU + NCHWC2_TRANS_FP32_RELU6 NCHWC2_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [ptr_din] "+r"(din_ptr), + [cnt] "+r"(cnt_loop) + : [six] "w"(six) + : "q0", "q1", "q2", "q3", "q15"); +#endif + break; + case lite_api::ActivationType::kLeakyRelu: +/*din = din >= 0 ? din : din * scale*/ +#ifdef __aarch64__ + asm volatile(NCHWC2_TRANS_FP32_COMPUTE NCHWC2_TRANS_FP32_LEAKY_RELU + NCHWC2_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_ptr) + : [scale] "w"(scale) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v20"); +#else + asm volatile(NCHWC2_TRANS_FP32_COMPUTE NCHWC2_TRANS_FP32_LEAKY_RELU + NCHWC2_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [ptr_din] "+r"(din_ptr), + [cnt] "+r"(cnt_loop) + : [scale] "w"(scale) + : "q0", + "q1", + "q2", + "q3", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q15"); +#endif + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param->active_type) + << " fuse not support"; + } + } else { +#ifdef __aarch64__ + asm volatile(NCHWC2_TRANS_FP32_COMPUTE NCHWC2_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_ptr) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v20"); +#else + asm volatile(NCHWC2_TRANS_FP32_COMPUTE NCHWC2_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [ptr_din] "+r"(din_ptr), + [cnt] "+r"(cnt_loop) + : + : "q0", "q1", "q2", "q3", "q15"); +#endif + } +} /*wirte result in outputs * input din: [n, c / 4, h, w * 4], output dout: [n, c, h, w] */ @@ -738,11 +1147,11 @@ inline bool write_to_output_c2_fp32(const float* din, int height, int width, bool flag_relu, - float* trash_ptr) { + float* trash_ptr, + operators::ActivationParam* act_param) { if (cs > channel) { return true; } - const int c2 = 2; const int w4 = 4; @@ -775,141 +1184,56 @@ inline bool write_to_output_c2_fp32(const float* din, const float* din_hei_ptr = ptr_din + i * w_round * c2; if (cnt > 0) { int cnt_loop = cnt; - if (flag_relu) { -#ifdef __aarch64__ - asm volatile( - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load data, c0r0, c1r0, c0r1, - c1r1, , c0r2, c1r2, c0r3, - c1r3 */ - "movi v20.4s, #0 \n" /* for relu */ - "1: \n" /* main loop*/ - "trn1 v2.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ - "trn2 v3.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load data, c0r0, c1r0, c0r1, - c1r1, , c0r2, c1r2, c0r3, - c1r3 */ - "trn1 v4.2d, v2.2d, v3.2d \n" /* trans q8, q10*/ - "trn2 v5.2d, v2.2d, v3.2d \n" /* trans q8, q10*/ - - "fmax v2.4s, v4.4s, v20.4s \n" /*relu*/ - "fmax v3.4s, v5.4s, v20.4s \n" /*relu*/ - - "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ - - "str q2, [%[doutc0r0]], #16 \n" /* store c0r0*/ - "str q3, [%[doutc1r0]], #16 \n" /* store c2r0*/ - - "bne 1b \n" /* jump to main loop*/ - - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [cnt] "+r"(cnt_loop), - [ptr_din] "+r"(din_hei_ptr) - : - : "v0", "v1", "v2", "v3", "v4", "v5", "v20"); -#else - asm volatile( - "vld1.32 {d0-d3}, [%[ptr_din]]! @ load data, c0r0, " - "c1r0, c0r1, c1r1, , c0r2, c1r2, c0r3, c1r3\n" - "vmov.u32 q15, #0 @ dump zero\n" - "1: @ main loop\n" - "vtrn.32 d0, d1 @ trans data:c0r0, c0r1, " - "c1r0, c1r1 \n" - "vtrn.32 d2, d3 @ trans data:c0r2, c0r3, " - "c1r2, c1r3 \n" - - "vswp d1, d2 @ swap data\n" - - "vmax.f32 q0, q0, q15 @ relu\n" - "vmax.f32 q1, q1, q15 @ relu\n" - - "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add " - "pointer\n" - "vst1.32 {d2-d3}, [%[doutc1r0]]! @ store result, add " - "pointer\n" - - "subs %[cnt], %[cnt], #1 @ loop count - 1\n" - - "vld1.32 {d0-d3}, [%[ptr_din]]! @ load data \n" - - "bne 1b @ jump to main loop\n" - - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [ptr_din] "+r"(din_hei_ptr), - [cnt] "+r"(cnt_loop) - : - : "q0", "q1", "q2", "q3", "q15"); -#endif - } else { -#ifdef __aarch64__ - asm volatile( - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load data, c0r0, c1r0, c0r1, - c1r1, , c0r2, c1r2, c0r3, - c1r3 */ - "1: \n" /* main loop*/ - "trn1 v2.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ - "trn2 v3.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load data, c0r0, c1r0, c0r1, - c1r1, , c0r2, c1r2, c0r3, - c1r3 */ - "trn1 v4.2d, v2.2d, v3.2d \n" /* trans q8, q10*/ - "trn2 v5.2d, v2.2d, v3.2d \n" /* trans q8, q10*/ - - "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ - - "str q4, [%[doutc0r0]], #16 \n" /* store c0r0*/ - "str q5, [%[doutc1r0]], #16 \n" /* store c2r0*/ - - "bne 1b \n" /* jump to main loop*/ - - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [cnt] "+r"(cnt_loop), - [ptr_din] "+r"(din_hei_ptr) - : - : "v0", "v1", "v2", "v3", "v4", "v5"); -#else - asm volatile( - "vld1.32 {d0-d3}, [%[ptr_din]]! @ load data, c0r0, " - "c1r0, c0r1, c1r1, , c0r2, c1r2, c0r3, c1r3\n" - "1: @ main loop\n" - "vtrn.32 d0, d1 @ trans data:c0r0, c0r1, " - "c1r0, c1r1 \n" - "vtrn.32 d2, d3 @ trans data:c0r2, c0r3, " - "c1r2, c1r3 \n" - - "vswp d1, d2 @ swap data\n" - - "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add " - "pointer\n" - "vst1.32 {d2-d3}, [%[doutc1r0]]! @ store result, add " - "pointer\n" - - "subs %[cnt], %[cnt], #1 @ loop count - 1\n" - - "vld1.32 {d0-d3}, [%[ptr_din]]! @ load data \n" - - "bne 1b @ jump to main loop\n" - - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [ptr_din] "+r"(din_hei_ptr), - [cnt] "+r"(cnt_loop) - : - : "q0", "q1", "q2", "q3", "q15"); -#endif - } + act_switch_c2_fp32( + din_hei_ptr, doutc0_ptr, doutc1_ptr, cnt_loop, act_param); } if (we > width) { int offset = i * w_round * c2 + c2 * w4 * cnt; din_hei_ptr = ptr_din + offset; + doutc0_ptr += w4 * cnt; + doutc1_ptr += w4 * cnt; int j = we - w4; - if (flag_relu) { - for (; j < width; ++j) { - *(doutc0_ptr++) = LITEMAX(din_hei_ptr[0], 0.f); - *(doutc1_ptr++) = LITEMAX(din_hei_ptr[1], 0.f); - din_hei_ptr += 2; + if (act_param != nullptr && act_param->has_active) { + float six = act_param->Relu_clipped_coef; + float scale = act_param->Leaky_relu_alpha; + switch (act_param->active_type) { + case lite_api::ActivationType::kRelu: + for (; j < width; ++j) { + *(doutc0_ptr++) = LITEMAX(din_hei_ptr[0], 0.f); + *(doutc1_ptr++) = LITEMAX(din_hei_ptr[1], 0.f); + din_hei_ptr += 2; + } + break; + case lite_api::ActivationType::kRelu6: + /* 0 <= din <= 6 */ + for (; j < width; ++j) { + float tmp1 = LITEMAX(din_hei_ptr[0], 0.f); + float tmp2 = LITEMAX(din_hei_ptr[1], 0.f); + *(doutc0_ptr++) = LITEMIN(tmp1, six); + *(doutc1_ptr++) = LITEMIN(tmp2, six); + din_hei_ptr += 2; + } + break; + case lite_api::ActivationType::kLeakyRelu: + /*din = din >= 0 ? din : din * scale*/ + for (; j < width; ++j) { + if (din_hei_ptr[0] >= 0) { + *(doutc0_ptr++) = din_hei_ptr[0]; + } else { + *(doutc0_ptr++) = din_hei_ptr[0] * scale; + } + if (din_hei_ptr[1] >= 0) { + *(doutc1_ptr++) = din_hei_ptr[1]; + } else { + *(doutc1_ptr++) = din_hei_ptr[1] * scale; + } + din_hei_ptr += 2; + } + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param->active_type) + << " fuse not support"; } } else { for (; j < width; ++j) { @@ -921,7 +1245,309 @@ inline bool write_to_output_c2_fp32(const float* din, } return true; } - +// clang-format off +#ifdef __aarch64__ +#define NCHWC4_TRANS_FP32_COMPUTE \ + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ \ + "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ \ + "movi v20.4s, #0 \n" /* for relu */ \ + "1: \n" /* main loop*/ \ + "trn1 v8.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ \ + "trn2 v9.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ \ + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ \ + "trn1 v10.4s, v2.4s, v3.4s \n" /* trans q2, q3*/ \ + "trn2 v11.4s, v2.4s, v3.4s \n" /* trans q2, q3*/ \ + "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ \ + "trn1 v16.2d, v8.2d, v10.2d \n" /* trans q8, q10*/ \ + "trn2 v17.2d, v8.2d, v10.2d \n" /* trans q8, q10*/ \ + "trn1 v18.2d, v9.2d, v11.2d \n" /* trans q9, q11*/ \ + "trn2 v19.2d, v9.2d, v11.2d \n" /* trans q9, q11*/ + +#define NCHWC4_TRANS_FP32_RELU \ + "fmax v16.4s, v16.4s, v20.4s \n" /*relu*/ \ + "fmax v17.4s, v17.4s, v20.4s \n" /*relu*/ \ + "fmax v18.4s, v18.4s, v20.4s \n" /*relu*/ \ + "fmax v19.4s, v19.4s, v20.4s \n" /*relu*/ + +#define NCHWC4_TRANS_FP32_RELU6 \ + "fmin v16.4s, v16.4s, %[six].4s \n" /* relu6 */ \ + "fmin v17.4s, v17.4s, %[six].4s \n" /* relu6 */ \ + "fmin v18.4s, v18.4s, %[six].4s \n" /* relu6 */ \ + "fmin v19.4s, v19.4s, %[six].4s \n" /* relu6 */ + +#define NCHWC4_TRANS_FP32_LEAKY_RELU \ + "fcmge v8.4s, v16.4s, v20.4s \n" /* vcgeq_f32 */ \ + "fcmge v9.4s, v17.4s, v20.4s \n" /* vcgeq_f32 */ \ + "fcmge v10.4s, v18.4s, v20.4s \n" /* vcgeq_f32 */ \ + "fcmge v11.4s, v19.4s, v20.4s \n" /* vcgeq_f32 */ \ + "fmul v4.4s, v16.4s, %[scale].4s \n" /* mul */ \ + "fmul v5.4s, v17.4s, %[scale].4s \n" /* mul */ \ + "fmul v6.4s, v18.4s, %[scale].4s \n" /* mul */ \ + "fmul v7.4s, v19.4s, %[scale].4s \n" /* mul */ \ + "bif v16.16b, v4.16b, v8.16b \n" /* choose*/ \ + "bif v17.16b, v5.16b, v9.16b \n" /* choose*/ \ + "bif v18.16b, v6.16b, v10.16b \n" /* choose*/ \ + "bif v19.16b, v7.16b, v11.16b \n" /* choose*/ + +#define NCHWC4_TRANS_FP32_STORE \ + "str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/ \ + "str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/ \ + "str q18, [%[doutc1r0]], #16 \n" /* store c1r0*/ \ + "str q19, [%[doutc3r0]], #16 \n" /* store c3r0*/ \ + \ + "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ \ + "bne 1b \n" /* jump to main loop*/ +#else +#define NCHWC4_TRANS_FP32_COMPUTE \ + "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" \ + "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" \ + "vmov.u32 q15, #0 @ dump zero\n" \ + "1: @ main loop\n" \ + "vtrn.32 q0, q1 @ trans data:c00c01c20c21 " \ + "\n" \ + "vtrn.32 q2, q3 @ trans data:c02c03c22c23 " \ + "\n" \ + \ + "vswp d1, d4 @ swap data\n" \ + "vswp d3, d6 @ swap data\n" + +#define NCHWC4_TRANS_FP32_RELU \ + "vmax.f32 q0, q0, q15 @ relu\n" \ + "vmax.f32 q1, q1, q15 @ relu\n" \ + "vmax.f32 q2, q2, q15 @ relu\n" \ + "vmax.f32 q3, q3, q15 @ relu\n" + +#define NCHWC4_TRANS_FP32_RELU6 \ + "vmin.f32 q0, q0, %q[six] @ relu6 \n" \ + "vmin.f32 q1, q1, %q[six] @ relu6 \n" \ + "vmin.f32 q2, q2, %q[six] @ relu6 \n" \ + "vmin.f32 q3, q3, %q[six] @ relu6 \n" + +#define NCHWC4_TRANS_FP32_LEAKY_RELU \ + "vcge.f32 q5, q0, q15 @ q0 > 0 \n" \ + "vcge.f32 q6, q1, q15 @ q0 > 0 \n" \ + "vcge.f32 q7, q2, q15 @ q0 > 0 \n" \ + "vcge.f32 q8, q3, q15 @ q0 > 0 \n" \ + "vmul.f32 q9, q0, %q[scale] \n" \ + "vmul.f32 q10, q1, %q[scale] \n" \ + "vmul.f32 q11, q2, %q[scale] \n" \ + "vmul.f32 q12, q3, %q[scale] \n" \ + "vbif q0, q9, q5 @ choose \n" \ + "vbif q1, q10, q6 @ choose \n" \ + "vbif q2, q11, q7 @ choose \n" \ + "vbif q3, q12, q8 @ choose \n" + +#define NCHWC4_TRANS_FP32_STORE \ + "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add pointer\n" \ + "vst1.32 {d2-d3}, [%[doutc1r0]]! @ store result, add pointer\n" \ + "vst1.32 {d4-d5}, [%[doutc2r0]]! @ store result, add pointer\n" \ + "vst1.32 {d6-d7}, [%[doutc3r0]]! @ store result, add pointer\n" \ + \ + "subs %[cnt], %[cnt], #1 @ loop count - 1\n" \ + \ + "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" \ + "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" \ + \ + "bne 1b @ jump to main loop\n" +#endif +// clang-format on +inline void act_switch_c4_fp32(const float* din_ptr, + float* doutc0_ptr, + float* doutc1_ptr, + float* doutc2_ptr, + float* doutc3_ptr, + int cnt_loop, + const operators::ActivationParam* act_param) { + if (act_param != nullptr && act_param->has_active) { + float32x4_t six = vdupq_n_f32(act_param->Relu_clipped_coef); + float32x4_t scale = vdupq_n_f32(act_param->Leaky_relu_alpha); + switch (act_param->active_type) { + case lite_api::ActivationType::kRelu: +#ifdef __aarch64__ + asm volatile(NCHWC4_TRANS_FP32_COMPUTE NCHWC4_TRANS_FP32_RELU + NCHWC4_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_ptr) + : + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v16", + "v17", + "v18", + "v19", + "v20"); +#else + asm volatile(NCHWC4_TRANS_FP32_COMPUTE NCHWC4_TRANS_FP32_RELU + NCHWC4_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [ptr_din] "+r"(din_ptr), + [cnt] "+r"(cnt_loop) + : + : "q0", "q1", "q2", "q3", "q15"); +#endif + break; + case lite_api::ActivationType::kRelu6: +/* 0 <= din <= 6 */ +#ifdef __aarch64__ + asm volatile(NCHWC4_TRANS_FP32_COMPUTE NCHWC4_TRANS_FP32_RELU + NCHWC4_TRANS_FP32_RELU6 NCHWC4_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_ptr) + : [six] "w"(six) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v16", + "v17", + "v18", + "v19", + "v20"); +#else + asm volatile(NCHWC4_TRANS_FP32_COMPUTE NCHWC4_TRANS_FP32_RELU + NCHWC4_TRANS_FP32_RELU6 NCHWC4_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [ptr_din] "+r"(din_ptr), + [cnt] "+r"(cnt_loop) + : [six] "w"(six) + : "q0", "q1", "q2", "q3", "q15"); +#endif + break; + case lite_api::ActivationType::kLeakyRelu: +/*din = din >= 0 ? din : din * scale*/ +#ifdef __aarch64__ + asm volatile(NCHWC4_TRANS_FP32_COMPUTE NCHWC4_TRANS_FP32_LEAKY_RELU + NCHWC4_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_ptr) + : [scale] "w"(scale) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v16", + "v17", + "v18", + "v19", + "v20"); +#else + asm volatile(NCHWC4_TRANS_FP32_COMPUTE NCHWC4_TRANS_FP32_LEAKY_RELU + NCHWC4_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [ptr_din] "+r"(din_ptr), + [cnt] "+r"(cnt_loop) + : [scale] "w"(scale) + : "q0", + "q1", + "q2", + "q3", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q15"); +#endif + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param->active_type) + << " fuse not support"; + } + } else { +#ifdef __aarch64__ + asm volatile(NCHWC4_TRANS_FP32_COMPUTE NCHWC4_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_ptr) + : + : "v0", + "v1", + "v2", + "v3", + "v8", + "v9", + "v10", + "v11", + "v16", + "v17", + "v18", + "v19"); +#else + asm volatile(NCHWC4_TRANS_FP32_COMPUTE NCHWC4_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [ptr_din] "+r"(din_ptr), + [cnt] "+r"(cnt_loop) + : + : "q0", "q1", "q2", "q3", "q15"); +#endif + } +} /*wirte result in outputs * input din: [n, c / 4, h, w * 4], output dout: [n, c, h, w] */ @@ -937,11 +1563,13 @@ inline bool write_to_output_c4_fp32(const float* din, int height, int width, bool flag_relu, - float* trash_ptr) { + float* trash_ptr, + operators::ActivationParam* act_param) { const int c4 = 4; const int w4 = 4; const int w_round = we - ws; const int ch_n = ce - cs; + if (ch_n != 4) { LOG(ERROR) << "write_to_output_c4_fp32 ch_n must be equal 4 and hei_n is " "more than zero"; @@ -958,7 +1586,9 @@ inline bool write_to_output_c4_fp32(const float* din, int size_h = (he > height ? height : he) - hs; // size_h == hei_n - int cnt = (width - ws) / w4; + int valid_we = we > width ? width : we; + int cnt = (valid_we - ws) / w4; + int remain = valid_we - ws - cnt * w4; for (int i = 0; i < size_h; i++) { int size_w = i * width; @@ -981,206 +1611,751 @@ inline bool write_to_output_c4_fp32(const float* din, const float* din_hei_ptr = ptr_din + i * w_round * ch_n; if (cnt > 0) { int cnt_loop = cnt; - if (flag_relu) { -#ifdef __aarch64__ - asm volatile( - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - "movi v20.4s, #0 \n" /* for relu */ - "1: \n" /* main loop*/ - "trn1 v8.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ - "trn2 v9.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - "trn1 v10.4s, v2.4s, v3.4s \n" /* trans q2, q3*/ - "trn2 v11.4s, v2.4s, v3.4s \n" /* trans q2, q3*/ - "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - "trn1 v16.2d, v8.2d, v10.2d \n" /* trans q8, q10*/ - "trn2 v17.2d, v8.2d, v10.2d \n" /* trans q8, q10*/ - "trn1 v18.2d, v9.2d, v11.2d \n" /* trans q9, q11*/ - "trn2 v19.2d, v9.2d, v11.2d \n" /* trans q9, q11*/ - "fmax v16.4s, v16.4s, v20.4s \n" /*relu*/ - "fmax v17.4s, v17.4s, v20.4s \n" /*relu*/ - "fmax v18.4s, v18.4s, v20.4s \n" /*relu*/ - "fmax v19.4s, v19.4s, v20.4s \n" /*relu*/ - "str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/ - "str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/ - "str q18, [%[doutc1r0]], #16 \n" /* store c1r0*/ - "str q19, [%[doutc3r0]], #16 \n" /* store c3r0*/ - - "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ - "bne 1b \n" /* jump to main loop*/ - - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [doutc2r0] "+r"(doutc2_ptr), - [doutc3r0] "+r"(doutc3_ptr), - [cnt] "+r"(cnt_loop), - [ptr_din] "+r"(din_hei_ptr) - : - : "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v16", - "v17", - "v18", - "v19", - "v20"); -#else - asm volatile( - "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" - "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" - "vmov.u32 q15, #0 @ dump zero\n" - "1: @ main loop\n" - "vtrn.32 q0, q1 @ trans data:c00c01c20c21 " - "\n" - "vtrn.32 q2, q3 @ trans data:c02c03c22c23 " - "\n" - - "vswp d1, d4 @ swap data\n" - "vswp d3, d6 @ swap data\n" - - "vmax.f32 q0, q0, q15 @ relu\n" - "vmax.f32 q1, q1, q15 @ relu\n" - "vmax.f32 q2, q2, q15 @ relu\n" - "vmax.f32 q3, q3, q15 @ relu\n" - - "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add pointer\n" - "vst1.32 {d2-d3}, [%[doutc1r0]]! @ store result, add pointer\n" - "vst1.32 {d4-d5}, [%[doutc2r0]]! @ store result, add pointer\n" - "vst1.32 {d6-d7}, [%[doutc3r0]]! @ store result, add pointer\n" - - "subs %[cnt], %[cnt], #1 @ loop count - 1\n" - - "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" - "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" - - "bne 1b @ jump to main loop\n" - - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [doutc2r0] "+r"(doutc2_ptr), - [doutc3r0] "+r"(doutc3_ptr), - [ptr_din] "+r"(din_hei_ptr), - [cnt] "+r"(cnt_loop) - : - : "q0", "q1", "q2", "q3", "q15"); -#endif - } else { -#ifdef __aarch64__ - asm volatile( - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - "1: \n" /* main loop*/ - "trn1 v8.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ - "trn2 v9.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - "trn1 v10.4s, v2.4s, v3.4s \n" /* trans q2, q3*/ - "trn2 v11.4s, v2.4s, v3.4s \n" /* trans q2, q3*/ - "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - "trn1 v16.2d, v8.2d, v10.2d \n" /* trans q8, q10*/ - "trn2 v17.2d, v8.2d, v10.2d \n" /* trans q8, q10*/ - "trn1 v18.2d, v9.2d, v11.2d \n" /* trans q9, q11*/ - "trn2 v19.2d, v9.2d, v11.2d \n" /* trans q9, q11*/ - "str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/ - "str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/ - "str q18, [%[doutc1r0]], #16 \n" /* store c1r0*/ - "str q19, [%[doutc3r0]], #16 \n" /* store c3r0*/ - - "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ - "bne 1b \n" /* jump to main loop*/ - - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [doutc2r0] "+r"(doutc2_ptr), - [doutc3r0] "+r"(doutc3_ptr), - [cnt] "+r"(cnt_loop), - [ptr_din] "+r"(din_hei_ptr) - : - : "v0", - "v1", - "v2", - "v3", - "v8", - "v9", - "v10", - "v11", - "v16", - "v17", - "v18", - "v19"); -#else - asm volatile( - "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" - "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" - "1: @ main loop\n" - "vtrn.32 q0, q1 @ trans data:c00c01c20c21 " - "\n" - "vtrn.32 q2, q3 @ trans data:c02c03c22c23 " - "\n" - - "vswp d1, d4 @ swap data\n" - "vswp d3, d6 @ swap data\n" - - "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add pointer\n" - "vst1.32 {d2-d3}, [%[doutc1r0]]! @ store result, add pointer\n" - "vst1.32 {d4-d5}, [%[doutc2r0]]! @ store result, add pointer\n" - "vst1.32 {d6-d7}, [%[doutc3r0]]! @ store result, add pointer\n" - - "subs %[cnt], %[cnt], #1 @ loop count - 1\n" - - "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" - "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" - - "bne 1b @ jump to main loop\n" - - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [doutc2r0] "+r"(doutc2_ptr), - [doutc3r0] "+r"(doutc3_ptr), - [ptr_din] "+r"(din_hei_ptr), - [cnt] "+r"(cnt_loop) - : - : "q0", "q1", "q2", "q3"); -#endif - } + act_switch_c4_fp32(din_hei_ptr, + doutc0_ptr, + doutc1_ptr, + doutc2_ptr, + doutc3_ptr, + cnt_loop, + act_param); } - if (we > width) { + if (remain > 0) { int offset = i * w_round * c4 + c4 * w4 * cnt; din_hei_ptr = ptr_din + offset; - int j = we - w4; - if (flag_relu) { - for (; j < width; ++j) { - *(doutc0_ptr++) = LITEMAX(din_hei_ptr[0], 0.f); - *(doutc1_ptr++) = LITEMAX(din_hei_ptr[1], 0.f); - *(doutc2_ptr++) = LITEMAX(din_hei_ptr[2], 0.f); - *(doutc3_ptr++) = LITEMAX(din_hei_ptr[3], 0.f); - din_hei_ptr += w4; + doutc0_ptr += w4 * cnt; + doutc1_ptr += w4 * cnt; + doutc2_ptr += w4 * cnt; + doutc3_ptr += w4 * cnt; + int j = 0; + if (act_param != nullptr && act_param->has_active) { + float six = act_param->Relu_clipped_coef; + float scale = act_param->Leaky_relu_alpha; + switch (act_param->active_type) { + case lite_api::ActivationType::kRelu: + for (; j < remain; ++j) { + *(doutc0_ptr++) = LITEMAX(din_hei_ptr[0], 0.f); + *(doutc1_ptr++) = LITEMAX(din_hei_ptr[1], 0.f); + *(doutc2_ptr++) = LITEMAX(din_hei_ptr[2], 0.f); + *(doutc3_ptr++) = LITEMAX(din_hei_ptr[3], 0.f); + din_hei_ptr += 4; + } + break; + case lite_api::ActivationType::kRelu6: + /* 0 <= din <= 6 */ + for (; j < remain; ++j) { + float tmp1 = LITEMAX(din_hei_ptr[0], 0.f); + float tmp2 = LITEMAX(din_hei_ptr[1], 0.f); + float tmp3 = LITEMAX(din_hei_ptr[2], 0.f); + float tmp4 = LITEMAX(din_hei_ptr[3], 0.f); + *(doutc0_ptr++) = LITEMIN(tmp1, six); + *(doutc1_ptr++) = LITEMIN(tmp2, six); + *(doutc2_ptr++) = LITEMIN(tmp3, six); + *(doutc3_ptr++) = LITEMIN(tmp4, six); + din_hei_ptr += 4; + } + break; + case lite_api::ActivationType::kLeakyRelu: + /*din = din >= 0 ? din : din * scale*/ + for (; j < remain; ++j) { + if (din_hei_ptr[0] >= 0) { + *(doutc0_ptr++) = din_hei_ptr[0]; + } else { + *(doutc0_ptr++) = din_hei_ptr[0] * scale; + } + if (din_hei_ptr[1] >= 0) { + *(doutc1_ptr++) = din_hei_ptr[1]; + } else { + *(doutc1_ptr++) = din_hei_ptr[1] * scale; + } + if (din_hei_ptr[2] >= 0) { + *(doutc2_ptr++) = din_hei_ptr[2]; + } else { + *(doutc2_ptr++) = din_hei_ptr[2] * scale; + } + if (din_hei_ptr[3] >= 0) { + *(doutc3_ptr++) = din_hei_ptr[3]; + } else { + *(doutc3_ptr++) = din_hei_ptr[3] * scale; + } + din_hei_ptr += 4; + } + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param->active_type) + << " fuse not support"; } } else { - for (; j < width; ++j) { + for (; j < remain; ++j) { *(doutc0_ptr++) = din_hei_ptr[0]; *(doutc1_ptr++) = din_hei_ptr[1]; *(doutc2_ptr++) = din_hei_ptr[2]; *(doutc3_ptr++) = din_hei_ptr[3]; - din_hei_ptr += w4; + din_hei_ptr += 4; } } } } return true; } +// clang-format off +#ifdef __aarch64__ +#define NCHWC8_TRANS_FP32_COMPUTE \ + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ \ + "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ \ + "ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ \ + "ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ \ + "movi v20.4s, #0 \n" /* for relu */ \ + "1: \n" /* main loop*/ \ + "trn1 v8.4s, v0.4s, v2.4s \n" /* trans q0, q1*/ \ + "trn2 v9.4s, v0.4s, v2.4s \n" /* trans q0, q1*/ \ + "trn1 v10.4s, v1.4s, v3.4s \n" /* trans q2, q3*/ \ + "trn2 v11.4s, v1.4s, v3.4s \n" /* trans q2, q3*/ \ + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ \ + \ + "trn1 v12.4s, v4.4s, v6.4s \n" /* trans q0, q1*/ \ + "trn2 v13.4s, v4.4s, v6.4s \n" /* trans q0, q1*/ \ + "trn1 v14.4s, v5.4s, v7.4s \n" /* trans q2, q3*/ \ + "trn2 v15.4s, v5.4s, v7.4s \n" /* trans q2, q3*/ \ + "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ \ + \ + "trn1 v16.2d, v8.2d, v12.2d \n" /* trans q8, q10 00 01 02 03*/ \ + "trn2 v17.2d, v8.2d, v12.2d \n" /* trans q8, q10 20 21 22 23*/ \ + "trn1 v18.2d, v9.2d, v13.2d \n" /* trans q9, q11 10 11 12 13*/ \ + "trn2 v19.2d, v9.2d, v13.2d \n" /* trans q9, q11 30 31 32 33*/ \ + "ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ \ + \ + "trn1 v8.2d, v10.2d, v14.2d \n" /* trans q8, q10 40 41 42 43*/ \ + "trn2 v9.2d, v10.2d, v14.2d \n" /* trans q8, q10 60 61 62 63*/ \ + "trn1 v12.2d, v11.2d, v15.2d \n" /* trans q9, q11 50 51 52 53*/ \ + "trn2 v13.2d, v11.2d, v15.2d \n" /* trans q9, q11 70 71 72 73*/ \ + "ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + +#define NCHWC8_TRANS_FP32_RELU \ + "fmax v16.4s, v16.4s, v20.4s \n" /*relu*/ \ + "fmax v17.4s, v17.4s, v20.4s \n" /*relu*/ \ + "fmax v18.4s, v18.4s, v20.4s \n" /*relu*/ \ + "fmax v19.4s, v19.4s, v20.4s \n" /*relu*/ \ + \ + "fmax v8.4s, v8.4s, v20.4s \n" /*relu*/ \ + "fmax v9.4s, v9.4s, v20.4s \n" /*relu*/ \ + "fmax v12.4s, v12.4s, v20.4s \n" /*relu*/ \ + "fmax v13.4s, v13.4s, v20.4s \n" /*relu*/ + +#define NCHWC8_TRANS_FP32_RELU6 \ + "fmin v16.4s, v16.4s, %[six].4s \n" /*relu6*/ \ + "fmin v17.4s, v17.4s, %[six].4s \n" /*relu6*/ \ + "fmin v18.4s, v18.4s, %[six].4s \n" /*relu6*/ \ + "fmin v19.4s, v19.4s, %[six].4s \n" /*relu6*/ \ + \ + "fmin v8.4s, v8.4s, %[six].4s \n" /*relu6*/ \ + "fmin v9.4s, v9.4s, %[six].4s \n" /*relu6*/ \ + "fmin v12.4s, v12.4s, %[six].4s \n" /*relu6*/ \ + "fmin v13.4s, v13.4s, %[six].4s \n" /*relu6*/ + +#define NCHWC8_TRANS_FP32_LEAKY_RELU \ + "fcmge v10.4s, v16.4s, v20.4s \n" /* vcgeq_u32 */ \ + "fcmge v11.4s, v17.4s, v20.4s \n" /* vcgeq_u32 */ \ + "fcmge v14.4s, v18.4s, v20.4s \n" /* vcgeq_u32 */ \ + "fcmge v15.4s, v19.4s, v20.4s \n" /* vcgeq_u32 */ \ + \ + "fcmge v21.4s, v8.4s, v20.4s \n" /* vcgeq_u32 */ \ + "fcmge v22.4s, v9.4s, v20.4s \n" /* vcgeq_u32 */ \ + "fcmge v23.4s, v12.4s, v20.4s \n" /* vcgeq_u32 */ \ + "fcmge v24.4s, v13.4s, v20.4s \n" /* vcgeq_u32 */ \ + \ + "fmul v25.4s, v16.4s, %[scale].4s \n" /* mul */ \ + "fmul v26.4s, v17.4s, %[scale].4s \n" /* mul */ \ + "fmul v27.4s, v18.4s, %[scale].4s \n" /* mul */ \ + "fmul v28.4s, v19.4s, %[scale].4s \n" /* mul */ \ + \ + "fmul v29.4s, v8.4s, %[scale].4s \n" /* mul */ \ + "fmul v30.4s, v9.4s, %[scale].4s \n" /* mul */ \ + "fmul v31.4s, v12.4s, %[scale].4s \n" /* mul */ \ + \ + "bif v16.16b, v25.16b, v10.16b \n" /* choose*/ \ + "bif v17.16b, v26.16b, v11.16b \n" /* choose*/ \ + "bif v18.16b, v27.16b, v14.16b \n" /* choose*/ \ + "bif v19.16b, v28.16b, v15.16b \n" /* choose*/ \ + "fmul v25.4s, v13.4s, %[scale].4s \n" /* mul */ \ + \ + "bif v8.16b, v29.16b, v21.16b \n" /* choose*/ \ + "bif v9.16b, v30.16b, v22.16b \n" /* choose*/ \ + "bif v12.16b, v31.16b, v23.16b \n" /* choose*/ \ + "bif v13.16b, v25.16b, v24.16b \n" /* choose*/ + +#define NCHWC8_TRANS_FP32_STORE \ + "str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/ \ + "str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/ \ + "str q18, [%[doutc1r0]], #16 \n" /* store c1r0*/ \ + "str q19, [%[doutc3r0]], #16 \n" /* store c3r0*/ \ + \ + "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ \ + "str q8, [%[doutc4r0]], #16 \n" /* store c0r0*/ \ + "str q9, [%[doutc6r0]], #16 \n" /* store c2r0*/ \ + "str q12, [%[doutc5r0]], #16 \n" /* store c1r0*/ \ + "str q13, [%[doutc7r0]], #16 \n" /* store c3r0*/ \ + \ + "bne 1b \n" /* jump to main loop*/ + +#else +#define NCHWC8_TRANS_FP32_COMPUTE \ + "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" \ + "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" \ + "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" \ + "vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n" \ + "vmov.u32 q15, #0 @ dump zero\n" \ + "1: @ main loop\n" \ + "vtrn.32 q0, q2 @ trans q0, q2 \n" \ + "vtrn.32 q4, q6 @ trans q4, q6 \n" \ + "vswp.32 d1, d8 @ swap d1, d8 \n" \ + "vswp.32 d5, d12 @ swap d5, d12\n" \ + \ + "vtrn.32 q1, q3 @ trans q1, q3 \n" \ + "vtrn.32 q5, q7 @ trans q5, q7 \n" \ + "vswp.32 d3, d10 @ swap d3, d10\n" \ + "vswp.32 d7, d14 @ swap d7, d14\n" + +#define NCHWC8_TRANS_FP32_RELU \ + "vmax.f32 q0, q0, q15 @ relu\n" \ + "vmax.f32 q1, q1, q15 @ relu\n" \ + "vmax.f32 q2, q2, q15 @ relu\n" \ + "vmax.f32 q3, q3, q15 @ relu\n" \ + \ + "vmax.f32 q4, q4, q15 @ relu\n" \ + "vmax.f32 q5, q5, q15 @ relu\n" \ + "vmax.f32 q6, q6, q15 @ relu\n" \ + "vmax.f32 q7, q7, q15 @ relu\n" + +#define NCHWC8_TRANS_FP32_RELU6 \ + "vmin.f32 q0, q0, %q[six] @ relu6\n" \ + "vmin.f32 q1, q1, %q[six] @ relu6\n" \ + "vmin.f32 q2, q2, %q[six] @ relu6\n" \ + "vmin.f32 q3, q3, %q[six] @ relu6\n" \ + \ + "vmin.f32 q4, q4, %q[six] @ relu6\n" \ + "vmin.f32 q5, q5, %q[six] @ relu6\n" \ + "vmin.f32 q6, q6, %q[six] @ relu6\n" \ + "vmin.f32 q7, q7, %q[six] @ relu6\n" + +#define NCHWC8_TRANS_FP32_LEAKY_RELU \ + "vcge.f32 q9, q0, q15 @ q0 > 0 \n" \ + "vcge.f32 q10, q1, q15 @ q0 > 0 \n" \ + "vcge.f32 q11, q2, q15 @ q0 > 0 \n" \ + "vcge.f32 q12, q3, q15 @ q0 > 0 \n" \ + "vmul.f32 q13, q0, %q[scale] \n" \ + "vmul.f32 q14, q1, %q[scale] \n" \ + "vmul.f32 q15, q2, %q[scale] \n" \ + \ + "vbif q0, q13, q9 @ choose \n" \ + "vmul.f32 q9, q3, %q[scale] \n" \ + \ + "vbif q1, q14, q10 @ choose \n" \ + "vbif q2, q15, q11 @ choose \n" \ + "vbif q3, q9, q12 @ choose \n" \ + \ + "vcge.f32 q9, q4, q15 @ q0 > 0 \n" \ + "vcge.f32 q10, q5, q15 @ q0 > 0 \n" \ + "vcge.f32 q11, q6, q15 @ q0 > 0 \n" \ + "vcge.f32 q12, q7, q15 @ q0 > 0 \n" \ + "vmul.f32 q13, q4, %q[scale] \n" \ + "vmul.f32 q14, q5, %q[scale] \n" \ + "vmul.f32 q15, q6, %q[scale] \n" \ + \ + "vbif q4, q13, q9 @ choose \n" \ + "vmul.f32 q9, q7, %q[scale] \n" \ + \ + "vbif q5, q14, q10 @ choose \n" \ + "vbif q6, q15, q11 @ choose \n" \ + "vbif q7, q9, q12 @ choose \n" + +#define NCHWC8_TRANS_FP32_STORE \ + "subs %[cnt], %[cnt], #1 @ loop count - 1\n" \ + "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add " \ + "pointer\n" \ + "vst1.32 {d2-d3}, [%[doutc4r0]]! @ store result, add " \ + "pointer\n" \ + "vst1.32 {d4-d5}, [%[doutc1r0]]! @ store result, add " \ + "pointer\n" \ + "vst1.32 {d6-d7}, [%[doutc5r0]]! @ store result, add " \ + "pointer\n" \ + \ + "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" \ + "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" \ + \ + "vst1.32 {d8-d9}, [%[doutc2r0]]! @ store result, add " \ + "pointer\n" \ + "vst1.32 {d10-d11}, [%[doutc6r0]]! @ store result, add " \ + "pointer\n" \ + "vst1.32 {d12-d13}, [%[doutc3r0]]! @ store result, add " \ + "pointer\n" \ + "vst1.32 {d14-d15}, [%[doutc7r0]]! @ store result, add " \ + "pointer\n" \ + \ + "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" \ + "vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n" \ + \ + "bne 1b @ jump to main loop\n" + +#endif +// clang-format on +inline void act_switch_c8_fp32(const float* din_ptr, + float* doutc0_ptr, + float* doutc1_ptr, + float* doutc2_ptr, + float* doutc3_ptr, + float* doutc4_ptr, + float* doutc5_ptr, + float* doutc6_ptr, + float* doutc7_ptr, + int cnt_loop, + const operators::ActivationParam* act_param) { + if (act_param != nullptr && act_param->has_active) { + float32x4_t six = vdupq_n_f32(act_param->Relu_clipped_coef); + float32x4_t scale = vdupq_n_f32(act_param->Leaky_relu_alpha); + switch (act_param->active_type) { + case lite_api::ActivationType::kRelu: +#ifdef __aarch64__ + asm volatile(NCHWC8_TRANS_FP32_COMPUTE NCHWC8_TRANS_FP32_RELU + NCHWC8_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [doutc4r0] "+r"(doutc4_ptr), + [doutc5r0] "+r"(doutc5_ptr), + [doutc6r0] "+r"(doutc6_ptr), + [doutc7r0] "+r"(doutc7_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_ptr) + : + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v16", + "v17", + "v18", + "v19", + "v20"); +#else + asm volatile(NCHWC8_TRANS_FP32_COMPUTE NCHWC8_TRANS_FP32_RELU + NCHWC8_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [doutc4r0] "+r"(doutc4_ptr), + [doutc5r0] "+r"(doutc5_ptr), + [doutc6r0] "+r"(doutc6_ptr), + [doutc7r0] "+r"(doutc7_ptr), + [ptr_din] "+r"(din_ptr), + [cnt] "+r"(cnt_loop) + : + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q15"); +#endif + break; + case lite_api::ActivationType::kRelu6: +/* 0 <= din <= 6 */ +#ifdef __aarch64__ + asm volatile(NCHWC8_TRANS_FP32_COMPUTE NCHWC8_TRANS_FP32_RELU6 + NCHWC8_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [doutc4r0] "+r"(doutc4_ptr), + [doutc5r0] "+r"(doutc5_ptr), + [doutc6r0] "+r"(doutc6_ptr), + [doutc7r0] "+r"(doutc7_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_ptr) + : [six] "w"(six) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v16", + "v17", + "v18", + "v19", + "v20"); +#else + asm volatile(NCHWC4_TRANS_FP32_COMPUTE NCHWC4_TRANS_FP32_RELU + NCHWC4_TRANS_FP32_RELU6 NCHWC4_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [ptr_din] "+r"(din_ptr), + [cnt] "+r"(cnt_loop) + : [six] "w"(six) + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q15"); +#endif + break; + case lite_api::ActivationType::kLeakyRelu: +/*din = din >= 0 ? din : din * scale*/ +#ifdef __aarch64__ + asm volatile(NCHWC8_TRANS_FP32_COMPUTE NCHWC8_TRANS_FP32_LEAKY_RELU + NCHWC8_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [doutc4r0] "+r"(doutc4_ptr), + [doutc5r0] "+r"(doutc5_ptr), + [doutc6r0] "+r"(doutc6_ptr), + [doutc7r0] "+r"(doutc7_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_ptr) + : [scale] "w"(scale) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25", + "v26", + "v27", + "v28", + "v29", + "v30", + "v31"); +#else + asm volatile(NCHWC8_TRANS_FP32_COMPUTE NCHWC8_TRANS_FP32_LEAKY_RELU + NCHWC8_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [doutc4r0] "+r"(doutc4_ptr), + [doutc5r0] "+r"(doutc5_ptr), + [doutc6r0] "+r"(doutc6_ptr), + [doutc7r0] "+r"(doutc7_ptr), + [ptr_din] "+r"(din_ptr), + [cnt] "+r"(cnt_loop) + : [scale] "w"(scale) + : "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param->active_type) + << " fuse not support"; + } + } else { +#ifdef __aarch64__ + asm volatile(NCHWC8_TRANS_FP32_COMPUTE NCHWC8_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [doutc4r0] "+r"(doutc4_ptr), + [doutc5r0] "+r"(doutc5_ptr), + [doutc6r0] "+r"(doutc6_ptr), + [doutc7r0] "+r"(doutc7_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_ptr) + : + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20"); +#else + asm volatile(NCHWC8_TRANS_FP32_COMPUTE NCHWC8_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [doutc4r0] "+r"(doutc4_ptr), + [doutc5r0] "+r"(doutc5_ptr), + [doutc6r0] "+r"(doutc6_ptr), + [doutc7r0] "+r"(doutc7_ptr), + [ptr_din] "+r"(din_ptr), + [cnt] "+r"(cnt_loop) + : + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q15"); +#endif + } +} + +#ifdef __aarch64__ +#define LOAD_DATA \ + "1: \n" \ + "ld1 {v0.4s}, [%[din_ptr]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v1.4s}, [%[din_ptr]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v2.4s}, [%[din_ptr]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v3.4s}, [%[din_ptr]], #16 \n" /*vld1q_f32(din_ptr0)*/ +#define DO_RELU \ + "fmax v0.4s, v0.4s, %[vzero].4s \n" /* vmaxq_f32() */ \ + "fmax v1.4s, v1.4s, %[vzero].4s \n" /* vmaxq_f32() */ \ + "fmax v2.4s, v2.4s, %[vzero].4s \n" /* vmaxq_f32() */ \ + "fmax v3.4s, v3.4s, %[vzero].4s \n" /* vmaxq_f32() */ +#define DO_RELU6 \ + "fmin v0.4s, v0.4s, %[vsix].4s \n" /* vmaxq_f32() */ \ + "fmin v1.4s, v1.4s, %[vsix].4s \n" /* vmaxq_f32() */ \ + "fmin v2.4s, v2.4s, %[vsix].4s \n" /* vmaxq_f32() */ \ + "fmin v3.4s, v3.4s, %[vsix].4s \n" /* vmaxq_f32() */ +#define DO_LEAKY_RELU \ + "fcmge v4.4s, v0.4s, %[vzero].4s \n" /* vcgeq_f32 */ \ + "fmul v5.4s, v0.4s, %[vscale].4s \n" /* vmulq_f32 */ \ + "fcmge v6.4s, v1.4s, %[vzero].4s \n" /* vcgeq_f32 */ \ + "fmul v7.4s, v1.4s, %[vscale].4s \n" /* vmulq_f32 */ \ + "fcmge v8.4s, v2.4s, %[vzero].4s \n" /* vcgeq_f32 */ \ + "fmul v9.4s, v2.4s, %[vscale].4s \n" /* vmulq_f32 */ \ + "fcmge v10.4s, v3.4s, %[vzero].4s \n" /* vcgeq_f32 */ \ + "fmul v11.4s, v3.4s, %[vscale].4s \n" /* vmulq_f32 */ \ + "bif v0.16b, v5.16b, v4.16b \n" /* choose*/ \ + "bif v1.16b, v7.16b, v6.16b \n" /* choose*/ \ + "bif v2.16b, v9.16b, v8.16b \n" /* choose*/ \ + "bif v3.16b, v11.16b, v10.16b \n" /* choose*/ +#define DO_STORE \ + "subs %w[cnt], %w[cnt], #1 \n" \ + "st1 {v0.4s}, [%[dout_ptr]], #16 \n" /* vst1q_f32() */ \ + "st1 {v1.4s}, [%[dout_ptr]], #16 \n" /* vst1q_f32() */ \ + "st1 {v2.4s}, [%[dout_ptr]], #16 \n" /* vst1q_f32() */ \ + "st1 {v3.4s}, [%[dout_ptr]], #16 \n" /* vst1q_f32() */ \ + "bne 1b \n" +#else +#define LOAD_DATA \ + "1: \n" \ + "vld1.32 {d6-d7}, [%[din_ptr]]! @ vld1q_f32(din_ptr) \n" \ + "vld1.32 {d8-d9}, [%[din_ptr]]! @ vld1q_f32(din_ptr) \n" \ + "vld1.32 {d10-d11}, [%[din_ptr]]! @ vld1q_f32(din_ptr) \n" \ + "vld1.32 {d12-d13}, [%[din_ptr]]! @ vld1q_f32(din_ptr) \n" +#define DO_RELU \ + "vmax.f32 q3, q3, %q[vzero] @ vmaxq_f32() \n" \ + "vmax.f32 q4, q4, %q[vzero] @ vmaxq_f32() \n" \ + "vmax.f32 q5, q5, %q[vzero] @ vmaxq_f32() \n" \ + "vmax.f32 q6, q6, %q[vzero] @ vmaxq_f32() \n" +#define DO_RELU6 \ + "vmin.f32 q3, q3, %q[vsix] @ vminq_f32() \n" \ + "vmin.f32 q4, q4, %q[vsix] @ vmaxq_f32() \n" \ + "vmin.f32 q5, q5, %q[vsix] @ vmaxq_f32() \n" \ + "vmin.f32 q6, q6, %q[vsix] @ vmaxq_f32() \n" +#define DO_LEAKY_RELU \ + "vcge.f32 q7, q3, %q[vzero] @ vcgeq_u32 \n" \ + "vmul.f32 q8, q3, %q[vscale] @ vmulq_f32 \n" \ + "vcge.f32 q9, q4, %q[vzero] @ vcgeq_u32 \n" \ + "vmul.f32 q10, q4, %q[vscale] @ vmulq_f32 \n" \ + "vcge.f32 q11, q5, %q[vzero] @ vcgeq_u32 \n" \ + "vmul.f32 q12, q5, %q[vscale] @ vmulq_f32 \n" \ + "vcge.f32 q13, q6, %q[vzero] @ vcgeq_u32 \n" \ + "vmul.f32 q14, q6, %q[vscale] @ vmulq_f32 \n" \ + "vbif q3, q8, q7 @ choose \n" \ + "vbif q4, q10, q9 @ choose \n" \ + "vbif q5, q12, q11 @ choose \n" \ + "vbif q6, q14, q13 @ choose \n" +#define DO_STORE \ + "subs %[cnt], #1 \n" \ + "vst1.32 {d6-d7}, [%[dout_ptr]]! @ vst1q_f32() \n" \ + "vst1.32 {d8-d9}, [%[dout_ptr]]! @ vst1q_f32() \n" \ + "vst1.32 {d10-d11}, [%[dout_ptr]]! @ vst1q_f32() \n" \ + "vst1.32 {d12-d13}, [%[dout_ptr]]! @ vst1q_f32() \n" \ + "bne 1b \n" +#endif +/* +* Data do activation process +* Now support relu relu6 leakyrelu act +*/ +inline void act_switch_process(float* src, + float* dst, + int size, + const operators::ActivationParam* act_param) { + int cnt = size >> 4; + int remain = size % 16; + float32x4_t vzero = vdupq_n_f32(0.f); + if (act_param != nullptr) { + float32x4_t vsix = vdupq_n_f32(act_param->Relu_clipped_coef); + float32x4_t vscale = vdupq_n_f32(act_param->Leaky_relu_alpha); + if (cnt > 0) { + switch (act_param->active_type) { + case lite_api::ActivationType::kRelu: +#ifdef __aarch64__ + asm volatile( + LOAD_DATA DO_RELU DO_STORE + : [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt) + : [vzero] "w"(vzero) + : "memory", "cc", "v0", "v1", "v2", "v3"); +#else + asm volatile( + LOAD_DATA DO_RELU DO_STORE + : [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt) + : [vzero] "w"(vzero) + : "memory", "cc", "q3", "q4", "q5", "q6"); +#endif + break; + case lite_api::ActivationType::kRelu6: +#ifdef __aarch64__ + asm volatile( + LOAD_DATA DO_RELU DO_RELU6 DO_STORE + : [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt) + : [vzero] "w"(vzero), [vsix] "w"(vsix) + : "memory", "cc", "v0", "v1", "v2", "v3"); +#else + asm volatile( + LOAD_DATA DO_RELU DO_RELU6 DO_STORE + : [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt) + : [vzero] "w"(vzero), [vsix] "w"(vsix) + : "memory", "cc", "q3", "q4", "q5", "q6"); +#endif + break; + case lite_api::ActivationType::kLeakyRelu: +#ifdef __aarch64__ + asm volatile( + LOAD_DATA DO_LEAKY_RELU DO_STORE + : [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt) + : [vzero] "w"(vzero), [vscale] "w"(vscale) + : "memory", + "cc", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11"); +#else + asm volatile( + LOAD_DATA DO_LEAKY_RELU DO_STORE + : [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt) + : [vzero] "w"(vzero), [vscale] "w"(vscale) + : "memory", + "cc", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14"); +#endif + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param->active_type) + << " fuse not support"; + } + } + // remain + switch (act_param->active_type) { + case lite_api::ActivationType::kRelu: + for (int i = 0; i < remain; i++) { + *dst = *src >= 0.f ? *src : 0.f; + src++; + dst++; + } + break; + case lite_api::ActivationType::kRelu6: + for (int i = 0; i < remain; i++) { + float tmp = *src >= 0.f ? *src : 0.f; + *dst = tmp <= act_param->Relu_clipped_coef + ? tmp + : act_param->Relu_clipped_coef; + src++; + dst++; + } + break; + case lite_api::ActivationType::kLeakyRelu: + for (int i = 0; i < remain; i++) { + if (*src >= 0.f) { + *dst = *src; + } else { + *dst = *src * act_param->Leaky_relu_alpha; + } + src++; + dst++; + } + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param->active_type) + << " fuse not support"; + } + } +} /*wirte result in outputs * input din: [n, c / 8, h, w * 8], output dout: [n, c, h, w] @@ -1199,7 +2374,8 @@ inline bool write_to_output_c8_fp32(const float* din, int height, int width, bool flag_relu, - float* trash_ptr) { + float* trash_ptr, + operators::ActivationParam* act_param) { if (ch_n != 8 || hei_n <= 0) { LOG(ERROR) << "ch_n must be equal 8 and hei_n is more than zero"; return false; @@ -1220,392 +2396,161 @@ inline bool write_to_output_c8_fp32(const float* din, int size_h = (he > height ? height : he) - hs; // size_h == hei_n int valid_w = we - ws; + int w4 = 4; int cnt = valid_w / 4; if (we > width) { cnt--; } - if (flag_relu) { - for (int i = 0; i < size_h; i++) { - int size_w = i * width; - float* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; - float* doutc1_ptr = doutc1r0 + size_w; - float* doutc2_ptr = doutc2r0 + size_w; - float* doutc3_ptr = doutc3r0 + size_w; - float* doutc4_ptr = doutc4r0 + size_w; - float* doutc5_ptr = doutc5r0 + size_w; - float* doutc6_ptr = doutc6r0 + size_w; - float* doutc7_ptr = doutc7r0 + size_w; - if (ce > channel) { - switch (ce - channel) { - case 7: - doutc1_ptr = trash_ptr; - case 6: - doutc2_ptr = trash_ptr; - case 5: - doutc3_ptr = trash_ptr; - case 4: - doutc4_ptr = trash_ptr; - case 3: - doutc5_ptr = trash_ptr; - case 2: - doutc6_ptr = trash_ptr; - case 1: - doutc7_ptr = trash_ptr; - default: - break; - } - } - ptr_din = din + i * valid_w * ch_n; - const float* din_hei_ptr = ptr_din; - if (cnt > 0) { - int cnt_loop = cnt; -#ifdef __aarch64__ - asm volatile( - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - "ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - "ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - "movi v20.4s, #0 \n" /* for relu */ - "1: \n" /* main loop*/ - "trn1 v8.4s, v0.4s, v2.4s \n" /* trans q0, q1*/ - "trn2 v9.4s, v0.4s, v2.4s \n" /* trans q0, q1*/ - "trn1 v10.4s, v1.4s, v3.4s \n" /* trans q2, q3*/ - "trn2 v11.4s, v1.4s, v3.4s \n" /* trans q2, q3*/ - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - - "trn1 v12.4s, v4.4s, v6.4s \n" /* trans q0, q1*/ - "trn2 v13.4s, v4.4s, v6.4s \n" /* trans q0, q1*/ - "trn1 v14.4s, v5.4s, v7.4s \n" /* trans q2, q3*/ - "trn2 v15.4s, v5.4s, v7.4s \n" /* trans q2, q3*/ - "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - - "trn1 v16.2d, v8.2d, v12.2d \n" /* trans q8, q10 00 01 02 03*/ - "trn2 v17.2d, v8.2d, v12.2d \n" /* trans q8, q10 20 21 22 23*/ - "trn1 v18.2d, v9.2d, v13.2d \n" /* trans q9, q11 10 11 12 13*/ - "trn2 v19.2d, v9.2d, v13.2d \n" /* trans q9, q11 30 31 32 33*/ - "ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - - "trn1 v8.2d, v10.2d, v14.2d \n" /* trans q8, q10 40 41 42 43*/ - "trn2 v9.2d, v10.2d, v14.2d \n" /* trans q8, q10 60 61 62 63*/ - "trn1 v12.2d, v11.2d, v15.2d \n" /* trans q9, q11 50 51 52 53*/ - "trn2 v13.2d, v11.2d, v15.2d \n" /* trans q9, q11 70 71 72 73*/ - "ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - - "fmax v16.4s, v16.4s, v20.4s \n" /*relu*/ - "fmax v17.4s, v17.4s, v20.4s \n" /*relu*/ - "fmax v18.4s, v18.4s, v20.4s \n" /*relu*/ - "fmax v19.4s, v19.4s, v20.4s \n" /*relu*/ - - "fmax v8.4s, v8.4s, v20.4s \n" /*relu*/ - "fmax v9.4s, v9.4s, v20.4s \n" /*relu*/ - "fmax v12.4s, v12.4s, v20.4s \n" /*relu*/ - "fmax v13.4s, v13.4s, v20.4s \n" /*relu*/ - - "str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/ - "str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/ - "str q18, [%[doutc1r0]], #16 \n" /* store c1r0*/ - "str q19, [%[doutc3r0]], #16 \n" /* store c3r0*/ - - "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ - "str q8, [%[doutc4r0]], #16 \n" /* store c0r0*/ - "str q9, [%[doutc6r0]], #16 \n" /* store c2r0*/ - "str q12, [%[doutc5r0]], #16 \n" /* store c1r0*/ - "str q13, [%[doutc7r0]], #16 \n" /* store c3r0*/ - - "bne 1b \n" /* jump to main loop*/ - - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [doutc2r0] "+r"(doutc2_ptr), - [doutc3r0] "+r"(doutc3_ptr), - [doutc4r0] "+r"(doutc4_ptr), - [doutc5r0] "+r"(doutc5_ptr), - [doutc6r0] "+r"(doutc6_ptr), - [doutc7r0] "+r"(doutc7_ptr), - [cnt] "+r"(cnt_loop), - [ptr_din] "+r"(din_hei_ptr) - : - : "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20"); -#else - asm volatile( - "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" - "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" - "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" - "vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n" - "vmov.u32 q15, #0 @ dump zero\n" - "1: @ main loop\n" - "vtrn.32 q0, q2 @ trans q0, q2 \n" - "vtrn.32 q4, q6 @ trans q4, q6 \n" - "vswp.32 d1, d8 @ swap d1, d8 \n" - "vswp.32 d5, d12 @ swap d5, d12\n" - - "vtrn.32 q1, q3 @ trans q1, q3 \n" - "vtrn.32 q5, q7 @ trans q5, q7 \n" - "vswp.32 d3, d10 @ swap d3, d10\n" - "vswp.32 d7, d14 @ swap d7, d14\n" - - "vmax.f32 q0, q0, q15 @ relu\n" - "vmax.f32 q1, q1, q15 @ relu\n" - "vmax.f32 q2, q2, q15 @ relu\n" - "vmax.f32 q3, q3, q15 @ relu\n" - - "vmax.f32 q4, q4, q15 @ relu\n" - "vmax.f32 q5, q5, q15 @ relu\n" - "vmax.f32 q6, q6, q15 @ relu\n" - "vmax.f32 q7, q7, q15 @ relu\n" - - "subs %[cnt], %[cnt], #1 @ loop count - 1\n" - "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add " - "pointer\n" - "vst1.32 {d2-d3}, [%[doutc4r0]]! @ store result, add " - "pointer\n" - "vst1.32 {d4-d5}, [%[doutc1r0]]! @ store result, add " - "pointer\n" - "vst1.32 {d6-d7}, [%[doutc5r0]]! @ store result, add " - "pointer\n" - - "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" - "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" - - "vst1.32 {d8-d9}, [%[doutc2r0]]! @ store result, add " - "pointer\n" - "vst1.32 {d10-d11}, [%[doutc6r0]]! @ store result, add " - "pointer\n" - "vst1.32 {d12-d13}, [%[doutc3r0]]! @ store result, add " - "pointer\n" - "vst1.32 {d14-d15}, [%[doutc7r0]]! @ store result, add " - "pointer\n" - - "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" - "vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n" - - "bne 1b @ jump to main loop\n" - - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [doutc2r0] "+r"(doutc2_ptr), - [doutc3r0] "+r"(doutc3_ptr), - [doutc4r0] "+r"(doutc4_ptr), - [doutc5r0] "+r"(doutc5_ptr), - [doutc6r0] "+r"(doutc6_ptr), - [doutc7r0] "+r"(doutc7_ptr), - [ptr_din] "+r"(din_hei_ptr), - [cnt] "+r"(cnt_loop) - : - : "q0", "q1", "q2", "q3", "q4", "q15"); -#endif - } - if (we > width) { - int offset = 32 * (valid_w / 4 - 1); - din_hei_ptr = ptr_din + offset; - int i = we - 4; - for (; i < width; ++i) { - *(doutc0_ptr++) = LITEMAX(din_hei_ptr[0], 0.f); - *(doutc1_ptr++) = LITEMAX(din_hei_ptr[1], 0.f); - *(doutc2_ptr++) = LITEMAX(din_hei_ptr[2], 0.f); - *(doutc3_ptr++) = LITEMAX(din_hei_ptr[3], 0.f); - *(doutc4_ptr++) = LITEMAX(din_hei_ptr[4], 0.f); - *(doutc5_ptr++) = LITEMAX(din_hei_ptr[5], 0.f); - *(doutc6_ptr++) = LITEMAX(din_hei_ptr[6], 0.f); - *(doutc7_ptr++) = LITEMAX(din_hei_ptr[7], 0.f); - din_hei_ptr += 8; - } + for (int i = 0; i < size_h; i++) { + int size_w = i * width; + float* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; + float* doutc1_ptr = doutc1r0 + size_w; + float* doutc2_ptr = doutc2r0 + size_w; + float* doutc3_ptr = doutc3r0 + size_w; + float* doutc4_ptr = doutc4r0 + size_w; + float* doutc5_ptr = doutc5r0 + size_w; + float* doutc6_ptr = doutc6r0 + size_w; + float* doutc7_ptr = doutc7r0 + size_w; + if (ce > channel) { + switch (ce - channel) { + case 7: + doutc1_ptr = trash_ptr; + case 6: + doutc2_ptr = trash_ptr; + case 5: + doutc3_ptr = trash_ptr; + case 4: + doutc4_ptr = trash_ptr; + case 3: + doutc5_ptr = trash_ptr; + case 2: + doutc6_ptr = trash_ptr; + case 1: + doutc7_ptr = trash_ptr; + default: + break; } } - } else { - for (int i = 0; i < size_h; i++) { - int size_w = i * width; - float* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; - float* doutc1_ptr = doutc1r0 + size_w; - float* doutc2_ptr = doutc2r0 + size_w; - float* doutc3_ptr = doutc3r0 + size_w; - float* doutc4_ptr = doutc4r0 + size_w; - float* doutc5_ptr = doutc5r0 + size_w; - float* doutc6_ptr = doutc6r0 + size_w; - float* doutc7_ptr = doutc7r0 + size_w; - if (ce > channel) { - switch (ce - channel) { - case 7: - doutc1_ptr = trash_ptr; - case 6: - doutc2_ptr = trash_ptr; - case 5: - doutc3_ptr = trash_ptr; - case 4: - doutc4_ptr = trash_ptr; - case 3: - doutc5_ptr = trash_ptr; - case 2: - doutc6_ptr = trash_ptr; - case 1: - doutc7_ptr = trash_ptr; - default: + ptr_din = din + i * valid_w * ch_n; + const float* din_hei_ptr = ptr_din; + if (cnt > 0) { + int cnt_loop = cnt; + act_switch_c8_fp32(din_hei_ptr, + doutc0_ptr, + doutc1_ptr, + doutc2_ptr, + doutc3_ptr, + doutc4_ptr, + doutc5_ptr, + doutc6_ptr, + doutc7_ptr, + cnt_loop, + act_param); + } + if (we > width) { + int offset = 32 * (valid_w / 4 - 1); + din_hei_ptr = ptr_din + offset; + doutc0_ptr += w4 * cnt; + doutc1_ptr += w4 * cnt; + doutc2_ptr += w4 * cnt; + doutc3_ptr += w4 * cnt; + doutc4_ptr += w4 * cnt; + doutc5_ptr += w4 * cnt; + doutc6_ptr += w4 * cnt; + doutc7_ptr += w4 * cnt; + int i = we - 4; + if (act_param != nullptr && act_param->has_active) { + float six = act_param->Relu_clipped_coef; + float scale = act_param->Leaky_relu_alpha; + switch (act_param->active_type) { + case lite_api::ActivationType::kRelu: + for (; i < width; ++i) { + *(doutc0_ptr++) = LITEMAX(din_hei_ptr[0], 0.f); + *(doutc1_ptr++) = LITEMAX(din_hei_ptr[1], 0.f); + *(doutc2_ptr++) = LITEMAX(din_hei_ptr[2], 0.f); + *(doutc3_ptr++) = LITEMAX(din_hei_ptr[3], 0.f); + *(doutc4_ptr++) = LITEMAX(din_hei_ptr[4], 0.f); + *(doutc5_ptr++) = LITEMAX(din_hei_ptr[5], 0.f); + *(doutc6_ptr++) = LITEMAX(din_hei_ptr[6], 0.f); + *(doutc7_ptr++) = LITEMAX(din_hei_ptr[7], 0.f); + din_hei_ptr += 8; + } + break; + case lite_api::ActivationType::kRelu6: + /* 0 <= din <= 6 */ + for (; i < width; ++i) { + float tmp1 = LITEMAX(din_hei_ptr[0], 0.f); + float tmp2 = LITEMAX(din_hei_ptr[1], 0.f); + float tmp3 = LITEMAX(din_hei_ptr[2], 0.f); + float tmp4 = LITEMAX(din_hei_ptr[3], 0.f); + float tmp5 = LITEMAX(din_hei_ptr[4], 0.f); + float tmp6 = LITEMAX(din_hei_ptr[5], 0.f); + float tmp7 = LITEMAX(din_hei_ptr[6], 0.f); + float tmp8 = LITEMAX(din_hei_ptr[7], 0.f); + *(doutc0_ptr++) = LITEMIN(tmp1, six); + *(doutc1_ptr++) = LITEMIN(tmp2, six); + *(doutc2_ptr++) = LITEMIN(tmp3, six); + *(doutc3_ptr++) = LITEMIN(tmp4, six); + *(doutc4_ptr++) = LITEMIN(tmp5, six); + *(doutc5_ptr++) = LITEMIN(tmp6, six); + *(doutc6_ptr++) = LITEMIN(tmp7, six); + *(doutc7_ptr++) = LITEMIN(tmp8, six); + din_hei_ptr += 8; + } + break; + case lite_api::ActivationType::kLeakyRelu: + /*din = din >= 0 ? din : din * scale*/ + for (; i < width; ++i) { + if (din_hei_ptr[0] >= 0) { + *(doutc0_ptr++) = din_hei_ptr[0]; + } else { + *(doutc0_ptr++) = din_hei_ptr[0] * scale; + } + if (din_hei_ptr[1] >= 0) { + *(doutc1_ptr++) = din_hei_ptr[1]; + } else { + *(doutc1_ptr++) = din_hei_ptr[1] * scale; + } + if (din_hei_ptr[2] >= 0) { + *(doutc2_ptr++) = din_hei_ptr[2]; + } else { + *(doutc2_ptr++) = din_hei_ptr[2] * scale; + } + if (din_hei_ptr[3] >= 0) { + *(doutc3_ptr++) = din_hei_ptr[3]; + } else { + *(doutc3_ptr++) = din_hei_ptr[3] * scale; + } + if (din_hei_ptr[4] >= 0) { + *(doutc4_ptr++) = din_hei_ptr[4]; + } else { + *(doutc4_ptr++) = din_hei_ptr[4] * scale; + } + if (din_hei_ptr[4] >= 0) { + *(doutc5_ptr++) = din_hei_ptr[5]; + } else { + *(doutc5_ptr++) = din_hei_ptr[5] * scale; + } + if (din_hei_ptr[6] >= 0) { + *(doutc6_ptr++) = din_hei_ptr[6]; + } else { + *(doutc6_ptr++) = din_hei_ptr[6] * scale; + } + if (din_hei_ptr[7] >= 0) { + *(doutc7_ptr++) = din_hei_ptr[7]; + } else { + *(doutc7_ptr++) = din_hei_ptr[7] * scale; + } + din_hei_ptr += 8; + } break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param->active_type) + << " fuse not support"; } - } - ptr_din = din + i * valid_w * ch_n; - const float* din_hei_ptr = ptr_din; - if (cnt > 0) { - int cnt_loop = cnt; -#ifdef __aarch64__ - asm volatile( - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - "ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - "ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - "1: \n" /* main loop*/ - "trn1 v8.4s, v0.4s, v2.4s \n" /* trans q0, q1*/ - "trn2 v9.4s, v0.4s, v2.4s \n" /* trans q0, q1*/ - "trn1 v10.4s, v1.4s, v3.4s \n" /* trans q2, q3*/ - "trn2 v11.4s, v1.4s, v3.4s \n" /* trans q2, q3*/ - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - - "trn1 v12.4s, v4.4s, v6.4s \n" /* trans q0, q1*/ - "trn2 v13.4s, v4.4s, v6.4s \n" /* trans q0, q1*/ - "trn1 v14.4s, v5.4s, v7.4s \n" /* trans q2, q3*/ - "trn2 v15.4s, v5.4s, v7.4s \n" /* trans q2, q3*/ - "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - - "trn1 v16.2d, v8.2d, v12.2d \n" /* trans q8, q10 00 01 02 03*/ - "trn2 v17.2d, v8.2d, v12.2d \n" /* trans q8, q10 20 21 22 23*/ - "trn1 v18.2d, v9.2d, v13.2d \n" /* trans q9, q11 10 11 12 13*/ - "trn2 v19.2d, v9.2d, v13.2d \n" /* trans q9, q11 30 31 32 33*/ - "ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - - "trn1 v8.2d, v10.2d, v14.2d \n" /* trans q8, q10 40 41 42 43*/ - "trn2 v9.2d, v10.2d, v14.2d \n" /* trans q8, q10 60 61 62 63*/ - "trn1 v12.2d, v11.2d, v15.2d \n" /* trans q9, q11 50 51 52 53*/ - "trn2 v13.2d, v11.2d, v15.2d \n" /* trans q9, q11 70 71 72 73*/ - "ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - - "str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/ - "str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/ - "str q18, [%[doutc1r0]], #16 \n" /* store c1r0*/ - "str q19, [%[doutc3r0]], #16 \n" /* store c3r0*/ - - "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ - "str q8, [%[doutc4r0]], #16 \n" /* store c0r0*/ - "str q9, [%[doutc6r0]], #16 \n" /* store c2r0*/ - "str q12, [%[doutc5r0]], #16 \n" /* store c1r0*/ - "str q13, [%[doutc7r0]], #16 \n" /* store c3r0*/ - - "bne 1b \n" /* jump to main loop*/ - - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [doutc2r0] "+r"(doutc2_ptr), - [doutc3r0] "+r"(doutc3_ptr), - [doutc4r0] "+r"(doutc4_ptr), - [doutc5r0] "+r"(doutc5_ptr), - [doutc6r0] "+r"(doutc6_ptr), - [doutc7r0] "+r"(doutc7_ptr), - [cnt] "+r"(cnt_loop), - [ptr_din] "+r"(din_hei_ptr) - : - : "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20"); -#else - asm volatile( - "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" - "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" - "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" - "vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n" - "1: @ main loop\n" - "vtrn.32 q0, q2 @ trans q0, q2 \n" - "vtrn.32 q4, q6 @ trans q4, q6 \n" - "vswp.32 d1, d8 @ swap d1, d8 \n" - "vswp.32 d5, d12 @ swap d5, d12\n" - - "vtrn.32 q1, q3 @ trans q1, q3 \n" - "vtrn.32 q5, q7 @ trans q5, q7 \n" - "vswp.32 d3, d10 @ swap d3, d10\n" - "vswp.32 d7, d14 @ swap d7, d14\n" - - "subs %[cnt], %[cnt], #1 @ loop count - 1\n" - - "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add " - "pointer\n" - "vst1.32 {d2-d3}, [%[doutc4r0]]! @ store result, add " - "pointer\n" - "vst1.32 {d4-d5}, [%[doutc1r0]]! @ store result, add " - "pointer\n" - "vst1.32 {d6-d7}, [%[doutc5r0]]! @ store result, add " - "pointer\n" - - "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" - "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" - - "vst1.32 {d8-d9}, [%[doutc2r0]]! @ store result, add " - "pointer\n" - "vst1.32 {d10-d11}, [%[doutc6r0]]! @ store result, add " - "pointer\n" - "vst1.32 {d12-d13}, [%[doutc3r0]]! @ store result, add " - "pointer\n" - "vst1.32 {d14-d15}, [%[doutc7r0]]! @ store result, add " - "pointer\n" - - "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" - "vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n" - - "bne 1b @ jump to main loop\n" - - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [doutc2r0] "+r"(doutc2_ptr), - [doutc3r0] "+r"(doutc3_ptr), - [doutc4r0] "+r"(doutc4_ptr), - [doutc5r0] "+r"(doutc5_ptr), - [doutc6r0] "+r"(doutc6_ptr), - [doutc7r0] "+r"(doutc7_ptr), - [ptr_din] "+r"(din_hei_ptr), - [cnt] "+r"(cnt_loop) - : - : "q0", "q1", "q2", "q3", "q4"); -#endif - } - if (we > width) { - int offset = 32 * (valid_w / 4 - 1); - din_hei_ptr = ptr_din + offset; - int i = we - 4; + } else { for (; i < width; ++i) { *(doutc0_ptr++) = din_hei_ptr[0]; *(doutc1_ptr++) = din_hei_ptr[1]; diff --git a/lite/backends/arm/math/conv_depthwise.h b/lite/backends/arm/math/conv_depthwise.h index 1a23982cd575afb6b249390de7081165c03414b9..4c5f284a19f615382ea04904184427f569f95ff3 100644 --- a/lite/backends/arm/math/conv_depthwise.h +++ b/lite/backends/arm/math/conv_depthwise.h @@ -37,6 +37,7 @@ void conv_3x3s1_depthwise_fp32(const float* i_data, const float* weights, const float* bias, const operators::ConvParam& param, + const operators::ActivationParam act_param, ARMContext* ctx); void conv_3x3s2_depthwise_fp32(const float* i_data, @@ -51,6 +52,7 @@ void conv_3x3s2_depthwise_fp32(const float* i_data, const float* weights, const float* bias, const operators::ConvParam& param, + const operators::ActivationParam act_param, ARMContext* ctx); void conv_depthwise_3x3s1_fp32(const float* din, @@ -66,7 +68,7 @@ void conv_depthwise_3x3s1_fp32(const float* din, const float* bias, int pad, bool flag_bias, - bool flag_relu, + const operators::ActivationParam act_param, ARMContext* ctx); void conv_depthwise_3x3s2_fp32(const float* din, @@ -82,39 +84,7 @@ void conv_depthwise_3x3s2_fp32(const float* din, const float* bias, int pad, bool flag_bias, - bool flag_relu, - ARMContext* ctx); - -void conv_depthwise_3x3p0_fp32(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - int stride, - bool flag_bias, - bool flag_relu, - ARMContext* ctx); - -void conv_depthwise_3x3p1_fp32(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - int stride, - bool flag_bias, - bool flag_relu, + const operators::ActivationParam act_param, ARMContext* ctx); template @@ -153,20 +123,21 @@ void conv_depthwise_3x3s2_int8(Dtype* dout, int padh, ARMContext* ctx); -void conv_depthwise_5x5s1_fp32(const float* din, - float* dout, - int num, - int chout, - int hout, - int wout, - int chin, - int hin, - int win, +void conv_depthwise_5x5s1_fp32(float* dout, + const float* din, const float* weights, const float* bias, - int pad, bool flag_bias, bool flag_relu, + int num, + int chin, + int hin, + int win, + int hout, + int wout, + int padw, + int padh, + const operators::ConvParam& param, ARMContext* ctx); void conv_depthwise_5x5s2_fp32(const float* din, @@ -180,13 +151,46 @@ void conv_depthwise_5x5s2_fp32(const float* din, int win, const float* weights, const float* bias, - int pad, + const operators::ConvParam& param, + const operators::ActivationParam act_param, + ARMContext* ctx); + +void conv_depthwise_5x5s2p2_fp32(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const float* weights, + const float* bias, + int pad, + bool flag_bias, + bool flag_relu, + ARMContext* ctx); + +template +void conv_depthwise_5x5s1_int8(Dtype* dout, + const int8_t* din, + const int8_t* weights, + const float* scale, + const float* bias, bool flag_bias, bool flag_relu, + int num, + int chin, + int hin, + int win, + int hout, + int wout, + int padw, + int padh, ARMContext* ctx); template -void conv_depthwise_5x5s1_int8(Dtype* dout, +void conv_depthwise_5x5s2_int8(Dtype* dout, const int8_t* din, const int8_t* weights, const float* scale, diff --git a/lite/backends/arm/math/conv_depthwise_3x3p0.cc b/lite/backends/arm/math/conv_depthwise_3x3p0.cc deleted file mode 100644 index 0c050ffe6fb0f064f5c26ea0da6acee17f4403ae..0000000000000000000000000000000000000000 --- a/lite/backends/arm/math/conv_depthwise_3x3p0.cc +++ /dev/null @@ -1,4178 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "lite/backends/arm/math/conv_depthwise.h" -#include - -namespace paddle { -namespace lite { -namespace arm { -namespace math { - -void conv_depthwise_3x3s1p0_bias(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); - -//! for input width <= 4 -void conv_depthwise_3x3s1p0_bias_s(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); - -void conv_depthwise_3x3s2p0_bias(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); - -//! for input width <= 4 -void conv_depthwise_3x3s2p0_bias_s(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); - -void conv_depthwise_3x3s1p0_bias_relu(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); - -//! for input width <= 4 -void conv_depthwise_3x3s1p0_bias_s_relu(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); - -void conv_depthwise_3x3s2p0_bias_relu(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); - -//! for input width <= 4 -void conv_depthwise_3x3s2p0_bias_s_relu(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); - -void conv_depthwise_3x3p0_fp32(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - int stride, - bool flag_bias, - bool flag_relu, - ARMContext* ctx) { - if (stride == 1) { - if (flag_relu) { - if (w_in > 5) { - conv_depthwise_3x3s1p0_bias_relu(dout, - din, - weights, - bias, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } else { - conv_depthwise_3x3s1p0_bias_s_relu(dout, - din, - weights, - bias, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } - } else { - if (w_in > 5) { - conv_depthwise_3x3s1p0_bias(dout, - din, - weights, - bias, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } else { - conv_depthwise_3x3s1p0_bias_s(dout, - din, - weights, - bias, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } - } - } else { //! stride = 2 - if (flag_relu) { - if (w_in > 8) { - conv_depthwise_3x3s2p0_bias_relu(dout, - din, - weights, - bias, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } else { - conv_depthwise_3x3s2p0_bias_s_relu(dout, - din, - weights, - bias, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } - } else { - if (w_in > 8) { - conv_depthwise_3x3s2p0_bias(dout, - din, - weights, - bias, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } else { - conv_depthwise_3x3s2p0_bias_s(dout, - din, - weights, - bias, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } - } - } -} -/** - * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, - * width > 4 - */ -// 4line -void conv_depthwise_3x3s1p0_bias(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx) { - //! pad is done implicit - const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; - //! for 4x6 convolution window - const unsigned int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0}; - - float* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - float* write_ptr = zero_ptr + w_in; - - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - int w_stride = 9; - - int tile_w = w_out >> 2; - int remain = w_out % 4; - - unsigned int size_pad_right = (unsigned int)(6 + (tile_w << 2) - w_in); - const int remian_idx[4] = {0, 1, 2, 3}; - - uint32x4_t vmask_rp1 = - vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right)); - uint32x4_t vmask_rp2 = - vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right)); - uint32x4_t vmask_result = - vcgtq_s32(vdupq_n_s32(remain), vld1q_s32(remian_idx)); - - unsigned int vmask[8]; - vst1q_u32(vmask, vmask_rp1); - vst1q_u32(vmask + 4, vmask_rp2); - - unsigned int rmask[4]; - vst1q_u32(rmask, vmask_result); - - float32x4_t vzero = vdupq_n_f32(0.f); - - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * ch_in * size_in_channel; - float* dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for -#ifdef __aarch64__ - for (int c = 0; c < ch_in; c++) { - float* dout_ptr = dout_batch + c * size_out_channel; - - const float* din_ch_ptr = din_batch + c * size_in_channel; - - float bias_val = flag_bias ? bias[c] : 0.f; - float vbias[4] = {bias_val, bias_val, bias_val, bias_val}; - - const float* wei_ptr = weights + c * w_stride; - - float32x4_t wr0 = vld1q_f32(wei_ptr); - float32x4_t wr1 = vld1q_f32(wei_ptr + 3); - float32x4_t wr2 = vld1q_f32(wei_ptr + 6); - // wr0 = vsetq_lane_f32(0.f, wr0, 3); - // wr1 = vsetq_lane_f32(0.f, wr1, 3); - // wr2 = vsetq_lane_f32(0.f, wr2, 3); - - float* doutr0 = dout_ptr; - float* doutr1 = doutr0 + w_out; - float* doutr2 = doutr1 + w_out; - float* doutr3 = doutr2 + w_out; - - const float* dr0 = din_ch_ptr; - const float* dr1 = dr0 + w_in; - const float* dr2 = dr1 + w_in; - const float* dr3 = dr2 + w_in; - const float* dr4 = dr3 + w_in; - const float* dr5 = dr4 + w_in; - - const float* din_ptr0 = dr0; - const float* din_ptr1 = dr1; - const float* din_ptr2 = dr2; - const float* din_ptr3 = dr3; - const float* din_ptr4 = dr4; - const float* din_ptr5 = dr5; - - for (int i = 0; i < h_out; i += 4) { - //! process top pad pad_h = 1 - din_ptr0 = dr0; - din_ptr1 = dr1; - din_ptr2 = dr2; - din_ptr3 = dr3; - din_ptr4 = dr4; - din_ptr5 = dr5; - - doutr0 = dout_ptr; - doutr1 = doutr0 + w_out; - doutr2 = doutr1 + w_out; - doutr3 = doutr2 + w_out; - - dr0 = dr4; - dr1 = dr5; - dr2 = dr1 + w_in; - dr3 = dr2 + w_in; - dr4 = dr3 + w_in; - dr5 = dr4 + w_in; - - //! process bottom pad - if (i + 5 >= h_in) { - switch (i + 5 - h_in) { - case 5: - din_ptr1 = zero_ptr; - case 4: - din_ptr2 = zero_ptr; - case 3: - din_ptr3 = zero_ptr; - case 2: - din_ptr4 = zero_ptr; - case 1: - din_ptr5 = zero_ptr; - case 0: - din_ptr5 = zero_ptr; - default: - break; - } - } - //! process bottom remain - if (i + 4 > h_out) { - switch (i + 4 - h_out) { - case 3: - doutr1 = write_ptr; - case 2: - doutr2 = write_ptr; - case 1: - doutr3 = write_ptr; - default: - break; - } - } - - int cnt = tile_w; - asm volatile( - "PRFM PLDL1KEEP, [%[din_ptr0]] \n" - "PRFM PLDL1KEEP, [%[din_ptr1]] \n" - "PRFM PLDL1KEEP, [%[din_ptr2]] \n" - "PRFM PLDL1KEEP, [%[din_ptr3]] \n" - "PRFM PLDL1KEEP, [%[din_ptr4]] \n" - "PRFM PLDL1KEEP, [%[din_ptr5]] \n" - "movi v21.4s, #0x0\n" /* out0 = 0 */ - - "ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ - - "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */ - - // mid - // "cmp %[cnt], #1 \n" - // "blt 5f \n" - "4: \n" - // r0 - "fmla v12.4s , v0.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v12.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v2.16b, v3.16b, #8 \n" /* v16 = 2345 */ - - // r1 - "fmla v13.4s , v2.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v12.4s , v2.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v13.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v12.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v4.16b, v5.16b, #8 \n" /* v16 = 2345 */ - - // r2 - "fmla v14.4s , v4.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v13.4s , v4.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v12.4s , v4.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v14.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v13.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v12.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */ - - // r3 - "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "st1 {v12.4s}, [%[doutr0]], #16 \n" - - "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - - "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ - - // r4 - "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "st1 {v13.4s}, [%[doutr1]], #16 \n" - - "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - - "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ - - // r5 - "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "st1 {v14.4s}, [%[doutr2]], #16 \n" - - "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - - "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ - - "subs %[cnt], %[cnt], #1 \n" - - "st1 {v15.4s}, [%[doutr3]], #16 \n" - "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - - "bne 4b \n" - - // right - "5: \n" - "cmp %[remain], #1 \n" - "blt 0f \n" - "ld1 {v18.4s, v19.4s}, [%[vmask]] \n" - "ld1 {v22.4s}, [%[doutr0]] \n" - "ld1 {v23.4s}, [%[doutr1]] \n" - "ld1 {v24.4s}, [%[doutr2]] \n" - "ld1 {v25.4s}, [%[doutr3]] \n" - - "bif v0.16b, %[vzero].16b, v18.16b \n" - "bif v1.16b, %[vzero].16b, v19.16b \n" - "bif v2.16b, %[vzero].16b, v18.16b \n" - "bif v3.16b, %[vzero].16b, v19.16b \n" - - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ - - // r0 - "fmla v12.4s, v0.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "bif v4.16b, %[vzero].16b, v18.16b \n" - "bif v5.16b, %[vzero].16b, v19.16b \n" - "bif v6.16b, %[vzero].16b, v18.16b \n" - "bif v7.16b, %[vzero].16b, v19.16b \n" - - "fmla v12.4s, v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "bif v8.16b, %[vzero].16b, v18.16b \n" - "bif v9.16b, %[vzero].16b, v19.16b \n" - "bif v10.16b, %[vzero].16b, v18.16b \n" - "bif v11.16b, %[vzero].16b, v19.16b \n" - - "fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v2.16b, v3.16b, #8 \n" /* v16 = 2345 */ - "ld1 {v18.4s}, [%[rmask]] \n" - - // r1 - "fmla v13.4s , v2.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v12.4s , v2.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "fmla v13.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v12.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v4.16b, v5.16b, #8 \n" /* v16 = 2345 */ - - // r2 - "fmla v14.4s , v4.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v13.4s , v4.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v12.4s , v4.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "fmla v14.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v13.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v12.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */ - - // r3 - "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "bif v12.16b, v22.16b, v18.16b \n" - - "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "st1 {v12.4s}, [%[doutr0]], #16 \n" - - "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ - - // r3 - "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "bif v13.16b, v23.16b, v18.16b \n" - - "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "st1 {v13.4s}, [%[doutr1]], #16 \n" - - "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ - - // r3 - "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "bif v14.16b, v24.16b, v18.16b \n" - - "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "st1 {v14.4s}, [%[doutr2]], #16 \n" - - "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "bif v15.16b, v25.16b, v18.16b \n" - - "st1 {v15.4s}, [%[doutr3]], #16 \n" - // end - "0: \n" - : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [din_ptr4] "+r"(din_ptr4), - [din_ptr5] "+r"(din_ptr5), - [doutr0] "+r"(doutr0), - [doutr1] "+r"(doutr1), - [doutr2] "+r"(doutr2), - [doutr3] "+r"(doutr3) - : [w0] "w"(wr0), - [w1] "w"(wr1), - [w2] "w"(wr2), - [bias_val] "r"(vbias), - [vmask] "r"(vmask), - [rmask] "r"(rmask), - [vzero] "w"(vzero), - [remain] "r"(remain) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25"); - dout_ptr = dout_ptr + 4 * w_out; - } - } -#else - for (int i = 0; i < ch_in; ++i) { - const float* din_channel = din_batch + i * size_in_channel; - - const float* weight_ptr = weights + i * 9; - float32x4_t wr0 = vld1q_f32(weight_ptr); - float32x4_t wr1 = vld1q_f32(weight_ptr + 3); - float32x4_t wr2 = vld1q_f32(weight_ptr + 6); - float bias_val = flag_bias ? bias[i] : 0.f; - - float* dout_channel = dout_batch + i * size_out_channel; - - const float* dr0 = din_channel; - const float* dr1 = dr0 + w_in; - const float* dr2 = dr1 + w_in; - const float* dr3 = dr2 + w_in; - - const float* din0_ptr = nullptr; - const float* din1_ptr = nullptr; - const float* din2_ptr = nullptr; - const float* din3_ptr = nullptr; - - float* doutr0 = nullptr; - float* doutr1 = nullptr; - - float* ptr_zero = const_cast(zero); - - for (int i = 0; i < h_out; i += 2) { - din0_ptr = dr0; - din1_ptr = dr1; - din2_ptr = dr2; - din3_ptr = dr3; - - doutr0 = dout_channel; - doutr1 = dout_channel + w_out; - - dr0 = dr2; - dr1 = dr3; - dr2 = dr1 + w_in; - dr3 = dr2 + w_in; - //! process bottom pad - if (i + 3 >= h_in) { - switch (i + 3 - h_in) { - case 3: - din1_ptr = zero_ptr; - case 2: - din2_ptr = zero_ptr; - case 1: - din3_ptr = zero_ptr; - case 0: - din3_ptr = zero_ptr; - default: - break; - } - } - //! process bottom remain - if (i + 2 > h_out) { - doutr1 = write_ptr; - } - int cnt = tile_w; - unsigned int* rmask_ptr = rmask; - unsigned int* vmask_ptr = vmask; - asm volatile( - "pld [%[din0_ptr]] @ preload data\n" - "pld [%[din1_ptr]] @ preload data\n" - "pld [%[din2_ptr]] @ preload data\n" - "pld [%[din3_ptr]] @ preload data\n" - - "vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0\n" - "vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r1\n" - "vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r2\n" - "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r3\n" - "vld1.32 {d18}, [%[din0_ptr]] @ load din r0\n" - "vld1.32 {d22}, [%[din1_ptr]] @ load din r0\n" - "vld1.32 {d26}, [%[din2_ptr]] @ load din r0\n" - "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" - - "vdup.32 q4, %[bias_val] @ and \n" // q4 - // = - // vbias - "vdup.32 q5, %[bias_val] @ and \n" // q5 - // = - // vbias - - "vext.32 q6, q8, q9, #1 @ 1234\n" - "vext.32 q7, q8, q9, #2 @ 2345\n" - // mid - "1: @ right pad entry\n" - // r0 - "vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n" - - "pld [%[din0_ptr]] @ preload data\n" - "pld [%[din1_ptr]] @ preload data\n" - "pld [%[din2_ptr]] @ preload data\n" - "pld [%[din3_ptr]] @ preload data\n" - - "vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0\n" - - "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n" - - "vld1.32 {d18}, [%[din0_ptr]] @ load din r0\n" - - "vext.32 q6, q10, q11, #1 @ 1234\n" - "vext.32 q7, q10, q11, #2 @ 2345\n" - - // r1 - "vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" - "vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r0\n" - - "vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d22}, [%[din1_ptr]] @ load din r0\n" - - "vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" - - "vext.32 q6, q12, q13, #1 @ 1234\n" - "vext.32 q7, q12, q13, #2 @ 2345\n" - - // r2 - "vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" - "vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r0\n" - - "vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d26}, [%[din2_ptr]] @ load din r0\n" - - "vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n" - - "vext.32 q6, q14, q15, #1 @ 1234\n" - "vext.32 q7, q14, q15, #2 @ 2345\n" - - // r3 - "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" - - "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" - "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" - - "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" - "vdup.32 q4, %[bias_val] @ and \n" // q4 - // = - // vbias - - "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" - - "vext.32 q6, q8, q9, #1 @ 1234\n" - "vext.32 q7, q8, q9, #2 @ 2345\n" - - "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add " - "pointer\n" - - "subs %[cnt], #1 @ loop count minus 1\n" - - "vdup.32 q5, %[bias_val] @ and \n" // q4 - // = - // vbias - - "bne 1b @ jump to main loop start " - "point\n" - - // right - "3: @ right pad entry\n" - "cmp %[remain], #1 @ check whether has " - "mid cols\n" - "blt 0f @ jump to main loop start " - "point\n" - "vld1.32 {d19}, [%[vmask]]! @ load din r0\n" - "vld1.32 {d23}, [%[vmask]]! @ load din r0\n" - - "vld1.32 {d27}, [%[vmask]]! @ load din r0\n" - "vld1.32 {d31}, [%[vmask]]! @ load din r0\n" - - "vbif d16, %e[vzero], d19 @ bit select, deal with " - "right pad\n" - "vbif d17, %e[vzero], d23 @ bit select, deal with " - "right pad\n" - "vbif d18, %e[vzero], d27 @ bit select, deal with " - "right pad\n" - - "vbif d20, %e[vzero], d19 @ bit select, deal with " - "right pad\n" - "vbif d21, %e[vzero], d23 @ bit select, deal with " - "right pad\n" - "vbif d22, %e[vzero], d27 @ bit select, deal with " - "right pad\n" - - "vext.32 q6, q8, q9, #1 @ 1234\n" - "vext.32 q7, q8, q9, #2 @ 2345\n" - - // r0 - "vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n" - - "vbif d24, %e[vzero], d19 @ bit select, deal with " - "right pad\n" - "vbif d25, %e[vzero], d23 @ bit select, deal with " - "right pad\n" - "vbif d26, %e[vzero], d27 @ bit select, deal with " - "right pad\n" - - "vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" - - "vbif d28, %e[vzero], d19 @ bit select, deal with " - "right pad\n" - "vbif d29, %e[vzero], d23 @ bit select, deal with " - "right pad\n" - "vbif d30, %e[vzero], d27 @ bit select, deal with " - "right pad\n" - - "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n" - - "vext.32 q6, q10, q11, #1 @ 1234\n" - "vext.32 q7, q10, q11, #2 @ 2345\n" - - // r1 - "vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" - "vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d19}, [%[rmask]]! @ load din r0\n" - "vld1.32 {d23}, [%[rmask]]! @ load din r0\n" - - "vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d16-d17}, [%[dout_ptr1]] @ load din r0\n" - "vld1.32 {d20-d21}, [%[dout_ptr2]] @ load din r0\n" - - "vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" - - "vext.32 q6, q12, q13, #1 @ 1234\n" - "vext.32 q7, q12, q13, #2 @ 2345\n" - - // r2 - "vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" - "vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n" - - "vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" - - "vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n" - - "vext.32 q6, q14, q15, #1 @ 1234\n" - "vext.32 q7, q14, q15, #2 @ 2345\n" - - // r3 - "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" - - "vbif d8, d16, d19 @ bit select, deal with right pad\n" - "vbif d9, d17, d23 @ bit select, deal with right pad\n" - - "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" - - "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" - - "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" - - "vbif d10, d20, d19 @ bit select, deal with right " - "pad\n" - "vbif d11, d21, d23 @ bit select, deal with right " - "pad\n" - - "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add " - "pointer\n" - "0: \n" - - : [dout_ptr1] "+r"(doutr0), - [dout_ptr2] "+r"(doutr1), - [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr), - [din3_ptr] "+r"(din3_ptr), - [cnt] "+r"(cnt), - [rmask] "+r"(rmask_ptr), - [vmask] "+r"(vmask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias_val] "r"(bias_val), - [vzero] "w"(vzero), - [remain] "r"(remain) - : "cc", - "memory", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - dout_channel += 2 * w_out; - } //! end of processing mid rows - } -#endif - } -} - -/** - * \brief depthwise convolution kernel 3x3, stride 2 - */ -// w_in > 7 -void conv_depthwise_3x3s2p0_bias(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx) { - int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; - int out_pad_idx[4] = {0, 1, 2, 3}; - - int tile_w = w_out >> 2; - int cnt_remain = w_out % 4; - - unsigned int size_right_remain = (unsigned int)(w_in - (tile_w << 3)); - - uint32x4_t vmask_rp1 = vcgtq_s32(vdupq_n_s32(size_right_remain), - vld1q_s32(right_pad_idx)); // 0 2 4 6 - uint32x4_t vmask_rp2 = vcgtq_s32(vdupq_n_s32(size_right_remain), - vld1q_s32(right_pad_idx + 4)); // 1 3 5 7 - uint32x4_t wmask = - vcgtq_s32(vdupq_n_s32(cnt_remain), vld1q_s32(out_pad_idx)); // 0 1 2 3 - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - - float* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - float* write_ptr = zero_ptr + w_in; - - unsigned int dmask[12]; - - vst1q_u32(dmask, vmask_rp1); - vst1q_u32(dmask + 4, vmask_rp2); - vst1q_u32(dmask + 8, wmask); - - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * ch_in * size_in_channel; - float* dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int i = 0; i < ch_in; ++i) { - const float* din_channel = din_batch + i * size_in_channel; - float* dout_channel = dout_batch + i * size_out_channel; - - const float* weight_ptr = weights + i * 9; - float32x4_t wr0 = vld1q_f32(weight_ptr); - float32x4_t wr1 = vld1q_f32(weight_ptr + 3); - float32x4_t wr2 = vld1q_f32(weight_ptr + 6); - - float32x4_t vzero = vdupq_n_f32(0.f); - - float32x4_t wbias; - float bias_c = 0.f; - if (flag_bias) { - wbias = vdupq_n_f32(bias[i]); - bias_c = bias[i]; - } else { - wbias = vdupq_n_f32(0.f); - } - - const float* dr0 = din_channel; - const float* dr1 = dr0 + w_in; - const float* dr2 = dr1 + w_in; - const float* dr3 = dr2 + w_in; - const float* dr4 = dr3 + w_in; - - const float* din0_ptr = dr0; - const float* din1_ptr = dr1; - const float* din2_ptr = dr2; - const float* din3_ptr = dr3; - const float* din4_ptr = dr4; - - float* doutr0 = dout_channel; - float* doutr0_ptr = nullptr; - float* doutr1_ptr = nullptr; - -#ifdef __aarch64__ - for (int i = 0; i < h_out; i += 2) { - din0_ptr = dr0; - din1_ptr = dr1; - din2_ptr = dr2; - din3_ptr = dr3; - din4_ptr = dr4; - - doutr0_ptr = doutr0; - doutr1_ptr = doutr0 + w_out; - - dr0 = dr4; - dr1 = dr0 + w_in; - dr2 = dr1 + w_in; - dr3 = dr2 + w_in; - dr4 = dr3 + w_in; - - //! process bottom pad - if (i + 4 >= h_in) { - switch (i + 4 - h_in) { - case 4: - din1_ptr = zero_ptr; - case 3: - din2_ptr = zero_ptr; - case 2: - din3_ptr = zero_ptr; - case 1: - din4_ptr = zero_ptr; - case 0: - din4_ptr = zero_ptr; - default: - break; - } - } - //! process output pad - if (i + 2 > h_out) { - doutr1_ptr = write_ptr; - } - int cnt = tile_w; - asm volatile( - // top - // Load up 12 elements (3 vectors) from each of 8 sources. - "0: \n" - "prfm pldl1keep, [%[inptr0]] \n" - "prfm pldl1keep, [%[inptr1]] \n" - "prfm pldl1keep, [%[inptr2]] \n" - "prfm pldl1keep, [%[inptr3]] \n" - "prfm pldl1keep, [%[inptr4]] \n" - "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" // v0={0,2,4,6} - // v1={1,3,5,7} - "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" - "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" - "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" - "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" - - "and v16.16b, %[vbias].16b, %[vbias].16b \n" // v10 = vbias - "and v17.16b, %[vbias].16b, %[vbias].16b \n" // v16 = vbias - - "ld1 {v15.4s}, [%[inptr0]] \n" - "ld1 {v18.4s}, [%[inptr1]] \n" - "ld1 {v19.4s}, [%[inptr2]] \n" - "ld1 {v20.4s}, [%[inptr3]] \n" - "ld1 {v21.4s}, [%[inptr4]] \n" - - "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} - // mid - "2: \n" - // r0 - "fmul v11.4s, v0.4s, %[w0].s[0] \n" // {0,2,4,6} * w00 - "fmul v12.4s, v1.4s, %[w0].s[1] \n" // {1,3,5,7} * w01 - "fmla v16.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v2.16b, v18.16b, #4 \n" // v10 = {2,4,6,8} - "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" // v0={0,2,4,6} - // v1={1,3,5,7} - - // r1 - "fmla v11.4s, v2.4s, %[w1].s[0] \n" // {0,2,4,6} * w00 - "fmla v12.4s, v3.4s, %[w1].s[1] \n" // {1,3,5,7} * w01 - "fmla v16.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v4.16b, v19.16b, #4 \n" // v10 = {2,4,6,8} - - "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" - - // r2 - "fmul v13.4s, v4.4s, %[w0].s[0] \n" // {0,2,4,6} * w00 - "fmla v11.4s, v4.4s, %[w2].s[0] \n" // {0,2,4,6} * w00 - - "fmul v14.4s, v5.4s, %[w0].s[1] \n" // {1,3,5,7} * w01 - "fmla v12.4s, v5.4s, %[w2].s[1] \n" // {1,3,5,7} * w01 - - "fmla v17.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02 - "fmla v16.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v6.16b, v20.16b, #4 \n" // v10 = {2,4,6,8} - - "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" - - // r3 - "fmla v13.4s, v6.4s, %[w1].s[0] \n" // {0,2,4,6} * w00 - "fmla v14.4s, v7.4s, %[w1].s[1] \n" // {1,3,5,7} * w01 - "fmla v17.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v8.16b, v21.16b, #4 \n" // v10 = {2,4,6,8} - - "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" - - "fadd v16.4s, v16.4s, v11.4s \n" - "fadd v16.4s, v16.4s, v12.4s \n" - - // r4 - "fmla v13.4s, v8.4s, %[w2].s[0] \n" // {0,2,4,6} * w00 - "fmla v14.4s, v9.4s, %[w2].s[1] \n" // {1,3,5,7} * w01 - "fmla v17.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02 - - "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" - "ld1 {v15.4s}, [%[inptr0]] \n" - "ld1 {v18.4s}, [%[inptr1]] \n" - "st1 {v16.4s}, [%[outptr0]], #16 \n" - - "fadd v17.4s, v17.4s, v13.4s \n" - - "ld1 {v19.4s}, [%[inptr2]] \n" - "ld1 {v20.4s}, [%[inptr3]] \n" - "ld1 {v21.4s}, [%[inptr4]] \n" - - "fadd v17.4s, v17.4s, v14.4s \n" - - "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} - "and v16.16b, %[vbias].16b, %[vbias].16b \n" // v10 = vbias - "subs %[cnt], %[cnt], #1 \n" - - "st1 {v17.4s}, [%[outptr1]], #16 \n" - - "and v17.16b, %[vbias].16b, %[vbias].16b \n" // v16 = vbias - - "bne 2b \n" - - // right - "1: \n" - "cmp %[remain], #1 \n" - "blt 4f \n" - "3: \n" - "bif v0.16b, %[vzero].16b, %[mask1].16b \n" // pipei - "bif v1.16b, %[vzero].16b, %[mask2].16b \n" // pipei - - "bif v2.16b, %[vzero].16b, %[mask1].16b \n" // pipei - "bif v3.16b, %[vzero].16b, %[mask2].16b \n" // pipei - - "bif v4.16b, %[vzero].16b, %[mask1].16b \n" // pipei - "bif v5.16b, %[vzero].16b, %[mask2].16b \n" // pipei - - "ext v10.16b, v0.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} - - "bif v6.16b, %[vzero].16b, %[mask1].16b \n" // pipei - "bif v7.16b, %[vzero].16b, %[mask2].16b \n" // pipei - - // r0 - "fmul v11.4s, v0.4s, %[w0].s[0] \n" // {0,2,4,6} * w00 - "fmul v12.4s, v1.4s, %[w0].s[1] \n" // {1,3,5,7} * w01 - "fmla v16.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v2.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} - "bif v8.16b, %[vzero].16b, %[mask1].16b \n" // pipei - "bif v9.16b, %[vzero].16b, %[mask2].16b \n" // pipei - - // r1 - "fmla v11.4s, v2.4s, %[w1].s[0] \n" // {0,2,4,6} * w00 - "fmla v12.4s, v3.4s, %[w1].s[1] \n" // {1,3,5,7} * w01 - "fmla v16.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v4.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} - - // r2 - "fmul v13.4s, v4.4s, %[w0].s[0] \n" // {0,2,4,6} * w00 - "fmla v11.4s, v4.4s, %[w2].s[0] \n" // {0,2,4,6} * w00 - - "fmul v14.4s, v5.4s, %[w0].s[1] \n" // {1,3,5,7} * w01 - "fmla v12.4s, v5.4s, %[w2].s[1] \n" // {1,3,5,7} * w01 - - "fmla v17.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02 - "fmla v16.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v6.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} - - // r3 - "fmla v13.4s, v6.4s, %[w1].s[0] \n" // {0,2,4,6} * w00 - "fmla v14.4s, v7.4s, %[w1].s[1] \n" // {1,3,5,7} * w01 - "fmla v17.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v8.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} - "ld1 {v0.4s}, [%[outptr0]] \n" - - "fadd v16.4s, v16.4s, v11.4s \n" - "fadd v16.4s, v16.4s, v12.4s \n" - "ld1 {v1.4s}, [%[outptr1]] \n" - - // r4 - "fmla v13.4s, v8.4s, %[w2].s[0] \n" // {0,2,4,6} * w00 - "fmla v14.4s, v9.4s, %[w2].s[1] \n" // {1,3,5,7} * w01 - "fmla v17.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02 - - "bif v16.16b, v0.16b, %[wmask].16b \n" // pipei - - "fadd v17.4s, v17.4s, v13.4s \n" - - "st1 {v16.4s}, [%[outptr0]], #16 \n" - - "fadd v17.4s, v17.4s, v14.4s \n" - - "bif v17.16b, v1.16b, %[wmask].16b \n" // pipei - - "st1 {v17.4s}, [%[outptr1]], #16 \n" - "4: \n" - : [inptr0] "+r"(din0_ptr), - [inptr1] "+r"(din1_ptr), - [inptr2] "+r"(din2_ptr), - [inptr3] "+r"(din3_ptr), - [inptr4] "+r"(din4_ptr), - [outptr0] "+r"(doutr0_ptr), - [outptr1] "+r"(doutr1_ptr), - [cnt] "+r"(cnt) - : [vzero] "w"(vzero), - [w0] "w"(wr0), - [w1] "w"(wr1), - [w2] "w"(wr2), - [remain] "r"(cnt_remain), - [mask1] "w"(vmask_rp1), - [mask2] "w"(vmask_rp2), - [wmask] "w"(wmask), - [vbias] "w"(wbias) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21"); - doutr0 = doutr0 + 2 * w_out; - } -#else - for (int i = 0; i < h_out; i++) { - din0_ptr = dr0; - din1_ptr = dr1; - din2_ptr = dr2; - - doutr0_ptr = doutr0; - - dr0 = dr2; - dr1 = dr0 + w_in; - dr2 = dr1 + w_in; - - //! process bottom pad - if (i + 2 > h_in) { - switch (i + 2 - h_in) { - case 2: - din1_ptr = zero_ptr; - case 1: - din2_ptr = zero_ptr; - default: - break; - } - } - int cnt = tile_w; - unsigned int* mask_ptr = dmask; - asm volatile( - // Load up 12 elements (3 vectors) from each of 8 sources. - "0: \n" - "vmov.u32 q9, #0 \n" - "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r1\n" - "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" - "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r1\n" - "pld [%[din0_ptr]] @ preload data\n" - "pld [%[din1_ptr]] @ preload data\n" - "pld [%[din2_ptr]] @ preload data\n" - - "vld1.32 {d16}, [%[din0_ptr]] @ load din r0\n" // q2={8,10,12,14} - - "vdup.32 q3, %[bias] @ and \n" // q10 = - // vbias - // mid - "2: \n" - "vext.32 q6, q10, q8, #1 @ shift left 1 \n" // q6 = {2,4,6,8} - "vld1.32 {d16}, [%[din1_ptr]] @ load din r1\n" // q2={8,10,12,14} - - "vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, " - "out0\n" // q0 * w00 - "vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, " - "out0\n" // q1 * w01 - "vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, " - "out0\n" // q6 * w02 - - "vext.32 q7, q12, q8, #1 @ shift left 1 \n" // q6 = {2,4,6,8} - "vld1.32 {d16}, [%[din2_ptr]] @ load din r1\n" // q2={8,10,12,14} - - "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" // v0={0,2,4,6} v1={1,3,5,7} - - "vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, " - "out0\n" // q0 * w00 - "vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, " - "out0\n" // q1 * w01 - "vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, " - "out0\n" // q6 * w02 - - "vext.32 q6, q14, q8, #1 @ shift left 1 \n" // q6 = {2,4,6,8} - - "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" // v0={0,2,4,6} v1={1,3,5,7} - - "vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, " - "out0\n" // q0 * w00 - "vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, " - "out0\n" // q1 * w01 - "vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, " - "out0\n" // q6 * w02 - - "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2\n" // v4={0,2,4,6} v5={1,3,5,7} - - "vadd.f32 q3, q3, q4 @ add \n" - "vadd.f32 q3, q3, q5 @ add \n" - - "subs %[cnt], #1 \n" - - "vld1.32 {d16}, [%[din0_ptr]] @ load din r0\n" // q2={8,10,12,14} - - "vst1.32 {d6-d7}, [%[outptr]]! \n" - - "vdup.32 q3, %[bias] @ and \n" // q10 = - // vbias - "bne 2b \n" - - // right - "1: \n" - "cmp %[remain], #1 \n" - "blt 3f \n" - - "vld1.f32 {d12-d15}, [%[mask_ptr]]! @ load mask\n" - - "vbif q10, q9, q6 @ bit select, deal " - "with right pad\n" - "vbif q11, q9, q7 @ bit select, deal " - "with right pad\n" - "vbif q12, q9, q6 @ bit select, deal " - "with right pad\n" - "vbif q13, q9, q7 @ bit select, deal " - "with right pad\n" - "vbif q14, q9, q6 @ bit select, deal " - "with right pad\n" - "vbif q15, q9, q7 @ bit select, deal " - "with right pad\n" - - "vext.32 q6, q10, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,8} - "vext.32 q7, q12, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,8} - - "vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, " - "out0\n" // q0 * w00 - "vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, " - "out0\n" // q1 * w01 - "vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, " - "out0\n" // q6 * w02 - - "vext.32 q6, q14, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,8} - "vld1.f32 {d20-d21}, [%[outptr]] @ load output\n" - - "vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, " - "out0\n" // q0 * w00 - "vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, " - "out0\n" // q1 * w01 - "vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, " - "out0\n" // q6 * w02 - - "vld1.f32 {d22-d23}, [%[mask_ptr]] @ load mask\n" - - "vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, " - "out0\n" // q0 * w00 - "vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, " - "out0\n" // q1 * w01 - "vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, " - "out0\n" // q6 * w02 - - "vadd.f32 q3, q3, q4 @ add \n" - "vadd.f32 q3, q3, q5 @ add \n" - - "vbif.f32 q3, q10, q11 @ write mask\n" - - "vst1.32 {d6-d7}, [%[outptr]]! \n" - "3: \n" - : [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr), - [outptr] "+r"(doutr0_ptr), - [cnt] "+r"(cnt), - [mask_ptr] "+r"(mask_ptr) - : [remain] "r"(cnt_remain), - [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias] "r"(bias_c) - : "cc", - "memory", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - - doutr0 = doutr0 + w_out; - } -#endif - } - } -} - -// 4line -void conv_depthwise_3x3s1p0_bias_relu(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx) { - //! pad is done implicit - const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; - //! for 4x6 convolution window - const unsigned int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0}; - - float* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - float* write_ptr = zero_ptr + w_in; - - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - int w_stride = 9; - - int tile_w = w_out >> 2; - int remain = w_out % 4; - - unsigned int size_pad_right = (unsigned int)(6 + (tile_w << 2) - w_in); - const int remian_idx[4] = {0, 1, 2, 3}; - - uint32x4_t vmask_rp1 = - vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right)); - uint32x4_t vmask_rp2 = - vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right)); - uint32x4_t vmask_result = - vcgtq_s32(vdupq_n_s32(remain), vld1q_s32(remian_idx)); - - unsigned int vmask[8]; - vst1q_u32(vmask, vmask_rp1); - vst1q_u32(vmask + 4, vmask_rp2); - - unsigned int rmask[4]; - vst1q_u32(rmask, vmask_result); - - float32x4_t vzero = vdupq_n_f32(0.f); - - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * ch_in * size_in_channel; - float* dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for -#ifdef __aarch64__ - for (int c = 0; c < ch_in; c++) { - float* dout_ptr = dout_batch + c * size_out_channel; - - const float* din_ch_ptr = din_batch + c * size_in_channel; - - float bias_val = flag_bias ? bias[c] : 0.f; - float vbias[4] = {bias_val, bias_val, bias_val, bias_val}; - - const float* wei_ptr = weights + c * w_stride; - - float32x4_t wr0 = vld1q_f32(wei_ptr); - float32x4_t wr1 = vld1q_f32(wei_ptr + 3); - float32x4_t wr2 = vld1q_f32(wei_ptr + 6); - // wr0 = vsetq_lane_f32(0.f, wr0, 3); - // wr1 = vsetq_lane_f32(0.f, wr1, 3); - // wr2 = vsetq_lane_f32(0.f, wr2, 3); - - float* doutr0 = dout_ptr; - float* doutr1 = doutr0 + w_out; - float* doutr2 = doutr1 + w_out; - float* doutr3 = doutr2 + w_out; - - const float* dr0 = din_ch_ptr; - const float* dr1 = dr0 + w_in; - const float* dr2 = dr1 + w_in; - const float* dr3 = dr2 + w_in; - const float* dr4 = dr3 + w_in; - const float* dr5 = dr4 + w_in; - - const float* din_ptr0 = dr0; - const float* din_ptr1 = dr1; - const float* din_ptr2 = dr2; - const float* din_ptr3 = dr3; - const float* din_ptr4 = dr4; - const float* din_ptr5 = dr5; - - for (int i = 0; i < h_out; i += 4) { - //! process top pad pad_h = 1 - din_ptr0 = dr0; - din_ptr1 = dr1; - din_ptr2 = dr2; - din_ptr3 = dr3; - din_ptr4 = dr4; - din_ptr5 = dr5; - - doutr0 = dout_ptr; - doutr1 = doutr0 + w_out; - doutr2 = doutr1 + w_out; - doutr3 = doutr2 + w_out; - - dr0 = dr4; - dr1 = dr5; - dr2 = dr1 + w_in; - dr3 = dr2 + w_in; - dr4 = dr3 + w_in; - dr5 = dr4 + w_in; - - //! process bottom pad - if (i + 5 >= h_in) { - switch (i + 5 - h_in) { - case 5: - din_ptr1 = zero_ptr; - case 4: - din_ptr2 = zero_ptr; - case 3: - din_ptr3 = zero_ptr; - case 2: - din_ptr4 = zero_ptr; - case 1: - din_ptr5 = zero_ptr; - case 0: - din_ptr5 = zero_ptr; - default: - break; - } - } - //! process bottom remain - if (i + 4 > h_out) { - switch (i + 4 - h_out) { - case 3: - doutr1 = write_ptr; - case 2: - doutr2 = write_ptr; - case 1: - doutr3 = write_ptr; - default: - break; - } - } - - int cnt = tile_w; - asm volatile( - "PRFM PLDL1KEEP, [%[din_ptr0]] \n" - "PRFM PLDL1KEEP, [%[din_ptr1]] \n" - "PRFM PLDL1KEEP, [%[din_ptr2]] \n" - "PRFM PLDL1KEEP, [%[din_ptr3]] \n" - "PRFM PLDL1KEEP, [%[din_ptr4]] \n" - "PRFM PLDL1KEEP, [%[din_ptr5]] \n" - "movi v21.4s, #0x0\n" /* out0 = 0 */ - - "ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ - - "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */ - - // mid - "4: \n" - // r0 - "fmla v12.4s , v0.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v12.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v2.16b, v3.16b, #8 \n" /* v16 = 2345 */ - - // r1 - "fmla v13.4s , v2.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v12.4s , v2.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v13.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v12.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v4.16b, v5.16b, #8 \n" /* v16 = 2345 */ - - // r2 - "fmla v14.4s , v4.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v13.4s , v4.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v12.4s , v4.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v14.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v13.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v12.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */ - - // r3 - "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "fmax v12.4s, v12.4s, %[vzero].4s \n" /* relu */ - - "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "st1 {v12.4s}, [%[doutr0]], #16 \n" - "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ - "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - - // r4 - "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "fmax v13.4s, v13.4s, %[vzero].4s \n" /* relu */ - - "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "st1 {v13.4s}, [%[doutr1]], #16 \n" - - "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ - "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - // r5 - "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "fmax v14.4s, v14.4s, %[vzero].4s \n" /* relu */ - - "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "st1 {v14.4s}, [%[doutr2]], #16 \n" - - "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ - "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - "fmax v15.4s, v15.4s, %[vzero].4s \n" /* relu */ - - "subs %[cnt], %[cnt], #1 \n" - - "st1 {v15.4s}, [%[doutr3]], #16 \n" - "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - - "bne 4b \n" - - // right - "5: \n" - "cmp %[remain], #1 \n" - "blt 0f \n" - "ld1 {v18.4s, v19.4s}, [%[vmask]] \n" - "ld1 {v22.4s}, [%[doutr0]] \n" - "ld1 {v23.4s}, [%[doutr1]] \n" - "ld1 {v24.4s}, [%[doutr2]] \n" - "ld1 {v25.4s}, [%[doutr3]] \n" - - "bif v0.16b, %[vzero].16b, v18.16b \n" - "bif v1.16b, %[vzero].16b, v19.16b \n" - "bif v2.16b, %[vzero].16b, v18.16b \n" - "bif v3.16b, %[vzero].16b, v19.16b \n" - - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ - - // r0 - "fmla v12.4s, v0.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "bif v4.16b, %[vzero].16b, v18.16b \n" - "bif v5.16b, %[vzero].16b, v19.16b \n" - "bif v6.16b, %[vzero].16b, v18.16b \n" - "bif v7.16b, %[vzero].16b, v19.16b \n" - - "fmla v12.4s, v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "bif v8.16b, %[vzero].16b, v18.16b \n" - "bif v9.16b, %[vzero].16b, v19.16b \n" - "bif v10.16b, %[vzero].16b, v18.16b \n" - "bif v11.16b, %[vzero].16b, v19.16b \n" - - "fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v2.16b, v3.16b, #8 \n" /* v16 = 2345 */ - "ld1 {v18.4s}, [%[rmask]] \n" - - // r1 - "fmla v13.4s , v2.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v12.4s , v2.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "fmla v13.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v12.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v4.16b, v5.16b, #8 \n" /* v16 = 2345 */ - - // r2 - "fmla v14.4s , v4.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v13.4s , v4.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v12.4s , v4.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "fmla v14.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v13.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v12.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */ - - // r3 - "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "fmax v12.4s, v12.4s, %[vzero].4s \n" /* relu */ - - "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "bif v12.16b, v22.16b, v18.16b \n" - - "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ - "st1 {v12.4s}, [%[doutr0]], #16 \n" - - // r3 - "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "fmax v13.4s, v13.4s, %[vzero].4s \n" /* relu */ - - "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "bif v13.16b, v23.16b, v18.16b \n" - - "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "st1 {v13.4s}, [%[doutr1]], #16 \n" - - "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ - - // r3 - "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "fmax v14.4s, v14.4s, %[vzero].4s \n" /* relu */ - - "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "bif v14.16b, v24.16b, v18.16b \n" - - "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "st1 {v14.4s}, [%[doutr2]], #16 \n" - - "fmax v15.4s, v15.4s, %[vzero].4s \n" /* relu */ - - "bif v15.16b, v25.16b, v18.16b \n" - - "st1 {v15.4s}, [%[doutr3]], #16 \n" - // end - "0: \n" - : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [din_ptr4] "+r"(din_ptr4), - [din_ptr5] "+r"(din_ptr5), - [doutr0] "+r"(doutr0), - [doutr1] "+r"(doutr1), - [doutr2] "+r"(doutr2), - [doutr3] "+r"(doutr3) - : [w0] "w"(wr0), - [w1] "w"(wr1), - [w2] "w"(wr2), - [bias_val] "r"(vbias), - [vmask] "r"(vmask), - [rmask] "r"(rmask), - [vzero] "w"(vzero), - [remain] "r"(remain) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25"); - dout_ptr = dout_ptr + 4 * w_out; - } - } -#else - for (int i = 0; i < ch_in; ++i) { - const float* din_channel = din_batch + i * size_in_channel; - - const float* weight_ptr = weights + i * 9; - float32x4_t wr0 = vld1q_f32(weight_ptr); - float32x4_t wr1 = vld1q_f32(weight_ptr + 3); - float32x4_t wr2 = vld1q_f32(weight_ptr + 6); - float bias_val = flag_bias ? bias[i] : 0.f; - - float* dout_channel = dout_batch + i * size_out_channel; - - const float* dr0 = din_channel; - const float* dr1 = dr0 + w_in; - const float* dr2 = dr1 + w_in; - const float* dr3 = dr2 + w_in; - - const float* din0_ptr = nullptr; - const float* din1_ptr = nullptr; - const float* din2_ptr = nullptr; - const float* din3_ptr = nullptr; - - float* doutr0 = nullptr; - float* doutr1 = nullptr; - - float* ptr_zero = const_cast(zero); - - for (int i = 0; i < h_out; i += 2) { - //! process top pad pad_h = 1 - din0_ptr = dr0; - din1_ptr = dr1; - din2_ptr = dr2; - din3_ptr = dr3; - - doutr0 = dout_channel; - doutr1 = dout_channel + w_out; - - dr0 = dr2; - dr1 = dr3; - dr2 = dr1 + w_in; - dr3 = dr2 + w_in; - //! process bottom pad - if (i + 3 >= h_in) { - switch (i + 3 - h_in) { - case 3: - din1_ptr = zero_ptr; - case 2: - din2_ptr = zero_ptr; - case 1: - din3_ptr = zero_ptr; - case 0: - din3_ptr = zero_ptr; - default: - break; - } - } - //! process bottom remain - if (i + 2 > h_out) { - doutr1 = write_ptr; - } - int cnt = tile_w; - unsigned int* rmask_ptr = rmask; - unsigned int* vmask_ptr = vmask; - asm volatile( - "pld [%[din0_ptr]] @ preload data\n" - "pld [%[din1_ptr]] @ preload data\n" - "pld [%[din2_ptr]] @ preload data\n" - "pld [%[din3_ptr]] @ preload data\n" - - "vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0\n" - "vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r1\n" - "vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r2\n" - "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r3\n" - "vld1.32 {d18}, [%[din0_ptr]] @ load din r0\n" - "vld1.32 {d22}, [%[din1_ptr]] @ load din r0\n" - "vld1.32 {d26}, [%[din2_ptr]] @ load din r0\n" - "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" - - "vdup.32 q4, %[bias_val] @ and \n" // q4 - // = - // vbias - "vdup.32 q5, %[bias_val] @ and \n" // q5 - // = - // vbias - - "vext.32 q6, q8, q9, #1 @ 1234\n" - "vext.32 q7, q8, q9, #2 @ 2345\n" - - // mid - "1: @ right pad entry\n" - // r0 - "vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n" - - "pld [%[din0_ptr]] @ preload data\n" - "pld [%[din1_ptr]] @ preload data\n" - "pld [%[din2_ptr]] @ preload data\n" - "pld [%[din3_ptr]] @ preload data\n" - - "vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0\n" - - "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n" - - "vld1.32 {d18}, [%[din0_ptr]] @ load din r0\n" - - "vext.32 q6, q10, q11, #1 @ 1234\n" - "vext.32 q7, q10, q11, #2 @ 2345\n" - - // r1 - "vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" - "vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r0\n" - - "vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d22}, [%[din1_ptr]] @ load din r0\n" - - "vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" - - "vext.32 q6, q12, q13, #1 @ 1234\n" - "vext.32 q7, q12, q13, #2 @ 2345\n" - - // r2 - "vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" - "vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r0\n" - - "vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d26}, [%[din2_ptr]] @ load din r0\n" - - "vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n" - - "vext.32 q6, q14, q15, #1 @ 1234\n" - "vext.32 q7, q14, q15, #2 @ 2345\n" - - // r3 - "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" - - "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" - "vmax.f32 q4, q4, %q[vzero] @ relu \n" - - "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" - "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" - - "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" - - "vext.32 q6, q8, q9, #1 @ 1234\n" - "vext.32 q7, q8, q9, #2 @ 2345\n" - "vmax.f32 q5, q5, %q[vzero] @ relu \n" - - "vdup.32 q4, %[bias_val] @ and \n" // q4 - // = - // vbias - - "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add " - "pointer\n" - - "subs %[cnt], #1 @ loop count minus 1\n" - - "vdup.32 q5, %[bias_val] @ and \n" // q4 - // = - // vbias - - "bne 1b @ jump to main loop start " - "point\n" - - // right - "3: @ right pad entry\n" - "cmp %[remain], #1 @ check whether has " - "mid cols\n" - "blt 0f @ jump to main loop start " - "point\n" - "vld1.32 {d19}, [%[vmask]]! @ load din r0\n" - "vld1.32 {d23}, [%[vmask]]! @ load din r0\n" - - "vld1.32 {d27}, [%[vmask]]! @ load din r0\n" - "vld1.32 {d31}, [%[vmask]]! @ load din r0\n" - - "vbif d16, %e[vzero], d19 @ bit select, deal with " - "right pad\n" - "vbif d17, %e[vzero], d23 @ bit select, deal with " - "right pad\n" - "vbif d18, %e[vzero], d27 @ bit select, deal with " - "right pad\n" - - "vbif d20, %e[vzero], d19 @ bit select, deal with " - "right pad\n" - "vbif d21, %e[vzero], d23 @ bit select, deal with " - "right pad\n" - "vbif d22, %e[vzero], d27 @ bit select, deal with " - "right pad\n" - - "vext.32 q6, q8, q9, #1 @ 1234\n" - "vext.32 q7, q8, q9, #2 @ 2345\n" - - // r0 - "vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n" - - "vbif d24, %e[vzero], d19 @ bit select, deal with " - "right pad\n" - "vbif d25, %e[vzero], d23 @ bit select, deal with " - "right pad\n" - "vbif d26, %e[vzero], d27 @ bit select, deal with " - "right pad\n" - - "vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" - - "vbif d28, %e[vzero], d19 @ bit select, deal with " - "right pad\n" - "vbif d29, %e[vzero], d23 @ bit select, deal with " - "right pad\n" - "vbif d30, %e[vzero], d27 @ bit select, deal with " - "right pad\n" - - "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n" - - "vext.32 q6, q10, q11, #1 @ 1234\n" - "vext.32 q7, q10, q11, #2 @ 2345\n" - - // r1 - "vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" - "vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d19}, [%[rmask]]! @ load din r0\n" - "vld1.32 {d23}, [%[rmask]]! @ load din r0\n" - - "vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d16-d17}, [%[dout_ptr1]] @ load din r0\n" - "vld1.32 {d20-d21}, [%[dout_ptr2]] @ load din r0\n" - - "vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" - - "vext.32 q6, q12, q13, #1 @ 1234\n" - "vext.32 q7, q12, q13, #2 @ 2345\n" - - // r2 - "vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" - "vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n" - - "vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" - - "vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n" - - "vext.32 q6, q14, q15, #1 @ 1234\n" - "vext.32 q7, q14, q15, #2 @ 2345\n" - - // r3 - "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" - - "vmax.f32 q4, q4, %q[vzero] @ relu \n" - - "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" - - "vbif d8, d16, d19 @ bit select, deal with right pad\n" - "vbif d9, d17, d23 @ bit select, deal with right pad\n" - - "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" - - "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" - - "vmax.f32 q5, q5, %q[vzero] @ relu \n" - - "vbif d10, d20, d19 @ bit select, deal with right " - "pad\n" - "vbif d11, d21, d23 @ bit select, deal with right " - "pad\n" - - "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add " - "pointer\n" - "0: \n" - - : [dout_ptr1] "+r"(doutr0), - [dout_ptr2] "+r"(doutr1), - [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr), - [din3_ptr] "+r"(din3_ptr), - [cnt] "+r"(cnt), - [rmask] "+r"(rmask_ptr), - [vmask] "+r"(vmask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias_val] "r"(bias_val), - [vzero] "w"(vzero), - [remain] "r"(remain) - : "cc", - "memory", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - dout_channel += 2 * w_out; - } //! end of processing mid rows - } -#endif - } -} -/** - * \brief depthwise convolution kernel 3x3, stride 2, with reulu - */ -// w_in > 7 -void conv_depthwise_3x3s2p0_bias_relu(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx) { - int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; - int out_pad_idx[4] = {0, 1, 2, 3}; - - int tile_w = w_out >> 2; - int cnt_remain = w_out % 4; - - unsigned int size_right_remain = (unsigned int)(w_in - (tile_w << 3)); - - uint32x4_t vmask_rp1 = vcgtq_s32(vdupq_n_s32(size_right_remain), - vld1q_s32(right_pad_idx)); // 0 2 4 6 - uint32x4_t vmask_rp2 = vcgtq_s32(vdupq_n_s32(size_right_remain), - vld1q_s32(right_pad_idx + 4)); // 1 3 5 7 - uint32x4_t wmask = - vcgtq_s32(vdupq_n_s32(cnt_remain), vld1q_s32(out_pad_idx)); // 0 1 2 3 - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - - float* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - float* write_ptr = zero_ptr + w_in; - - unsigned int dmask[12]; - - vst1q_u32(dmask, vmask_rp1); - vst1q_u32(dmask + 4, vmask_rp2); - vst1q_u32(dmask + 8, wmask); - - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * ch_in * size_in_channel; - float* dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int i = 0; i < ch_in; ++i) { - const float* din_channel = din_batch + i * size_in_channel; - float* dout_channel = dout_batch + i * size_out_channel; - - const float* weight_ptr = weights + i * 9; - float32x4_t wr0 = vld1q_f32(weight_ptr); - float32x4_t wr1 = vld1q_f32(weight_ptr + 3); - float32x4_t wr2 = vld1q_f32(weight_ptr + 6); - - float32x4_t vzero = vdupq_n_f32(0.f); - - float32x4_t wbias; - float bias_c = 0.f; - if (flag_bias) { - wbias = vdupq_n_f32(bias[i]); - bias_c = bias[i]; - } else { - wbias = vdupq_n_f32(0.f); - } - - const float* dr0 = din_channel; - const float* dr1 = dr0 + w_in; - const float* dr2 = dr1 + w_in; - const float* dr3 = dr2 + w_in; - const float* dr4 = dr3 + w_in; - - const float* din0_ptr = dr0; - const float* din1_ptr = dr1; - const float* din2_ptr = dr2; - const float* din3_ptr = dr3; - const float* din4_ptr = dr4; - - float* doutr0 = dout_channel; - float* doutr0_ptr = nullptr; - float* doutr1_ptr = nullptr; - -#ifdef __aarch64__ - for (int i = 0; i < h_out; i += 2) { - din0_ptr = dr0; - din1_ptr = dr1; - din2_ptr = dr2; - din3_ptr = dr3; - din4_ptr = dr4; - - doutr0_ptr = doutr0; - doutr1_ptr = doutr0 + w_out; - - dr0 = dr4; - dr1 = dr0 + w_in; - dr2 = dr1 + w_in; - dr3 = dr2 + w_in; - dr4 = dr3 + w_in; - - //! process bottom pad - if (i + 4 >= h_in) { - switch (i + 4 - h_in) { - case 4: - din1_ptr = zero_ptr; - case 3: - din2_ptr = zero_ptr; - case 2: - din3_ptr = zero_ptr; - case 1: - din4_ptr = zero_ptr; - case 0: - din4_ptr = zero_ptr; - default: - break; - } - } - //! process output pad - if (i + 2 > h_out) { - doutr1_ptr = write_ptr; - } - int cnt = tile_w; - asm volatile( - // top - // Load up 12 elements (3 vectors) from each of 8 sources. - "0: \n" - "prfm pldl1keep, [%[inptr0]] \n" - "prfm pldl1keep, [%[inptr1]] \n" - "prfm pldl1keep, [%[inptr2]] \n" - "prfm pldl1keep, [%[inptr3]] \n" - "prfm pldl1keep, [%[inptr4]] \n" - "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" // v0={0,2,4,6} - // v1={1,3,5,7} - "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" - "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" - "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" - "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" - - "and v16.16b, %[vbias].16b, %[vbias].16b \n" // v10 = vbias - "and v17.16b, %[vbias].16b, %[vbias].16b \n" // v16 = vbias - - "ld1 {v15.4s}, [%[inptr0]] \n" - "ld1 {v18.4s}, [%[inptr1]] \n" - "ld1 {v19.4s}, [%[inptr2]] \n" - "ld1 {v20.4s}, [%[inptr3]] \n" - "ld1 {v21.4s}, [%[inptr4]] \n" - - "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} - // mid - "2: \n" - // r0 - "fmul v11.4s, v0.4s, %[w0].s[0] \n" // {0,2,4,6} * w00 - "fmul v12.4s, v1.4s, %[w0].s[1] \n" // {1,3,5,7} * w01 - "fmla v16.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v2.16b, v18.16b, #4 \n" // v10 = {2,4,6,8} - "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" // v0={0,2,4,6} - // v1={1,3,5,7} - - // r1 - "fmla v11.4s, v2.4s, %[w1].s[0] \n" // {0,2,4,6} * w00 - "fmla v12.4s, v3.4s, %[w1].s[1] \n" // {1,3,5,7} * w01 - "fmla v16.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v4.16b, v19.16b, #4 \n" // v10 = {2,4,6,8} - - "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" - - // r2 - "fmul v13.4s, v4.4s, %[w0].s[0] \n" // {0,2,4,6} * w00 - "fmla v11.4s, v4.4s, %[w2].s[0] \n" // {0,2,4,6} * w00 - - "fmul v14.4s, v5.4s, %[w0].s[1] \n" // {1,3,5,7} * w01 - "fmla v12.4s, v5.4s, %[w2].s[1] \n" // {1,3,5,7} * w01 - - "fmla v17.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02 - "fmla v16.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v6.16b, v20.16b, #4 \n" // v10 = {2,4,6,8} - - "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" - - // r3 - "fmla v13.4s, v6.4s, %[w1].s[0] \n" // {0,2,4,6} * w00 - "fmla v14.4s, v7.4s, %[w1].s[1] \n" // {1,3,5,7} * w01 - "fmla v17.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v8.16b, v21.16b, #4 \n" // v10 = {2,4,6,8} - - "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" - - "fadd v16.4s, v16.4s, v11.4s \n" - "fadd v16.4s, v16.4s, v12.4s \n" - - // r4 - "fmla v13.4s, v8.4s, %[w2].s[0] \n" // {0,2,4,6} * w00 - "fmla v14.4s, v9.4s, %[w2].s[1] \n" // {1,3,5,7} * w01 - "fmla v17.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02 - - "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" - "ld1 {v15.4s}, [%[inptr0]] \n" - "ld1 {v18.4s}, [%[inptr1]] \n" - "fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ - - "fadd v17.4s, v17.4s, v13.4s \n" - - "ld1 {v19.4s}, [%[inptr2]] \n" - "ld1 {v20.4s}, [%[inptr3]] \n" - "ld1 {v21.4s}, [%[inptr4]] \n" - "st1 {v16.4s}, [%[outptr0]], #16 \n" - - "fadd v17.4s, v17.4s, v14.4s \n" - - "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} - "and v16.16b, %[vbias].16b, %[vbias].16b \n" // v10 = vbias - "fmax v17.4s, v17.4s, %[vzero].4s \n" /* relu */ - - "subs %[cnt], %[cnt], #1 \n" - - "st1 {v17.4s}, [%[outptr1]], #16 \n" - - "and v17.16b, %[vbias].16b, %[vbias].16b \n" // v16 = vbias - - "bne 2b \n" - - // right - "1: \n" - "cmp %[remain], #1 \n" - "blt 4f \n" - "3: \n" - "bif v0.16b, %[vzero].16b, %[mask1].16b \n" // pipei - "bif v1.16b, %[vzero].16b, %[mask2].16b \n" // pipei - - "bif v2.16b, %[vzero].16b, %[mask1].16b \n" // pipei - "bif v3.16b, %[vzero].16b, %[mask2].16b \n" // pipei - - "bif v4.16b, %[vzero].16b, %[mask1].16b \n" // pipei - "bif v5.16b, %[vzero].16b, %[mask2].16b \n" // pipei - - "ext v10.16b, v0.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} - - "bif v6.16b, %[vzero].16b, %[mask1].16b \n" // pipei - "bif v7.16b, %[vzero].16b, %[mask2].16b \n" // pipei - - // r0 - "fmul v11.4s, v0.4s, %[w0].s[0] \n" // {0,2,4,6} * w00 - "fmul v12.4s, v1.4s, %[w0].s[1] \n" // {1,3,5,7} * w01 - "fmla v16.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v2.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} - "bif v8.16b, %[vzero].16b, %[mask1].16b \n" // pipei - "bif v9.16b, %[vzero].16b, %[mask2].16b \n" // pipei - - // r1 - "fmla v11.4s, v2.4s, %[w1].s[0] \n" // {0,2,4,6} * w00 - "fmla v12.4s, v3.4s, %[w1].s[1] \n" // {1,3,5,7} * w01 - "fmla v16.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v4.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} - - // r2 - "fmul v13.4s, v4.4s, %[w0].s[0] \n" // {0,2,4,6} * w00 - "fmla v11.4s, v4.4s, %[w2].s[0] \n" // {0,2,4,6} * w00 - - "fmul v14.4s, v5.4s, %[w0].s[1] \n" // {1,3,5,7} * w01 - "fmla v12.4s, v5.4s, %[w2].s[1] \n" // {1,3,5,7} * w01 - - "fmla v17.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02 - "fmla v16.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v6.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} - - // r3 - "fmla v13.4s, v6.4s, %[w1].s[0] \n" // {0,2,4,6} * w00 - "fmla v14.4s, v7.4s, %[w1].s[1] \n" // {1,3,5,7} * w01 - "fmla v17.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v8.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} - "ld1 {v0.4s}, [%[outptr0]] \n" - - "fadd v16.4s, v16.4s, v11.4s \n" - "fadd v16.4s, v16.4s, v12.4s \n" - "ld1 {v1.4s}, [%[outptr1]] \n" - - // r4 - "fmla v13.4s, v8.4s, %[w2].s[0] \n" // {0,2,4,6} * w00 - "fmla v14.4s, v9.4s, %[w2].s[1] \n" // {1,3,5,7} * w01 - "fmla v17.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02 - - "fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ - - "fadd v17.4s, v17.4s, v13.4s \n" - - "bif v16.16b, v0.16b, %[wmask].16b \n" // pipei - - "fadd v17.4s, v17.4s, v14.4s \n" - - "st1 {v16.4s}, [%[outptr0]], #16 \n" - - "fmax v17.4s, v17.4s, %[vzero].4s \n" /* relu */ - - "bif v17.16b, v1.16b, %[wmask].16b \n" // pipei - - "st1 {v17.4s}, [%[outptr1]], #16 \n" - "4: \n" - : [inptr0] "+r"(din0_ptr), - [inptr1] "+r"(din1_ptr), - [inptr2] "+r"(din2_ptr), - [inptr3] "+r"(din3_ptr), - [inptr4] "+r"(din4_ptr), - [outptr0] "+r"(doutr0_ptr), - [outptr1] "+r"(doutr1_ptr), - [cnt] "+r"(cnt) - : [vzero] "w"(vzero), - [w0] "w"(wr0), - [w1] "w"(wr1), - [w2] "w"(wr2), - [remain] "r"(cnt_remain), - [mask1] "w"(vmask_rp1), - [mask2] "w"(vmask_rp2), - [wmask] "w"(wmask), - [vbias] "w"(wbias) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21"); - doutr0 = doutr0 + 2 * w_out; - } -#else - for (int i = 0; i < h_out; i++) { - din0_ptr = dr0; - din1_ptr = dr1; - din2_ptr = dr2; - - doutr0_ptr = doutr0; - - dr0 = dr2; - dr1 = dr0 + w_in; - dr2 = dr1 + w_in; - - //! process bottom pad - if (i + 2 > h_in) { - switch (i + 2 - h_in) { - case 2: - din1_ptr = zero_ptr; - case 1: - din2_ptr = zero_ptr; - default: - break; - } - } - int cnt = tile_w; - unsigned int* mask_ptr = dmask; - asm volatile( - // Load up 12 elements (3 vectors) from each of 8 sources. - "0: \n" - "vmov.u32 q9, #0 \n" - "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r1\n" - "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" - "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r1\n" - "pld [%[din0_ptr]] @ preload data\n" - "pld [%[din1_ptr]] @ preload data\n" - "pld [%[din2_ptr]] @ preload data\n" - - "vld1.32 {d16}, [%[din0_ptr]] @ load din r0\n" // q2={8,10,12,14} - - "vdup.32 q3, %[bias] @ and \n" // q10 = - // vbias - // mid - "2: \n" - "vext.32 q6, q10, q8, #1 @ shift left 1 \n" // q6 = {2,4,6,8} - "vld1.32 {d16}, [%[din1_ptr]] @ load din r1\n" // q2={8,10,12,14} - - "vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, " - "out0\n" // q0 * w00 - "vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, " - "out0\n" // q1 * w01 - "vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, " - "out0\n" // q6 * w02 - - "vext.32 q7, q12, q8, #1 @ shift left 1 \n" // q6 = {2,4,6,8} - "vld1.32 {d16}, [%[din2_ptr]] @ load din r1\n" // q2={8,10,12,14} - - "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" // v0={0,2,4,6} v1={1,3,5,7} - - "vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, " - "out0\n" // q0 * w00 - "vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, " - "out0\n" // q1 * w01 - "vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, " - "out0\n" // q6 * w02 - - "vext.32 q6, q14, q8, #1 @ shift left 1 \n" // q6 = {2,4,6,8} - - "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" // v0={0,2,4,6} v1={1,3,5,7} - - "vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, " - "out0\n" // q0 * w00 - "vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, " - "out0\n" // q1 * w01 - "vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, " - "out0\n" // q6 * w02 - - "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2\n" // v4={0,2,4,6} v5={1,3,5,7} - - "vadd.f32 q3, q3, q4 @ add \n" - "vadd.f32 q3, q3, q5 @ add \n" - - "subs %[cnt], #1 \n" - "vmax.f32 q3, q3, q9 @ relu \n" - - "vld1.32 {d16}, [%[din0_ptr]] @ load din r0\n" // q2={8,10,12,14} - - "vst1.32 {d6-d7}, [%[outptr]]! \n" - - "vdup.32 q3, %[bias] @ and \n" // q10 = - // vbias - "bne 2b \n" - - // right - "1: \n" - "cmp %[remain], #1 \n" - "blt 3f \n" - - "vld1.f32 {d12-d15}, [%[mask_ptr]]! @ load mask\n" - - "vbif q10, q9, q6 @ bit select, deal " - "with right pad\n" - "vbif q11, q9, q7 @ bit select, deal " - "with right pad\n" - "vbif q12, q9, q6 @ bit select, deal " - "with right pad\n" - "vbif q13, q9, q7 @ bit select, deal " - "with right pad\n" - "vbif q14, q9, q6 @ bit select, deal " - "with right pad\n" - "vbif q15, q9, q7 @ bit select, deal " - "with right pad\n" - - "vext.32 q6, q10, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,8} - "vext.32 q7, q12, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,8} - - "vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, " - "out0\n" // q0 * w00 - "vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, " - "out0\n" // q1 * w01 - "vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, " - "out0\n" // q6 * w02 - - "vext.32 q6, q14, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,8} - "vld1.f32 {d20-d21}, [%[outptr]] @ load output\n" - - "vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, " - "out0\n" // q0 * w00 - "vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, " - "out0\n" // q1 * w01 - "vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, " - "out0\n" // q6 * w02 - - "vld1.f32 {d22-d23}, [%[mask_ptr]] @ load mask\n" - - "vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, " - "out0\n" // q0 * w00 - "vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, " - "out0\n" // q1 * w01 - "vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, " - "out0\n" // q6 * w02 - - "vadd.f32 q3, q3, q4 @ add \n" - "vadd.f32 q3, q3, q5 @ add \n" - - "vmax.f32 q3, q3, q9 @ relu \n" - - "vbif.f32 q3, q10, q11 @ write mask\n" - - "vst1.32 {d6-d7}, [%[outptr]]! \n" - "3: \n" - : [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr), - [outptr] "+r"(doutr0_ptr), - [cnt] "+r"(cnt), - [mask_ptr] "+r"(mask_ptr) - : [remain] "r"(cnt_remain), - [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias] "r"(bias_c) - : "cc", - "memory", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - - doutr0 = doutr0 + w_out; - } -#endif - } - } -} -/** - * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, - * width <= 4 - */ -void conv_depthwise_3x3s1p0_bias_s(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx) { - //! 3x3s1 convolution, implemented by direct algorithm - //! pad is done implicit - //! for 4x6 convolution window - const int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0}; - const float zero_ptr[4] = {0.f, 0.f, 0.f, 0.f}; - - float32x4_t vzero = vdupq_n_f32(0.f); - uint32x4_t vmask_rp1 = - vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(6 - w_in)); - uint32x4_t vmask_rp2 = - vcgeq_s32(vld1q_s32(right_pad_idx + 4), vdupq_n_s32(6 - w_in)); - - unsigned int vmask[8]; - vst1q_u32(vmask, vmask_rp1); - vst1q_u32(vmask + 4, vmask_rp2); - - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * ch_in * size_in_channel; - float* dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int i = 0; i < ch_in; ++i) { - float* dout_channel = dout_batch + i * size_out_channel; - const float* din_channel = din_batch + i * size_in_channel; - const float* weight_ptr = weights + i * 9; - float32x4_t wr0 = vld1q_f32(weight_ptr); - float32x4_t wr1 = vld1q_f32(weight_ptr + 3); - float32x4_t wr2 = vld1q_f32(weight_ptr + 6); - float32x4_t wbias; - if (flag_bias) { - wbias = vdupq_n_f32(bias[i]); - } else { - wbias = vdupq_n_f32(0.f); - } - - float out_buf1[4]; - float out_buf2[4]; - float trash_buf[4]; - - float* doutr0 = dout_channel; - float* doutr1 = dout_channel + w_out; - - for (int j = 0; j < h_out; j += 2) { - const float* dr0 = din_channel + j * w_in; - const float* dr1 = dr0 + w_in; - const float* dr2 = dr1 + w_in; - const float* dr3 = dr2 + w_in; - - doutr0 = dout_channel + j * w_out; - doutr1 = doutr0 + w_out; - - if (j + 3 >= h_in) { - switch (j + 3 - h_in) { - case 3: - dr1 = zero_ptr; - case 2: - dr2 = zero_ptr; - case 1: - dr3 = zero_ptr; - doutr1 = trash_buf; - case 0: - dr3 = zero_ptr; - doutr1 = trash_buf; - default: - break; - } - } -#ifdef __aarch64__ - asm volatile( - "prfm pldl1keep, [%[din0]]\n" - "prfm pldl1keep, [%[din1]]\n" - "prfm pldl1keep, [%[din2]]\n" - "prfm pldl1keep, [%[din3]]\n" - - "ld1 {v0.4s, v1.4s}, [%[din0]]\n" - "ld1 {v2.4s, v3.4s}, [%[din1]]\n" - "ld1 {v4.4s, v5.4s}, [%[din2]]\n" - "ld1 {v6.4s, v7.4s}, [%[din3]]\n" - - "bif v0.16b, %[zero].16b, %[mask1].16b\n" // d0_1234 - "bif v1.16b, %[zero].16b, %[mask2].16b\n" // d0_1234 - - "bif v2.16b, %[zero].16b, %[mask1].16b\n" // d1_1234 - "bif v3.16b, %[zero].16b, %[mask2].16b\n" // d1_1234 - - "bif v4.16b, %[zero].16b, %[mask1].16b\n" // d2_1234 - "bif v5.16b, %[zero].16b, %[mask2].16b\n" // d2_1234 - - "bif v6.16b, %[zero].16b, %[mask1].16b\n" // d3_1234 - "bif v7.16b, %[zero].16b, %[mask2].16b\n" // d3_1234 - - "ext v8.16b, v0.16b, v1.16b, #4\n" // d1_2345 - "ext v9.16b, v0.16b, v1.16b, #8\n" // d1_3450 - - "and v12.16b, %[vbias].16b, %[vbias].16b \n" // v12 = vbias - "and v13.16b, %[vbias].16b, %[vbias].16b \n" // v13 = vbias - - // r0 - "fmul v10.4s, v0.4s, %[wr0].s[0]\n" // d0_1234 * w0[0] - "fmul v11.4s, v8.4s, %[wr0].s[1]\n" // d1_2345 * w0[1] - "fmla v12.4s, v9.4s, %[wr0].s[2]\n" // d0_3456 * w0[2] - - "ext v8.16b, v2.16b, v3.16b, #4\n" // d1_2345 - "ext v9.16b, v2.16b, v3.16b, #8\n" // d1_3450 - - // r1 - "fmul v14.4s, v2.4s, %[wr0].s[0]\n" // d0_1234 * w0[0] - "fmla v10.4s, v2.4s, %[wr1].s[0]\n" // d0_1234 * w0[0] - - "fmul v15.4s, v8.4s, %[wr0].s[1]\n" // d1_2345 * w0[1] - "fmla v11.4s, v8.4s, %[wr1].s[1]\n" // d1_2345 * w0[1] - - "fmla v13.4s, v9.4s, %[wr0].s[2]\n" // d0_3456 * w0[2] - "fmla v12.4s, v9.4s, %[wr1].s[2]\n" // d0_3456 * w0[2] - - "ext v8.16b, v4.16b, v5.16b, #4\n" // d1_2345 - "ext v9.16b, v4.16b, v5.16b, #8\n" // d1_3450 - - // r2 - "fmla v14.4s, v4.4s, %[wr1].s[0]\n" // d0_1234 * w0[0] - "fmla v10.4s, v4.4s, %[wr2].s[0]\n" // d0_1234 * w0[0] - - "fmla v15.4s, v8.4s, %[wr1].s[1]\n" // d1_2345 * w0[1] - "fmla v11.4s, v8.4s, %[wr2].s[1]\n" // d1_2345 * w0[1] - - "fmla v13.4s, v9.4s, %[wr1].s[2]\n" // d0_3456 * w0[2] - "fmla v12.4s, v9.4s, %[wr2].s[2]\n" // d0_3456 * w0[2] - - "ext v8.16b, v6.16b, v7.16b, #4\n" // d1_2345 - "ext v9.16b, v6.16b, v7.16b, #8\n" // d1_3450 - - // r3 - "fmla v14.4s, v6.4s, %[wr2].s[0]\n" // d0_1234 * w0[0] - - "fmla v15.4s, v8.4s, %[wr2].s[1]\n" // d1_2345 * w0[1] - - "fadd v12.4s, v12.4s, v10.4s\n" - - "fmla v13.4s, v9.4s, %[wr2].s[2]\n" // d0_3456 * w0[2] - - "fadd v12.4s, v12.4s, v11.4s\n" // out1 - "fadd v13.4s, v13.4s, v14.4s\n" // out2 - "fadd v13.4s, v13.4s, v15.4s\n" // out2 - - "prfm pldl1keep, [%[out1]]\n" - "prfm pldl1keep, [%[out2]]\n" - - "st1 {v12.4s}, [%[out1]]\n" - "st1 {v13.4s}, [%[out2]]\n" - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [vbias] "w"(wbias), - [mask1] "w"(vmask_rp1), - [mask2] "w"(vmask_rp2), - [zero] "w"(vzero), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15"); -#else - unsigned int* vmask_ptr = vmask; - float bias_val = flag_bias ? bias[i] : 0.f; - asm volatile( - "pld [%[din0]]\n" - "pld [%[din1]]\n" - "pld [%[din2]]\n" - "pld [%[din3]]\n" - - "vld1.32 {d16-d18}, [%[din0]] @ load din r0\n" - "vld1.32 {d20-d22}, [%[din1]] @ load din r1\n" - "vld1.32 {d24-d26}, [%[din2]] @ load din r2\n" - "vld1.32 {d28-d30}, [%[din3]] @ load din r3\n" - - "vdup.32 q4, %[bias_val] @ and \n" // q4 - // = - // vbias - "vdup.32 q5, %[bias_val] @ and \n" // q5 - // = - // vbias - - "vld1.32 {d19}, [%[vmask]]! @ load din r0\n" - "vld1.32 {d23}, [%[vmask]]! @ load din r0\n" - - "vld1.32 {d27}, [%[vmask]]! @ load din r0\n" - - "vbif d16, %e[vzero], d19 @ bit select, deal with " - "right pad\n" - "vbif d20, %e[vzero], d19 @ bit select, deal with " - "right pad\n" - - "vbif d17, %e[vzero], d23 @ bit select, deal with " - "right pad\n" - "vbif d21, %e[vzero], d23 @ bit select, deal with " - "right pad\n" - - "vbif d18, %e[vzero], d27 @ bit select, deal with " - "right pad\n" - "vbif d22, %e[vzero], d27 @ bit select, deal with " - "right pad\n" - - "vext.32 q6, q8, q9, #1 @ 1234\n" - "vext.32 q7, q8, q9, #2 @ 2345\n" - - // r0 - "vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n" - - "vbif d24, %e[vzero], d19 @ bit select, deal with " - "right pad\n" - "vbif d25, %e[vzero], d23 @ bit select, deal with " - "right pad\n" - "vbif d26, %e[vzero], d27 @ bit select, deal with " - "right pad\n" - - "vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" - - "vbif d28, %e[vzero], d19 @ bit select, deal with " - "right pad\n" - "vbif d29, %e[vzero], d23 @ bit select, deal with " - "right pad\n" - "vbif d30, %e[vzero], d27 @ bit select, deal with " - "right pad\n" - - "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n" - - "vext.32 q6, q10, q11, #1 @ 1234\n" - "vext.32 q7, q10, q11, #2 @ 2345\n" - - // r1 - "vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" - "vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n" - - "vmul.f32 q8, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" - "vmul.f32 q10, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" - - "vmul.f32 q9, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n" - "vmul.f32 q11, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" - - "vext.32 q6, q12, q13, #1 @ 1234\n" - "vext.32 q7, q12, q13, #2 @ 2345\n" - - // r2 - "vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" - "vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n" - - "vmla.f32 q8, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q10, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" - - "vmla.f32 q9, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q11, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n" - - "vext.32 q6, q14, q15, #1 @ 1234\n" - "vext.32 q7, q14, q15, #2 @ 2345\n" - - // r3 - "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" - - "vmla.f32 q8, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" - "vadd.f32 q4, q4, q10 @ q4 += q10 \n" - - "pld [%[out1]]\n" - "pld [%[out2]]\n" - - "vmla.f32 q9, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" - "vadd.f32 q4, q4, q11 @ q4 += q10 \n" - - "vadd.f32 q5, q5, q8 @ q4 += q10 \n" - "vadd.f32 q5, q5, q9 @ q4 += q10 \n" - - "vst1.32 {d8-d9}, [%[out1]] @ store result, add pointer\n" - "vst1.32 {d10-d11}, [%[out2]] @ store result, add pointer\n" - - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3), - [vmask] "+r"(vmask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [vzero] "w"(vzero), - [bias_val] "r"(bias_val), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "cc", - "memory", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -#endif // __aarch64__ - for (int w = 0; w < w_out; ++w) { - *doutr0++ = out_buf1[w]; - *doutr1++ = out_buf2[w]; - } - } // end of processing heights - } // end of processing channels - } // end of processing batchs -} -/** - * \brief depthwise convolution kernel 3x3, stride 2, width <= 4 - */ - -void conv_depthwise_3x3s2p0_bias_s(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx) { - int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; - int out_pad_idx[4] = {0, 1, 2, 3}; - float zeros[8] = {0.0f}; - - uint32x4_t vmask_rp1 = - vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx)); // 0 2 4 6 - uint32x4_t vmask_rp2 = - vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx + 4)); // 1 3 5 7 - - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - - unsigned int dmask[8]; - vst1q_u32(dmask, vmask_rp1); - vst1q_u32(dmask + 4, vmask_rp2); - - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * ch_in * size_in_channel; - float* dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int i = 0; i < ch_in; ++i) { - const float* din_channel = din_batch + i * size_in_channel; - float* dout_channel = dout_batch + i * size_out_channel; - - const float* weight_ptr = weights + i * 9; - float32x4_t wr0 = vld1q_f32(weight_ptr); - float32x4_t wr1 = vld1q_f32(weight_ptr + 3); - float32x4_t wr2 = vld1q_f32(weight_ptr + 6); - - float bias_c = 0.f; - - if (flag_bias) { - bias_c = bias[i]; - } - float32x4_t vbias = vdupq_n_f32(bias_c); - float out_buf[4]; - const float* dr0 = din_channel; - const float* dr1 = dr0 + w_in; - const float* dr2 = dr1 + w_in; - for (int j = 0; j < h_out; ++j) { - const float* din0_ptr = dr0; - const float* din1_ptr = dr1; - const float* din2_ptr = dr2; - - dr0 = dr2; - dr1 = dr0 + w_in; - dr2 = dr1 + w_in; - - unsigned int* mask_ptr = dmask; -#ifdef __aarch64__ - asm volatile( - // Load up 12 elements (3 vectors) from each of 8 sources. - "movi v9.4s, #0 \n" - "ld1 {v6.4s, v7.4s}, [%[mask_ptr]], #32 \n" - - "ld2 {v10.4s, v11.4s}, [%[din0_ptr]], #32 \n" // v10={0,2,4,6} - // v11={1,3,5,7} - "ld2 {v12.4s, v13.4s}, [%[din1_ptr]], #32 \n" // v13={0,2,4,6} - // v12={1,3,5,7} - "ld2 {v14.4s, v15.4s}, [%[din2_ptr]], #32 \n" // v14={0,2,4,6} - // v15={1,3,5,7} - "and v4.16b, %[bias].16b, %[bias].16b \n" // v10 = vbias - - "bif v10.16b, v9.16b, v6.16b \n" - "bif v11.16b, v9.16b, v7.16b \n" - "bif v12.16b, v9.16b, v6.16b \n" - "bif v13.16b, v9.16b, v7.16b \n" - "bif v14.16b, v9.16b, v6.16b \n" - "bif v15.16b, v9.16b, v7.16b \n" - - "ext v6.16b, v10.16b, v9.16b, #4 \n" // v6 = - // {2,4,6,8} - "ext v7.16b, v12.16b, v9.16b, #4 \n" // v6 = - // {2,4,6,8} - "ext v8.16b, v14.16b, v9.16b, #4 \n" // v6 = - // {2,4,6,8} - - "fmla v4.4s, v10.4s, %[wr0].s[0] \n" // 0246 * w00 - "fmul v5.4s, v11.4s, %[wr0].s[1] \n" // 1357 * w01 - "fmul v16.4s, v6.4s, %[wr0].s[2] \n" // 2468 * w02 - - "fmla v4.4s, v12.4s, %[wr1].s[0] \n" // v12 * w11 - "fmla v5.4s, v13.4s, %[wr1].s[1] \n" // v13 * w12 - "fmla v16.4s, v7.4s, %[wr1].s[2] \n" // v7 * w10 - - "fmla v4.4s, v14.4s, %[wr2].s[0] \n" // v14 * w20 - "fmla v5.4s, v15.4s, %[wr2].s[1] \n" // v15 * w21 - "fmla v16.4s, v8.4s, %[wr2].s[2] \n" // v8 * w22 - - "fadd v4.4s, v4.4s, v5.4s \n" - "fadd v4.4s, v4.4s, v16.4s \n" - - // "fadd v4.4s, v4.4s, %[bias].4s \n" - "st1 {v4.4s}, [%[out]] \n" - : [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr), - [mask_ptr] "+r"(mask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias] "w"(vbias), - [out] "r"(out_buf) - : "cc", - "memory", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16"); - -#else - asm volatile( - // Load up 12 elements (3 vectors) from each of 8 sources. - "vmov.u32 q9, #0 \n" - "vld1.f32 {d12-d15}, [%[mask_ptr]] @ load mask\n" - "vdup.32 q3, %[bias] @ and \n" // q3 = - // vbias - - "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" // q10={0,2,4,6} q11={1,3,5,7} - "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" // q13={0,2,4,6} q12={1,3,5,7} - "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2\n" // q14={0,2,4,6} q15={1,3,5,7} - - "vbif q10, q9, q6 @ bit select, deal " - "with right pad\n" - "vbif q11, q9, q7 @ bit select, deal " - "with right pad\n" - "vbif q12, q9, q6 @ bit select, deal " - "with right pad\n" - "vbif q13, q9, q7 @ bit select, deal " - "with right pad\n" - "vbif q14, q9, q6 @ bit select, deal " - "with right pad\n" - "vbif q15, q9, q7 @ bit select, deal " - "with right pad\n" - - "vext.32 q6, q10, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,0} - "vext.32 q7, q12, q9, #1 @ shift left 1 \n" // q7 = {2,4,6,0} - "vext.32 q8, q14, q9, #1 @ shift left 1 \n" // q8 = {2,4,6,0} - - "vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, " - "out0\n" // {0,2,4,6} - "vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, " - "out0\n" // {1,3,5,7} - "vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, " - "out0\n" // {2,4,6,0} - - "vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, " - "out0\n" // q12 * w11 - "vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, " - "out0\n" // q13 * w12 - "vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, " - "out0\n" // q7 * w10 - - "vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, " - "out0\n" // q14 * w20 - "vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, " - "out0\n" // q15 * w21 - "vmla.f32 q3, q8, %f[wr2][0] @ mul weight 2, " - "out0\n" // q8 * w22 - - "vadd.f32 q3, q3, q4 @ add \n" - "vadd.f32 q3, q3, q5 @ add \n" - - "vst1.32 {d6-d7}, [%[out]] \n" - : [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias] "r"(bias_c), - [out] "r"(out_buf), - [mask_ptr] "r"(dmask) - : "cc", - "memory", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -#endif // __aarch64__ - for (int w = 0; w < w_out; ++w) { - *dout_channel++ = out_buf[w]; - } - } - } - } -} -/** - * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, - * width <= 4 - */ -void conv_depthwise_3x3s1p0_bias_s_relu(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx) { - //! 3x3s1 convolution, implemented by direct algorithm - //! pad is done implicit - //! for 4x6 convolution window - const int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0}; - const float zero_ptr[4] = {0.f, 0.f, 0.f, 0.f}; - - float32x4_t vzero = vdupq_n_f32(0.f); - uint32x4_t vmask_rp1 = - vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(6 - w_in)); - uint32x4_t vmask_rp2 = - vcgeq_s32(vld1q_s32(right_pad_idx + 4), vdupq_n_s32(6 - w_in)); - - unsigned int vmask[8]; - vst1q_u32(vmask, vmask_rp1); - vst1q_u32(vmask + 4, vmask_rp2); - - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * ch_in * size_in_channel; - float* dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int i = 0; i < ch_in; ++i) { - float* dout_channel = dout_batch + i * size_out_channel; - const float* din_channel = din_batch + i * size_in_channel; - const float* weight_ptr = weights + i * 9; - float32x4_t wr0 = vld1q_f32(weight_ptr); - float32x4_t wr1 = vld1q_f32(weight_ptr + 3); - float32x4_t wr2 = vld1q_f32(weight_ptr + 6); - float32x4_t wbias; - if (flag_bias) { - wbias = vdupq_n_f32(bias[i]); - } else { - wbias = vdupq_n_f32(0.f); - } - - float out_buf1[4]; - float out_buf2[4]; - float trash_buf[4]; - - float* doutr0 = dout_channel; - float* doutr1 = dout_channel + w_out; - - for (int j = 0; j < h_out; j += 2) { - const float* dr0 = din_channel + j * w_in; - const float* dr1 = dr0 + w_in; - const float* dr2 = dr1 + w_in; - const float* dr3 = dr2 + w_in; - - doutr0 = dout_channel + j * w_out; - doutr1 = doutr0 + w_out; - - if (j + 3 >= h_in) { - switch (j + 3 - h_in) { - case 3: - dr1 = zero_ptr; - case 2: - dr2 = zero_ptr; - case 1: - dr3 = zero_ptr; - doutr1 = trash_buf; - case 0: - dr3 = zero_ptr; - doutr1 = trash_buf; - default: - break; - } - } -#ifdef __aarch64__ - asm volatile( - "prfm pldl1keep, [%[din0]]\n" - "prfm pldl1keep, [%[din1]]\n" - "prfm pldl1keep, [%[din2]]\n" - "prfm pldl1keep, [%[din3]]\n" - - "ld1 {v0.4s, v1.4s}, [%[din0]]\n" - "ld1 {v2.4s, v3.4s}, [%[din1]]\n" - "ld1 {v4.4s, v5.4s}, [%[din2]]\n" - "ld1 {v6.4s, v7.4s}, [%[din3]]\n" - - "bif v0.16b, %[zero].16b, %[mask1].16b\n" // d0_1234 - "bif v1.16b, %[zero].16b, %[mask2].16b\n" // d0_1234 - - "bif v2.16b, %[zero].16b, %[mask1].16b\n" // d1_1234 - "bif v3.16b, %[zero].16b, %[mask2].16b\n" // d1_1234 - - "bif v4.16b, %[zero].16b, %[mask1].16b\n" // d2_1234 - "bif v5.16b, %[zero].16b, %[mask2].16b\n" // d2_1234 - - "bif v6.16b, %[zero].16b, %[mask1].16b\n" // d3_1234 - "bif v7.16b, %[zero].16b, %[mask2].16b\n" // d3_1234 - - "ext v8.16b, v0.16b, v1.16b, #4\n" // d1_2345 - "ext v9.16b, v0.16b, v1.16b, #8\n" // d1_3450 - - "and v12.16b, %[vbias].16b, %[vbias].16b \n" // v12 = vbias - "and v13.16b, %[vbias].16b, %[vbias].16b \n" // v13 = vbias - - // r0 - "fmul v10.4s, v0.4s, %[wr0].s[0]\n" // d0_1234 * w0[0] - "fmul v11.4s, v8.4s, %[wr0].s[1]\n" // d1_2345 * w0[1] - "fmla v12.4s, v9.4s, %[wr0].s[2]\n" // d0_3456 * w0[2] - - "ext v8.16b, v2.16b, v3.16b, #4\n" // d1_2345 - "ext v9.16b, v2.16b, v3.16b, #8\n" // d1_3450 - - // r1 - "fmul v14.4s, v2.4s, %[wr0].s[0]\n" // d0_1234 * w0[0] - "fmla v10.4s, v2.4s, %[wr1].s[0]\n" // d0_1234 * w0[0] - - "fmul v15.4s, v8.4s, %[wr0].s[1]\n" // d1_2345 * w0[1] - "fmla v11.4s, v8.4s, %[wr1].s[1]\n" // d1_2345 * w0[1] - - "fmla v13.4s, v9.4s, %[wr0].s[2]\n" // d0_3456 * w0[2] - "fmla v12.4s, v9.4s, %[wr1].s[2]\n" // d0_3456 * w0[2] - - "ext v8.16b, v4.16b, v5.16b, #4\n" // d1_2345 - "ext v9.16b, v4.16b, v5.16b, #8\n" // d1_3450 - - // r2 - "fmla v14.4s, v4.4s, %[wr1].s[0]\n" // d0_1234 * w0[0] - "fmla v10.4s, v4.4s, %[wr2].s[0]\n" // d0_1234 * w0[0] - - "fmla v15.4s, v8.4s, %[wr1].s[1]\n" // d1_2345 * w0[1] - "fmla v11.4s, v8.4s, %[wr2].s[1]\n" // d1_2345 * w0[1] - - "fmla v13.4s, v9.4s, %[wr1].s[2]\n" // d0_3456 * w0[2] - "fmla v12.4s, v9.4s, %[wr2].s[2]\n" // d0_3456 * w0[2] - - "ext v8.16b, v6.16b, v7.16b, #4\n" // d1_2345 - "ext v9.16b, v6.16b, v7.16b, #8\n" // d1_3450 - - // r3 - "fmla v14.4s, v6.4s, %[wr2].s[0]\n" // d0_1234 * w0[0] - - "fmla v15.4s, v8.4s, %[wr2].s[1]\n" // d1_2345 * w0[1] - - "fadd v12.4s, v12.4s, v10.4s\n" - - "fmla v13.4s, v9.4s, %[wr2].s[2]\n" // d0_3456 * w0[2] - - "fadd v12.4s, v12.4s, v11.4s\n" // out1 - "fadd v13.4s, v13.4s, v14.4s\n" // out2 - "fadd v13.4s, v13.4s, v15.4s\n" // out2 - - "prfm pldl1keep, [%[out1]]\n" - "prfm pldl1keep, [%[out2]]\n" - "fmax v12.4s, v12.4s, %[zero].4s \n" - "fmax v13.4s, v13.4s, %[zero].4s \n" - - "st1 {v12.4s}, [%[out1]]\n" - "st1 {v13.4s}, [%[out2]]\n" - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [vbias] "w"(wbias), - [mask1] "w"(vmask_rp1), - [mask2] "w"(vmask_rp2), - [zero] "w"(vzero), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15"); -#else - unsigned int* vmask_ptr = vmask; - float bias_val = flag_bias ? bias[i] : 0.f; - asm volatile( - "pld [%[din0]]\n" - "pld [%[din1]]\n" - "pld [%[din2]]\n" - "pld [%[din3]]\n" - - "vld1.32 {d16-d18}, [%[din0]] @ load din r0\n" - "vld1.32 {d20-d22}, [%[din1]] @ load din r1\n" - "vld1.32 {d24-d26}, [%[din2]] @ load din r2\n" - "vld1.32 {d28-d30}, [%[din3]] @ load din r3\n" - - "vdup.32 q4, %[bias_val] @ and \n" // q4 - // = - // vbias - "vdup.32 q5, %[bias_val] @ and \n" // q5 - // = - // vbias - - "vld1.32 {d19}, [%[vmask]]! @ load din r0\n" - "vld1.32 {d23}, [%[vmask]]! @ load din r0\n" - - "vld1.32 {d27}, [%[vmask]]! @ load din r0\n" - - "vbif d16, %e[vzero], d19 @ bit select, deal with " - "right pad\n" - "vbif d20, %e[vzero], d19 @ bit select, deal with " - "right pad\n" - - "vbif d17, %e[vzero], d23 @ bit select, deal with " - "right pad\n" - "vbif d21, %e[vzero], d23 @ bit select, deal with " - "right pad\n" - - "vbif d18, %e[vzero], d27 @ bit select, deal with " - "right pad\n" - "vbif d22, %e[vzero], d27 @ bit select, deal with " - "right pad\n" - - "vext.32 q6, q8, q9, #1 @ 1234\n" - "vext.32 q7, q8, q9, #2 @ 2345\n" - - // r0 - "vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n" - - "vbif d24, %e[vzero], d19 @ bit select, deal with " - "right pad\n" - "vbif d25, %e[vzero], d23 @ bit select, deal with " - "right pad\n" - "vbif d26, %e[vzero], d27 @ bit select, deal with " - "right pad\n" - - "vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" - - "vbif d28, %e[vzero], d19 @ bit select, deal with " - "right pad\n" - "vbif d29, %e[vzero], d23 @ bit select, deal with " - "right pad\n" - "vbif d30, %e[vzero], d27 @ bit select, deal with " - "right pad\n" - - "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n" - - "vext.32 q6, q10, q11, #1 @ 1234\n" - "vext.32 q7, q10, q11, #2 @ 2345\n" - - // r1 - "vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" - "vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n" - - "vmul.f32 q8, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" - "vmul.f32 q10, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" - - "vmul.f32 q9, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n" - "vmul.f32 q11, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" - - "vext.32 q6, q12, q13, #1 @ 1234\n" - "vext.32 q7, q12, q13, #2 @ 2345\n" - - // r2 - "vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" - "vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n" - - "vmla.f32 q8, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q10, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" - - "vmla.f32 q9, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q11, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n" - - "vext.32 q6, q14, q15, #1 @ 1234\n" - "vext.32 q7, q14, q15, #2 @ 2345\n" - - // r3 - "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" - - "vmla.f32 q8, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" - "vadd.f32 q4, q4, q10 @ q4 += q10 \n" - - "pld [%[out1]]\n" - "pld [%[out2]]\n" - - "vmla.f32 q9, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" - "vadd.f32 q4, q4, q11 @ q4 += q10 \n" - - "vadd.f32 q5, q5, q8 @ q4 += q10 \n" - "vadd.f32 q5, q5, q9 @ q4 += q10 \n" - "vmax.f32 q4, q4, %q[vzero] @ relu \n" - "vmax.f32 q5, q5, %q[vzero] @ relu \n" - - "vst1.32 {d8-d9}, [%[out1]] @ store result, add pointer\n" - "vst1.32 {d10-d11}, [%[out2]] @ store result, add pointer\n" - - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3), - [vmask] "+r"(vmask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [vzero] "w"(vzero), - [bias_val] "r"(bias_val), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "cc", - "memory", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -#endif // __aarch64__ - for (int w = 0; w < w_out; ++w) { - *doutr0++ = out_buf1[w]; - *doutr1++ = out_buf2[w]; - } - // doutr0 = doutr1; - // doutr1 += w_out; - } // end of processing heights - } // end of processing channels - } // end of processing batchs -} - -/** - * \brief depthwise convolution kernel 3x3, stride 2, width <= 7 - */ -void conv_depthwise_3x3s2p0_bias_s_relu(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx) { - int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; - int out_pad_idx[4] = {0, 1, 2, 3}; - float zeros[8] = {0.0f}; - - uint32x4_t vmask_rp1 = - vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx)); // 0 2 4 6 - uint32x4_t vmask_rp2 = - vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx + 4)); // 1 3 5 7 - - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - - unsigned int dmask[8]; - vst1q_u32(dmask, vmask_rp1); - vst1q_u32(dmask + 4, vmask_rp2); - - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * ch_in * size_in_channel; - float* dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int i = 0; i < ch_in; ++i) { - const float* din_channel = din_batch + i * size_in_channel; - float* dout_channel = dout_batch + i * size_out_channel; - - const float* weight_ptr = weights + i * 9; - float32x4_t wr0 = vld1q_f32(weight_ptr); - float32x4_t wr1 = vld1q_f32(weight_ptr + 3); - float32x4_t wr2 = vld1q_f32(weight_ptr + 6); - - float bias_c = 0.f; - - if (flag_bias) { - bias_c = bias[i]; - } - float32x4_t vbias = vdupq_n_f32(bias_c); - float out_buf[4]; - const float* dr0 = din_channel; - const float* dr1 = dr0 + w_in; - const float* dr2 = dr1 + w_in; - for (int j = 0; j < h_out; ++j) { - const float* din0_ptr = dr0; - const float* din1_ptr = dr1; - const float* din2_ptr = dr2; - - dr0 = dr2; - dr1 = dr0 + w_in; - dr2 = dr1 + w_in; - - unsigned int* mask_ptr = dmask; -#ifdef __aarch64__ - asm volatile( - // Load up 12 elements (3 vectors) from each of 8 sources. - "movi v9.4s, #0 \n" - "ld1 {v6.4s, v7.4s}, [%[mask_ptr]] \n" - - "ld2 {v10.4s, v11.4s}, [%[din0_ptr]], #32 \n" // v10={0,2,4,6} - // v11={1,3,5,7} - "ld2 {v12.4s, v13.4s}, [%[din1_ptr]], #32 \n" // v13={0,2,4,6} - // v12={1,3,5,7} - "ld2 {v14.4s, v15.4s}, [%[din2_ptr]], #32 \n" // v14={0,2,4,6} - // v15={1,3,5,7} - "and v4.16b, %[bias].16b, %[bias].16b \n" // v10 = vbias - - "bif v10.16b, v9.16b, v6.16b \n" - "bif v11.16b, v9.16b, v7.16b \n" - "bif v12.16b, v9.16b, v6.16b \n" - "bif v13.16b, v9.16b, v7.16b \n" - "bif v14.16b, v9.16b, v6.16b \n" - "bif v15.16b, v9.16b, v7.16b \n" - - "ext v6.16b, v10.16b, v9.16b, #4 \n" // v6 = - // {2,4,6,8} - "ext v7.16b, v12.16b, v9.16b, #4 \n" // v6 = - // {2,4,6,8} - "ext v8.16b, v14.16b, v9.16b, #4 \n" // v6 = - // {2,4,6,8} - - "fmla v4.4s, v10.4s, %[wr0].s[0] \n" // 0246 * w00 - "fmul v5.4s, v11.4s, %[wr0].s[1] \n" // 1357 * w01 - "fmul v16.4s, v6.4s, %[wr0].s[2] \n" // 2468 * w02 - - "fmla v4.4s, v12.4s, %[wr1].s[0] \n" // v12 * w11 - "fmla v5.4s, v13.4s, %[wr1].s[1] \n" // v13 * w12 - "fmla v16.4s, v7.4s, %[wr1].s[2] \n" // v7 * w10 - - "fmla v4.4s, v14.4s, %[wr2].s[0] \n" // v14 * w20 - "fmla v5.4s, v15.4s, %[wr2].s[1] \n" // v15 * w21 - "fmla v16.4s, v8.4s, %[wr2].s[2] \n" // v8 * w22 - - "fadd v4.4s, v4.4s, v5.4s \n" - "fadd v4.4s, v4.4s, v16.4s \n" - "fmax v4.4s, v4.4s, v9.4s \n" - - // "fadd v4.4s, v4.4s, %[bias].4s \n" - "st1 {v4.4s}, [%[out]] \n" - : [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias] "w"(vbias), - [out] "r"(out_buf), - [mask_ptr] "r"(mask_ptr) - : "cc", - "memory", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16"); - -#else - asm volatile( - // Load up 12 elements (3 vectors) from each of 8 sources. - "vmov.u32 q9, #0 \n" - "vld1.f32 {d12-d15}, [%[mask_ptr]] @ load mask\n" - "vdup.32 q3, %[bias] @ and \n" // q3 = - // vbias - - "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" // q10={0,2,4,6} q11={1,3,5,7} - "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" // q13={0,2,4,6} q12={1,3,5,7} - "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2\n" // q14={0,2,4,6} q15={1,3,5,7} - - "vbif q10, q9, q6 @ bit select, deal " - "with right pad\n" - "vbif q11, q9, q7 @ bit select, deal " - "with right pad\n" - "vbif q12, q9, q6 @ bit select, deal " - "with right pad\n" - "vbif q13, q9, q7 @ bit select, deal " - "with right pad\n" - "vbif q14, q9, q6 @ bit select, deal " - "with right pad\n" - "vbif q15, q9, q7 @ bit select, deal " - "with right pad\n" - - "vext.32 q6, q10, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,0} - "vext.32 q7, q12, q9, #1 @ shift left 1 \n" // q7 = {2,4,6,0} - "vext.32 q8, q14, q9, #1 @ shift left 1 \n" // q8 = {2,4,6,0} - - "vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, " - "out0\n" // {0,2,4,6} - "vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, " - "out0\n" // {1,3,5,7} - "vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, " - "out0\n" // {2,4,6,0} - - "vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, " - "out0\n" // q12 * w11 - "vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, " - "out0\n" // q13 * w12 - "vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, " - "out0\n" // q7 * w10 - - "vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, " - "out0\n" // q14 * w20 - "vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, " - "out0\n" // q15 * w21 - "vmla.f32 q3, q8, %f[wr2][0] @ mul weight 2, " - "out0\n" // q8 * w22 - - "vadd.f32 q3, q3, q4 @ add \n" - "vadd.f32 q3, q3, q5 @ add \n" - - "vmax.f32 q3, q3, q9 @ relu \n" - - "vst1.32 {d6-d7}, [%[out]] \n" - : [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias] "r"(bias_c), - [out] "r"(out_buf), - [mask_ptr] "r"(mask_ptr) - : "cc", - "memory", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -#endif // __aarch64__ - for (int w = 0; w < w_out; ++w) { - *dout_channel++ = out_buf[w]; - } - } - } - } -} - -} // namespace math -} // namespace arm -} // namespace lite -} // namespace paddle diff --git a/lite/backends/arm/math/conv_depthwise_3x3p1.cc b/lite/backends/arm/math/conv_depthwise_3x3p1.cc deleted file mode 100644 index 6f28d48d6d2bdd60e0c33f9b4b753835337fc8a4..0000000000000000000000000000000000000000 --- a/lite/backends/arm/math/conv_depthwise_3x3p1.cc +++ /dev/null @@ -1,4850 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "lite/backends/arm/math/conv_depthwise.h" -#include - -namespace paddle { -namespace lite { -namespace arm { -namespace math { - -void conv_depthwise_3x3s1p1_bias(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); - -//! for input width <= 4 -void conv_depthwise_3x3s1p1_bias_s(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); - -void conv_depthwise_3x3s2p1_bias(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); - -//! for input width <= 4 -void conv_depthwise_3x3s2p1_bias_s(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); - -void conv_depthwise_3x3s1p1_bias_relu(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); - -//! for input width <= 4 -void conv_depthwise_3x3s1p1_bias_s_relu(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); - -void conv_depthwise_3x3s2p1_bias_relu(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); - -//! for input width <= 4 -void conv_depthwise_3x3s2p1_bias_s_relu(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); - -void conv_depthwise_3x3p1_fp32(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - int stride, - bool flag_bias, - bool flag_relu, - ARMContext* ctx) { - if (stride == 1) { - if (flag_relu) { - if (w_in > 4) { - conv_depthwise_3x3s1p1_bias_relu(dout, - din, - weights, - bias, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } else { - conv_depthwise_3x3s1p1_bias_s_relu(dout, - din, - weights, - bias, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } - } else { - if (w_in > 4) { - conv_depthwise_3x3s1p1_bias(dout, - din, - weights, - bias, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } else { - conv_depthwise_3x3s1p1_bias_s(dout, - din, - weights, - bias, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } - } - } else { //! stride = 2 - if (flag_relu) { - if (w_in > 7) { - conv_depthwise_3x3s2p1_bias_relu(dout, - din, - weights, - bias, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } else { - conv_depthwise_3x3s2p1_bias_s_relu(dout, - din, - weights, - bias, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } - } else { - if (w_in > 7) { - conv_depthwise_3x3s2p1_bias(dout, - din, - weights, - bias, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } else { - conv_depthwise_3x3s2p1_bias_s(dout, - din, - weights, - bias, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } - } - } -} -/** - * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, - * width > 4 - */ -// 4line -void conv_depthwise_3x3s1p1_bias(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx) { - //! pad is done implicit - const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; - //! for 4x6 convolution window - const unsigned int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0}; - - float* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - float* write_ptr = zero_ptr + w_in; - - // printf("conv3x3_dw start \n"); - - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - int w_stride = 9; - - int tile_w = (w_in + 3) >> 2; - int cnt_col = tile_w - 2; - - unsigned int size_pad_right = (unsigned int)(1 + (tile_w << 2) - w_in); - - uint32x4_t vmask_rp1 = - vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right)); - uint32x4_t vmask_rp2 = - vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right)); - uint32x4_t vmask_result = - vcgtq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right)); - - unsigned int vmask[8]; - vst1q_u32(vmask, vmask_rp1); - vst1q_u32(vmask + 4, vmask_rp2); - - unsigned int rmask[4]; - vst1q_u32(rmask, vmask_result); - - float32x4_t vzero = vdupq_n_f32(0.f); - - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * ch_in * size_in_channel; - float* dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for -#ifdef __aarch64__ - for (int c = 0; c < ch_in; c++) { - float* dout_ptr = dout_batch + c * size_out_channel; - - const float* din_ch_ptr = din_batch + c * size_in_channel; - - float bias_val = flag_bias ? bias[c] : 0.f; - float vbias[4] = {bias_val, bias_val, bias_val, bias_val}; - - const float* wei_ptr = weights + c * w_stride; - - float32x4_t wr0 = vld1q_f32(wei_ptr); - float32x4_t wr1 = vld1q_f32(wei_ptr + 3); - float32x4_t wr2 = vld1q_f32(wei_ptr + 6); - - float* doutr0 = dout_ptr; - float* doutr1 = doutr0 + w_out; - float* doutr2 = doutr1 + w_out; - float* doutr3 = doutr2 + w_out; - - const float* dr0 = din_ch_ptr; - const float* dr1 = dr0 + w_in; - const float* dr2 = dr1 + w_in; - const float* dr3 = dr2 + w_in; - const float* dr4 = dr3 + w_in; - const float* dr5 = dr4 + w_in; - - const float* din_ptr0 = dr0; - const float* din_ptr1 = dr1; - const float* din_ptr2 = dr2; - const float* din_ptr3 = dr3; - const float* din_ptr4 = dr4; - const float* din_ptr5 = dr5; - - for (int i = 0; i < h_in; i += 4) { - //! process top pad pad_h = 1 - din_ptr0 = dr0; - din_ptr1 = dr1; - din_ptr2 = dr2; - din_ptr3 = dr3; - din_ptr4 = dr4; - din_ptr5 = dr5; - - doutr0 = dout_ptr; - doutr1 = doutr0 + w_out; - doutr2 = doutr1 + w_out; - doutr3 = doutr2 + w_out; - if (i == 0) { - din_ptr0 = zero_ptr; - din_ptr1 = dr0; - din_ptr2 = dr1; - din_ptr3 = dr2; - din_ptr4 = dr3; - din_ptr5 = dr4; - dr0 = dr3; - dr1 = dr4; - dr2 = dr5; - } else { - dr0 = dr4; - dr1 = dr5; - dr2 = dr1 + w_in; - } - dr3 = dr2 + w_in; - dr4 = dr3 + w_in; - dr5 = dr4 + w_in; - - //! process bottom pad - if (i + 5 > h_in) { - switch (i + 5 - h_in) { - case 5: - din_ptr1 = zero_ptr; - case 4: - din_ptr2 = zero_ptr; - case 3: - din_ptr3 = zero_ptr; - case 2: - din_ptr4 = zero_ptr; - case 1: - din_ptr5 = zero_ptr; - default: - break; - } - } - //! process bottom remain - if (i + 4 > h_out) { - switch (i + 4 - h_out) { - case 3: - doutr1 = write_ptr; - case 2: - doutr2 = write_ptr; - case 1: - doutr3 = write_ptr; - default: - break; - } - } - - int cnt = cnt_col; - asm volatile( - "PRFM PLDL1KEEP, [%[din_ptr0]] \n" - "PRFM PLDL1KEEP, [%[din_ptr1]] \n" - "PRFM PLDL1KEEP, [%[din_ptr2]] \n" - "PRFM PLDL1KEEP, [%[din_ptr3]] \n" - "PRFM PLDL1KEEP, [%[din_ptr4]] \n" - "PRFM PLDL1KEEP, [%[din_ptr5]] \n" - "movi v21.4s, #0x0\n" /* out0 = 0 */ - - "ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ - - "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - - "ext v16.16b, %[vzero].16b, v0.16b, #12 \n" /* v16 = 00123*/ - "ext v17.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */ - - // left - // r0 - "fmla v12.4s, v0.4s, %[w0].s[1]\n" /* outr00 += din0_0123 * - w0[1]*/ - - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "sub %[din_ptr0], %[din_ptr0], #4 \n" /* din_ptr0-- */ - "sub %[din_ptr1], %[din_ptr1], #4 \n" /* din_ptr0-- */ - - "fmla v12.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din0_0012 * - w0[0]*/ - - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ - "sub %[din_ptr2], %[din_ptr2], #4 \n" /* din_ptr0-- */ - "sub %[din_ptr3], %[din_ptr3], #4 \n" /* din_ptr0-- */ - - "fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_1234 * - w0[2]*/ - - "ext v16.16b, %[vzero].16b, v2.16b, #12 \n" /* v16 = 00123*/ - "ext v17.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234 */ - - // r1 - "fmla v13.4s , v2.4s, %[w0].s[1]\n" /* outr00 += din1_0123 * - w0[1]*/ - "fmla v12.4s , v2.4s, %[w1].s[1]\n" /* outr00 += din1_0123 * - w1[1]*/ - "sub %[din_ptr4], %[din_ptr4], #4 \n" /* din_ptr0-- */ - "sub %[din_ptr5], %[din_ptr5], #4 \n" /* din_ptr0-- */ - - "fmla v13.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din1_0123 * - w0[1]*/ - "fmla v12.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din1_0123 * - w1[1]*/ - - "ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din1_0123 * - w0[1]*/ - "fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * - w1[1]*/ - - "ext v16.16b, %[vzero].16b, v4.16b, #12 \n" /* v16 = 00123*/ - "ext v17.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234 */ - - // r2 - "fmla v14.4s , v4.4s, %[w0].s[1]\n" /* outr00 += din2_0123 * - w0[1]*/ - "fmla v13.4s , v4.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * - w1[1]*/ - "fmla v12.4s , v4.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * - w2[1]*/ - - "ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v14.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din2_0123 * - w0[1]*/ - "fmla v13.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * - w0[1]*/ - "fmla v12.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * - w1[1]*/ - - "ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din1_0123 * - w0[1]*/ - "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * - w0[1]*/ - "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * - w1[1]*/ - - "ext v16.16b, %[vzero].16b, v6.16b, #12 \n" /* v16 = 00123*/ - "ext v17.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234 */ - - // r3 - "fmla v15.4s , v6.4s, %[w0].s[1]\n" /*outr00 += din2_0123 * - w0[1]*/ - "fmla v14.4s , v6.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * - w1[1]*/ - "fmla v13.4s , v6.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * - w2[1]*/ - - "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v15.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din2_0123 * - w0[1]*/ - "fmla v14.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * - w0[1]*/ - "fmla v13.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * - w1[1]*/ - - "ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din1_0123 * - w0[1]*/ - "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * - w0[1]*/ - "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * - w1[1]*/ - - "ext v16.16b, %[vzero].16b, v8.16b, #12 \n" /* v16 = 00123*/ - "ext v17.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234 */ - - // r4 - "fmla v15.4s , v8.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * - w1[1]*/ - "fmla v14.4s , v8.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * - w2[1]*/ - - "st1 {v12.4s}, [%[doutr0]], #16 \n" /* vst1q_f32() */ - "st1 {v13.4s}, [%[doutr1]], #16 \n" /* vst1q_f32() */ - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v15.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * - w0[1]*/ - "fmla v14.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * - w1[1]*/ - - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - - "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * - w0[1]*/ - "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * - w1[1]*/ - - "ext v16.16b, %[vzero].16b, v10.16b, #12 \n" /* v16 = 00123*/ - "ext v17.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234 */ - - // r5 - "fmla v15.4s , v10.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * - w1[1]*/ - - "st1 {v14.4s}, [%[doutr2]], #16 \n" /* vst1q_f32() */ - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v15.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * - w0[1]*/ - - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - - "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * - w0[1]*/ - - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ - - "st1 {v15.4s}, [%[doutr3]], #16 \n" /* vst1q_f32() */ - "cmp %[cnt], #1 \n" - "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - - "blt 3f \n" - // mid - "1: \n" - // r0 - "fmla v12.4s , v0.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v12.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v2.16b, v3.16b, #8 \n" /* v16 = 2345 */ - - // r1 - "fmla v13.4s , v2.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v12.4s , v2.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v13.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v12.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v4.16b, v5.16b, #8 \n" /* v16 = 2345 */ - - // r2 - "fmla v14.4s , v4.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v13.4s , v4.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v12.4s , v4.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v14.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v13.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v12.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */ - - // r3 - "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "st1 {v12.4s}, [%[doutr0]], #16 \n" - - "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - - "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ - - // r3 - "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "st1 {v13.4s}, [%[doutr1]], #16 \n" - - "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - - "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ - - // r3 - "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "st1 {v14.4s}, [%[doutr2]], #16 \n" - - "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - - "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ - - "subs %[cnt], %[cnt], #1 \n" - - "st1 {v15.4s}, [%[doutr3]], #16 \n" - "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - - "bne 1b \n" - - // right - "3: \n" - "ld1 {v18.4s, v19.4s}, [%[vmask]] \n" - "ld1 {v22.4s}, [%[doutr0]] \n" - "ld1 {v23.4s}, [%[doutr1]] \n" - "ld1 {v24.4s}, [%[doutr2]] \n" - "ld1 {v25.4s}, [%[doutr3]] \n" - - "bif v0.16b, %[vzero].16b, v18.16b \n" - "bif v1.16b, %[vzero].16b, v19.16b \n" - "bif v2.16b, %[vzero].16b, v18.16b \n" - "bif v3.16b, %[vzero].16b, v19.16b \n" - - "bif v4.16b, %[vzero].16b, v18.16b \n" - "bif v5.16b, %[vzero].16b, v19.16b \n" - "bif v6.16b, %[vzero].16b, v18.16b \n" - "bif v7.16b, %[vzero].16b, v19.16b \n" - - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ - - // r0 - "fmla v12.4s, v0.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "bif v8.16b, %[vzero].16b, v18.16b \n" - "bif v9.16b, %[vzero].16b, v19.16b \n" - "bif v10.16b, %[vzero].16b, v18.16b \n" - "bif v11.16b, %[vzero].16b, v19.16b \n" - - "fmla v12.4s, v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "ld1 {v18.4s}, [%[rmask]] \n" - - "fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v2.16b, v3.16b, #8 \n" /* v16 = 2345 */ - - // r1 - "fmla v13.4s , v2.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v12.4s , v2.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "fmla v13.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v12.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v4.16b, v5.16b, #8 \n" /* v16 = 2345 */ - - // r2 - "fmla v14.4s , v4.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v13.4s , v4.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v12.4s , v4.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "fmla v14.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v13.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v12.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */ - - // r3 - "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "bif v12.16b, v22.16b, v18.16b \n" - - "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "st1 {v12.4s}, [%[doutr0]], #16 \n" - - "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ - - // r3 - "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "bif v13.16b, v23.16b, v18.16b \n" - - "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "st1 {v13.4s}, [%[doutr1]], #16 \n" - - "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ - - // r3 - "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "bif v14.16b, v24.16b, v18.16b \n" - - "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "st1 {v14.4s}, [%[doutr2]], #16 \n" - - "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "bif v15.16b, v25.16b, v18.16b \n" - - "st1 {v15.4s}, [%[doutr3]], #16 \n" - : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [din_ptr4] "+r"(din_ptr4), - [din_ptr5] "+r"(din_ptr5), - [doutr0] "+r"(doutr0), - [doutr1] "+r"(doutr1), - [doutr2] "+r"(doutr2), - [doutr3] "+r"(doutr3) - : [w0] "w"(wr0), - [w1] "w"(wr1), - [w2] "w"(wr2), - [bias_val] "r"(vbias), - [vmask] "r"(vmask), - [rmask] "r"(rmask), - [vzero] "w"(vzero) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25"); - dout_ptr = dout_ptr + 4 * w_out; - } - } -#else - for (int i = 0; i < ch_in; ++i) { - const float* din_channel = din_batch + i * size_in_channel; - - const float* weight_ptr = weights + i * 9; - float32x4_t wr0 = vld1q_f32(weight_ptr); - float32x4_t wr1 = vld1q_f32(weight_ptr + 3); - float32x4_t wr2 = vld1q_f32(weight_ptr + 6); - float bias_val = flag_bias ? bias[i] : 0.f; - - float* dout_channel = dout_batch + i * size_out_channel; - - const float* dr0 = din_channel; - const float* dr1 = dr0 + w_in; - const float* dr2 = dr1 + w_in; - const float* dr3 = dr2 + w_in; - - const float* din0_ptr = nullptr; - const float* din1_ptr = nullptr; - const float* din2_ptr = nullptr; - const float* din3_ptr = nullptr; - - float* doutr0 = nullptr; - float* doutr1 = nullptr; - - float* ptr_zero = const_cast(zero); - - for (int i = 0; i < h_in; i += 2) { - //! process top pad pad_h = 1 - din0_ptr = dr0; - din1_ptr = dr1; - din2_ptr = dr2; - din3_ptr = dr3; - - doutr0 = dout_channel; - doutr1 = dout_channel + w_out; - // unsigned int* rst_mask = rmask; - - if (i == 0) { - din0_ptr = zero_ptr; - din1_ptr = dr0; - din2_ptr = dr1; - din3_ptr = dr2; - dr0 = dr1; - dr1 = dr2; - dr2 = dr3; - dr3 = dr2 + w_in; - } else { - dr0 = dr2; - dr1 = dr3; - dr2 = dr1 + w_in; - dr3 = dr2 + w_in; - } - //! process bottom pad - if (i + 3 > h_in) { - switch (i + 3 - h_in) { - case 3: - din1_ptr = zero_ptr; - case 2: - din2_ptr = zero_ptr; - case 1: - din3_ptr = zero_ptr; - default: - break; - } - } - //! process bottom remain - if (i + 2 > h_out) { - doutr1 = write_ptr; - } - int cnt = cnt_col; - unsigned int* rmask_ptr = rmask; - unsigned int* vmask_ptr = vmask; - asm volatile( - "pld [%[din0_ptr]] @ preload data\n" - "pld [%[din1_ptr]] @ preload data\n" - "pld [%[din2_ptr]] @ preload data\n" - "pld [%[din3_ptr]] @ preload data\n" - - "vld1.32 {d16-d18}, [%[din0_ptr]]! @ load din r0\n" - "vld1.32 {d20-d22}, [%[din1_ptr]]! @ load din r1\n" - "vld1.32 {d24-d26}, [%[din2_ptr]]! @ load din r2\n" - "vld1.32 {d28-d30}, [%[din3_ptr]]! @ load din r3\n" - - "vdup.32 q4, %[bias_val] @ and \n" // q4 - // = - // vbias - "vdup.32 q5, %[bias_val] @ and \n" // q5 - // = - // vbias - - "vext.32 q6, %q[vzero], q8, #3 @ 0012\n" - "vext.32 q7, q8, q9, #1 @ 1234\n" - - // left - // r0 - "vmla.f32 q4, q8, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" - - "sub %[din0_ptr], #12 @ 1pad + 2 float data overlap\n" - "sub %[din1_ptr], #12 @ 1pad + 2 float data overlap\n" - "sub %[din2_ptr], #12 @ 1pad + 2 float data overlap\n" - "sub %[din3_ptr], #12 @ 1pad + 2 float data overlap\n" - - "vmla.f32 q4, q6, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" - - "pld [%[din0_ptr]] @ preload data\n" - "pld [%[din1_ptr]] @ preload data\n" - "pld [%[din2_ptr]] @ preload data\n" - "pld [%[din3_ptr]] @ preload data\n" - - "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 1234 * wr0[2]\n" - - "vext.32 q6, %q[vzero], q10, #3 @ 0012\n" - "vext.32 q7, q10, q11, #1 @ 1234\n" - - // r1 - "vmla.f32 q5, q10, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q10, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0\n" - "vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r0\n" - - "vmla.f32 q5, q6, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" - "vmla.f32 q4, q6, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" - - "vld1.32 {d18}, [%[din0_ptr]] @ load din r0\n" - "vld1.32 {d22}, [%[din1_ptr]] @ load din r0\n" - - "vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[2]\n" - "vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[2]\n" - - "vext.32 q6, %q[vzero], q12, #3 @ 0012\n" - "vext.32 q7, q12, q13, #1 @ 1234\n" - - // r2 - "vmla.f32 q5, q12, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q12, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r0\n" - - "vmla.f32 q5, q6, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" - "vmla.f32 q4, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]\n" - - "vld1.32 {d26}, [%[din2_ptr]] @ load din r0\n" - - "vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[2]\n" - "vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]\n" - - "vext.32 q6, %q[vzero], q14, #3 @ 0012\n" - "vext.32 q7, q14, q15, #1 @ 1234\n" - - // r3 - "vmla.f32 q5, q14, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" - "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" - - "vmla.f32 q5, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]\n" - - "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" - "vdup.32 q4, %[bias_val] @ and \n" // q4 - // = - // vbias - - "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]\n" - - "vext.32 q6, q8, q9, #1 @ 1234\n" - "vext.32 q7, q8, q9, #2 @ 2345\n" - "cmp %[cnt], #1 @ check whether has " - "mid cols\n" - - "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add " - "pointer\n" - - "vdup.32 q5, %[bias_val] @ and \n" // q5 - // = - // vbias - "blt 3f @ jump to main loop start " - "point\n" - - // mid - "1: @ right pad entry\n" - // r0 - "vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n" - - "pld [%[din0_ptr]] @ preload data\n" - "pld [%[din1_ptr]] @ preload data\n" - "pld [%[din2_ptr]] @ preload data\n" - "pld [%[din3_ptr]] @ preload data\n" - - "vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0\n" - - "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n" - - "vld1.32 {d18}, [%[din0_ptr]] @ load din r0\n" - - "vext.32 q6, q10, q11, #1 @ 1234\n" - "vext.32 q7, q10, q11, #2 @ 2345\n" - - // r1 - "vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" - "vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r0\n" - - "vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d22}, [%[din1_ptr]] @ load din r0\n" - - "vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" - - "vext.32 q6, q12, q13, #1 @ 1234\n" - "vext.32 q7, q12, q13, #2 @ 2345\n" - - // r2 - "vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" - "vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r0\n" - - "vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d26}, [%[din2_ptr]] @ load din r0\n" - - "vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n" - - "vext.32 q6, q14, q15, #1 @ 1234\n" - "vext.32 q7, q14, q15, #2 @ 2345\n" - - // r3 - "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" - - "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" - "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" - - "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" - "vdup.32 q4, %[bias_val] @ and \n" // q4 - // = - // vbias - - "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" - - "vext.32 q6, q8, q9, #1 @ 1234\n" - "vext.32 q7, q8, q9, #2 @ 2345\n" - - "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add " - "pointer\n" - - "subs %[cnt], #1 @ loop count minus 1\n" - - "vdup.32 q5, %[bias_val] @ and \n" // q4 - // = - // vbias - - "bne 1b @ jump to main loop start " - "point\n" - - // right - "3: @ right pad entry\n" - "vld1.32 {d19}, [%[vmask]]! @ load din r0\n" - "vld1.32 {d23}, [%[vmask]]! @ load din r0\n" - - "vld1.32 {d27}, [%[vmask]]! @ load din r0\n" - "vld1.32 {d31}, [%[vmask]]! @ load din r0\n" - - "vbif d16, %e[vzero], d19 @ bit select, deal with " - "right pad\n" - "vbif d17, %e[vzero], d23 @ bit select, deal with " - "right pad\n" - "vbif d18, %e[vzero], d27 @ bit select, deal with " - "right pad\n" - - "vbif d20, %e[vzero], d19 @ bit select, deal with " - "right pad\n" - "vbif d21, %e[vzero], d23 @ bit select, deal with " - "right pad\n" - "vbif d22, %e[vzero], d27 @ bit select, deal with " - "right pad\n" - - "vext.32 q6, q8, q9, #1 @ 1234\n" - "vext.32 q7, q8, q9, #2 @ 2345\n" - - // r0 - "vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n" - - "vbif d24, %e[vzero], d19 @ bit select, deal with " - "right pad\n" - "vbif d25, %e[vzero], d23 @ bit select, deal with " - "right pad\n" - "vbif d26, %e[vzero], d27 @ bit select, deal with " - "right pad\n" - - "vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" - - "vbif d28, %e[vzero], d19 @ bit select, deal with " - "right pad\n" - "vbif d29, %e[vzero], d23 @ bit select, deal with " - "right pad\n" - "vbif d30, %e[vzero], d27 @ bit select, deal with " - "right pad\n" - - "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n" - - "vext.32 q6, q10, q11, #1 @ 1234\n" - "vext.32 q7, q10, q11, #2 @ 2345\n" - - // r1 - "vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" - "vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d19}, [%[rmask]]! @ load din r0\n" - "vld1.32 {d23}, [%[rmask]]! @ load din r0\n" - - "vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d16-d17}, [%[dout_ptr1]] @ load din r0\n" - "vld1.32 {d20-d21}, [%[dout_ptr2]] @ load din r0\n" - - "vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" - - "vext.32 q6, q12, q13, #1 @ 1234\n" - "vext.32 q7, q12, q13, #2 @ 2345\n" - - // r2 - "vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" - "vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n" - - "vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" - - "vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n" - - "vext.32 q6, q14, q15, #1 @ 1234\n" - "vext.32 q7, q14, q15, #2 @ 2345\n" - - // r3 - "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" - - "vbif d8, d16, d19 @ bit select, deal with right pad\n" - "vbif d9, d17, d23 @ bit select, deal with right pad\n" - - "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" - - "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" - - "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" - - "vbif d10, d20, d19 @ bit select, deal with right " - "pad\n" - "vbif d11, d21, d23 @ bit select, deal with right " - "pad\n" - - "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add " - "pointer\n" - - : [dout_ptr1] "+r"(doutr0), - [dout_ptr2] "+r"(doutr1), - [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr), - [din3_ptr] "+r"(din3_ptr), - [cnt] "+r"(cnt), - [rmask] "+r"(rmask_ptr), - [vmask] "+r"(vmask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias_val] "r"(bias_val), - [vzero] "w"(vzero) - : "cc", - "memory", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - dout_channel += 2 * w_out; - } //! end of processing mid rows - } -#endif - } -} - -/** - * \brief depthwise convolution kernel 3x3, stride 2 - */ -// w_in > 7 -void conv_depthwise_3x3s2p1_bias(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx) { - int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; - int out_pad_idx[4] = {0, 1, 2, 3}; - int size_pad_bottom = h_out * 2 - h_in; - - int cnt_col = (w_out >> 2) - 2; - int size_right_remain = w_in - (7 + cnt_col * 8); - if (size_right_remain >= 9) { - cnt_col++; - size_right_remain -= 8; - } - int cnt_remain = (size_right_remain == 8) ? 4 : (w_out % 4); // - - int size_right_pad = w_out * 2 - w_in; - - uint32x4_t vmask_rp1 = vcgtq_s32(vdupq_n_s32(size_right_remain), - vld1q_s32(right_pad_idx)); // 0 2 4 6 - uint32x4_t vmask_rp2 = vcgtq_s32(vdupq_n_s32(size_right_remain), - vld1q_s32(right_pad_idx + 4)); // 1 3 5 7 - uint32x4_t wmask = - vcgtq_s32(vdupq_n_s32(cnt_remain), vld1q_s32(out_pad_idx)); // 0 1 2 3 - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - - float* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - float* write_ptr = zero_ptr + w_in; - - unsigned int dmask[12]; - - vst1q_u32(dmask, vmask_rp1); - vst1q_u32(dmask + 4, vmask_rp2); - vst1q_u32(dmask + 8, wmask); - - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * ch_in * size_in_channel; - float* dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int i = 0; i < ch_in; ++i) { - const float* din_channel = din_batch + i * size_in_channel; - float* dout_channel = dout_batch + i * size_out_channel; - - const float* weight_ptr = weights + i * 9; - float32x4_t wr0 = vld1q_f32(weight_ptr); - float32x4_t wr1 = vld1q_f32(weight_ptr + 3); - float32x4_t wr2 = vld1q_f32(weight_ptr + 6); - - float32x4_t vzero = vdupq_n_f32(0.f); - - float32x4_t wbias; - float bias_c = 0.f; - if (flag_bias) { - wbias = vdupq_n_f32(bias[i]); - bias_c = bias[i]; - } else { - wbias = vdupq_n_f32(0.f); - } - - const float* dr0 = din_channel; - const float* dr1 = dr0 + w_in; - const float* dr2 = dr1 + w_in; - const float* dr3 = dr2 + w_in; - const float* dr4 = dr3 + w_in; - - const float* din0_ptr = dr0; - const float* din1_ptr = dr1; - const float* din2_ptr = dr2; - const float* din3_ptr = dr3; - const float* din4_ptr = dr4; - - float* doutr0 = dout_channel; - float* doutr0_ptr = nullptr; - float* doutr1_ptr = nullptr; - -#ifdef __aarch64__ - for (int i = 0; i < h_in; i += 4) { - din0_ptr = dr0; - din1_ptr = dr1; - din2_ptr = dr2; - din3_ptr = dr3; - din4_ptr = dr4; - - doutr0_ptr = doutr0; - doutr1_ptr = doutr0 + w_out; - - if (i == 0) { - din0_ptr = zero_ptr; - din1_ptr = dr0; - din2_ptr = dr1; - din3_ptr = dr2; - din4_ptr = dr3; - dr0 = dr3; - dr1 = dr4; - } else { - dr0 = dr4; - dr1 = dr0 + w_in; - } - dr2 = dr1 + w_in; - dr3 = dr2 + w_in; - dr4 = dr3 + w_in; - - //! process bottom pad - if (i + 4 > h_in) { - switch (i + 4 - h_in) { - case 4: - din1_ptr = zero_ptr; - case 3: - din2_ptr = zero_ptr; - case 2: - din3_ptr = zero_ptr; - case 1: - din4_ptr = zero_ptr; - default: - break; - } - } - //! process output pad - if (i / 2 + 2 > h_out) { - doutr1_ptr = write_ptr; - } - int cnt = cnt_col; - asm volatile( - // top - // Load up 12 elements (3 vectors) from each of 8 sources. - "0: \n" - "prfm pldl1keep, [%[inptr0]] \n" - "prfm pldl1keep, [%[inptr1]] \n" - "prfm pldl1keep, [%[inptr2]] \n" - "prfm pldl1keep, [%[inptr3]] \n" - "prfm pldl1keep, [%[inptr4]] \n" - "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" // v0={0,2,4,6} - // v1={1,3,5,7} - "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" - "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" - "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" - "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" - - "and v16.16b, %[vbias].16b, %[vbias].16b \n" // v10 = vbias - "and v17.16b, %[vbias].16b, %[vbias].16b \n" // v16 = vbias - - "ext v10.16b, %[vzero].16b, v1.16b, #12 \n" // v10 = {0,1,3,5} - - // r0 - "fmul v11.4s, v0.4s, %[w0].s[1] \n" // {0,2,4,6} * w01 - "fmul v12.4s, v1.4s, %[w0].s[2] \n" // {1,3,5,7} * w02 - "fmla v16.4s, v10.4s, %[w0].s[0] \n" // {0,1,3,5} * w00 - - "ext v10.16b, %[vzero].16b, v3.16b, #12 \n" // v10 = {0,1,3,5} - - "sub %[inptr0], %[inptr0], #4 \n" - "sub %[inptr1], %[inptr1], #4 \n" - - // r1 - "fmla v11.4s, v2.4s, %[w1].s[1] \n" // {0,2,4,6} * w01 - "fmla v12.4s, v3.4s, %[w1].s[2] \n" // {1,3,5,7} * w02 - "fmla v16.4s, v10.4s, %[w1].s[0] \n" // {0,1,3,5} * w00 - - "ext v10.16b, %[vzero].16b, v5.16b, #12 \n" // v10 = {0,1,3,5} - - "sub %[inptr2], %[inptr2], #4 \n" - "sub %[inptr3], %[inptr3], #4 \n" - - // r2 - "fmul v13.4s, v4.4s, %[w0].s[1] \n" // {0,2,4,6} * w01 - "fmla v11.4s, v4.4s, %[w2].s[1] \n" // {0,2,4,6} * w01 - - "fmul v14.4s, v5.4s, %[w0].s[2] \n" // {1,3,5,7} * w02 - "fmla v12.4s, v5.4s, %[w2].s[2] \n" // {1,3,5,7} * w02 - - "fmla v17.4s, v10.4s, %[w0].s[0] \n" // {0,1,3,5} * w00 - "fmla v16.4s, v10.4s, %[w2].s[0] \n" // {0,1,3,5} * w00 - - "ext v10.16b, %[vzero].16b, v7.16b, #12 \n" // v10 = {0,1,3,5} - - "sub %[inptr4], %[inptr4], #4 \n" - - // r3 - "fmla v13.4s, v6.4s, %[w1].s[1] \n" // {0,2,4,6} * w01 - "fmla v14.4s, v7.4s, %[w1].s[2] \n" // {1,3,5,7} * w02 - "fmla v17.4s, v10.4s, %[w1].s[0] \n" // {0,1,3,5} * w00 - - "ext v10.16b, %[vzero].16b, v9.16b, #12 \n" // v10 = {0,1,3,5} - "fadd v16.4s, v16.4s, v11.4s \n" - "fadd v16.4s, v16.4s, v12.4s \n" - - // r4 - "fmla v13.4s, v8.4s, %[w2].s[1] \n" // {0,2,4,6} * w01 - "fmla v14.4s, v9.4s, %[w2].s[2] \n" // {1,3,5,7} * w02 - "fmla v17.4s, v10.4s, %[w2].s[0] \n" // {0,1,3,5} * w00 - - "st1 {v16.4s}, [%[outptr0]], #16 \n" - - "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" // v0={0,2,4,6} - // v1={1,3,5,7} - "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" - "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" - - "fadd v17.4s, v17.4s, v13.4s \n" - - "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" - "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" - "ld1 {v15.4s}, [%[inptr0]] \n" - "and v16.16b, %[vbias].16b, %[vbias].16b \n" // v10 = vbias - - "fadd v17.4s, v17.4s, v14.4s \n" - - "ld1 {v18.4s}, [%[inptr1]] \n" - "ld1 {v19.4s}, [%[inptr2]] \n" - - "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} - - "ld1 {v20.4s}, [%[inptr3]] \n" - "ld1 {v21.4s}, [%[inptr4]] \n" - - "st1 {v17.4s}, [%[outptr1]], #16 \n" - - "cmp %[cnt], #1 \n" - - "and v17.16b, %[vbias].16b, %[vbias].16b \n" // v16 = vbias - - "blt 1f \n" - // mid - "2: \n" - // r0 - "fmul v11.4s, v0.4s, %[w0].s[0] \n" // {0,2,4,6} * w00 - "fmul v12.4s, v1.4s, %[w0].s[1] \n" // {1,3,5,7} * w01 - "fmla v16.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v2.16b, v18.16b, #4 \n" // v10 = {2,4,6,8} - "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" // v0={0,2,4,6} - // v1={1,3,5,7} - - // r1 - "fmla v11.4s, v2.4s, %[w1].s[0] \n" // {0,2,4,6} * w00 - "fmla v12.4s, v3.4s, %[w1].s[1] \n" // {1,3,5,7} * w01 - "fmla v16.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v4.16b, v19.16b, #4 \n" // v10 = {2,4,6,8} - - "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" - - // r2 - "fmul v13.4s, v4.4s, %[w0].s[0] \n" // {0,2,4,6} * w00 - "fmla v11.4s, v4.4s, %[w2].s[0] \n" // {0,2,4,6} * w00 - - "fmul v14.4s, v5.4s, %[w0].s[1] \n" // {1,3,5,7} * w01 - "fmla v12.4s, v5.4s, %[w2].s[1] \n" // {1,3,5,7} * w01 - - "fmla v17.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02 - "fmla v16.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v6.16b, v20.16b, #4 \n" // v10 = {2,4,6,8} - - "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" - - // r3 - "fmla v13.4s, v6.4s, %[w1].s[0] \n" // {0,2,4,6} * w00 - "fmla v14.4s, v7.4s, %[w1].s[1] \n" // {1,3,5,7} * w01 - "fmla v17.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v8.16b, v21.16b, #4 \n" // v10 = {2,4,6,8} - - "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" - - "fadd v16.4s, v16.4s, v11.4s \n" - "fadd v16.4s, v16.4s, v12.4s \n" - - // r4 - "fmla v13.4s, v8.4s, %[w2].s[0] \n" // {0,2,4,6} * w00 - "fmla v14.4s, v9.4s, %[w2].s[1] \n" // {1,3,5,7} * w01 - "fmla v17.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02 - - "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" - "ld1 {v15.4s}, [%[inptr0]] \n" - "ld1 {v18.4s}, [%[inptr1]] \n" - "st1 {v16.4s}, [%[outptr0]], #16 \n" - - "fadd v17.4s, v17.4s, v13.4s \n" - - "ld1 {v19.4s}, [%[inptr2]] \n" - "ld1 {v20.4s}, [%[inptr3]] \n" - "ld1 {v21.4s}, [%[inptr4]] \n" - - "fadd v17.4s, v17.4s, v14.4s \n" - - "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} - "and v16.16b, %[vbias].16b, %[vbias].16b \n" // v10 = vbias - "subs %[cnt], %[cnt], #1 \n" - - "st1 {v17.4s}, [%[outptr1]], #16 \n" - - "and v17.16b, %[vbias].16b, %[vbias].16b \n" // v16 = vbias - - "bne 2b \n" - - // right - "1: \n" - "cmp %[remain], #1 \n" - "blt 4f \n" - "3: \n" - "bif v0.16b, %[vzero].16b, %[mask1].16b \n" // pipei - "bif v1.16b, %[vzero].16b, %[mask2].16b \n" // pipei - - "bif v2.16b, %[vzero].16b, %[mask1].16b \n" // pipei - "bif v3.16b, %[vzero].16b, %[mask2].16b \n" // pipei - - "bif v4.16b, %[vzero].16b, %[mask1].16b \n" // pipei - "bif v5.16b, %[vzero].16b, %[mask2].16b \n" // pipei - - "ext v10.16b, v0.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} - - "bif v6.16b, %[vzero].16b, %[mask1].16b \n" // pipei - "bif v7.16b, %[vzero].16b, %[mask2].16b \n" // pipei - - // r0 - "fmul v11.4s, v0.4s, %[w0].s[0] \n" // {0,2,4,6} * w00 - "fmul v12.4s, v1.4s, %[w0].s[1] \n" // {1,3,5,7} * w01 - "fmla v16.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v2.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} - "bif v8.16b, %[vzero].16b, %[mask1].16b \n" // pipei - "bif v9.16b, %[vzero].16b, %[mask2].16b \n" // pipei - - // r1 - "fmla v11.4s, v2.4s, %[w1].s[0] \n" // {0,2,4,6} * w00 - "fmla v12.4s, v3.4s, %[w1].s[1] \n" // {1,3,5,7} * w01 - "fmla v16.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v4.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} - - // r2 - "fmul v13.4s, v4.4s, %[w0].s[0] \n" // {0,2,4,6} * w00 - "fmla v11.4s, v4.4s, %[w2].s[0] \n" // {0,2,4,6} * w00 - - "fmul v14.4s, v5.4s, %[w0].s[1] \n" // {1,3,5,7} * w01 - "fmla v12.4s, v5.4s, %[w2].s[1] \n" // {1,3,5,7} * w01 - - "fmla v17.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02 - "fmla v16.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v6.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} - - // r3 - "fmla v13.4s, v6.4s, %[w1].s[0] \n" // {0,2,4,6} * w00 - "fmla v14.4s, v7.4s, %[w1].s[1] \n" // {1,3,5,7} * w01 - "fmla v17.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v8.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} - "ld1 {v0.4s}, [%[outptr0]] \n" - - "fadd v16.4s, v16.4s, v11.4s \n" - "fadd v16.4s, v16.4s, v12.4s \n" - "ld1 {v1.4s}, [%[outptr1]] \n" - - // r4 - "fmla v13.4s, v8.4s, %[w2].s[0] \n" // {0,2,4,6} * w00 - "fmla v14.4s, v9.4s, %[w2].s[1] \n" // {1,3,5,7} * w01 - "fmla v17.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02 - - "bif v16.16b, v0.16b, %[wmask].16b \n" // pipei - - "fadd v17.4s, v17.4s, v13.4s \n" - - "st1 {v16.4s}, [%[outptr0]], #16 \n" - - "fadd v17.4s, v17.4s, v14.4s \n" - - "bif v17.16b, v1.16b, %[wmask].16b \n" // pipei - - "st1 {v17.4s}, [%[outptr1]], #16 \n" - "4: \n" - : [inptr0] "+r"(din0_ptr), - [inptr1] "+r"(din1_ptr), - [inptr2] "+r"(din2_ptr), - [inptr3] "+r"(din3_ptr), - [inptr4] "+r"(din4_ptr), - [outptr0] "+r"(doutr0_ptr), - [outptr1] "+r"(doutr1_ptr), - [cnt] "+r"(cnt) - : [vzero] "w"(vzero), - [w0] "w"(wr0), - [w1] "w"(wr1), - [w2] "w"(wr2), - [remain] "r"(cnt_remain), - [mask1] "w"(vmask_rp1), - [mask2] "w"(vmask_rp2), - [wmask] "w"(wmask), - [vbias] "w"(wbias) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21"); - doutr0 = doutr0 + 2 * w_out; - } -#else - for (int i = 0; i < h_in; i += 2) { - din0_ptr = dr0; - din1_ptr = dr1; - din2_ptr = dr2; - - doutr0_ptr = doutr0; - - if (i == 0) { - din0_ptr = zero_ptr; - din1_ptr = dr0; - din2_ptr = dr1; - dr0 = dr1; - dr1 = dr2; - dr2 = dr1 + w_in; - } else { - dr0 = dr2; - dr1 = dr0 + w_in; - dr2 = dr1 + w_in; - } - - //! process bottom pad - if (i + 2 > h_in) { - switch (i + 2 - h_in) { - case 2: - din1_ptr = zero_ptr; - case 1: - din2_ptr = zero_ptr; - default: - break; - } - } - int cnt = cnt_col; - unsigned int* mask_ptr = dmask; - asm volatile( - // top - // Load up 12 elements (3 vectors) from each of 8 sources. - "0: \n" - "vmov.u32 q9, #0 \n" - "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r1\n" // v11={0,2,4,6} v12={1,3,5,7}, q10, q11 - "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" // v11={0,2,4,6} v12={1,3,5,7}, q12, q13 - "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r1\n" // v13={0,2,4,6} v14={1,3,5,7}, q14, q15 - "pld [%[din0_ptr]] @ preload data\n" - "pld [%[din1_ptr]] @ preload data\n" - "pld [%[din2_ptr]] @ preload data\n" - - "vdup.32 q3, %[bias] @ and \n" // q10 = - // vbias - - "vext.32 q6, q9, q11, #3 @ shift right 1 " - "data\n" // q2 = {0,1,3,5} - "vext.32 q7, q9, q13, #3 @ shift right 1 " - "data\n" // q6 = {0,1,3,5} - "vext.32 q8, q9, q15, #3 @ shift right 1 " - "data\n" // q6 = {0,1,3,5} - - "vmul.f32 q4, q10, %e[wr0][1] @ mul weight 1, " - "out0\n" // q11 * w01 - "vmul.f32 q5, q11, %f[wr0][0] @ mul weight 1, " - "out0\n" // q12 * w02 - "vmla.f32 q3, q6, %e[wr0][0] @ mul weight 1, " - "out0\n" // q6 * w00 - - "sub %[din0_ptr], #4 @ inpitr0 - 1\n" - "sub %[din1_ptr], #4 @ inpitr1 - 1\n" - "sub %[din2_ptr], #4 @ inpitr2 - 1\n" - - "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" // v0={0,2,4,6} v1={1,3,5,7} - - "vmla.f32 q4, q12, %e[wr1][1] @ mul weight 1, " - "out0\n" // q11 * w01 - "vmla.f32 q5, q13, %f[wr1][0] @ mul weight 1, " - "out0\n" // q12 * w02 - "vmla.f32 q3, q7, %e[wr1][0] @ mul weight 1, " - "out0\n" // q6 * w00 - - "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" // v4={0,2,4,6} v5={1,3,5,7} - - "vmla.f32 q4, q14, %e[wr2][1] @ mul weight 1, " - "out1\n" // q0 * w01 - "vmla.f32 q5, q15, %f[wr2][0] @ mul weight 1, " - "out1\n" // q1 * w02 - "vmla.f32 q3, q8, %e[wr2][0] @ mul weight 1, " - "out1\n" // q2 * w00 - - "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r1\n" // v4={0,2,4,6} v5={1,3,5,7} - - "vadd.f32 q3, q3, q4 @ add \n" - "vadd.f32 q3, q3, q5 @ add \n" - - "vst1.32 {d6-d7}, [%[outptr]]! \n" - "cmp %[cnt], #1 \n" - "blt 1f \n" - // mid - "2: \n" - "vld1.32 {d16}, [%[din0_ptr]] @ load din r0\n" // q2={8,10,12,14} - "vdup.32 q3, %[bias] @ and \n" // q10 = - // vbias - "vext.32 q6, q10, q8, #1 @ shift left 1 \n" // q6 = {2,4,6,8} - "vld1.32 {d16}, [%[din1_ptr]] @ load din r1\n" // q2={8,10,12,14} - - "vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, " - "out0\n" // q0 * w00 - "vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, " - "out0\n" // q1 * w01 - "vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, " - "out0\n" // q6 * w02 - - "vext.32 q7, q12, q8, #1 @ shift left 1 \n" // q6 = {2,4,6,8} - "vld1.32 {d16}, [%[din2_ptr]] @ load din r1\n" // q2={8,10,12,14} - - "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" // v0={0,2,4,6} v1={1,3,5,7} - - "vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, " - "out0\n" // q0 * w00 - "vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, " - "out0\n" // q1 * w01 - "vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, " - "out0\n" // q6 * w02 - - "vext.32 q6, q14, q8, #1 @ shift left 1 \n" // q6 = {2,4,6,8} - - "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" // v0={0,2,4,6} v1={1,3,5,7} - - "vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, " - "out0\n" // q0 * w00 - "vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, " - "out0\n" // q1 * w01 - "vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, " - "out0\n" // q6 * w02 - - "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2\n" // v4={0,2,4,6} v5={1,3,5,7} - - "vadd.f32 q3, q3, q4 @ add \n" - "vadd.f32 q3, q3, q5 @ add \n" - - "subs %[cnt], #1 \n" - - "vst1.32 {d6-d7}, [%[outptr]]! \n" - "bne 2b \n" - - // right - "1: \n" - "cmp %[remain], #1 \n" - "blt 3f \n" - - "vld1.f32 {d12-d15}, [%[mask_ptr]]! @ load mask\n" - "vdup.32 q3, %[bias] @ and \n" // q10 = - // vbias - - "vbif q10, q9, q6 @ bit select, deal " - "with right pad\n" - "vbif q11, q9, q7 @ bit select, deal " - "with right pad\n" - "vbif q12, q9, q6 @ bit select, deal " - "with right pad\n" - "vbif q13, q9, q7 @ bit select, deal " - "with right pad\n" - "vbif q14, q9, q6 @ bit select, deal " - "with right pad\n" - "vbif q15, q9, q7 @ bit select, deal " - "with right pad\n" - - "vext.32 q6, q10, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,8} - "vext.32 q7, q12, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,8} - - "vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, " - "out0\n" // q0 * w00 - "vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, " - "out0\n" // q1 * w01 - "vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, " - "out0\n" // q6 * w02 - - "vext.32 q6, q14, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,8} - "vld1.f32 {d20-d21}, [%[outptr]] @ load output\n" - - "vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, " - "out0\n" // q0 * w00 - "vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, " - "out0\n" // q1 * w01 - "vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, " - "out0\n" // q6 * w02 - - "vld1.f32 {d22-d23}, [%[mask_ptr]] @ load mask\n" - - "vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, " - "out0\n" // q0 * w00 - "vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, " - "out0\n" // q1 * w01 - "vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, " - "out0\n" // q6 * w02 - - "vadd.f32 q3, q3, q4 @ add \n" - "vadd.f32 q3, q3, q5 @ add \n" - - "vbif.f32 q3, q10, q11 @ write mask\n" - - "vst1.32 {d6-d7}, [%[outptr]]! \n" - "3: \n" - : [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr), - [outptr] "+r"(doutr0_ptr), - [cnt] "+r"(cnt), - [mask_ptr] "+r"(mask_ptr) - : [remain] "r"(cnt_remain), - [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias] "r"(bias_c) - : "cc", - "memory", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - - doutr0 = doutr0 + w_out; - } -#endif - } - } -} - -// 4line -void conv_depthwise_3x3s1p1_bias_relu(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx) { - //! pad is done implicit - const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; - //! for 4x6 convolution window - const unsigned int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0}; - - // printf("conv3x3_dw start \n"); - - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - int w_stride = 9; - - int tile_w = (w_in + 3) >> 2; - int tile_h = (h_in + 3) >> 2; - int cnt_col = tile_w - 2; - float* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - float* write_ptr = zero_ptr + w_in; - - unsigned int size_pad_right = (unsigned int)(1 + (tile_w << 2) - w_in); - int size_pad_bottom = (unsigned int)(1 + (tile_h << 2) - h_in); - - uint32x4_t vmask_rp1 = - vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right)); - uint32x4_t vmask_rp2 = - vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right)); - uint32x4_t vmask_result = - vcgtq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right)); - - unsigned int vmask[8]; - vst1q_u32(vmask, vmask_rp1); - vst1q_u32(vmask + 4, vmask_rp2); - - unsigned int rmask[4]; - vst1q_u32(rmask, vmask_result); - - float32x4_t vzero = vdupq_n_f32(0.f); - - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * ch_in * size_in_channel; - float* dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for -#ifdef __aarch64__ - for (int c = 0; c < ch_in; c++) { - float* dout_ptr = dout_batch + c * size_out_channel; - - const float* din_ch_ptr = din_batch + c * size_in_channel; - - float bias_val = flag_bias ? bias[c] : 0.f; - float vbias[4] = {bias_val, bias_val, bias_val, bias_val}; - - const float* wei_ptr = weights + c * w_stride; - - float32x4_t wr0 = vld1q_f32(wei_ptr); - float32x4_t wr1 = vld1q_f32(wei_ptr + 3); - float32x4_t wr2 = vld1q_f32(wei_ptr + 6); - - float* doutr0 = dout_ptr; - float* doutr1 = doutr0 + w_out; - float* doutr2 = doutr1 + w_out; - float* doutr3 = doutr2 + w_out; - - const float* dr0 = din_ch_ptr; - const float* dr1 = dr0 + w_in; - const float* dr2 = dr1 + w_in; - const float* dr3 = dr2 + w_in; - const float* dr4 = dr3 + w_in; - const float* dr5 = dr4 + w_in; - - const float* din_ptr0 = dr0; - const float* din_ptr1 = dr1; - const float* din_ptr2 = dr2; - const float* din_ptr3 = dr3; - const float* din_ptr4 = dr4; - const float* din_ptr5 = dr5; - - for (int i = 0; i < h_in; i += 4) { - //! process top pad pad_h = 1 - din_ptr0 = dr0; - din_ptr1 = dr1; - din_ptr2 = dr2; - din_ptr3 = dr3; - din_ptr4 = dr4; - din_ptr5 = dr5; - - doutr0 = dout_ptr; - doutr1 = doutr0 + w_out; - doutr2 = doutr1 + w_out; - doutr3 = doutr2 + w_out; - if (i == 0) { - din_ptr0 = zero_ptr; - din_ptr1 = dr0; - din_ptr2 = dr1; - din_ptr3 = dr2; - din_ptr4 = dr3; - din_ptr5 = dr4; - dr0 = dr3; - dr1 = dr4; - dr2 = dr5; - } else { - dr0 = dr4; - dr1 = dr5; - dr2 = dr1 + w_in; - } - dr3 = dr2 + w_in; - dr4 = dr3 + w_in; - dr5 = dr4 + w_in; - - //! process bottom pad - if (i + 5 > h_in) { - switch (i + 5 - h_in) { - case 5: - din_ptr1 = zero_ptr; - case 4: - din_ptr2 = zero_ptr; - case 3: - din_ptr3 = zero_ptr; - case 2: - din_ptr4 = zero_ptr; - case 1: - din_ptr5 = zero_ptr; - default: - break; - } - } - //! process bottom remain - if (i + 4 > h_out) { - switch (i + 4 - h_out) { - case 3: - doutr1 = write_ptr; - case 2: - doutr2 = write_ptr; - case 1: - doutr3 = write_ptr; - default: - break; - } - } - - int cnt = cnt_col; - asm volatile( - "PRFM PLDL1KEEP, [%[din_ptr0]] \n" - "PRFM PLDL1KEEP, [%[din_ptr1]] \n" - "PRFM PLDL1KEEP, [%[din_ptr2]] \n" - "PRFM PLDL1KEEP, [%[din_ptr3]] \n" - "PRFM PLDL1KEEP, [%[din_ptr4]] \n" - "PRFM PLDL1KEEP, [%[din_ptr5]] \n" - "movi v21.4s, #0x0\n" /* out0 = 0 */ - - "ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ - - "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - - "ext v16.16b, %[vzero].16b, v0.16b, #12 \n" /* v16 = 00123*/ - "ext v17.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */ - - // left - // r0 - "fmla v12.4s, v0.4s, %[w0].s[1]\n" /* outr00 += din0_0123 * - w0[1]*/ - - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "sub %[din_ptr0], %[din_ptr0], #4 \n" /* din_ptr0-- */ - "sub %[din_ptr1], %[din_ptr1], #4 \n" /* din_ptr0-- */ - - "fmla v12.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din0_0012 * - w0[0]*/ - - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ - "sub %[din_ptr2], %[din_ptr2], #4 \n" /* din_ptr0-- */ - "sub %[din_ptr3], %[din_ptr3], #4 \n" /* din_ptr0-- */ - - "fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_1234 * - w0[2]*/ - - "ext v16.16b, %[vzero].16b, v2.16b, #12 \n" /* v16 = 00123*/ - "ext v17.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234 */ - - // r1 - "fmla v13.4s , v2.4s, %[w0].s[1]\n" /* outr00 += din1_0123 * - w0[1]*/ - "fmla v12.4s , v2.4s, %[w1].s[1]\n" /* outr00 += din1_0123 * - w1[1]*/ - "sub %[din_ptr4], %[din_ptr4], #4 \n" /* din_ptr0-- */ - "sub %[din_ptr5], %[din_ptr5], #4 \n" /* din_ptr0-- */ - - "fmla v13.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din1_0123 * - w0[1]*/ - "fmla v12.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din1_0123 * - w1[1]*/ - - "ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din1_0123 * - w0[1]*/ - "fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * - w1[1]*/ - - "ext v16.16b, %[vzero].16b, v4.16b, #12 \n" /* v16 = 00123*/ - "ext v17.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234 */ - - // r2 - "fmla v14.4s , v4.4s, %[w0].s[1]\n" /* outr00 += din2_0123 * - w0[1]*/ - "fmla v13.4s , v4.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * - w1[1]*/ - "fmla v12.4s , v4.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * - w2[1]*/ - - "ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v14.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din2_0123 * - w0[1]*/ - "fmla v13.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * - w0[1]*/ - "fmla v12.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * - w1[1]*/ - - "ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din1_0123 * - w0[1]*/ - "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * - w0[1]*/ - "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * - w1[1]*/ - - "ext v16.16b, %[vzero].16b, v6.16b, #12 \n" /* v16 = 00123*/ - "ext v17.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234 */ - - // r3 - "fmla v15.4s , v6.4s, %[w0].s[1]\n" /*outr00 += din2_0123 * - w0[1]*/ - "fmla v14.4s , v6.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * - w1[1]*/ - "fmla v13.4s , v6.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * - w2[1]*/ - - "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v15.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din2_0123 * - w0[1]*/ - "fmla v14.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * - w0[1]*/ - "fmla v13.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * - w1[1]*/ - - "ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din1_0123 * - w0[1]*/ - "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * - w0[1]*/ - "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * - w1[1]*/ - - "ext v16.16b, %[vzero].16b, v8.16b, #12 \n" /* v16 = 00123*/ - "ext v17.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234 */ - - // r4 - "fmla v15.4s , v8.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * - w1[1]*/ - "fmla v14.4s , v8.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * - w2[1]*/ - - "fmax v12.4s, v12.4s, %[vzero].4s \n" /*relu*/ - "fmax v13.4s, v13.4s, %[vzero].4s \n" /*relu*/ - - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v15.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * - w0[1]*/ - "fmla v14.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * - w1[1]*/ - - "st1 {v12.4s}, [%[doutr0]], #16 \n" /* vst1q_f32() */ - "st1 {v13.4s}, [%[doutr1]], #16 \n" /* vst1q_f32() */ - - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * - w0[1]*/ - "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * - w1[1]*/ - - "ext v16.16b, %[vzero].16b, v10.16b, #12 \n" /* v16 = 00123*/ - "ext v17.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234 */ - "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - - // r5 - "fmla v15.4s , v10.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * - w1[1]*/ - - "fmax v14.4s, v14.4s, %[vzero].4s \n" /*relu*/ - - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v15.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * - w0[1]*/ - - "st1 {v14.4s}, [%[doutr2]], #16 \n" /* vst1q_f32() */ - - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * - w0[1]*/ - - "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ - - "fmax v15.4s, v15.4s, %[vzero].4s \n" /*relu*/ - - "st1 {v15.4s}, [%[doutr3]], #16 \n" /* vst1q_f32() */ - "cmp %[cnt], #1 \n" - "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - - "blt 3f \n" - // mid - "1: \n" - // r0 - "fmla v12.4s , v0.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v12.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v2.16b, v3.16b, #8 \n" /* v16 = 2345 */ - - // r1 - "fmla v13.4s , v2.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v12.4s , v2.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v13.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v12.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v4.16b, v5.16b, #8 \n" /* v16 = 2345 */ - - // r2 - "fmla v14.4s , v4.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v13.4s , v4.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v12.4s , v4.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v14.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v13.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v12.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */ - - // r3 - "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "fmax v12.4s, v12.4s, %[vzero].4s \n" /*relu*/ - - "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "st1 {v12.4s}, [%[doutr0]], #16 \n" - - "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - - "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ - - // r3 - "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "fmax v13.4s, v13.4s, %[vzero].4s \n" /*relu*/ - - "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "st1 {v13.4s}, [%[doutr1]], #16 \n" - - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - - "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ - - // r3 - "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "fmax v14.4s, v14.4s, %[vzero].4s \n" /*relu*/ - - "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "st1 {v14.4s}, [%[doutr2]], #16 \n" - - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - - "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ - - "subs %[cnt], %[cnt], #1 \n" - - "fmax v15.4s, v15.4s, %[vzero].4s \n" /*relu*/ - - "st1 {v15.4s}, [%[doutr3]], #16 \n" - "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - - "bne 1b \n" - - // right - "3: \n" - "ld1 {v18.4s, v19.4s}, [%[vmask]] \n" - "ld1 {v22.4s}, [%[doutr0]] \n" - "ld1 {v23.4s}, [%[doutr1]] \n" - "ld1 {v24.4s}, [%[doutr2]] \n" - "ld1 {v25.4s}, [%[doutr3]] \n" - - "bif v0.16b, %[vzero].16b, v18.16b \n" - "bif v1.16b, %[vzero].16b, v19.16b \n" - "bif v2.16b, %[vzero].16b, v18.16b \n" - "bif v3.16b, %[vzero].16b, v19.16b \n" - - "bif v4.16b, %[vzero].16b, v18.16b \n" - "bif v5.16b, %[vzero].16b, v19.16b \n" - "bif v6.16b, %[vzero].16b, v18.16b \n" - "bif v7.16b, %[vzero].16b, v19.16b \n" - - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ - - // r0 - "fmla v12.4s, v0.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "bif v8.16b, %[vzero].16b, v18.16b \n" - "bif v9.16b, %[vzero].16b, v19.16b \n" - "bif v10.16b, %[vzero].16b, v18.16b \n" - "bif v11.16b, %[vzero].16b, v19.16b \n" - - "fmla v12.4s, v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "ld1 {v18.4s}, [%[rmask]] \n" - - "fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v2.16b, v3.16b, #8 \n" /* v16 = 2345 */ - - // r1 - "fmla v13.4s , v2.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v12.4s , v2.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "fmla v13.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v12.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v4.16b, v5.16b, #8 \n" /* v16 = 2345 */ - - // r2 - "fmla v14.4s , v4.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v13.4s , v4.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v12.4s , v4.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "fmla v14.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v13.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v12.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */ - - // r3 - "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "fmax v12.4s, v12.4s, %[vzero].4s \n" /*relu*/ - - "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "bif v12.16b, v22.16b, v18.16b \n" - - "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ - - // r3 - "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "st1 {v12.4s}, [%[doutr0]], #16 \n" - "fmax v13.4s, v13.4s, %[vzero].4s \n" /*relu*/ - - "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "bif v13.16b, v23.16b, v18.16b \n" - - "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ - - "st1 {v13.4s}, [%[doutr1]], #16 \n" - - // r3 - "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "fmax v14.4s, v14.4s, %[vzero].4s \n" /*relu*/ - - "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "bif v14.16b, v24.16b, v18.16b \n" - - "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "st1 {v14.4s}, [%[doutr2]], #16 \n" - - "fmax v15.4s, v15.4s, %[vzero].4s \n" /*relu*/ - - "bif v15.16b, v25.16b, v18.16b \n" - - "st1 {v15.4s}, [%[doutr3]], #16 \n" - : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [din_ptr4] "+r"(din_ptr4), - [din_ptr5] "+r"(din_ptr5), - [doutr0] "+r"(doutr0), - [doutr1] "+r"(doutr1), - [doutr2] "+r"(doutr2), - [doutr3] "+r"(doutr3) - : [w0] "w"(wr0), - [w1] "w"(wr1), - [w2] "w"(wr2), - [bias_val] "r"(vbias), - [vmask] "r"(vmask), - [rmask] "r"(rmask), - [vzero] "w"(vzero) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25"); - dout_ptr = dout_ptr + 4 * w_out; - } - } -#else - for (int i = 0; i < ch_in; ++i) { - const float* din_channel = din_batch + i * size_in_channel; - - const float* weight_ptr = weights + i * 9; - float32x4_t wr0 = vld1q_f32(weight_ptr); - float32x4_t wr1 = vld1q_f32(weight_ptr + 3); - float32x4_t wr2 = vld1q_f32(weight_ptr + 6); - float bias_val = flag_bias ? bias[i] : 0.f; - - float* dout_channel = dout_batch + i * size_out_channel; - - const float* dr0 = din_channel; - const float* dr1 = dr0 + w_in; - const float* dr2 = dr1 + w_in; - const float* dr3 = dr2 + w_in; - - const float* din0_ptr = nullptr; - const float* din1_ptr = nullptr; - const float* din2_ptr = nullptr; - const float* din3_ptr = nullptr; - - float* doutr0 = nullptr; - float* doutr1 = nullptr; - - float* ptr_zero = const_cast(zero); - - for (int i = 0; i < h_in; i += 2) { - //! process top pad pad_h = 1 - din0_ptr = dr0; - din1_ptr = dr1; - din2_ptr = dr2; - din3_ptr = dr3; - - doutr0 = dout_channel; - doutr1 = dout_channel + w_out; - // unsigned int* rst_mask = rmask; - - if (i == 0) { - din0_ptr = zero_ptr; - din1_ptr = dr0; - din2_ptr = dr1; - din3_ptr = dr2; - dr0 = dr1; - dr1 = dr2; - dr2 = dr3; - dr3 = dr2 + w_in; - } else { - dr0 = dr2; - dr1 = dr3; - dr2 = dr1 + w_in; - dr3 = dr2 + w_in; - } - //! process bottom pad - if (i + 3 > h_in) { - switch (i + 3 - h_in) { - case 3: - din1_ptr = zero_ptr; - case 2: - din2_ptr = zero_ptr; - case 1: - din3_ptr = zero_ptr; - default: - break; - } - } - //! process bottom remain - if (i + 2 > h_out) { - doutr1 = write_ptr; - } - int cnt = cnt_col; - unsigned int* rmask_ptr = rmask; - unsigned int* vmask_ptr = vmask; - asm volatile( - "pld [%[din0_ptr]] @ preload data\n" - "pld [%[din1_ptr]] @ preload data\n" - "pld [%[din2_ptr]] @ preload data\n" - "pld [%[din3_ptr]] @ preload data\n" - - "vld1.32 {d16-d18}, [%[din0_ptr]]! @ load din r0\n" - "vld1.32 {d20-d22}, [%[din1_ptr]]! @ load din r1\n" - "vld1.32 {d24-d26}, [%[din2_ptr]]! @ load din r2\n" - "vld1.32 {d28-d30}, [%[din3_ptr]]! @ load din r3\n" - - "vdup.32 q4, %[bias_val] @ and \n" // q4 - // = - // vbias - "vdup.32 q5, %[bias_val] @ and \n" // q5 - // = - // vbias - - "vext.32 q6, %q[vzero], q8, #3 @ 0012\n" - "vext.32 q7, q8, q9, #1 @ 1234\n" - - // left - // r0 - "vmla.f32 q4, q8, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" - - "sub %[din0_ptr], #12 @ 1pad + 2 float data overlap\n" - "sub %[din1_ptr], #12 @ 1pad + 2 float data overlap\n" - "sub %[din2_ptr], #12 @ 1pad + 2 float data overlap\n" - "sub %[din3_ptr], #12 @ 1pad + 2 float data overlap\n" - - "vmla.f32 q4, q6, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" - - "pld [%[din0_ptr]] @ preload data\n" - "pld [%[din1_ptr]] @ preload data\n" - "pld [%[din2_ptr]] @ preload data\n" - "pld [%[din3_ptr]] @ preload data\n" - - "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 1234 * wr0[2]\n" - - "vext.32 q6, %q[vzero], q10, #3 @ 0012\n" - "vext.32 q7, q10, q11, #1 @ 1234\n" - - // r1 - "vmla.f32 q5, q10, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q10, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0\n" - "vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r0\n" - - "vmla.f32 q5, q6, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" - "vmla.f32 q4, q6, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" - - "vld1.32 {d18}, [%[din0_ptr]] @ load din r0\n" - "vld1.32 {d22}, [%[din1_ptr]] @ load din r0\n" - - "vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[2]\n" - "vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[2]\n" - - "vext.32 q6, %q[vzero], q12, #3 @ 0012\n" - "vext.32 q7, q12, q13, #1 @ 1234\n" - - // r2 - "vmla.f32 q5, q12, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q12, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r0\n" - - "vmla.f32 q5, q6, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" - "vmla.f32 q4, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]\n" - - "vld1.32 {d26}, [%[din2_ptr]] @ load din r0\n" - - "vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[2]\n" - "vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]\n" - - "vext.32 q6, %q[vzero], q14, #3 @ 0012\n" - "vext.32 q7, q14, q15, #1 @ 1234\n" - - // r3 - "vmla.f32 q5, q14, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" - "vmax.f32 q4, q4, %q[vzero] @ relu \n" - - "vmla.f32 q5, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]\n" - - "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" - "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" - - "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]\n" - - "vext.32 q6, q8, q9, #1 @ 1234\n" - "vext.32 q7, q8, q9, #2 @ 2345\n" - "vdup.32 q4, %[bias_val] @ and \n" // q4 - // = - // vbias - - "vmax.f32 q5, q5, %q[vzero] @ relu \n" - - "cmp %[cnt], #1 @ check whether has " - "mid cols\n" - - "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add " - "pointer\n" - - "vdup.32 q5, %[bias_val] @ and \n" // q5 - // = - // vbias - "blt 3f @ jump to main loop start " - "point\n" - - // mid - "1: @ right pad entry\n" - // r0 - "vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n" - - "pld [%[din0_ptr]] @ preload data\n" - "pld [%[din1_ptr]] @ preload data\n" - "pld [%[din2_ptr]] @ preload data\n" - "pld [%[din3_ptr]] @ preload data\n" - - "vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0\n" - - "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n" - - "vld1.32 {d18}, [%[din0_ptr]] @ load din r0\n" - - "vext.32 q6, q10, q11, #1 @ 1234\n" - "vext.32 q7, q10, q11, #2 @ 2345\n" - - // r1 - "vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" - "vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r0\n" - - "vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d22}, [%[din1_ptr]] @ load din r0\n" - - "vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" - - "vext.32 q6, q12, q13, #1 @ 1234\n" - "vext.32 q7, q12, q13, #2 @ 2345\n" - - // r2 - "vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" - "vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r0\n" - - "vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d26}, [%[din2_ptr]] @ load din r0\n" - - "vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n" - - "vext.32 q6, q14, q15, #1 @ 1234\n" - "vext.32 q7, q14, q15, #2 @ 2345\n" - - // r3 - "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" - - "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" - "vmax.f32 q4, q4, %q[vzero] @ relu \n" - - "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" - "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" - - "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" - - "vext.32 q6, q8, q9, #1 @ 1234\n" - "vext.32 q7, q8, q9, #2 @ 2345\n" - "vdup.32 q4, %[bias_val] @ and \n" // q4 - // = - // vbias - - "vmax.f32 q5, q5, %q[vzero] @ relu \n" - - "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add " - "pointer\n" - - "subs %[cnt], #1 @ loop count minus 1\n" - - "vdup.32 q5, %[bias_val] @ and \n" // q4 - // = - // vbias - - "bne 1b @ jump to main loop start " - "point\n" - - // right - "3: @ right pad entry\n" - "vld1.32 {d19}, [%[vmask]]! @ load din r0\n" - "vld1.32 {d23}, [%[vmask]]! @ load din r0\n" - - "vld1.32 {d27}, [%[vmask]]! @ load din r0\n" - "vld1.32 {d31}, [%[vmask]]! @ load din r0\n" - - "vbif d16, %e[vzero], d19 @ bit select, deal with " - "right pad\n" - "vbif d17, %e[vzero], d23 @ bit select, deal with " - "right pad\n" - "vbif d18, %e[vzero], d27 @ bit select, deal with " - "right pad\n" - - "vbif d20, %e[vzero], d19 @ bit select, deal with " - "right pad\n" - "vbif d21, %e[vzero], d23 @ bit select, deal with " - "right pad\n" - "vbif d22, %e[vzero], d27 @ bit select, deal with " - "right pad\n" - - "vext.32 q6, q8, q9, #1 @ 1234\n" - "vext.32 q7, q8, q9, #2 @ 2345\n" - - // r0 - "vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n" - - "vbif d24, %e[vzero], d19 @ bit select, deal with " - "right pad\n" - "vbif d25, %e[vzero], d23 @ bit select, deal with " - "right pad\n" - "vbif d26, %e[vzero], d27 @ bit select, deal with " - "right pad\n" - - "vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" - - "vbif d28, %e[vzero], d19 @ bit select, deal with " - "right pad\n" - "vbif d29, %e[vzero], d23 @ bit select, deal with " - "right pad\n" - "vbif d30, %e[vzero], d27 @ bit select, deal with " - "right pad\n" - - "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n" - - "vext.32 q6, q10, q11, #1 @ 1234\n" - "vext.32 q7, q10, q11, #2 @ 2345\n" - - // r1 - "vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" - "vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d19}, [%[rmask]]! @ load din r0\n" - "vld1.32 {d23}, [%[rmask]]! @ load din r0\n" - - "vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d16-d17}, [%[dout_ptr1]] @ load din r0\n" - "vld1.32 {d20-d21}, [%[dout_ptr2]] @ load din r0\n" - - "vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" - - "vext.32 q6, q12, q13, #1 @ 1234\n" - "vext.32 q7, q12, q13, #2 @ 2345\n" - - // r2 - "vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" - "vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n" - - "vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" - - "vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n" - - "vext.32 q6, q14, q15, #1 @ 1234\n" - "vext.32 q7, q14, q15, #2 @ 2345\n" - - // r3 - "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" - - "vmax.f32 q4, q4, %q[vzero] @ relu \n" - - "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" - - "vbif d8, d16, d19 @ bit select, deal with right pad\n" - "vbif d9, d17, d23 @ bit select, deal with right pad\n" - - "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" - "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" - - "vmax.f32 q5, q5, %q[vzero] @ relu \n" - - "vbif d10, d20, d19 @ bit select, deal with right " - "pad\n" - "vbif d11, d21, d23 @ bit select, deal with right " - "pad\n" - - "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add " - "pointer\n" - - : [dout_ptr1] "+r"(doutr0), - [dout_ptr2] "+r"(doutr1), - [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr), - [din3_ptr] "+r"(din3_ptr), - [cnt] "+r"(cnt), - [rmask] "+r"(rmask_ptr), - [vmask] "+r"(vmask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias_val] "r"(bias_val), - [vzero] "w"(vzero) - : "cc", - "memory", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - dout_channel += 2 * w_out; - } //! end of processing mid rows - } -#endif - } -} -/** - * \brief depthwise convolution kernel 3x3, stride 2, with reulu - */ -// w_in > 7 -void conv_depthwise_3x3s2p1_bias_relu(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx) { - int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; - int out_pad_idx[4] = {0, 1, 2, 3}; - int size_pad_bottom = h_out * 2 - h_in; - - int cnt_col = (w_out >> 2) - 2; - int size_right_remain = w_in - (7 + cnt_col * 8); - if (size_right_remain >= 9) { - cnt_col++; - size_right_remain -= 8; - } - int cnt_remain = (size_right_remain == 8) ? 4 : (w_out % 4); // - - int size_right_pad = w_out * 2 - w_in; - - uint32x4_t vmask_rp1 = vcgtq_s32(vdupq_n_s32(size_right_remain), - vld1q_s32(right_pad_idx)); // 0 2 4 6 - uint32x4_t vmask_rp2 = vcgtq_s32(vdupq_n_s32(size_right_remain), - vld1q_s32(right_pad_idx + 4)); // 1 3 5 7 - uint32x4_t wmask = - vcgtq_s32(vdupq_n_s32(cnt_remain), vld1q_s32(out_pad_idx)); // 0 1 2 3 - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - - float* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - float* write_ptr = zero_ptr + w_in; - - unsigned int dmask[12]; - - vst1q_u32(dmask, vmask_rp1); - vst1q_u32(dmask + 4, vmask_rp2); - vst1q_u32(dmask + 8, wmask); - - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * ch_in * size_in_channel; - float* dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int i = 0; i < ch_in; ++i) { - const float* din_channel = din_batch + i * size_in_channel; - float* dout_channel = dout_batch + i * size_out_channel; - - const float* weight_ptr = weights + i * 9; - float32x4_t wr0 = vld1q_f32(weight_ptr); - float32x4_t wr1 = vld1q_f32(weight_ptr + 3); - float32x4_t wr2 = vld1q_f32(weight_ptr + 6); - - float32x4_t vzero = vdupq_n_f32(0.f); - - float32x4_t wbias; - float bias_c = 0.f; - if (flag_bias) { - wbias = vdupq_n_f32(bias[i]); - bias_c = bias[i]; - } else { - wbias = vdupq_n_f32(0.f); - } - - const float* dr0 = din_channel; - const float* dr1 = dr0 + w_in; - const float* dr2 = dr1 + w_in; - const float* dr3 = dr2 + w_in; - const float* dr4 = dr3 + w_in; - - const float* din0_ptr = dr0; - const float* din1_ptr = dr1; - const float* din2_ptr = dr2; - const float* din3_ptr = dr3; - const float* din4_ptr = dr4; - - float* doutr0 = dout_channel; - float* doutr0_ptr = nullptr; - float* doutr1_ptr = nullptr; - -#ifdef __aarch64__ - for (int i = 0; i < h_in; i += 4) { - din0_ptr = dr0; - din1_ptr = dr1; - din2_ptr = dr2; - din3_ptr = dr3; - din4_ptr = dr4; - - doutr0_ptr = doutr0; - doutr1_ptr = doutr0 + w_out; - - if (i == 0) { - din0_ptr = zero_ptr; - din1_ptr = dr0; - din2_ptr = dr1; - din3_ptr = dr2; - din4_ptr = dr3; - dr0 = dr3; - dr1 = dr4; - } else { - dr0 = dr4; - dr1 = dr0 + w_in; - } - dr2 = dr1 + w_in; - dr3 = dr2 + w_in; - dr4 = dr3 + w_in; - - //! process bottom pad - if (i + 4 > h_in) { - switch (i + 4 - h_in) { - case 4: - din1_ptr = zero_ptr; - case 3: - din2_ptr = zero_ptr; - case 2: - din3_ptr = zero_ptr; - case 1: - din4_ptr = zero_ptr; - default: - break; - } - } - //! process output pad - if (i / 2 + 2 > h_out) { - doutr1_ptr = write_ptr; - } - int cnt = cnt_col; - asm volatile( - // top - // Load up 12 elements (3 vectors) from each of 8 sources. - "0: \n" - "prfm pldl1keep, [%[inptr0]] \n" - "prfm pldl1keep, [%[inptr1]] \n" - "prfm pldl1keep, [%[inptr2]] \n" - "prfm pldl1keep, [%[inptr3]] \n" - "prfm pldl1keep, [%[inptr4]] \n" - "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" // v0={0,2,4,6} - // v1={1,3,5,7} - "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" - "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" - "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" - "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" - - "and v16.16b, %[vbias].16b, %[vbias].16b \n" // v10 = vbias - "and v17.16b, %[vbias].16b, %[vbias].16b \n" // v16 = vbias - - "ext v10.16b, %[vzero].16b, v1.16b, #12 \n" // v10 = {0,1,3,5} - - // r0 - "fmul v11.4s, v0.4s, %[w0].s[1] \n" // {0,2,4,6} * w01 - "fmul v12.4s, v1.4s, %[w0].s[2] \n" // {1,3,5,7} * w02 - "fmla v16.4s, v10.4s, %[w0].s[0] \n" // {0,1,3,5} * w00 - - "ext v10.16b, %[vzero].16b, v3.16b, #12 \n" // v10 = {0,1,3,5} - - "sub %[inptr0], %[inptr0], #4 \n" - "sub %[inptr1], %[inptr1], #4 \n" - - // r1 - "fmla v11.4s, v2.4s, %[w1].s[1] \n" // {0,2,4,6} * w01 - "fmla v12.4s, v3.4s, %[w1].s[2] \n" // {1,3,5,7} * w02 - "fmla v16.4s, v10.4s, %[w1].s[0] \n" // {0,1,3,5} * w00 - - "ext v10.16b, %[vzero].16b, v5.16b, #12 \n" // v10 = {0,1,3,5} - - "sub %[inptr2], %[inptr2], #4 \n" - "sub %[inptr3], %[inptr3], #4 \n" - - // r2 - "fmul v13.4s, v4.4s, %[w0].s[1] \n" // {0,2,4,6} * w01 - "fmla v11.4s, v4.4s, %[w2].s[1] \n" // {0,2,4,6} * w01 - - "fmul v14.4s, v5.4s, %[w0].s[2] \n" // {1,3,5,7} * w02 - "fmla v12.4s, v5.4s, %[w2].s[2] \n" // {1,3,5,7} * w02 - - "fmla v17.4s, v10.4s, %[w0].s[0] \n" // {0,1,3,5} * w00 - "fmla v16.4s, v10.4s, %[w2].s[0] \n" // {0,1,3,5} * w00 - - "ext v10.16b, %[vzero].16b, v7.16b, #12 \n" // v10 = {0,1,3,5} - - "sub %[inptr4], %[inptr4], #4 \n" - - // r3 - "fmla v13.4s, v6.4s, %[w1].s[1] \n" // {0,2,4,6} * w01 - "fmla v14.4s, v7.4s, %[w1].s[2] \n" // {1,3,5,7} * w02 - "fmla v17.4s, v10.4s, %[w1].s[0] \n" // {0,1,3,5} * w00 - - "ext v10.16b, %[vzero].16b, v9.16b, #12 \n" // v10 = {0,1,3,5} - "fadd v16.4s, v16.4s, v11.4s \n" - "fadd v16.4s, v16.4s, v12.4s \n" - - // r4 - "fmla v13.4s, v8.4s, %[w2].s[1] \n" // {0,2,4,6} * w01 - "fmla v14.4s, v9.4s, %[w2].s[2] \n" // {1,3,5,7} * w02 - "fmla v17.4s, v10.4s, %[w2].s[0] \n" // {0,1,3,5} * w00 - - "fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ - - "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" // v0={0,2,4,6} - // v1={1,3,5,7} - "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" - "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" - - "fadd v17.4s, v17.4s, v13.4s \n" - - "st1 {v16.4s}, [%[outptr0]], #16 \n" - - "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" - "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" - "ld1 {v15.4s}, [%[inptr0]] \n" - - "fadd v17.4s, v17.4s, v14.4s \n" - - "and v16.16b, %[vbias].16b, %[vbias].16b \n" // v10 = vbias - - "ld1 {v18.4s}, [%[inptr1]] \n" - "ld1 {v19.4s}, [%[inptr2]] \n" - - "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} - - "fmax v17.4s, v17.4s, %[vzero].4s \n" /* relu */ - - "ld1 {v20.4s}, [%[inptr3]] \n" - "ld1 {v21.4s}, [%[inptr4]] \n" - - "st1 {v17.4s}, [%[outptr1]], #16 \n" - - "cmp %[cnt], #1 \n" - - "and v17.16b, %[vbias].16b, %[vbias].16b \n" // v16 = vbias - - "blt 1f \n" - // mid - "2: \n" - // r0 - "fmul v11.4s, v0.4s, %[w0].s[0] \n" // {0,2,4,6} * w00 - "fmul v12.4s, v1.4s, %[w0].s[1] \n" // {1,3,5,7} * w01 - "fmla v16.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v2.16b, v18.16b, #4 \n" // v10 = {2,4,6,8} - "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" // v0={0,2,4,6} - // v1={1,3,5,7} - - // r1 - "fmla v11.4s, v2.4s, %[w1].s[0] \n" // {0,2,4,6} * w00 - "fmla v12.4s, v3.4s, %[w1].s[1] \n" // {1,3,5,7} * w01 - "fmla v16.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v4.16b, v19.16b, #4 \n" // v10 = {2,4,6,8} - - "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" - - // r2 - "fmul v13.4s, v4.4s, %[w0].s[0] \n" // {0,2,4,6} * w00 - "fmla v11.4s, v4.4s, %[w2].s[0] \n" // {0,2,4,6} * w00 - - "fmul v14.4s, v5.4s, %[w0].s[1] \n" // {1,3,5,7} * w01 - "fmla v12.4s, v5.4s, %[w2].s[1] \n" // {1,3,5,7} * w01 - - "fmla v17.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02 - "fmla v16.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v6.16b, v20.16b, #4 \n" // v10 = {2,4,6,8} - - "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" - - // r3 - "fmla v13.4s, v6.4s, %[w1].s[0] \n" // {0,2,4,6} * w00 - "fmla v14.4s, v7.4s, %[w1].s[1] \n" // {1,3,5,7} * w01 - "fmla v17.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v8.16b, v21.16b, #4 \n" // v10 = {2,4,6,8} - - "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" - - "fadd v16.4s, v16.4s, v11.4s \n" - "fadd v16.4s, v16.4s, v12.4s \n" - - // r4 - "fmla v13.4s, v8.4s, %[w2].s[0] \n" // {0,2,4,6} * w00 - "fmla v14.4s, v9.4s, %[w2].s[1] \n" // {1,3,5,7} * w01 - "fmla v17.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02 - - "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" - "ld1 {v15.4s}, [%[inptr0]] \n" - "ld1 {v18.4s}, [%[inptr1]] \n" - "fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ - - "fadd v17.4s, v17.4s, v13.4s \n" - - "ld1 {v19.4s}, [%[inptr2]] \n" - "ld1 {v20.4s}, [%[inptr3]] \n" - "ld1 {v21.4s}, [%[inptr4]] \n" - - "st1 {v16.4s}, [%[outptr0]], #16 \n" - - "fadd v17.4s, v17.4s, v14.4s \n" - - "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} - "and v16.16b, %[vbias].16b, %[vbias].16b \n" // v10 = vbias - "subs %[cnt], %[cnt], #1 \n" - - "fmax v17.4s, v17.4s, %[vzero].4s \n" /* relu */ - - "st1 {v17.4s}, [%[outptr1]], #16 \n" - - "and v17.16b, %[vbias].16b, %[vbias].16b \n" // v16 = vbias - - "bne 2b \n" - - // right - "1: \n" - "cmp %[remain], #1 \n" - "blt 4f \n" - "3: \n" - "bif v0.16b, %[vzero].16b, %[mask1].16b \n" // pipei - "bif v1.16b, %[vzero].16b, %[mask2].16b \n" // pipei - - "bif v2.16b, %[vzero].16b, %[mask1].16b \n" // pipei - "bif v3.16b, %[vzero].16b, %[mask2].16b \n" // pipei - - "bif v4.16b, %[vzero].16b, %[mask1].16b \n" // pipei - "bif v5.16b, %[vzero].16b, %[mask2].16b \n" // pipei - - "ext v10.16b, v0.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} - - "bif v6.16b, %[vzero].16b, %[mask1].16b \n" // pipei - "bif v7.16b, %[vzero].16b, %[mask2].16b \n" // pipei - - // r0 - "fmul v11.4s, v0.4s, %[w0].s[0] \n" // {0,2,4,6} * w00 - "fmul v12.4s, v1.4s, %[w0].s[1] \n" // {1,3,5,7} * w01 - "fmla v16.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v2.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} - "bif v8.16b, %[vzero].16b, %[mask1].16b \n" // pipei - "bif v9.16b, %[vzero].16b, %[mask2].16b \n" // pipei - - // r1 - "fmla v11.4s, v2.4s, %[w1].s[0] \n" // {0,2,4,6} * w00 - "fmla v12.4s, v3.4s, %[w1].s[1] \n" // {1,3,5,7} * w01 - "fmla v16.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v4.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} - - // r2 - "fmul v13.4s, v4.4s, %[w0].s[0] \n" // {0,2,4,6} * w00 - "fmla v11.4s, v4.4s, %[w2].s[0] \n" // {0,2,4,6} * w00 - - "fmul v14.4s, v5.4s, %[w0].s[1] \n" // {1,3,5,7} * w01 - "fmla v12.4s, v5.4s, %[w2].s[1] \n" // {1,3,5,7} * w01 - - "fmla v17.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02 - "fmla v16.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v6.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} - - // r3 - "fmla v13.4s, v6.4s, %[w1].s[0] \n" // {0,2,4,6} * w00 - "fmla v14.4s, v7.4s, %[w1].s[1] \n" // {1,3,5,7} * w01 - "fmla v17.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v8.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} - "ld1 {v0.4s}, [%[outptr0]] \n" - - "fadd v16.4s, v16.4s, v11.4s \n" - "fadd v16.4s, v16.4s, v12.4s \n" - "ld1 {v1.4s}, [%[outptr1]] \n" - - // r4 - "fmla v13.4s, v8.4s, %[w2].s[0] \n" // {0,2,4,6} * w00 - "fmla v14.4s, v9.4s, %[w2].s[1] \n" // {1,3,5,7} * w01 - "fmla v17.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02 - - "fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ - - "fadd v17.4s, v17.4s, v13.4s \n" - - "bif v16.16b, v0.16b, %[wmask].16b \n" // pipei - - "fadd v17.4s, v17.4s, v14.4s \n" - - "st1 {v16.4s}, [%[outptr0]], #16 \n" - - "fmax v17.4s, v17.4s, %[vzero].4s \n" /* relu */ - - "bif v17.16b, v1.16b, %[wmask].16b \n" // pipei - - "st1 {v17.4s}, [%[outptr1]], #16 \n" - "4: \n" - : [inptr0] "+r"(din0_ptr), - [inptr1] "+r"(din1_ptr), - [inptr2] "+r"(din2_ptr), - [inptr3] "+r"(din3_ptr), - [inptr4] "+r"(din4_ptr), - [outptr0] "+r"(doutr0_ptr), - [outptr1] "+r"(doutr1_ptr), - [cnt] "+r"(cnt) - : [vzero] "w"(vzero), - [w0] "w"(wr0), - [w1] "w"(wr1), - [w2] "w"(wr2), - [remain] "r"(cnt_remain), - [mask1] "w"(vmask_rp1), - [mask2] "w"(vmask_rp2), - [wmask] "w"(wmask), - [vbias] "w"(wbias) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21"); - doutr0 = doutr0 + 2 * w_out; - } -#else - - for (int i = 0; i < h_in; i += 2) { - din0_ptr = dr0; - din1_ptr = dr1; - din2_ptr = dr2; - - doutr0_ptr = doutr0; - - if (i == 0) { - din0_ptr = zero_ptr; - din1_ptr = dr0; - din2_ptr = dr1; - dr0 = dr1; - dr1 = dr2; - dr2 = dr1 + w_in; - } else { - dr0 = dr2; - dr1 = dr0 + w_in; - dr2 = dr1 + w_in; - } - - //! process bottom pad - if (i + 2 > h_in) { - switch (i + 2 - h_in) { - case 2: - din1_ptr = zero_ptr; - case 1: - din2_ptr = zero_ptr; - default: - break; - } - } - int cnt = cnt_col; - - unsigned int* mask_ptr = dmask; - asm volatile( - // top - // Load up 12 elements (3 vectors) from each of 8 sources. - "0: \n" - "vmov.u32 q9, #0 \n" - "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r1\n" // v11={0,2,4,6} v12={1,3,5,7}, q10, q11 - "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" // v11={0,2,4,6} v12={1,3,5,7}, q12, q13 - "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r1\n" // v13={0,2,4,6} v14={1,3,5,7}, q14, q15 - "pld [%[din0_ptr]] @ preload data\n" - "pld [%[din1_ptr]] @ preload data\n" - "pld [%[din2_ptr]] @ preload data\n" - - "vdup.32 q3, %[bias] @ and \n" // q10 = - // vbias - - "vext.32 q6, q9, q11, #3 @ shift right 1 " - "data\n" // q2 = {0,1,3,5} - "vext.32 q7, q9, q13, #3 @ shift right 1 " - "data\n" // q6 = {0,1,3,5} - "vext.32 q8, q9, q15, #3 @ shift right 1 " - "data\n" // q6 = {0,1,3,5} - - "vmul.f32 q4, q10, %e[wr0][1] @ mul weight 1, " - "out0\n" // q11 * w01 - "vmul.f32 q5, q11, %f[wr0][0] @ mul weight 1, " - "out0\n" // q12 * w02 - "vmla.f32 q3, q6, %e[wr0][0] @ mul weight 1, " - "out0\n" // q6 * w00 - - "sub %[din0_ptr], #4 @ inpitr0 - 1\n" - "sub %[din1_ptr], #4 @ inpitr1 - 1\n" - "sub %[din2_ptr], #4 @ inpitr2 - 1\n" - - "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" // v0={0,2,4,6} v1={1,3,5,7} - - "vmla.f32 q4, q12, %e[wr1][1] @ mul weight 1, " - "out0\n" // q11 * w01 - "vmla.f32 q5, q13, %f[wr1][0] @ mul weight 1, " - "out0\n" // q12 * w02 - "vmla.f32 q3, q7, %e[wr1][0] @ mul weight 1, " - "out0\n" // q6 * w00 - - "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" // v4={0,2,4,6} v5={1,3,5,7} - - "vmla.f32 q4, q14, %e[wr2][1] @ mul weight 1, " - "out1\n" // q0 * w01 - "vmla.f32 q5, q15, %f[wr2][0] @ mul weight 1, " - "out1\n" // q1 * w02 - "vmla.f32 q3, q8, %e[wr2][0] @ mul weight 1, " - "out1\n" // q2 * w00 - - "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r1\n" // v4={0,2,4,6} v5={1,3,5,7} - - "vadd.f32 q3, q3, q4 @ add \n" - "vadd.f32 q3, q3, q5 @ add \n" - - "vmax.f32 q3, q3, q9 @ relu \n" - - "vst1.32 {d6-d7}, [%[outptr]]! \n" - "cmp %[cnt], #1 \n" - "blt 1f \n" - // mid - "2: \n" - "vld1.32 {d16}, [%[din0_ptr]] @ load din r0\n" // q2={8,10,12,14} - "vdup.32 q3, %[bias] @ and \n" // q10 = - // vbias - "vext.32 q6, q10, q8, #1 @ shift left 1 \n" // q6 = {2,4,6,8} - "vld1.32 {d16}, [%[din1_ptr]] @ load din r1\n" // q2={8,10,12,14} - - "vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, " - "out0\n" // q0 * w00 - "vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, " - "out0\n" // q1 * w01 - "vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, " - "out0\n" // q6 * w02 - - "vext.32 q7, q12, q8, #1 @ shift left 1 \n" // q6 = {2,4,6,8} - "vld1.32 {d16}, [%[din2_ptr]] @ load din r1\n" // q2={8,10,12,14} - - "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" // v0={0,2,4,6} v1={1,3,5,7} - - "vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, " - "out0\n" // q0 * w00 - "vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, " - "out0\n" // q1 * w01 - "vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, " - "out0\n" // q6 * w02 - - "vext.32 q6, q14, q8, #1 @ shift left 1 \n" // q6 = {2,4,6,8} - - "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" // v0={0,2,4,6} v1={1,3,5,7} - - "vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, " - "out0\n" // q0 * w00 - "vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, " - "out0\n" // q1 * w01 - "vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, " - "out0\n" // q6 * w02 - - "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2\n" // v4={0,2,4,6} v5={1,3,5,7} - - "vadd.f32 q3, q3, q4 @ add \n" - "vadd.f32 q3, q3, q5 @ add \n" - - "vmax.f32 q3, q3, q9 @ relu \n" - - "subs %[cnt], #1 \n" - - "vst1.32 {d6-d7}, [%[outptr]]! \n" - "bne 2b \n" - - // right - "1: \n" - "cmp %[remain], #1 \n" - "blt 3f \n" - - "vld1.f32 {d12-d15}, [%[mask_ptr]]! @ load mask\n" - "vdup.32 q3, %[bias] @ and \n" // q10 = - // vbias - - "vbif q10, q9, q6 @ bit select, deal " - "with right pad\n" - "vbif q11, q9, q7 @ bit select, deal " - "with right pad\n" - "vbif q12, q9, q6 @ bit select, deal " - "with right pad\n" - "vbif q13, q9, q7 @ bit select, deal " - "with right pad\n" - "vbif q14, q9, q6 @ bit select, deal " - "with right pad\n" - "vbif q15, q9, q7 @ bit select, deal " - "with right pad\n" - - "vext.32 q6, q10, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,8} - "vext.32 q7, q12, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,8} - - "vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, " - "out0\n" // q0 * w00 - "vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, " - "out0\n" // q1 * w01 - "vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, " - "out0\n" // q6 * w02 - - "vext.32 q6, q14, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,8} - "vld1.f32 {d20-d21}, [%[outptr]] @ load output\n" - - "vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, " - "out0\n" // q0 * w00 - "vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, " - "out0\n" // q1 * w01 - "vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, " - "out0\n" // q6 * w02 - - "vld1.f32 {d22-d23}, [%[mask_ptr]] @ load mask\n" - - "vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, " - "out0\n" // q0 * w00 - "vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, " - "out0\n" // q1 * w01 - "vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, " - "out0\n" // q6 * w02 - - "vadd.f32 q3, q3, q4 @ add \n" - "vadd.f32 q3, q3, q5 @ add \n" - - "vmax.f32 q3, q3, q9 @ relu \n" - - "vbif.f32 q3, q10, q11 @ write mask\n" - - "vst1.32 {d6-d7}, [%[outptr]]! \n" - "3: \n" - : [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr), - [outptr] "+r"(doutr0_ptr), - [cnt] "+r"(cnt), - [mask_ptr] "+r"(mask_ptr) - : [remain] "r"(cnt_remain), - [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias] "r"(bias_c) - : "cc", - "memory", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - - doutr0 = doutr0 + w_out; - } -#endif - } - } -} -/** - * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, - * width <= 4 - */ -void conv_depthwise_3x3s1p1_bias_s(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx) { - //! 3x3s1 convolution, implemented by direct algorithm - //! pad is done implicit - //! for 4x6 convolution window - const int right_pad_idx[4] = {3, 2, 1, 0}; - const float zero[4] = {0.f, 0.f, 0.f, 0.f}; - - float32x4_t vzero = vdupq_n_f32(0.f); - uint32x4_t vmask_rp = - vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(4 - w_in)); - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * ch_in * size_in_channel; - float* dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int i = 0; i < ch_in; ++i) { - float* dout_channel = dout_batch + i * size_out_channel; - const float* din_channel = din_batch + i * size_in_channel; - const float* weight_ptr = weights + i * 9; - float32x4_t wr0 = vld1q_f32(weight_ptr); - float32x4_t wr1 = vld1q_f32(weight_ptr + 3); - float32x4_t wr2 = vld1q_f32(weight_ptr + 6); - float32x4_t wbias; - if (flag_bias) { - wbias = vdupq_n_f32(bias[i]); - } else { - wbias = vdupq_n_f32(0.f); - } - - int hs = -1; - int he = 3; - - float out_buf1[4]; - float out_buf2[4]; - float trash_buf[4]; - - int h_cnt = (h_out + 1) >> 1; - float* doutr0 = dout_channel; - float* doutr1 = dout_channel + w_out; - - for (int j = 0; j < h_cnt; ++j) { - const float* dr0 = din_channel + hs * w_in; - const float* dr1 = dr0 + w_in; - const float* dr2 = dr1 + w_in; - const float* dr3 = dr2 + w_in; - - if (hs == -1) { - dr0 = zero; - } - - switch (he - h_in) { - case 2: - dr2 = zero; - doutr1 = trash_buf; - case 1: - dr3 = zero; - default: - break; - } -#ifdef __aarch64__ - asm volatile( - "prfm pldl1keep, [%[din0]]\n" - "prfm pldl1keep, [%[din1]]\n" - "prfm pldl1keep, [%[din2]]\n" - "prfm pldl1keep, [%[din3]]\n" - - "ld1 {v0.4s}, [%[din0]], #16\n" - "ld1 {v1.4s}, [%[din1]], #16\n" - "ld1 {v2.4s}, [%[din2]], #16\n" - "ld1 {v3.4s}, [%[din3]], #16\n" - - "bif v0.16b, %[zero].16b, %[mask].16b\n" // d0_1234 - "bif v1.16b, %[zero].16b, %[mask].16b\n" // d1_1234 - "bif v2.16b, %[zero].16b, %[mask].16b\n" // d2_1234 - "bif v3.16b, %[zero].16b, %[mask].16b\n" // d3_1234 - - "ext v4.16b, %[zero].16b, v0.16b, #12\n" // d0_0123 - "ext v5.16b, %[zero].16b, v1.16b, #12\n" // d1_0123 - "ext v6.16b, %[zero].16b, v2.16b, #12\n" // d2_0123 - "ext v7.16b, %[zero].16b, v3.16b, #12\n" // d3_0123 - - "ext v8.16b, v0.16b, %[zero].16b, #4\n" // d0_2340 - "ext v9.16b, v1.16b, %[zero].16b, #4\n" // d1_2340 - "ext v10.16b, v2.16b, %[zero].16b, #4\n" // d2_2340 - "ext v11.16b, v3.16b, %[zero].16b, #4\n" // d3_2340 - - "fmul v12.4s, v0.4s, %[wr0].s[1]\n" - "fmul v13.4s, v1.4s, %[wr0].s[1]\n" - - "fmul v14.4s, v1.4s, %[wr1].s[1]\n" - "fmul v15.4s, v2.4s, %[wr1].s[1]\n" - - "fmul v16.4s, v2.4s, %[wr2].s[1]\n" - "fmul v17.4s, v3.4s, %[wr2].s[1]\n" - - "fmla v12.4s, v4.4s, %[wr0].s[0]\n" - "fmla v13.4s, v5.4s, %[wr0].s[0]\n" - - "fmla v14.4s, v5.4s, %[wr1].s[0]\n" - "fmla v15.4s, v6.4s, %[wr1].s[0]\n" - - "fmla v16.4s, v6.4s, %[wr2].s[0]\n" - "fmla v17.4s, v7.4s, %[wr2].s[0]\n" - - "fmla v12.4s, v8.4s, %[wr0].s[2]\n" - "fmla v13.4s, v9.4s, %[wr0].s[2]\n" - - "fmla v14.4s, v9.4s, %[wr1].s[2]\n" - "fmla v15.4s, v10.4s, %[wr1].s[2]\n" - - "fmla v16.4s, v10.4s, %[wr2].s[2]\n" - "fmla v17.4s, v11.4s, %[wr2].s[2]\n" - - "fadd v12.4s, v12.4s, v14.4s\n" - "fadd v12.4s, v12.4s, v16.4s\n" - - "fadd v13.4s, v13.4s, v15.4s\n" // out1 - "fadd v13.4s, v13.4s, v17.4s\n" // out2 - - "fadd v12.4s, v12.4s, %[bias].4s\n" // out1 add bias - "fadd v13.4s, v13.4s, %[bias].4s\n" // out2 add bias - - "prfm pldl1keep, [%[out1]]\n" - "prfm pldl1keep, [%[out2]]\n" - - "st1 {v12.4s}, [%[out1]]\n" - "st1 {v13.4s}, [%[out2]]\n" - - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [zero] "w"(vzero), - [mask] "w"(vmask_rp), - [bias] "w"(wbias), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17"); -#else - asm volatile( - "pld [%[din0]]\n" - "pld [%[din1]]\n" - "pld [%[din2]]\n" - "pld [%[din3]]\n" - - "vld1.32 {d12-d13}, [%[din0]]!\n" - "vld1.32 {d14-d15}, [%[din1]]!\n" - "vld1.32 {d16-d17}, [%[din2]]!\n" - "vld1.32 {d18-d19}, [%[din3]]!\n" - - "vbif q6, %q[zero], %q[mask]\n" // d0_1234 - "vbif q7, %q[zero], %q[mask]\n" // d1_1234 - "vbif q8, %q[zero], %q[mask]\n" // d2_1234 - "vbif q9, %q[zero], %q[mask]\n" // d3_1234 - - "vmul.f32 q14, q6, %e[wr0][1]\n" - "vmul.f32 q15, q7, %e[wr0][1]\n" - - "vmla.f32 q14, q7, %e[wr1][1]\n" - "vmla.f32 q15, q8, %e[wr1][1]\n" - - "vmla.f32 q14, q8, %e[wr2][1]\n" - "vmla.f32 q15, q9, %e[wr2][1]\n" - - "vext.32 q10, %q[zero], q6, #3\n" // d0_0123 - "vext.32 q11, %q[zero], q7, #3\n" // d1_0123 - "vext.32 q12, %q[zero], q8, #3\n" // d2_0123 - "vext.32 q13, %q[zero], q9, #3\n" // d3_0123 - - "vmla.f32 q14, q10, %e[wr0][0]\n" - "vmla.f32 q15, q11, %e[wr0][0]\n" - - "vmla.f32 q14, q11, %e[wr1][0]\n" - "vmla.f32 q15, q12, %e[wr1][0]\n" - - "vmla.f32 q14, q12, %e[wr2][0]\n" - "vmla.f32 q15, q13, %e[wr2][0]\n" - - "vext.32 q10, q6, %q[zero], #1\n" // d0_2340 - "vext.32 q11, q7, %q[zero], #1\n" // d1_2340 - "vext.32 q12, q8, %q[zero], #1\n" // d2_2340 - "vext.32 q13, q9, %q[zero], #1\n" // d3_2340 - - "vmla.f32 q14, q10, %f[wr0][0]\n" - "vmla.f32 q15, q11, %f[wr0][0]\n" - - "vmla.f32 q14, q11, %f[wr1][0]\n" - "vmla.f32 q15, q12, %f[wr1][0]\n" - - "vmla.f32 q14, q12, %f[wr2][0]\n" // out1 - "vmla.f32 q15, q13, %f[wr2][0]\n" // out2 - - "vadd.f32 q14, q14, %q[bias]\n" // out1 add bias - "vadd.f32 q15, q15, %q[bias]\n" // out2 add bias - - "pld [%[out1]]\n" - "pld [%[out2]]\n" - - "vst1.32 {d28-d29}, [%[out1]]\n" - "vst1.32 {d30-d31}, [%[out2]]\n" - - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [zero] "w"(vzero), - [mask] "w"(vmask_rp), - [bias] "w"(wbias), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "cc", - "memory", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -#endif // __aarch64__ - for (int w = 0; w < w_out; ++w) { - *doutr0++ = out_buf1[w]; - *doutr1++ = out_buf2[w]; - } - doutr0 = doutr1; - doutr1 += w_out; - hs += 2; - he += 2; - } // end of processing heights - } // end of processing channels - } // end of processing batchs -} -/** - * \brief depthwise convolution kernel 3x3, stride 2, width <= 4 - */ - -void conv_depthwise_3x3s2p1_bias_s(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx) { - int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; - int out_pad_idx[4] = {0, 1, 2, 3}; - float zeros[8] = {0.0f}; - - uint32x4_t vmask_rp1 = - vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx)); // 0 2 4 6 - uint32x4_t vmask_rp2 = - vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx + 4)); // 1 3 5 7 - - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - - unsigned int dmask[8]; - vst1q_u32(dmask, vmask_rp1); - vst1q_u32(dmask + 4, vmask_rp2); - - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * ch_in * size_in_channel; - float* dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int i = 0; i < ch_in; ++i) { - const float* din_channel = din_batch + i * size_in_channel; - float* dout_channel = dout_batch + i * size_out_channel; - - const float* weight_ptr = weights + i * 9; - float32x4_t wr0 = vld1q_f32(weight_ptr); - float32x4_t wr1 = vld1q_f32(weight_ptr + 3); - float32x4_t wr2 = vld1q_f32(weight_ptr + 6); - - float bias_c = 0.f; - - if (flag_bias) { - bias_c = bias[i]; - } - float32x4_t vbias = vdupq_n_f32(bias_c); - int hs = -1; - int he = 2; - float out_buf[4]; - for (int j = 0; j < h_out; ++j) { - const float* dr0 = din_channel + hs * w_in; - const float* dr1 = dr0 + w_in; - const float* dr2 = dr1 + w_in; - if (hs == -1) { - dr0 = zeros; - } - if (he > h_in) { - dr2 = zeros; - } - const float* din0_ptr = dr0; - const float* din1_ptr = dr1; - const float* din2_ptr = dr2; - - unsigned int* mask_ptr = dmask; -#ifdef __aarch64__ - asm volatile( - // Load up 12 elements (3 vectors) from each of 8 sources. - "movi v9.4s, #0 \n" - "ld1 {v6.4s, v7.4s}, [%[mask_ptr]], #32 \n" - - "ld2 {v10.4s, v11.4s}, [%[din0_ptr]], #32 \n" // v10={0,2,4,6} - // v11={1,3,5,7} - "ld2 {v12.4s, v13.4s}, [%[din1_ptr]], #32 \n" // v13={0,2,4,6} - // v12={1,3,5,7} - "ld2 {v14.4s, v15.4s}, [%[din2_ptr]], #32 \n" // v14={0,2,4,6} - // v15={1,3,5,7} - - "bif v10.16b, v9.16b, v6.16b \n" - "bif v11.16b, v9.16b, v7.16b \n" - "bif v12.16b, v9.16b, v6.16b \n" - "bif v13.16b, v9.16b, v7.16b \n" - "bif v14.16b, v9.16b, v6.16b \n" - "bif v15.16b, v9.16b, v7.16b \n" - - "ext v6.16b, v9.16b, v11.16b, #12 \n" // v6 = - // {0,1,3,5} - "ext v7.16b, v9.16b, v13.16b, #12 \n" // v7 = - // {0,1,3,5} - "ext v8.16b, v9.16b, v15.16b, #12 \n" // v8 = - // {0,1,3,5} - - "fmul v4.4s, v10.4s, %[wr0].s[1] \n" // v10 * w01 - "fmul v5.4s, v11.4s, %[wr0].s[2] \n" // v11 * w02 - "fmul v6.4s, v6.4s, %[wr0].s[0] \n" // v6 * w00 - - "fmla v4.4s, v12.4s, %[wr1].s[1] \n" // v12 * w11 - "fmla v5.4s, v13.4s, %[wr1].s[2] \n" // v13 * w12 - "fmla v6.4s, v7.4s, %[wr1].s[0] \n" // v7 * w10 - - "fmla v4.4s, v14.4s, %[wr2].s[1] \n" // v14 * w20 - "fmla v5.4s, v15.4s, %[wr2].s[2] \n" // v15 * w21 - "fmla v6.4s, v8.4s, %[wr2].s[0] \n" // v8 * w22 - - "fadd v4.4s, v4.4s, v5.4s \n" - "fadd v4.4s, v4.4s, v6.4s \n" - - "fadd v4.4s, v4.4s, %[bias].4s \n" - - "st1 {v4.4s}, [%[out]] \n" - : [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr), - [mask_ptr] "+r"(mask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias] "w"(vbias), - [out] "r"(out_buf) - : "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15"); - -#else - asm volatile( - // Load up 12 elements (3 vectors) from each of 8 sources. - "vmov.u32 q9, #0 \n" - "vld1.f32 {d12-d15}, [%[mask_ptr]]! @ load mask\n" - "vdup.32 q3, %[bias] @ and \n" // q3 = - // vbias - - "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" // q10={0,2,4,6} q11={1,3,5,7} - "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" // q13={0,2,4,6} q12={1,3,5,7} - "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2\n" // q14={0,2,4,6} q15={1,3,5,7} - - "vbif q10, q9, q6 @ bit select, deal " - "with right pad\n" - "vbif q11, q9, q7 @ bit select, deal " - "with right pad\n" - "vbif q12, q9, q6 @ bit select, deal " - "with right pad\n" - "vbif q13, q9, q7 @ bit select, deal " - "with right pad\n" - "vbif q14, q9, q6 @ bit select, deal " - "with right pad\n" - "vbif q15, q9, q7 @ bit select, deal " - "with right pad\n" - - "vext.32 q6, q9, q11, #3 @ shift left 1 \n" // q6 = {0,1,3,5} - "vext.32 q7, q9, q13, #3 @ shift left 1 \n" // q7 = {0,1,3,5} - "vext.32 q8, q9, q15, #3 @ shift left 1 \n" // q8 = {0,1,3,5} - - "vmul.f32 q4, q10, %e[wr0][1] @ mul weight 0, " - "out0\n" // q10 * w01 - "vmul.f32 q5, q11, %f[wr0][0] @ mul weight 0, " - "out0\n" // q11 * w02 - "vmla.f32 q3, q6, %e[wr0][0] @ mul weight 0, " - "out0\n" // q6 * w00 - - "vmla.f32 q4, q12, %e[wr1][1] @ mul weight 1, " - "out0\n" // q12 * w11 - "vmla.f32 q5, q13, %f[wr1][0] @ mul weight 1, " - "out0\n" // q13 * w12 - "vmla.f32 q3, q7, %e[wr1][0] @ mul weight 1, " - "out0\n" // q7 * w10 - - "vmla.f32 q4, q14, %e[wr2][1] @ mul weight 2, " - "out0\n" // q14 * w20 - "vmla.f32 q5, q15, %f[wr2][0] @ mul weight 2, " - "out0\n" // q15 * w21 - "vmla.f32 q3, q8, %e[wr2][0] @ mul weight 2, " - "out0\n" // q8 * w22 - - "vadd.f32 q3, q3, q4 @ add \n" - "vadd.f32 q3, q3, q5 @ add \n" - - "vst1.32 {d6-d7}, [%[out]] \n" - : [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr), - [mask_ptr] "+r"(mask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias] "r"(bias_c), - [out] "r"(out_buf) - : "cc", - "memory", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -#endif // __aarch64__ - for (int w = 0; w < w_out; ++w) { - *dout_channel++ = out_buf[w]; - } - hs += 2; - he += 2; - } - } - } -} -/** - * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, - * width <= 4 - */ -void conv_depthwise_3x3s1p1_bias_s_relu(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx) { - //! 3x3s1 convolution, implemented by direct algorithm - //! pad is done implicit - //! for 4x6 convolution window - const int right_pad_idx[4] = {3, 2, 1, 0}; - const float zero[4] = {0.f, 0.f, 0.f, 0.f}; - - float32x4_t vzero = vdupq_n_f32(0.f); - uint32x4_t vmask_rp = - vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(4 - w_in)); - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * ch_in * size_in_channel; - float* dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int i = 0; i < ch_in; ++i) { - float* dout_channel = dout_batch + i * size_out_channel; - const float* din_channel = din_batch + i * size_in_channel; - const float* weight_ptr = weights + i * 9; - float32x4_t wr0 = vld1q_f32(weight_ptr); - float32x4_t wr1 = vld1q_f32(weight_ptr + 3); - float32x4_t wr2 = vld1q_f32(weight_ptr + 6); - float32x4_t wbias; - if (flag_bias) { - wbias = vdupq_n_f32(bias[i]); - } else { - wbias = vdupq_n_f32(0.f); - } - - int hs = -1; - int he = 3; - - float out_buf1[4]; - float out_buf2[4]; - float trash_buf[4]; - - int h_cnt = (h_out + 1) >> 1; - float* doutr0 = dout_channel; - float* doutr1 = dout_channel + w_out; - - for (int j = 0; j < h_cnt; ++j) { - const float* dr0 = din_channel + hs * w_in; - const float* dr1 = dr0 + w_in; - const float* dr2 = dr1 + w_in; - const float* dr3 = dr2 + w_in; - - if (hs == -1) { - dr0 = zero; - } - - switch (he - h_in) { - case 2: - dr2 = zero; - doutr1 = trash_buf; - case 1: - dr3 = zero; - default: - break; - } -#ifdef __aarch64__ - asm volatile( - "prfm pldl1keep, [%[din0]]\n" - "prfm pldl1keep, [%[din1]]\n" - "prfm pldl1keep, [%[din2]]\n" - "prfm pldl1keep, [%[din3]]\n" - - "ld1 {v0.4s}, [%[din0]], #16\n" - "ld1 {v1.4s}, [%[din1]], #16\n" - "ld1 {v2.4s}, [%[din2]], #16\n" - "ld1 {v3.4s}, [%[din3]], #16\n" - - "bif v0.16b, %[zero].16b, %[mask].16b\n" // d0_1234 - "bif v1.16b, %[zero].16b, %[mask].16b\n" // d1_1234 - "bif v2.16b, %[zero].16b, %[mask].16b\n" // d2_1234 - "bif v3.16b, %[zero].16b, %[mask].16b\n" // d3_1234 - - "ext v4.16b, %[zero].16b, v0.16b, #12\n" // d0_0123 - "ext v5.16b, %[zero].16b, v1.16b, #12\n" // d1_0123 - "ext v6.16b, %[zero].16b, v2.16b, #12\n" // d2_0123 - "ext v7.16b, %[zero].16b, v3.16b, #12\n" // d3_0123 - - "ext v8.16b, v0.16b, %[zero].16b, #4\n" // d0_2340 - "ext v9.16b, v1.16b, %[zero].16b, #4\n" // d1_2340 - "ext v10.16b, v2.16b, %[zero].16b, #4\n" // d2_2340 - "ext v11.16b, v3.16b, %[zero].16b, #4\n" // d3_2340 - - "fmul v12.4s, v0.4s, %[wr0].s[1]\n" - "fmul v13.4s, v1.4s, %[wr0].s[1]\n" - - "fmul v14.4s, v1.4s, %[wr1].s[1]\n" - "fmul v15.4s, v2.4s, %[wr1].s[1]\n" - - "fmul v16.4s, v2.4s, %[wr2].s[1]\n" - "fmul v17.4s, v3.4s, %[wr2].s[1]\n" - - "fmla v12.4s, v4.4s, %[wr0].s[0]\n" - "fmla v13.4s, v5.4s, %[wr0].s[0]\n" - - "fmla v14.4s, v5.4s, %[wr1].s[0]\n" - "fmla v15.4s, v6.4s, %[wr1].s[0]\n" - - "fmla v16.4s, v6.4s, %[wr2].s[0]\n" - "fmla v17.4s, v7.4s, %[wr2].s[0]\n" - - "fmla v12.4s, v8.4s, %[wr0].s[2]\n" - "fmla v13.4s, v9.4s, %[wr0].s[2]\n" - - "fmla v14.4s, v9.4s, %[wr1].s[2]\n" - "fmla v15.4s, v10.4s, %[wr1].s[2]\n" - - "fmla v16.4s, v10.4s, %[wr2].s[2]\n" - "fmla v17.4s, v11.4s, %[wr2].s[2]\n" - - "fadd v12.4s, v12.4s, v14.4s\n" - "fadd v12.4s, v12.4s, v16.4s\n" - - "fadd v13.4s, v13.4s, v15.4s\n" // out1 - "fadd v13.4s, v13.4s, v17.4s\n" // out2 - - "fadd v12.4s, v12.4s, %[bias].4s\n" // out1 add bias - "fadd v13.4s, v13.4s, %[bias].4s\n" // out2 add bias - - "prfm pldl1keep, [%[out1]]\n" - "prfm pldl1keep, [%[out2]]\n" - - "fmax v12.4s, v12.4s, %[zero].4s\n" // out1 -> relu - "fmax v13.4s, v13.4s, %[zero].4s\n" // out2 -> relu - - "st1 {v12.4s}, [%[out1]]\n" - "st1 {v13.4s}, [%[out2]]\n" - - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [zero] "w"(vzero), - [mask] "w"(vmask_rp), - [bias] "w"(wbias), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17"); -#else - asm volatile( - "pld [%[din0]]\n" - "pld [%[din1]]\n" - "pld [%[din2]]\n" - "pld [%[din3]]\n" - - "vld1.32 {d12-d13}, [%[din0]]!\n" - "vld1.32 {d14-d15}, [%[din1]]!\n" - "vld1.32 {d16-d17}, [%[din2]]!\n" - "vld1.32 {d18-d19}, [%[din3]]!\n" - - "vbif q6, %q[zero], %q[mask]\n" // d0_1234 - "vbif q7, %q[zero], %q[mask]\n" // d1_1234 - "vbif q8, %q[zero], %q[mask]\n" // d2_1234 - "vbif q9, %q[zero], %q[mask]\n" // d3_1234 - - "vmul.f32 q14, q6, %e[wr0][1]\n" - "vmul.f32 q15, q7, %e[wr0][1]\n" - - "vmla.f32 q14, q7, %e[wr1][1]\n" - "vmla.f32 q15, q8, %e[wr1][1]\n" - - "vmla.f32 q14, q8, %e[wr2][1]\n" - "vmla.f32 q15, q9, %e[wr2][1]\n" - - "vext.32 q10, %q[zero], q6, #3\n" // d0_0123 - "vext.32 q11, %q[zero], q7, #3\n" // d1_0123 - "vext.32 q12, %q[zero], q8, #3\n" // d2_0123 - "vext.32 q13, %q[zero], q9, #3\n" // d3_0123 - - "vmla.f32 q14, q10, %e[wr0][0]\n" - "vmla.f32 q15, q11, %e[wr0][0]\n" - - "vmla.f32 q14, q11, %e[wr1][0]\n" - "vmla.f32 q15, q12, %e[wr1][0]\n" - - "vmla.f32 q14, q12, %e[wr2][0]\n" - "vmla.f32 q15, q13, %e[wr2][0]\n" - - "vext.32 q10, q6, %q[zero], #1\n" // d0_2340 - "vext.32 q11, q7, %q[zero], #1\n" // d1_2340 - "vext.32 q12, q8, %q[zero], #1\n" // d2_2340 - "vext.32 q13, q9, %q[zero], #1\n" // d3_2340 - - "vmla.f32 q14, q10, %f[wr0][0]\n" - "vmla.f32 q15, q11, %f[wr0][0]\n" - - "vmla.f32 q14, q11, %f[wr1][0]\n" - "vmla.f32 q15, q12, %f[wr1][0]\n" - - "vmla.f32 q14, q12, %f[wr2][0]\n" // out1 - "vmla.f32 q15, q13, %f[wr2][0]\n" // out2 - - "vadd.f32 q14, q14, %q[bias]\n" // out1 add bias - "vadd.f32 q15, q15, %q[bias]\n" // out2 add bias - - "pld [%[out1]]\n" - "pld [%[out2]]\n" - - "vmax.f32 q14, q14, %q[zero]\n" // out1 -> relu - "vmax.f32 q15, q15, %q[zero]\n" // out2 -> relu - - "vst1.32 {d28-d29}, [%[out1]]\n" - "vst1.32 {d30-d31}, [%[out2]]\n" - - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [zero] "w"(vzero), - [mask] "w"(vmask_rp), - [bias] "w"(wbias), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "cc", - "memory", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -#endif // __aarch64__ - for (int w = 0; w < w_out; ++w) { - *doutr0++ = out_buf1[w]; - *doutr1++ = out_buf2[w]; - } - doutr0 = doutr1; - doutr1 += w_out; - hs += 2; - he += 2; - } // end of processing heights - } // end of processing channels - } // end of processing batchs -} - -/** - * \brief depthwise convolution kernel 3x3, stride 2, width <= 7 - */ -void conv_depthwise_3x3s2p1_bias_s_relu(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx) { - int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; - int out_pad_idx[4] = {0, 1, 2, 3}; - float zeros[8] = {0.0f}; - - uint32x4_t vmask_rp1 = - vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx)); // 0 2 4 6 - uint32x4_t vmask_rp2 = - vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx + 4)); // 1 3 5 7 - - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - - unsigned int dmask[8]; - vst1q_u32(dmask, vmask_rp1); - vst1q_u32(dmask + 4, vmask_rp2); - - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * ch_in * size_in_channel; - float* dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int i = 0; i < ch_in; ++i) { - const float* din_channel = din_batch + i * size_in_channel; - float* dout_channel = dout_batch + i * size_out_channel; - - const float* weight_ptr = weights + i * 9; - float32x4_t wr0 = vld1q_f32(weight_ptr); - float32x4_t wr1 = vld1q_f32(weight_ptr + 3); - float32x4_t wr2 = vld1q_f32(weight_ptr + 6); - - float bias_c = 0.f; - - if (flag_bias) { - bias_c = bias[i]; - } - float32x4_t vbias = vdupq_n_f32(bias_c); - int hs = -1; - int he = 2; - float out_buf[4]; - for (int j = 0; j < h_out; ++j) { - const float* dr0 = din_channel + hs * w_in; - const float* dr1 = dr0 + w_in; - const float* dr2 = dr1 + w_in; - if (hs == -1) { - dr0 = zeros; - } - if (he > h_in) { - dr2 = zeros; - } - const float* din0_ptr = dr0; - const float* din1_ptr = dr1; - const float* din2_ptr = dr2; - - unsigned int* mask_ptr = dmask; -#ifdef __aarch64__ - asm volatile( - // Load up 12 elements (3 vectors) from each of 8 sources. - "movi v9.4s, #0 \n" - "ld1 {v6.4s, v7.4s}, [%[mask_ptr]], #32 \n" - - "ld2 {v10.4s, v11.4s}, [%[din0_ptr]], #32 \n" // v10={0,2,4,6} - // v11={1,3,5,7} - "ld2 {v12.4s, v13.4s}, [%[din1_ptr]], #32 \n" // v13={0,2,4,6} - // v12={1,3,5,7} - "ld2 {v14.4s, v15.4s}, [%[din2_ptr]], #32 \n" // v14={0,2,4,6} - // v15={1,3,5,7} - - "bif v10.16b, v9.16b, v6.16b \n" - "bif v11.16b, v9.16b, v7.16b \n" - "bif v12.16b, v9.16b, v6.16b \n" - "bif v13.16b, v9.16b, v7.16b \n" - "bif v14.16b, v9.16b, v6.16b \n" - "bif v15.16b, v9.16b, v7.16b \n" - - "ext v6.16b, v9.16b, v11.16b, #12 \n" // v6 = - // {0,1,3,5} - "ext v7.16b, v9.16b, v13.16b, #12 \n" // v7 = - // {0,1,3,5} - "ext v8.16b, v9.16b, v15.16b, #12 \n" // v8 = - // {0,1,3,5} - - "fmul v4.4s, v10.4s, %[wr0].s[1] \n" // v10 * w01 - "fmul v5.4s, v11.4s, %[wr0].s[2] \n" // v11 * w02 - "fmul v6.4s, v6.4s, %[wr0].s[0] \n" // v6 * w00 - - "fmla v4.4s, v12.4s, %[wr1].s[1] \n" // v12 * w11 - "fmla v5.4s, v13.4s, %[wr1].s[2] \n" // v13 * w12 - "fmla v6.4s, v7.4s, %[wr1].s[0] \n" // v7 * w10 - - "fmla v4.4s, v14.4s, %[wr2].s[1] \n" // v14 * w20 - "fmla v5.4s, v15.4s, %[wr2].s[2] \n" // v15 * w21 - "fmla v6.4s, v8.4s, %[wr2].s[0] \n" // v8 * w22 - - "fadd v4.4s, v4.4s, v5.4s \n" - "fadd v4.4s, v4.4s, v6.4s \n" - - "fadd v4.4s, v4.4s, %[bias].4s \n" // out add bias - "fmax v4.4s, v4.4s, v9.4s \n" - - "st1 {v4.4s}, [%[out]] \n" - : [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr), - [mask_ptr] "+r"(mask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias] "w"(vbias), - [out] "r"(out_buf) - : "cc", - "memory", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15"); - -#else - asm volatile( - // Load up 12 elements (3 vectors) from each of 8 sources. - "vmov.u32 q9, #0 \n" - "vld1.f32 {d12-d15}, [%[mask_ptr]]! @ load mask\n" - "vdup.32 q3, %[bias] @ and \n" // q3 = - // vbias - - "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" // q10={0,2,4,6} q11={1,3,5,7} - "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" // q13={0,2,4,6} q12={1,3,5,7} - "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2\n" // q14={0,2,4,6} q15={1,3,5,7} - - "vbif q10, q9, q6 @ bit select, deal " - "with right pad\n" - "vbif q11, q9, q7 @ bit select, deal " - "with right pad\n" - "vbif q12, q9, q6 @ bit select, deal " - "with right pad\n" - "vbif q13, q9, q7 @ bit select, deal " - "with right pad\n" - "vbif q14, q9, q6 @ bit select, deal " - "with right pad\n" - "vbif q15, q9, q7 @ bit select, deal " - "with right pad\n" - - "vext.32 q6, q9, q11, #3 @ shift left 1 \n" // q6 = {0,1,3,5} - "vext.32 q7, q9, q13, #3 @ shift left 1 \n" // q7 = {0,1,3,5} - "vext.32 q8, q9, q15, #3 @ shift left 1 \n" // q8 = {0,1,3,5} - - "vmul.f32 q4, q10, %e[wr0][1] @ mul weight 0, " - "out0\n" // q10 * w01 - "vmul.f32 q5, q11, %f[wr0][0] @ mul weight 0, " - "out0\n" // q11 * w02 - "vmla.f32 q3, q6, %e[wr0][0] @ mul weight 0, " - "out0\n" // q6 * w00 - - "vmla.f32 q4, q12, %e[wr1][1] @ mul weight 1, " - "out0\n" // q12 * w11 - "vmla.f32 q5, q13, %f[wr1][0] @ mul weight 1, " - "out0\n" // q13 * w12 - "vmla.f32 q3, q7, %e[wr1][0] @ mul weight 1, " - "out0\n" // q7 * w10 - - "vmla.f32 q4, q14, %e[wr2][1] @ mul weight 2, " - "out0\n" // q14 * w20 - "vmla.f32 q5, q15, %f[wr2][0] @ mul weight 2, " - "out0\n" // q15 * w21 - "vmla.f32 q3, q8, %e[wr2][0] @ mul weight 2, " - "out0\n" // q8 * w22 - - "vadd.f32 q3, q3, q4 @ add \n" - "vadd.f32 q3, q3, q5 @ add \n" - - "vmax.f32 q3, q3, q9 @ relu\n" - - "vst1.32 {d6-d7}, [%[out]] \n" - : [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr), - [mask_ptr] "+r"(mask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias] "r"(bias_c), - [out] "r"(out_buf) - : "cc", - "memory", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -#endif // __aarch64__ - for (int w = 0; w < w_out; ++w) { - *dout_channel++ = out_buf[w]; - } - hs += 2; - he += 2; - } - } - } -} - -} // namespace math -} // namespace arm -} // namespace lite -} // namespace paddle diff --git a/lite/backends/arm/math/conv_depthwise_3x3s1.cc b/lite/backends/arm/math/conv_depthwise_3x3s1.cc deleted file mode 100644 index 8d0ebb58ad1b7e325bae3649b13914641021038f..0000000000000000000000000000000000000000 --- a/lite/backends/arm/math/conv_depthwise_3x3s1.cc +++ /dev/null @@ -1,2539 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "lite/backends/arm/math/conv_depthwise.h" -#include - -namespace paddle { -namespace lite { -namespace arm { -namespace math { - -void conv_depthwise_3x3s1p0_bias(float *dout, - const float *din, - const float *weights, - const float *bias, - bool flag_bias, - bool flag_relu, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext *ctx); - -void conv_depthwise_3x3s1p0_bias_s(float *dout, - const float *din, - const float *weights, - const float *bias, - bool flag_bias, - bool flag_relu, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext *ctx); - -void conv_depthwise_3x3s1p1_bias(float *dout, - const float *din, - const float *weights, - const float *bias, - bool flag_bias, - bool flag_relu, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext *ctx); - -void conv_depthwise_3x3s1p1_bias_s(float *dout, - const float *din, - const float *weights, - const float *bias, - bool flag_bias, - bool flag_relu, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext *ctx); - -void conv_depthwise_3x3s1_fp32(const float *din, - float *dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float *weights, - const float *bias, - int pad, - bool flag_bias, - bool flag_relu, - ARMContext *ctx) { - if (pad == 0) { - if (w_in > 5) { - conv_depthwise_3x3s1p0_bias(dout, - din, - weights, - bias, - flag_bias, - flag_relu, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } else { - conv_depthwise_3x3s1p0_bias_s(dout, - din, - weights, - bias, - flag_bias, - flag_relu, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } - } - if (pad == 1) { - if (w_in > 4) { - conv_depthwise_3x3s1p1_bias(dout, - din, - weights, - bias, - flag_bias, - flag_relu, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } else { - conv_depthwise_3x3s1p1_bias_s(dout, - din, - weights, - bias, - flag_bias, - flag_relu, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } - } -} - -#ifdef __aarch64__ -#define INIT_S1 \ - "PRFM PLDL1KEEP, [%[din_ptr0]] \n" \ - "PRFM PLDL1KEEP, [%[din_ptr1]] \n" \ - "PRFM PLDL1KEEP, [%[din_ptr2]] \n" \ - "PRFM PLDL1KEEP, [%[din_ptr3]] \n" \ - "PRFM PLDL1KEEP, [%[din_ptr4]] \n" \ - "PRFM PLDL1KEEP, [%[din_ptr5]] \n" \ - "movi v21.4s, #0x0\n" /* out0 = 0 */ \ - \ - "ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - -#define LEFT_COMPUTE_S1 \ - "ext v16.16b, %[vzero].16b, v0.16b, #12 \n" /* v16 = 00123*/ \ - "ext v17.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */ /* r0 */ \ - "fmla v12.4s, v0.4s, %[w0].s[1]\n" /* outr00 += din0_0123 * w0[1]*/ \ - \ - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - "sub %[din_ptr0], %[din_ptr0], #4 \n" /* din_ptr0-- */ \ - "sub %[din_ptr1], %[din_ptr1], #4 \n" /* din_ptr0-- */ \ - \ - "fmla v12.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din0_0012 * w0[0]*/ \ - \ - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \ - "sub %[din_ptr2], %[din_ptr2], #4 \n" /* din_ptr0-- */ \ - "sub %[din_ptr3], %[din_ptr3], #4 \n" /* din_ptr0-- */ \ - \ - "fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_1234 * w0[2]*/ \ - \ - "ext v16.16b, %[vzero].16b, v2.16b, #12 \n" /* v16 = 00123*/ \ - "ext v17.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234 */ /* r1 */ \ - "fmla v13.4s , v2.4s, %[w0].s[1]\n" /* outr00 += din1_0123 * w0[1]*/ \ - "fmla v12.4s , v2.4s, %[w1].s[1]\n" /* outr00 += din1_0123 * w1[1]*/ \ - "sub %[din_ptr4], %[din_ptr4], #4 \n" /* din_ptr0-- */ \ - "sub %[din_ptr5], %[din_ptr5], #4 \n" /* din_ptr0-- */ \ - \ - "fmla v13.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din1_0123 * w0[1]*/ \ - "fmla v12.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din1_0123 * w1[1]*/ \ - \ - "ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ - "fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * w1[1]*/ \ - \ - "ext v17.16b, v4.16b, v5.16b, #4 \n" /* v16=1234 */ \ - "ext v16.16b, %[vzero].16b, v4.16b, #12 \n" /* v16 = 00123*/ \ - \ - /* r2 */ \ - "fmla v14.4s , v4.4s, %[w0].s[1]\n" /* outr00 += din2_0123 * w0[1]*/ \ - "fmla v13.4s , v4.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \ - "fmla v12.4s , v4.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w2[1]*/ \ - \ - "ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v14.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ - "fmla v13.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ - "fmla v12.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w1[1]*/ \ - \ - "ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ - "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ - "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w1[1]*/ \ - \ - "ext v16.16b, %[vzero].16b, v6.16b, #12 \n" /* v16 = 00123*/ \ - "ext v17.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234 */ /* r3 */ \ - "fmla v15.4s , v6.4s, %[w0].s[1]\n" /*outr00 += din2_0123 * w0[1]*/ \ - "fmla v14.4s , v6.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \ - "fmla v13.4s , v6.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w2[1]*/ \ - \ - "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v15.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ - "fmla v14.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ - "fmla v13.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w1[1]*/ \ - \ - "ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ - "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ - "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w1[1]*/ \ - \ - "ext v16.16b, %[vzero].16b, v8.16b, #12 \n" /* v16 = 00123*/ \ - "ext v17.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234 */ - -#define LEFT_RESULT_S1 \ - /* r4 */ \ - "fmla v15.4s , v8.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \ - "fmla v14.4s , v8.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w2[1]*/ \ - \ - "st1 {v12.4s}, [%[doutr0]], #16 \n" /* vst1q_f32() */ \ - "st1 {v13.4s}, [%[doutr1]], #16 \n" /* vst1q_f32() */ \ - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v15.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ - "fmla v14.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w1[1]*/ \ - \ - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - \ - "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ - "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w1[1]*/ \ - \ - "ext v16.16b, %[vzero].16b, v10.16b, #12 \n" /* v16 = 00123*/ \ - "ext v17.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234 */ /* r5 */ \ - "fmla v15.4s , v10.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \ - \ - "st1 {v14.4s}, [%[doutr2]], #16 \n" /* vst1q_f32() */ \ - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v15.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ - \ - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - \ - "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ - \ - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \ - \ - "st1 {v15.4s}, [%[doutr3]], #16 \n" /* vst1q_f32() */ \ - "cmp %w[cnt], #1 \n" \ - "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - \ - "blt 3f \n" - -#define MID_COMPUTE_S1 \ - "1: \n" /* r0 */ \ - "fmla v12.4s , v0.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v12.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v2.16b, v3.16b, #8 \n" /* v16 = 2345 */ /* r1 */ \ - "fmla v13.4s , v2.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v12.4s , v2.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v13.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v12.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v4.16b, v5.16b, #8 \n" /* v16 = 2345 */ /* r2 */ \ - "fmla v14.4s , v4.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v13.4s , v4.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v12.4s , v4.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v14.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v13.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v12.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */ - -#define MID_RESULT_S1 \ - /* r3 */ \ - "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - "st1 {v12.4s}, [%[doutr0]], #16 \n" \ - \ - "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - \ - "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ - "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - "st1 {v13.4s}, [%[doutr1]], #16 \n" \ - \ - "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - \ - "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ - "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - "st1 {v14.4s}, [%[doutr2]], #16 \n" \ - \ - "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - \ - "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \ - \ - "subs %w[cnt], %w[cnt], #1 \n" \ - \ - "st1 {v15.4s}, [%[doutr3]], #16 \n" \ - "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - \ - "bne 1b \n" - -#define RIGHT_COMPUTE_S1 \ - "3: \n" \ - "ld1 {v18.4s, v19.4s}, [%[vmask]] \n" \ - "ld1 {v22.4s}, [%[doutr0]] \n" \ - "ld1 {v23.4s}, [%[doutr1]] \n" \ - "ld1 {v24.4s}, [%[doutr2]] \n" \ - "ld1 {v25.4s}, [%[doutr3]] \n" \ - \ - "bif v0.16b, %[vzero].16b, v18.16b \n" \ - "bif v1.16b, %[vzero].16b, v19.16b \n" \ - "bif v2.16b, %[vzero].16b, v18.16b \n" \ - "bif v3.16b, %[vzero].16b, v19.16b \n" \ - \ - "bif v4.16b, %[vzero].16b, v18.16b \n" \ - "bif v5.16b, %[vzero].16b, v19.16b \n" \ - "bif v6.16b, %[vzero].16b, v18.16b \n" \ - "bif v7.16b, %[vzero].16b, v19.16b \n" \ - \ - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ /* r0 */ \ - "fmla v12.4s, v0.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "bif v8.16b, %[vzero].16b, v18.16b \n" \ - "bif v9.16b, %[vzero].16b, v19.16b \n" \ - "bif v10.16b, %[vzero].16b, v18.16b \n" \ - "bif v11.16b, %[vzero].16b, v19.16b \n" \ - \ - "fmla v12.4s, v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "ld1 {v18.4s}, [%[rmask]] \n" \ - \ - "fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v2.16b, v3.16b, #8 \n" /* v16 = 2345 */ /* r1 */ \ - "fmla v13.4s , v2.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v12.4s , v2.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "fmla v13.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v12.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v4.16b, v5.16b, #8 \n" /* v16 = 2345 */ /* r2 */ \ - "fmla v14.4s , v4.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v13.4s , v4.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v12.4s , v4.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "fmla v14.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v13.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v12.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */ - -#define RIGHT_RESULT_S1 \ - /* r3 */ \ - "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "bif v12.16b, v22.16b, v18.16b \n" \ - \ - "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "st1 {v12.4s}, [%[doutr0]], #16 \n" \ - \ - "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ - "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "bif v13.16b, v23.16b, v18.16b \n" \ - \ - "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "st1 {v13.4s}, [%[doutr1]], #16 \n" \ - \ - "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ - "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "bif v14.16b, v24.16b, v18.16b \n" \ - \ - "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "st1 {v14.4s}, [%[doutr2]], #16 \n" \ - \ - "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "bif v15.16b, v25.16b, v18.16b \n" \ - \ - "st1 {v15.4s}, [%[doutr3]], #16 \n" - -#define LEFT_RESULT_S1_RELU \ - /* r4 */ \ - "fmla v15.4s , v8.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \ - "fmla v14.4s , v8.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w2[1]*/ \ - \ - "fmax v12.4s, v12.4s, %[vzero].4s \n" /*relu*/ \ - "fmax v13.4s, v13.4s, %[vzero].4s \n" /*relu*/ \ - \ - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v15.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ - "fmla v14.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w1[1]*/ \ - \ - "st1 {v12.4s}, [%[doutr0]], #16 \n" /* vst1q_f32() */ \ - "st1 {v13.4s}, [%[doutr1]], #16 \n" /* vst1q_f32() */ \ - \ - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ - "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w1[1]*/ \ - \ - "ext v16.16b, %[vzero].16b, v10.16b, #12 \n" /* v16 = 00123*/ \ - "ext v17.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234 */ \ - "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ /* r5*/ \ - "fmla v15.4s , v10.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \ - \ - "fmax v14.4s, v14.4s, %[vzero].4s \n" /*relu*/ \ - \ - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v15.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ - \ - "st1 {v14.4s}, [%[doutr2]], #16 \n" /* vst1q_f32() */ \ - \ - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ - \ - "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - \ - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \ - \ - "fmax v15.4s, v15.4s, %[vzero].4s \n" /*relu*/ \ - \ - "st1 {v15.4s}, [%[doutr3]], #16 \n" /* vst1q_f32() */ \ - "cmp %w[cnt], #1 \n" \ - "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - "blt 3f \n" - -#define MID_RESULT_S1_RELU \ - /* r3 */ \ - "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - "fmax v12.4s, v12.4s, %[vzero].4s \n" /*relu*/ \ - \ - "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "st1 {v12.4s}, [%[doutr0]], #16 \n" \ - \ - "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - \ - "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ - "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - "fmax v13.4s, v13.4s, %[vzero].4s \n" /*relu*/ \ - \ - "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "st1 {v13.4s}, [%[doutr1]], #16 \n" \ - \ - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - \ - "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ \ - \ - /* r3 */ \ - "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - "fmax v14.4s, v14.4s, %[vzero].4s \n" /*relu*/ \ - \ - "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "st1 {v14.4s}, [%[doutr2]], #16 \n" \ - \ - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - \ - "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \ - \ - "subs %w[cnt], %w[cnt], #1 \n" \ - \ - "fmax v15.4s, v15.4s, %[vzero].4s \n" /*relu*/ \ - \ - "st1 {v15.4s}, [%[doutr3]], #16 \n" \ - "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - \ - "bne 1b \n" - -#define RIGHT_RESULT_S1_RELU \ - /* r3 */ \ - "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "fmax v12.4s, v12.4s, %[vzero].4s \n" /*relu*/ \ - \ - "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "bif v12.16b, v22.16b, v18.16b \n" \ - \ - "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ - "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "st1 {v12.4s}, [%[doutr0]], #16 \n" \ - "fmax v13.4s, v13.4s, %[vzero].4s \n" /*relu*/ \ - \ - "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "bif v13.16b, v23.16b, v18.16b \n" \ - \ - "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ \ - \ - "st1 {v13.4s}, [%[doutr1]], #16 \n" /* r3 */ \ - "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "fmax v14.4s, v14.4s, %[vzero].4s \n" /*relu*/ \ - \ - "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "bif v14.16b, v24.16b, v18.16b \n" \ - \ - "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "st1 {v14.4s}, [%[doutr2]], #16 \n" \ - \ - "fmax v15.4s, v15.4s, %[vzero].4s \n" /*relu*/ \ - \ - "bif v15.16b, v25.16b, v18.16b \n" \ - \ - "st1 {v15.4s}, [%[doutr3]], #16 \n" - -#define COMPUTE_S_S1 \ - "prfm pldl1keep, [%[din0]]\n" \ - "prfm pldl1keep, [%[din1]]\n" \ - "prfm pldl1keep, [%[din2]]\n" \ - "prfm pldl1keep, [%[din3]]\n" \ - \ - "ld1 {v0.4s}, [%[din0]], #16\n" \ - "ld1 {v1.4s}, [%[din1]], #16\n" \ - "ld1 {v2.4s}, [%[din2]], #16\n" \ - "ld1 {v3.4s}, [%[din3]], #16\n" \ - \ - "bif v0.16b, %[zero].16b, %[mask].16b\n" \ - "bif v1.16b, %[zero].16b, %[mask].16b\n" \ - "bif v2.16b, %[zero].16b, %[mask].16b\n" \ - "bif v3.16b, %[zero].16b, %[mask].16b\n" \ - \ - "ext v4.16b, %[zero].16b, v0.16b, #12\n" \ - "ext v5.16b, %[zero].16b, v1.16b, #12\n" \ - "ext v6.16b, %[zero].16b, v2.16b, #12\n" \ - "ext v7.16b, %[zero].16b, v3.16b, #12\n" \ - \ - "ext v8.16b, v0.16b, %[zero].16b, #4\n" \ - "ext v9.16b, v1.16b, %[zero].16b, #4\n" \ - "ext v10.16b, v2.16b, %[zero].16b, #4\n" \ - "ext v11.16b, v3.16b, %[zero].16b, #4\n" \ - \ - "fmul v12.4s, v0.4s, %[wr0].s[1]\n" \ - "fmul v13.4s, v1.4s, %[wr0].s[1]\n" \ - \ - "fmul v14.4s, v1.4s, %[wr1].s[1]\n" \ - "fmul v15.4s, v2.4s, %[wr1].s[1]\n" \ - \ - "fmul v16.4s, v2.4s, %[wr2].s[1]\n" \ - "fmul v17.4s, v3.4s, %[wr2].s[1]\n" \ - \ - "fmla v12.4s, v4.4s, %[wr0].s[0]\n" \ - "fmla v13.4s, v5.4s, %[wr0].s[0]\n" \ - \ - "fmla v14.4s, v5.4s, %[wr1].s[0]\n" \ - "fmla v15.4s, v6.4s, %[wr1].s[0]\n" \ - \ - "fmla v16.4s, v6.4s, %[wr2].s[0]\n" \ - "fmla v17.4s, v7.4s, %[wr2].s[0]\n" \ - \ - "fmla v12.4s, v8.4s, %[wr0].s[2]\n" \ - "fmla v13.4s, v9.4s, %[wr0].s[2]\n" \ - \ - "fmla v14.4s, v9.4s, %[wr1].s[2]\n" \ - "fmla v15.4s, v10.4s, %[wr1].s[2]\n" \ - \ - "fmla v16.4s, v10.4s, %[wr2].s[2]\n" \ - "fmla v17.4s, v11.4s, %[wr2].s[2]\n" \ - \ - "fadd v12.4s, v12.4s, v14.4s\n" \ - "fadd v12.4s, v12.4s, v16.4s\n" \ - \ - "fadd v13.4s, v13.4s, v15.4s\n" \ - "fadd v13.4s, v13.4s, v17.4s\n" \ - \ - "fadd v12.4s, v12.4s, %[bias].4s\n" \ - "fadd v13.4s, v13.4s, %[bias].4s\n" - -#define RESULT_S_S1 \ - "prfm pldl1keep, [%[out1]]\n" \ - "prfm pldl1keep, [%[out2]]\n" \ - \ - "st1 {v12.4s}, [%[out1]]\n" \ - "st1 {v13.4s}, [%[out2]]\n" - -#define RESULT_S_S1_RELU \ - "prfm pldl1keep, [%[out1]]\n" \ - "prfm pldl1keep, [%[out2]]\n" \ - \ - "fmax v12.4s, v12.4s, %[zero].4s\n" \ - "fmax v13.4s, v13.4s, %[zero].4s\n" \ - \ - "st1 {v12.4s}, [%[out1]]\n" \ - "st1 {v13.4s}, [%[out2]]\n" - -#define COMPUTE_S_S1_P0 \ - "prfm pldl1keep, [%[din0]]\n" \ - "prfm pldl1keep, [%[din1]]\n" \ - "prfm pldl1keep, [%[din2]]\n" \ - "prfm pldl1keep, [%[din3]]\n" \ - \ - "ld1 {v0.4s, v1.4s}, [%[din0]]\n" \ - "ld1 {v2.4s, v3.4s}, [%[din1]]\n" \ - "ld1 {v4.4s, v5.4s}, [%[din2]]\n" \ - "ld1 {v6.4s, v7.4s}, [%[din3]]\n" \ - \ - "bif v0.16b, %[zero].16b, %[mask1].16b\n" \ - "bif v1.16b, %[zero].16b, %[mask2].16b\n" \ - \ - "bif v2.16b, %[zero].16b, %[mask1].16b\n" \ - "bif v3.16b, %[zero].16b, %[mask2].16b\n" \ - \ - "bif v4.16b, %[zero].16b, %[mask1].16b\n" \ - "bif v5.16b, %[zero].16b, %[mask2].16b\n" \ - \ - "bif v6.16b, %[zero].16b, %[mask1].16b\n" \ - "bif v7.16b, %[zero].16b, %[mask2].16b\n" \ - \ - "ext v8.16b, v0.16b, v1.16b, #4\n" \ - "ext v9.16b, v0.16b, v1.16b, #8\n" \ - \ - "and v12.16b, %[vbias].16b, %[vbias].16b \n" \ - "and v13.16b, %[vbias].16b, %[vbias].16b \n" /* r0 */ \ - "fmul v10.4s, v0.4s, %[wr0].s[0]\n" \ - "fmul v11.4s, v8.4s, %[wr0].s[1]\n" \ - "fmla v12.4s, v9.4s, %[wr0].s[2]\n" \ - \ - "ext v8.16b, v2.16b, v3.16b, #4\n" \ - "ext v9.16b, v2.16b, v3.16b, #8\n" /* r1 */ \ - "fmul v14.4s, v2.4s, %[wr0].s[0]\n" \ - "fmla v10.4s, v2.4s, %[wr1].s[0]\n" \ - \ - "fmul v15.4s, v8.4s, %[wr0].s[1]\n" \ - "fmla v11.4s, v8.4s, %[wr1].s[1]\n" \ - \ - "fmla v13.4s, v9.4s, %[wr0].s[2]\n" \ - "fmla v12.4s, v9.4s, %[wr1].s[2]\n" \ - \ - "ext v8.16b, v4.16b, v5.16b, #4\n" \ - "ext v9.16b, v4.16b, v5.16b, #8\n" /* r2 */ \ - "fmla v14.4s, v4.4s, %[wr1].s[0]\n" \ - "fmla v10.4s, v4.4s, %[wr2].s[0]\n" \ - \ - "fmla v15.4s, v8.4s, %[wr1].s[1]\n" \ - "fmla v11.4s, v8.4s, %[wr2].s[1]\n" \ - \ - "fmla v13.4s, v9.4s, %[wr1].s[2]\n" \ - "fmla v12.4s, v9.4s, %[wr2].s[2]\n" \ - \ - "ext v8.16b, v6.16b, v7.16b, #4\n" \ - "ext v9.16b, v6.16b, v7.16b, #8\n" \ - \ - "fmla v14.4s, v6.4s, %[wr2].s[0]\n" \ - \ - "fmla v15.4s, v8.4s, %[wr2].s[1]\n" \ - \ - "fadd v12.4s, v12.4s, v10.4s\n" \ - \ - "fmla v13.4s, v9.4s, %[wr2].s[2]\n" \ - \ - "fadd v12.4s, v12.4s, v11.4s\n" \ - "fadd v13.4s, v13.4s, v14.4s\n" \ - "fadd v13.4s, v13.4s, v15.4s\n" // \ - // "prfm pldl1keep, [%[out1]]\n" \ - // "prfm pldl1keep, [%[out2]]\n" \ - // \ - // "st1 {v12.4s}, [%[out1]]\n" \ - // "st1 {v13.4s}, [%[out2]]\n" \ - - -#else -#define INIT_S1 \ - "pld [%[din0_ptr]] @ preload data\n" \ - "pld [%[din1_ptr]] @ preload data\n" \ - "pld [%[din2_ptr]] @ preload data\n" \ - "pld [%[din3_ptr]] @ preload data\n" \ - \ - "vld1.32 {d16-d18}, [%[din0_ptr]]! @ load din r0\n" \ - "vld1.32 {d20-d22}, [%[din1_ptr]]! @ load din r1\n" \ - "vld1.32 {d24-d26}, [%[din2_ptr]]! @ load din r2\n" \ - "vld1.32 {d28-d30}, [%[din3_ptr]]! @ load din r3\n" \ - \ - "vdup.32 q4, %[bias_val] @ and \n" \ - "vdup.32 q5, %[bias_val] @ and \n" - -#define LEFT_COMPUTE_S1 \ - "vext.32 q6, %q[vzero], q8, #3 @ 0012\n" \ - "vext.32 q7, q8, q9, #1 @ 1234\n" /* r0 */ \ - "vmla.f32 q4, q8, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "sub %[din0_ptr], #12 @ 1pad + 2 float data overlap\n" \ - "sub %[din1_ptr], #12 @ 1pad + 2 float data overlap\n" \ - "sub %[din2_ptr], #12 @ 1pad + 2 float data overlap\n" \ - "sub %[din3_ptr], #12 @ 1pad + 2 float data overlap\n" \ - \ - "vmla.f32 q4, q6, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" \ - \ - "pld [%[din0_ptr]] @ preload data\n" \ - "pld [%[din1_ptr]] @ preload data\n" \ - "pld [%[din2_ptr]] @ preload data\n" \ - "pld [%[din3_ptr]] @ preload data\n" \ - \ - "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 1234 * wr0[2]\n" \ - \ - "vext.32 q6, %q[vzero], q10, #3 @ 0012\n" \ - "vext.32 q7, q10, q11, #1 @ 1234\n" \ - \ - /* r1 */ \ - "vmla.f32 q5, q10, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" \ - "vmla.f32 q4, q10, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0\n" \ - "vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r0\n" \ - \ - "vmla.f32 q5, q6, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" \ - "vmla.f32 q4, q6, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" \ - \ - "vld1.32 {d18}, [%[din0_ptr]] @ load din r0\n" \ - "vld1.32 {d22}, [%[din1_ptr]] @ load din r0\n" \ - \ - "vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[2]\n" \ - "vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[2]\n" \ - \ - "vext.32 q6, %q[vzero], q12, #3 @ 0012\n" \ - "vext.32 q7, q12, q13, #1 @ 1234\n" \ - \ - /* r2 */ \ - "vmla.f32 q5, q12, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" \ - "vmla.f32 q4, q12, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r0\n" \ - \ - "vmla.f32 q5, q6, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" \ - "vmla.f32 q4, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]\n" \ - \ - "vld1.32 {d26}, [%[din2_ptr]] @ load din r0\n" \ - \ - "vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[2]\n" \ - "vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]\n" \ - \ - "vext.32 q6, %q[vzero], q14, #3 @ 0012\n" \ - "vext.32 q7, q14, q15, #1 @ 1234\n" - -#define LEFT_RESULT_S1 \ - /* r3 */ \ - "vmla.f32 q5, q14, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" \ - "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ - \ - "vmla.f32 q5, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]\n" \ - \ - "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" \ - "vdup.32 q4, %[bias_val] @ and \n" \ - \ - "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]\n" \ - \ - "vext.32 q6, q8, q9, #1 @ 1234\n" \ - "vext.32 q7, q8, q9, #2 @ 2345\n" \ - "cmp %[cnt], #1 @ check whether has mid cols\n" \ - \ - "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" \ - \ - "vdup.32 q5, %[bias_val] @ and \n" \ - "blt 3f @ jump to main loop start point\n" - -#define MID_COMPUTE_S1 \ - "1: @ right pad entry\n" /* r0 */ \ - "vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n" \ - \ - "pld [%[din0_ptr]] @ preload data\n" \ - "pld [%[din1_ptr]] @ preload data\n" \ - "pld [%[din2_ptr]] @ preload data\n" \ - "pld [%[din3_ptr]] @ preload data\n" \ - \ - "vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0\n" \ - \ - "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n" \ - \ - "vld1.32 {d18}, [%[din0_ptr]] @ load din r0\n" \ - \ - "vext.32 q6, q10, q11, #1 @ 1234\n" \ - "vext.32 q7, q10, q11, #2 @ 2345\n" /* r1 */ \ - "vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" \ - "vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n" \ - \ - "vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r0\n" \ - \ - "vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" \ - "vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vld1.32 {d22}, [%[din1_ptr]] @ load din r0\n" \ - \ - "vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n" \ - "vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" \ - \ - "vext.32 q6, q12, q13, #1 @ 1234\n" \ - "vext.32 q7, q12, q13, #2 @ 2345\n" /* r2 */ \ - "vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" \ - "vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n" \ - \ - "vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r0\n" \ - \ - "vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" \ - "vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vld1.32 {d26}, [%[din2_ptr]] @ load din r0\n" \ - \ - "vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" \ - "vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n" \ - \ - "vext.32 q6, q14, q15, #1 @ 1234\n" \ - "vext.32 q7, q14, q15, #2 @ 2345\n" - -#define MID_RESULT_S1 \ - /* r3 */ \ - "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \ - \ - "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" \ - "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ - \ - "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" \ - "vdup.32 q4, %[bias_val] @ and \n" \ - \ - "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \ - \ - "vext.32 q6, q8, q9, #1 @ 1234\n" \ - "vext.32 q7, q8, q9, #2 @ 2345\n" \ - \ - "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" \ - \ - "subs %[cnt], #1 @ loop count minus 1\n" \ - \ - "vdup.32 q5, %[bias_val] @ and \n" \ - \ - "bne 1b @ jump to main loop start point\n" - -#define RIGHT_COMPUTE_S1 \ - "3: @ right pad entry\n" \ - "vld1.32 {d19}, [%[vmask]]! @ load din r0\n" \ - "vld1.32 {d23}, [%[vmask]]! @ load din r0\n" \ - \ - "vld1.32 {d27}, [%[vmask]]! @ load din r0\n" \ - "vld1.32 {d31}, [%[vmask]]! @ load din r0\n" \ - \ - "vbif d16, %e[vzero], d19 @ bit select, deal with right pad\n" \ - "vbif d17, %e[vzero], d23 @ bit select, deal with right pad\n" \ - "vbif d18, %e[vzero], d27 @ bit select, deal with right pad\n" \ - \ - "vbif d20, %e[vzero], d19 @ bit select, deal with right pad\n" \ - "vbif d21, %e[vzero], d23 @ bit select, deal with right pad\n" \ - "vbif d22, %e[vzero], d27 @ bit select, deal with right pad\n" \ - \ - "vext.32 q6, q8, q9, #1 @ 1234\n" \ - "vext.32 q7, q8, q9, #2 @ 2345\n" /* r0 */ \ - "vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n" \ - \ - "vbif d24, %e[vzero], d19 @ bit select, deal with right pad\n" \ - "vbif d25, %e[vzero], d23 @ bit select, deal with right pad\n" \ - "vbif d26, %e[vzero], d27 @ bit select, deal with right pad\n" \ - \ - "vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vbif d28, %e[vzero], d19 @ bit select, deal with right pad\n" \ - "vbif d29, %e[vzero], d23 @ bit select, deal with right pad\n" \ - "vbif d30, %e[vzero], d27 @ bit select, deal with right pad\n" \ - \ - "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n" \ - \ - "vext.32 q6, q10, q11, #1 @ 1234\n" \ - "vext.32 q7, q10, q11, #2 @ 2345\n" /* r1 */ \ - "vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" \ - "vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n" \ - \ - "vld1.32 {d19}, [%[rmask]]! @ load din r0\n" \ - "vld1.32 {d23}, [%[rmask]]! @ load din r0\n" \ - \ - "vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" \ - "vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vld1.32 {d16-d17}, [%[dout_ptr1]] @ load din r0\n" \ - "vld1.32 {d20-d21}, [%[dout_ptr2]] @ load din r0\n" \ - \ - "vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n" \ - "vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" \ - \ - "vext.32 q6, q12, q13, #1 @ 1234\n" \ - "vext.32 q7, q12, q13, #2 @ 2345\n" /* r2 */ \ - "vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" \ - "vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n" \ - \ - "vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" \ - "vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" \ - "vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n" \ - \ - "vext.32 q6, q14, q15, #1 @ 1234\n" \ - "vext.32 q7, q14, q15, #2 @ 2345\n" - -#define RIGHT_RESULT_S1 \ - /* r3 */ \ - "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \ - \ - "vbif d8, d16, d19 @ bit select, deal with right pad\n" \ - "vbif d9, d17, d23 @ bit select, deal with right pad\n" \ - \ - "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ - \ - "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \ - \ - "vbif d10, d20, d19 @ bit select, deal with right pad\n" \ - "vbif d11, d21, d23 @ bit select, deal with right pad\n" \ - \ - "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" - -#define LEFT_RESULT_S1_RELU \ - /* r3 */ \ - "vmla.f32 q5, q14, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" \ - "vmax.f32 q4, q4, %q[vzero] @ relu \n" \ - \ - "vmla.f32 q5, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]\n" \ - \ - "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" \ - "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ - \ - "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]\n" \ - \ - "vext.32 q6, q8, q9, #1 @ 1234\n" \ - "vext.32 q7, q8, q9, #2 @ 2345\n" \ - "vdup.32 q4, %[bias_val] @ and \n" \ - \ - "vmax.f32 q5, q5, %q[vzero] @ relu \n" \ - \ - "cmp %[cnt], #1 @ check whether has mid cols\n" \ - \ - "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" \ - \ - "vdup.32 q5, %[bias_val] @ and \n" \ - "blt 3f @ jump to main loop start point\n" - -#define MID_RESULT_S1_RELU \ - /* r3 */ \ - "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \ - \ - "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" \ - "vmax.f32 q4, q4, %q[vzero] @ relu \n" \ - \ - "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" \ - "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ - \ - "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \ - \ - "vext.32 q6, q8, q9, #1 @ 1234\n" \ - "vext.32 q7, q8, q9, #2 @ 2345\n" \ - "vdup.32 q4, %[bias_val] @ and \n" \ - \ - "vmax.f32 q5, q5, %q[vzero] @ relu \n" \ - \ - "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" \ - \ - "subs %[cnt], #1 @ loop count minus 1\n" \ - \ - "vdup.32 q5, %[bias_val] @ and \n" \ - \ - "bne 1b @ jump to main loop start point\n" - -#define RIGHT_RESULT_S1_RELU \ - /* r3 */ \ - "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \ - \ - "vmax.f32 q4, q4, %q[vzero] @ relu \n" \ - \ - "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vbif d8, d16, d19 @ bit select, deal with right pad\n" \ - "vbif d9, d17, d23 @ bit select, deal with right pad\n" \ - \ - "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \ - "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ - \ - "vmax.f32 q5, q5, %q[vzero] @ relu \n" \ - \ - "vbif d10, d20, d19 @ bit select, deal with right pad\n" \ - "vbif d11, d21, d23 @ bit select, deal with right pad\n" \ - \ - "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" - -#define COMPUTE_S_S1 \ - "pld [%[din0]]\n" \ - "pld [%[din1]]\n" \ - "pld [%[din2]]\n" \ - "pld [%[din3]]\n" \ - \ - "vld1.32 {d12-d13}, [%[din0]]!\n" \ - "vld1.32 {d14-d15}, [%[din1]]!\n" \ - "vld1.32 {d16-d17}, [%[din2]]!\n" \ - "vld1.32 {d18-d19}, [%[din3]]!\n" \ - \ - "vbif q6, %q[vzero], %q[mask]\n" \ - "vbif q7, %q[vzero], %q[mask]\n" \ - "vbif q8, %q[vzero], %q[mask]\n" \ - "vbif q9, %q[vzero], %q[mask]\n" \ - \ - "vmul.f32 q14, q6, %e[wr0][1]\n" \ - "vmul.f32 q15, q7, %e[wr0][1]\n" \ - \ - "vmla.f32 q14, q7, %e[wr1][1]\n" \ - "vmla.f32 q15, q8, %e[wr1][1]\n" \ - \ - "vmla.f32 q14, q8, %e[wr2][1]\n" \ - "vmla.f32 q15, q9, %e[wr2][1]\n" \ - \ - "vext.32 q10, %q[vzero], q6, #3\n" \ - "vext.32 q11, %q[vzero], q7, #3\n" \ - "vext.32 q12, %q[vzero], q8, #3\n" \ - "vext.32 q13, %q[vzero], q9, #3\n" \ - \ - "vmla.f32 q14, q10, %e[wr0][0]\n" \ - "vmla.f32 q15, q11, %e[wr0][0]\n" \ - \ - "vmla.f32 q14, q11, %e[wr1][0]\n" \ - "vmla.f32 q15, q12, %e[wr1][0]\n" \ - \ - "vmla.f32 q14, q12, %e[wr2][0]\n" \ - "vmla.f32 q15, q13, %e[wr2][0]\n" \ - \ - "vext.32 q10, q6, %q[vzero], #1\n" \ - "vext.32 q11, q7, %q[vzero], #1\n" \ - "vext.32 q12, q8, %q[vzero], #1\n" \ - "vext.32 q13, q9, %q[vzero], #1\n" \ - \ - "vmla.f32 q14, q10, %f[wr0][0]\n" \ - "vmla.f32 q15, q11, %f[wr0][0]\n" \ - \ - "vmla.f32 q14, q11, %f[wr1][0]\n" \ - "vmla.f32 q15, q12, %f[wr1][0]\n" \ - \ - "vmla.f32 q14, q12, %f[wr2][0]\n" \ - "vmla.f32 q15, q13, %f[wr2][0]\n" \ - \ - "vadd.f32 q14, q14, %q[bias]\n" \ - "vadd.f32 q15, q15, %q[bias]\n" - -#define RESULT_S_S1 \ - "pld [%[out1]]\n" \ - "pld [%[out2]]\n" \ - \ - "vst1.32 {d28-d29}, [%[out1]]\n" \ - "vst1.32 {d30-d31}, [%[out2]]\n" - -#define RESULT_S_S1_RELU \ - "pld [%[out1]]\n" \ - "pld [%[out2]]\n" \ - \ - "vmax.f32 q14, q14, %q[vzero]\n" \ - "vmax.f32 q15, q15, %q[vzero]\n" \ - \ - "vst1.32 {d28-d29}, [%[out1]]\n" \ - "vst1.32 {d30-d31}, [%[out2]]\n" - -#define COMPUTE_S_S1_P0 \ - "pld [%[din0]]\n" \ - "pld [%[din1]]\n" \ - "pld [%[din2]]\n" \ - "pld [%[din3]]\n" \ - "vld1.32 {d16-d18}, [%[din0]] @ load din r0\n" \ - "vld1.32 {d20-d22}, [%[din1]] @ load din r1\n" \ - "vld1.32 {d24-d26}, [%[din2]] @ load din r2\n" \ - "vld1.32 {d28-d30}, [%[din3]] @ load din r3\n" \ - \ - "vdup.32 q4, %[bias_val] @ and \n" \ - "vdup.32 q5, %[bias_val] @ and \n" \ - \ - "vld1.32 {d19}, [%[vmask]]! @ load din r0\n" \ - "vld1.32 {d23}, [%[vmask]]! @ load din r0\n" \ - \ - "vld1.32 {d27}, [%[vmask]]! @ load din r0\n" \ - \ - "vbif d16, %e[vzero], d19 @ bit select, deal with right pad\n" \ - "vbif d20, %e[vzero], d19 @ bit select, deal with right pad\n" \ - \ - "vbif d17, %e[vzero], d23 @ bit select, deal with right pad\n" \ - "vbif d21, %e[vzero], d23 @ bit select, deal with right pad\n" \ - \ - "vbif d18, %e[vzero], d27 @ bit select, deal with right pad\n" \ - "vbif d22, %e[vzero], d27 @ bit select, deal with right pad\n" \ - \ - "vext.32 q6, q8, q9, #1 @ 1234\n" \ - "vext.32 q7, q8, q9, #2 @ 2345\n" /* r0 */ \ - "vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n" \ - \ - "vbif d24, %e[vzero], d19 @ bit select, deal with right pad\n" \ - "vbif d25, %e[vzero], d23 @ bit select, deal with right pad\n" \ - "vbif d26, %e[vzero], d27 @ bit select, deal with right pad\n" \ - \ - "vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vbif d28, %e[vzero], d19 @ bit select, deal with right pad\n" \ - "vbif d29, %e[vzero], d23 @ bit select, deal with right pad\n" \ - "vbif d30, %e[vzero], d27 @ bit select, deal with right pad\n" \ - \ - "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n" \ - \ - "vext.32 q6, q10, q11, #1 @ 1234\n" \ - "vext.32 q7, q10, q11, #2 @ 2345\n" /* r1 */ \ - "vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" \ - "vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n" \ - \ - "vmul.f32 q8, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" \ - "vmul.f32 q10, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vmul.f32 q9, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n" \ - "vmul.f32 q11, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" \ - \ - "vext.32 q6, q12, q13, #1 @ 1234\n" \ - "vext.32 q7, q12, q13, #2 @ 2345\n" /* r2 */ \ - "vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" \ - "vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n" \ - \ - "vmla.f32 q8, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" \ - "vmla.f32 q10, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vmla.f32 q9, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" \ - "vmla.f32 q11, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n" \ - \ - "vext.32 q6, q14, q15, #1 @ 1234\n" \ - "vext.32 q7, q14, q15, #2 @ 2345\n" /* r3 */ \ - "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \ - \ - "vmla.f32 q8, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ - "vadd.f32 q4, q4, q10 @ q4 += q10 \n" \ - \ - "pld [%[out1]]\n" \ - "pld [%[out2]]\n" \ - \ - "vmla.f32 q9, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \ - "vadd.f32 q14, q4, q11 @ q4 += q10 \n" \ - \ - "vadd.f32 q5, q5, q8 @ q4 += q10 \n" \ - "vadd.f32 q15, q5, q9 @ q4 += q10 \n" - -#endif -/** - * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, - * width > 4 - */ -void conv_depthwise_3x3s1p1_bias(float *dout, - const float *din, - const float *weights, - const float *bias, - bool flag_bias, - bool flag_relu, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext *ctx) { - //! pad is done implicit - const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; - //! for 4x6 convolution window - const unsigned int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0}; - - float *zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - float *write_ptr = zero_ptr + w_in; - - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - int w_stride = 9; - - int tile_w = (w_in + 3) >> 2; - int cnt_col = tile_w - 2; - - unsigned int size_pad_right = (unsigned int)(1 + (tile_w << 2) - w_in); - - uint32x4_t vmask_rp1 = - vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right)); - uint32x4_t vmask_rp2 = - vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right)); - uint32x4_t vmask_result = - vcgtq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right)); - - unsigned int vmask[8]; - vst1q_u32(vmask, vmask_rp1); - vst1q_u32(vmask + 4, vmask_rp2); - - unsigned int rmask[4]; - vst1q_u32(rmask, vmask_result); - - float32x4_t vzero = vdupq_n_f32(0.f); - - for (int n = 0; n < num; ++n) { - const float *din_batch = din + n * ch_in * size_in_channel; - float *dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int c = 0; c < ch_in; c++) { - float *dout_ptr = dout_batch + c * size_out_channel; - - const float *din_ch_ptr = din_batch + c * size_in_channel; - - float bias_val = flag_bias ? bias[c] : 0.f; - float vbias[4] = {bias_val, bias_val, bias_val, bias_val}; - - const float *wei_ptr = weights + c * w_stride; - - float32x4_t wr0 = vld1q_f32(wei_ptr); - float32x4_t wr1 = vld1q_f32(wei_ptr + 3); - float32x4_t wr2 = vld1q_f32(wei_ptr + 6); - - float *doutr0 = dout_ptr; - float *doutr1 = doutr0 + w_out; - float *doutr2 = doutr1 + w_out; - float *doutr3 = doutr2 + w_out; - - const float *dr0 = din_ch_ptr; - const float *dr1 = dr0 + w_in; - const float *dr2 = dr1 + w_in; - const float *dr3 = dr2 + w_in; - const float *dr4 = dr3 + w_in; - const float *dr5 = dr4 + w_in; - - const float *din_ptr0 = dr0; - const float *din_ptr1 = dr1; - const float *din_ptr2 = dr2; - const float *din_ptr3 = dr3; - const float *din_ptr4 = dr4; - const float *din_ptr5 = dr5; - float *ptr_zero = const_cast(zero); -#ifdef __aarch64__ - for (int i = 0; i < h_in; i += 4) { - //! process top pad pad_h = 1 - din_ptr0 = dr0; - din_ptr1 = dr1; - din_ptr2 = dr2; - din_ptr3 = dr3; - din_ptr4 = dr4; - din_ptr5 = dr5; - - doutr0 = dout_ptr; - doutr1 = doutr0 + w_out; - doutr2 = doutr1 + w_out; - doutr3 = doutr2 + w_out; - if (i == 0) { - din_ptr0 = zero_ptr; - din_ptr1 = dr0; - din_ptr2 = dr1; - din_ptr3 = dr2; - din_ptr4 = dr3; - din_ptr5 = dr4; - dr0 = dr3; - dr1 = dr4; - dr2 = dr5; - } else { - dr0 = dr4; - dr1 = dr5; - dr2 = dr1 + w_in; - } - dr3 = dr2 + w_in; - dr4 = dr3 + w_in; - dr5 = dr4 + w_in; - - //! process bottom pad - if (i + 5 > h_in) { - switch (i + 5 - h_in) { - case 5: - din_ptr1 = zero_ptr; - case 4: - din_ptr2 = zero_ptr; - case 3: - din_ptr3 = zero_ptr; - case 2: - din_ptr4 = zero_ptr; - case 1: - din_ptr5 = zero_ptr; - default: - break; - } - } - //! process bottom remain - if (i + 4 > h_out) { - switch (i + 4 - h_out) { - case 3: - doutr1 = write_ptr; - case 2: - doutr2 = write_ptr; - case 1: - doutr3 = write_ptr; - default: - break; - } - } - - int cnt = cnt_col; - if (flag_relu) { - asm volatile( - INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU MID_COMPUTE_S1 - MID_RESULT_S1_RELU RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU - : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [din_ptr4] "+r"(din_ptr4), - [din_ptr5] "+r"(din_ptr5), - [doutr0] "+r"(doutr0), - [doutr1] "+r"(doutr1), - [doutr2] "+r"(doutr2), - [doutr3] "+r"(doutr3) - : [w0] "w"(wr0), - [w1] "w"(wr1), - [w2] "w"(wr2), - [bias_val] "r"(vbias), - [vmask] "r"(vmask), - [rmask] "r"(rmask), - [vzero] "w"(vzero) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25"); - } else { - asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1 MID_COMPUTE_S1 - MID_RESULT_S1 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1 - : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [din_ptr4] "+r"(din_ptr4), - [din_ptr5] "+r"(din_ptr5), - [doutr0] "+r"(doutr0), - [doutr1] "+r"(doutr1), - [doutr2] "+r"(doutr2), - [doutr3] "+r"(doutr3) - : [w0] "w"(wr0), - [w1] "w"(wr1), - [w2] "w"(wr2), - [bias_val] "r"(vbias), - [vmask] "r"(vmask), - [rmask] "r"(rmask), - [vzero] "w"(vzero) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25"); - } - dout_ptr = dout_ptr + 4 * w_out; - } -#else - for (int i = 0; i < h_in; i += 2) { - //! process top pad pad_h = 1 - din_ptr0 = dr0; - din_ptr1 = dr1; - din_ptr2 = dr2; - din_ptr3 = dr3; - - doutr0 = dout_ptr; - doutr1 = dout_ptr + w_out; - // unsigned int* rst_mask = rmask; - - if (i == 0) { - din_ptr0 = zero_ptr; - din_ptr1 = dr0; - din_ptr2 = dr1; - din_ptr3 = dr2; - dr0 = dr1; - dr1 = dr2; - dr2 = dr3; - dr3 = dr2 + w_in; - } else { - dr0 = dr2; - dr1 = dr3; - dr2 = dr1 + w_in; - dr3 = dr2 + w_in; - } - //! process bottom pad - if (i + 3 > h_in) { - switch (i + 3 - h_in) { - case 3: - din_ptr1 = zero_ptr; - case 2: - din_ptr2 = zero_ptr; - case 1: - din_ptr3 = zero_ptr; - default: - break; - } - } - //! process bottom remain - if (i + 2 > h_out) { - doutr1 = write_ptr; - } - int cnt = cnt_col; - unsigned int *rmask_ptr = rmask; - unsigned int *vmask_ptr = vmask; - if (flag_relu) { - asm volatile( - INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU MID_COMPUTE_S1 - MID_RESULT_S1_RELU RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU - : [dout_ptr1] "+r"(doutr0), - [dout_ptr2] "+r"(doutr1), - [din0_ptr] "+r"(din_ptr0), - [din1_ptr] "+r"(din_ptr1), - [din2_ptr] "+r"(din_ptr2), - [din3_ptr] "+r"(din_ptr3), - [cnt] "+r"(cnt), - [rmask] "+r"(rmask_ptr), - [vmask] "+r"(vmask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias_val] "r"(bias_val), - [vzero] "w"(vzero) - : "cc", - "memory", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - } else { - asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1 MID_COMPUTE_S1 - MID_RESULT_S1 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1 - : [dout_ptr1] "+r"(doutr0), - [dout_ptr2] "+r"(doutr1), - [din0_ptr] "+r"(din_ptr0), - [din1_ptr] "+r"(din_ptr1), - [din2_ptr] "+r"(din_ptr2), - [din3_ptr] "+r"(din_ptr3), - [cnt] "+r"(cnt), - [rmask] "+r"(rmask_ptr), - [vmask] "+r"(vmask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias_val] "r"(bias_val), - [vzero] "w"(vzero) - : "cc", - "memory", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - } - dout_ptr += 2 * w_out; - } //! end of processing mid rows -#endif - } - } -} - -/** - * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, - * width <= 4 - */ -void conv_depthwise_3x3s1p1_bias_s(float *dout, - const float *din, - const float *weights, - const float *bias, - bool flag_bias, - bool flag_relu, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext *ctx) { - //! 3x3s1 convolution, implemented by direct algorithm - //! pad is done implicit - //! for 4x6 convolution window - const int right_pad_idx[4] = {3, 2, 1, 0}; - const float zero[4] = {0.f, 0.f, 0.f, 0.f}; - - float32x4_t vzero = vdupq_n_f32(0.f); - uint32x4_t vmask_rp = - vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(4 - w_in)); - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - for (int n = 0; n < num; ++n) { - const float *din_batch = din + n * ch_in * size_in_channel; - float *dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int i = 0; i < ch_in; ++i) { - float *dout_channel = dout_batch + i * size_out_channel; - const float *din_channel = din_batch + i * size_in_channel; - const float *weight_ptr = weights + i * 9; - float32x4_t wr0 = vld1q_f32(weight_ptr); - float32x4_t wr1 = vld1q_f32(weight_ptr + 3); - float32x4_t wr2 = vld1q_f32(weight_ptr + 6); - float32x4_t wbias; - if (flag_bias) { - wbias = vdupq_n_f32(bias[i]); - } else { - wbias = vdupq_n_f32(0.f); - } - - int hs = -1; - int he = 3; - - float out_buf1[4]; - float out_buf2[4]; - float trash_buf[4]; - - int h_cnt = (h_out + 1) >> 1; - float *doutr0 = dout_channel; - float *doutr1 = dout_channel + w_out; - - for (int j = 0; j < h_cnt; ++j) { - const float *dr0 = din_channel + hs * w_in; - const float *dr1 = dr0 + w_in; - const float *dr2 = dr1 + w_in; - const float *dr3 = dr2 + w_in; - - if (hs == -1) { - dr0 = zero; - } - - switch (he - h_in) { - case 2: - dr2 = zero; - doutr1 = trash_buf; - case 1: - dr3 = zero; - default: - break; - } -#ifdef __aarch64__ - if (flag_relu) { - asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [zero] "w"(vzero), - [mask] "w"(vmask_rp), - [bias] "w"(wbias), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17"); - } else { - asm volatile(COMPUTE_S_S1 RESULT_S_S1 - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [zero] "w"(vzero), - [mask] "w"(vmask_rp), - [bias] "w"(wbias), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17"); - } -#else - if (flag_relu) { - asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [vzero] "w"(vzero), - [mask] "w"(vmask_rp), - [bias] "w"(wbias), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "cc", - "memory", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - } else { - asm volatile(COMPUTE_S_S1 RESULT_S_S1 - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [vzero] "w"(vzero), - [mask] "w"(vmask_rp), - [bias] "w"(wbias), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "cc", - "memory", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - } -#endif - for (int w = 0; w < w_out; ++w) { - *doutr0++ = out_buf1[w]; - *doutr1++ = out_buf2[w]; - } - doutr0 = doutr1; - doutr1 += w_out; - hs += 2; - he += 2; - } // end of processing heights - } // end of processing channels - } // end of processing batchs -} - -/** - * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, - * width > 4 - */ -void conv_depthwise_3x3s1p0_bias(float *dout, - const float *din, - const float *weights, - const float *bias, - bool flag_bias, - bool flag_relu, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext *ctx) { - //! pad is done implicit - const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; - //! for 4x6 convolution window - const unsigned int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0}; - - float *zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - float *write_ptr = zero_ptr + w_in; - - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - int w_stride = 9; - - int tile_w = w_out >> 2; - int remain = w_out % 4; - - unsigned int size_pad_right = (unsigned int)(6 + (tile_w << 2) - w_in); - const int remian_idx[4] = {0, 1, 2, 3}; - - uint32x4_t vmask_rp1 = - vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right)); - uint32x4_t vmask_rp2 = - vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right)); - uint32x4_t vmask_result = - vcgtq_s32(vdupq_n_s32(remain), vld1q_s32(remian_idx)); - - unsigned int vmask[8]; - vst1q_u32(vmask, vmask_rp1); - vst1q_u32(vmask + 4, vmask_rp2); - - unsigned int rmask[4]; - vst1q_u32(rmask, vmask_result); - - float32x4_t vzero = vdupq_n_f32(0.f); - - for (int n = 0; n < num; ++n) { - const float *din_batch = din + n * ch_in * size_in_channel; - float *dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int c = 0; c < ch_in; c++) { - float *dout_ptr = dout_batch + c * size_out_channel; - - const float *din_ch_ptr = din_batch + c * size_in_channel; - - float bias_val = flag_bias ? bias[c] : 0.f; - float vbias[4] = {bias_val, bias_val, bias_val, bias_val}; - - const float *wei_ptr = weights + c * w_stride; - - float32x4_t wr0 = vld1q_f32(wei_ptr); - float32x4_t wr1 = vld1q_f32(wei_ptr + 3); - float32x4_t wr2 = vld1q_f32(wei_ptr + 6); - - float *doutr0 = dout_ptr; - float *doutr1 = doutr0 + w_out; - float *doutr2 = doutr1 + w_out; - float *doutr3 = doutr2 + w_out; - - const float *dr0 = din_ch_ptr; - const float *dr1 = dr0 + w_in; - const float *dr2 = dr1 + w_in; - const float *dr3 = dr2 + w_in; - const float *dr4 = dr3 + w_in; - const float *dr5 = dr4 + w_in; - - const float *din_ptr0 = dr0; - const float *din_ptr1 = dr1; - const float *din_ptr2 = dr2; - const float *din_ptr3 = dr3; - const float *din_ptr4 = dr4; - const float *din_ptr5 = dr5; - - float *ptr_zero = const_cast(zero); -#ifdef __aarch64__ - for (int i = 0; i < h_out; i += 4) { - //! process top pad pad_h = 1 - din_ptr0 = dr0; - din_ptr1 = dr1; - din_ptr2 = dr2; - din_ptr3 = dr3; - din_ptr4 = dr4; - din_ptr5 = dr5; - - doutr0 = dout_ptr; - doutr1 = doutr0 + w_out; - doutr2 = doutr1 + w_out; - doutr3 = doutr2 + w_out; - - dr0 = dr4; - dr1 = dr5; - dr2 = dr1 + w_in; - dr3 = dr2 + w_in; - dr4 = dr3 + w_in; - dr5 = dr4 + w_in; - - //! process bottom pad - if (i + 5 >= h_in) { - switch (i + 5 - h_in) { - case 4: - din_ptr1 = zero_ptr; - case 3: - din_ptr2 = zero_ptr; - case 2: - din_ptr3 = zero_ptr; - case 1: - din_ptr4 = zero_ptr; - case 0: - din_ptr5 = zero_ptr; - default: - break; - } - } - //! process bottom remain - if (i + 4 > h_out) { - switch (i + 4 - h_out) { - case 3: - doutr1 = write_ptr; - case 2: - doutr2 = write_ptr; - case 1: - doutr3 = write_ptr; - default: - break; - } - } - - int cnt = tile_w; - if (flag_relu) { - asm volatile( - INIT_S1 - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */ - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ - MID_COMPUTE_S1 MID_RESULT_S1_RELU - "cmp %w[remain], #1 \n" - "blt 0f \n" RIGHT_COMPUTE_S1 - RIGHT_RESULT_S1_RELU "0: \n" - : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [din_ptr4] "+r"(din_ptr4), - [din_ptr5] "+r"(din_ptr5), - [doutr0] "+r"(doutr0), - [doutr1] "+r"(doutr1), - [doutr2] "+r"(doutr2), - [doutr3] "+r"(doutr3) - : [w0] "w"(wr0), - [w1] "w"(wr1), - [w2] "w"(wr2), - [bias_val] "r"(vbias), - [vmask] "r"(vmask), - [rmask] "r"(rmask), - [vzero] "w"(vzero), - [remain] "r"(remain) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25"); - } else { - asm volatile( - INIT_S1 - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */ - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ - MID_COMPUTE_S1 MID_RESULT_S1 - "cmp %w[remain], #1 \n" - "blt 0f \n" RIGHT_COMPUTE_S1 - RIGHT_RESULT_S1 "0: \n" - : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [din_ptr4] "+r"(din_ptr4), - [din_ptr5] "+r"(din_ptr5), - [doutr0] "+r"(doutr0), - [doutr1] "+r"(doutr1), - [doutr2] "+r"(doutr2), - [doutr3] "+r"(doutr3) - : [w0] "w"(wr0), - [w1] "w"(wr1), - [w2] "w"(wr2), - [bias_val] "r"(vbias), - [vmask] "r"(vmask), - [rmask] "r"(rmask), - [vzero] "w"(vzero), - [remain] "r"(remain) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25"); - } - dout_ptr = dout_ptr + 4 * w_out; - } -#else - for (int i = 0; i < h_out; i += 2) { - din_ptr0 = dr0; - din_ptr1 = dr1; - din_ptr2 = dr2; - din_ptr3 = dr3; - - doutr0 = dout_ptr; - doutr1 = dout_ptr + w_out; - - dr0 = dr2; - dr1 = dr3; - dr2 = dr1 + w_in; - dr3 = dr2 + w_in; - //! process bottom pad - if (i + 3 >= h_in) { - switch (i + 3 - h_in) { - case 3: - din_ptr1 = zero_ptr; - case 2: - din_ptr2 = zero_ptr; - case 1: - din_ptr3 = zero_ptr; - case 0: - din_ptr3 = zero_ptr; - default: - break; - } - } - //! process bottom remain - if (i + 2 > h_out) { - doutr1 = write_ptr; - } - int cnt = tile_w; - unsigned int *rmask_ptr = rmask; - unsigned int *vmask_ptr = vmask; - if (flag_relu) { - asm volatile(INIT_S1 - "sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n" - "sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n" - "sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n" - "sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n" - "vext.32 q6, q8, q9, #1 @ 0012\n" - "vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1 - MID_RESULT_S1_RELU - "cmp %[remain], #1 \n" - "blt 0f \n" RIGHT_COMPUTE_S1 - RIGHT_RESULT_S1_RELU "0: \n" - : [dout_ptr1] "+r"(doutr0), - [dout_ptr2] "+r"(doutr1), - [din0_ptr] "+r"(din_ptr0), - [din1_ptr] "+r"(din_ptr1), - [din2_ptr] "+r"(din_ptr2), - [din3_ptr] "+r"(din_ptr3), - [cnt] "+r"(cnt), - [rmask] "+r"(rmask_ptr), - [vmask] "+r"(vmask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias_val] "r"(bias_val), - [vzero] "w"(vzero), - [remain] "r"(remain) - : "cc", - "memory", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - } else { - asm volatile(INIT_S1 - "sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n" - "sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n" - "sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n" - "sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n" - "vext.32 q6, q8, q9, #1 @ 0012\n" - "vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1 - MID_RESULT_S1 - "cmp %[remain], #1 \n" - "blt 0f \n" RIGHT_COMPUTE_S1 - RIGHT_RESULT_S1 "0: \n" - : [dout_ptr1] "+r"(doutr0), - [dout_ptr2] "+r"(doutr1), - [din0_ptr] "+r"(din_ptr0), - [din1_ptr] "+r"(din_ptr1), - [din2_ptr] "+r"(din_ptr2), - [din3_ptr] "+r"(din_ptr3), - [cnt] "+r"(cnt), - [rmask] "+r"(rmask_ptr), - [vmask] "+r"(vmask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias_val] "r"(bias_val), - [vzero] "w"(vzero), - [remain] "r"(remain) - : "cc", - "memory", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - } - dout_ptr += 2 * w_out; - } //! end of processing mid rows -#endif - } - } -} -/** - * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, - * width <= 4 - */ -void conv_depthwise_3x3s1p0_bias_s(float *dout, - const float *din, - const float *weights, - const float *bias, - bool flag_bias, - bool flag_relu, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext *ctx) { - //! 3x3s1 convolution, implemented by direct algorithm - //! pad is done implicit - //! for 4x6 convolution window - const int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0}; - const float zero_ptr[4] = {0.f, 0.f, 0.f, 0.f}; - - float32x4_t vzero = vdupq_n_f32(0.f); - uint32x4_t vmask_rp1 = - vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(6 - w_in)); - uint32x4_t vmask_rp2 = - vcgeq_s32(vld1q_s32(right_pad_idx + 4), vdupq_n_s32(6 - w_in)); - - unsigned int vmask[8]; - vst1q_u32(vmask, vmask_rp1); - vst1q_u32(vmask + 4, vmask_rp2); - - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - for (int n = 0; n < num; ++n) { - const float *din_batch = din + n * ch_in * size_in_channel; - float *dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int i = 0; i < ch_in; ++i) { - float *dout_channel = dout_batch + i * size_out_channel; - const float *din_channel = din_batch + i * size_in_channel; - const float *weight_ptr = weights + i * 9; - float32x4_t wr0 = vld1q_f32(weight_ptr); - float32x4_t wr1 = vld1q_f32(weight_ptr + 3); - float32x4_t wr2 = vld1q_f32(weight_ptr + 6); - -#ifdef __aarch64__ - float32x4_t wbias; - if (flag_bias) { - wbias = vdupq_n_f32(bias[i]); - } else { - wbias = vdupq_n_f32(0.f); - } -#endif // __aarch64__ - - float out_buf1[4]; - float out_buf2[4]; - float trash_buf[4]; - - float *doutr0 = dout_channel; - float *doutr1 = dout_channel + w_out; - - for (int j = 0; j < h_out; j += 2) { - const float *dr0 = din_channel + j * w_in; - const float *dr1 = dr0 + w_in; - const float *dr2 = dr1 + w_in; - const float *dr3 = dr2 + w_in; - - doutr0 = dout_channel + j * w_out; - doutr1 = doutr0 + w_out; - - if (j + 3 >= h_in) { - switch (j + 3 - h_in) { - case 3: - dr1 = zero_ptr; - case 2: - dr2 = zero_ptr; - case 1: - dr3 = zero_ptr; - doutr1 = trash_buf; - case 0: - dr3 = zero_ptr; - doutr1 = trash_buf; - default: - break; - } - } -#ifdef __aarch64__ - if (flag_relu) { - asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [vbias] "w"(wbias), - [mask1] "w"(vmask_rp1), - [mask2] "w"(vmask_rp2), - [zero] "w"(vzero), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15"); - } else { - asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1 - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [vbias] "w"(wbias), - [mask1] "w"(vmask_rp1), - [mask2] "w"(vmask_rp2), - [zero] "w"(vzero), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15"); - } -#else - unsigned int *vmask_ptr = vmask; - float bias_val = flag_bias ? bias[i] : 0.f; - if (flag_relu) { - asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3), - [vmask] "+r"(vmask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [vzero] "w"(vzero), - [bias_val] "r"(bias_val), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "cc", - "memory", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - } else { - asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1 - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3), - [vmask] "+r"(vmask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [vzero] "w"(vzero), - [bias_val] "r"(bias_val), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "cc", - "memory", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - } -#endif - for (int w = 0; w < w_out; ++w) { - *doutr0++ = out_buf1[w]; - *doutr1++ = out_buf2[w]; - } - } // end of processing heights - } // end of processing channels - } // end of processing batchs -} -} // namespace math -} // namespace arm -} // namespace lite -} // namespace paddle diff --git a/lite/backends/arm/math/conv_impl.cc b/lite/backends/arm/math/conv_impl.cc index 02a49cf157296763ce3a61ea99dd4ce513dc2f30..96d0893bc0f0a1c145f4e58dd2caecfba78786ab 100644 --- a/lite/backends/arm/math/conv_impl.cc +++ b/lite/backends/arm/math/conv_impl.cc @@ -107,29 +107,35 @@ void im2col(const Dtype* data_im, int width, int kernel_h, int kernel_w, - int pad_h, - int pad_w, + int pad_top, + int pad_bottom, + int pad_left, + int pad_right, int stride_h, int stride_w, int dilation_h, int dilation_w, Dtype* data_col) { const int output_h = - (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + (height + pad_top + pad_bottom - (dilation_h * (kernel_h - 1) + 1)) / + stride_h + + 1; const int output_w = - (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + (width + pad_left + pad_right - (dilation_w * (kernel_w - 1) + 1)) / + stride_w + + 1; const int channel_size = height * width; for (int channel = channels; channel--; data_im += channel_size) { for (int kernel_row = 0; kernel_row < kernel_h; kernel_row++) { for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) { - int input_row = -pad_h + kernel_row * dilation_h; + int input_row = -pad_top + kernel_row * dilation_h; for (int output_rows = output_h; output_rows; output_rows--) { if (!is_a_ge_zero_and_a_lt_b(input_row, height)) { for (int output_cols = output_w; output_cols; output_cols--) { *(data_col++) = 0; } } else { - int input_col = -pad_w + kernel_col * dilation_w; + int input_col = -pad_left + kernel_col * dilation_w; for (int output_col = output_w; output_col; output_col--) { if (is_a_ge_zero_and_a_lt_b(input_col, width)) { *(data_col++) = data_im[input_row * width + input_col]; @@ -174,13 +180,14 @@ void conv1x1s1_gemm(const float* i_data, bool flag_relu = param.fuse_relu; bool flag_bias = param.bias != nullptr; + auto act_param = param.activation_param; + int hblock = get_hblock(ctx); int m_roundup = hblock * ((m + hblock - 1) / hblock); int weights_size_per_group = m * k; if (n > 1) { weights_size_per_group = ((m_roundup * k + 15) / 16) * 16; } - //! use gemv when the output channel size = 1 for (int b = 0; b < num; ++b) { // dC @@ -202,8 +209,11 @@ void conv1x1s1_gemm(const float* i_data, k, flag_bias, bias_group, - flag_relu, - ctx); + act_param.has_active, + act_param.active_type, + ctx, + act_param.Relu_clipped_coef, + act_param.Leaky_relu_alpha); } else { sgemm_prepack(false, m, @@ -217,7 +227,7 @@ void conv1x1s1_gemm(const float* i_data, n, bias_group, flag_bias, - flag_relu, + act_param, ctx); } } @@ -355,6 +365,8 @@ void conv_im2col_gemm(const float* i_data, int hblock = get_hblock(ctx); int m_roundup = hblock * ((m + hblock - 1) / hblock); int weights_size_per_group = m * k; + + auto act_param = param.activation_param; if (n > 1) { weights_size_per_group = ((m_roundup * k + 15) / 16) * 16; } @@ -362,6 +374,8 @@ void conv_im2col_gemm(const float* i_data, float* tmp_work_space = ctx->workspace_data() + ctx->llc_size() / sizeof(float); + auto paddings = *param.paddings; + auto dilations = *param.dilations; //! use gemv when the output channel size = 1 for (int b = 0; b < num; ++b) { // dC @@ -379,12 +393,14 @@ void conv_im2col_gemm(const float* i_data, win, kernel_h, kernel_w, - param.paddings[0], - param.paddings[1], + paddings[0], + paddings[1], + paddings[2], + paddings[3], param.strides[0], param.strides[1], - param.dilations[0], - param.dilations[1], + dilations[0], + dilations[1], dB); if (n == 1) { @@ -396,8 +412,11 @@ void conv_im2col_gemm(const float* i_data, k, flag_bias, bias_group, - flag_relu, - ctx); + act_param.has_active, + act_param.active_type, + ctx, + act_param.Relu_clipped_coef, + act_param.Leaky_relu_alpha); } else { int ldb = n; sgemm_prepack(false, @@ -412,7 +431,7 @@ void conv_im2col_gemm(const float* i_data, n, bias_group, flag_bias, - flag_relu, + act_param, ctx); } } @@ -436,14 +455,16 @@ void conv_im2col_gemm_int8(const int8_t* i_data, const float* scale) { int group = param.groups; auto filter_dims = param.filter->dims(); + auto paddings = *param.paddings; + auto dilations = *param.dilations; int kernel_h = filter_dims[2]; int kernel_w = filter_dims[3]; int stride_h = param.strides[0]; int stride_w = param.strides[1]; - int dila_h = param.dilations[0]; - int dila_w = param.dilations[1]; - int pad_h = param.paddings[0]; - int pad_w = param.paddings[1]; + int dila_h = dilations[0]; + int dila_w = dilations[1]; + int pad_h = paddings[0]; + int pad_w = paddings[2]; const int m = oc / group; const int n = oh * ow; const int k = ic * kernel_h * kernel_w / group; @@ -484,7 +505,9 @@ void conv_im2col_gemm_int8(const int8_t* i_data, kernel_h, kernel_w, pad_h, + paddings[1], pad_w, + paddings[3], stride_h, stride_w, dila_h, @@ -564,90 +587,83 @@ void conv_depthwise_3x3_fp32(const void* din, const operators::ConvParam& param, ARMContext* ctx, const float* scale) { - const int pad_h = param.paddings[0]; - const int pad_w = param.paddings[1]; - if (pad_w != pad_h) { - LOG(FATAL) << "fp32 depthwise conv3x3 pad_w: " << pad_w - << ", pad_h: " << pad_h << " must be equal"; - return; - } + auto paddings = *param.paddings; + auto act_param = param.activation_param; + const int pad_h = paddings[0]; + const int pad_w = paddings[2]; int stride = param.strides[1]; int pad = pad_w; - bool flag_relu = param.fuse_relu; bool flag_bias = param.bias != nullptr; - if (stride == 1 && pad < 2) { // support pad = [0, 1] - conv_depthwise_3x3s1_fp32(reinterpret_cast(din), - reinterpret_cast(dout), - num, - ch_out, - h_out, - w_out, - ch_in, - h_in, - w_in, - reinterpret_cast(weights), - bias, - pad, - flag_bias, - flag_relu, - ctx); - } else if (stride == 2 && pad < 2) { // support pad = [0, 1] - conv_depthwise_3x3s2_fp32(reinterpret_cast(din), - reinterpret_cast(dout), - num, - ch_out, - h_out, - w_out, - ch_in, - h_in, - w_in, - reinterpret_cast(weights), - bias, - pad, - flag_bias, - flag_relu, - ctx); - } else { - LOG(FATAL) << "fp32 depthwise conv3x3 stride: " << stride - << " or pad(<2): " << pad << " unsupported"; - } -#if 0 - if (pad == 1) { - conv_depthwise_3x3p1_fp32(reinterpret_cast(din), - reinterpret_cast(dout), - num, - ch_out, - h_out, - w_out, - ch_in, - h_in, - w_in, - reinterpret_cast(weights), - bias, - stride, - flag_bias, - flag_relu, - ctx); - } else if (pad == 0 && h_in > 2) { - conv_depthwise_3x3p0_fp32(reinterpret_cast(din), - reinterpret_cast(dout), - num, - ch_out, - h_out, - w_out, - ch_in, - h_in, - w_in, - reinterpret_cast(weights), - bias, - stride, - flag_bias, - flag_relu, - ctx); + bool pads_less = ((paddings[1] < 2) && (paddings[3] < 2)); + if (stride == 1) { + if (pads_less && (pad_h == pad_w) && (pad < 2)) { // support pad = [0, 1] + conv_depthwise_3x3s1_fp32(reinterpret_cast(din), + reinterpret_cast(dout), + num, + ch_out, + h_out, + w_out, + ch_in, + h_in, + w_in, + reinterpret_cast(weights), + bias, + pad, + flag_bias, + act_param, + ctx); + } else { + conv_3x3s1_depthwise_fp32(reinterpret_cast(din), + reinterpret_cast(dout), + num, + ch_out, + h_out, + w_out, + ch_in, + h_in, + w_in, + reinterpret_cast(weights), + bias, + param, + act_param, + ctx); + } + } else if (stride == 2) { + if (pads_less && pad_h == pad_w && (pad < 2)) { // support pad = [0, 1] + conv_depthwise_3x3s2_fp32(reinterpret_cast(din), + reinterpret_cast(dout), + num, + ch_out, + h_out, + w_out, + ch_in, + h_in, + w_in, + reinterpret_cast(weights), + bias, + pad, + flag_bias, + act_param, + ctx); + } else { + conv_3x3s2_depthwise_fp32(reinterpret_cast(din), + reinterpret_cast(dout), + num, + ch_out, + h_out, + w_out, + ch_in, + h_in, + w_in, + reinterpret_cast(weights), + bias, + param, + act_param, + ctx); + } } else { - LOG(FATAL) << "unsupport this type 3x3 dw conv"; + LOG(FATAL) << "fp32 depthwise conv3x3 stride: " << stride << " unsupported"; } -#endif } void conv_depthwise_5x5_fp32(const void* din, @@ -664,12 +680,15 @@ void conv_depthwise_5x5_fp32(const void* din, const operators::ConvParam& param, ARMContext* ctx, const float* scale) { - int pad = param.paddings[1]; + auto paddings = *param.paddings; + auto act_param = param.activation_param; + int pad_h = paddings[0]; + int pad_w = paddings[2]; int stride = param.strides[1]; bool flag_relu = param.fuse_relu; bool flag_bias = param.bias != nullptr; ctx->ExtendWorkspace((w_in + w_out) * sizeof(float)); - if (pad == 2 && stride == 2) { + if (stride == 2) { conv_depthwise_5x5s2_fp32(reinterpret_cast(din), reinterpret_cast(dout), num, @@ -681,25 +700,25 @@ void conv_depthwise_5x5_fp32(const void* din, w_in, reinterpret_cast(weights), bias, - pad, - flag_bias, - flag_relu, + param, + act_param, ctx); } else if (stride == 1) { - conv_depthwise_5x5s1_fp32(reinterpret_cast(din), - reinterpret_cast(dout), - num, - ch_out, - h_out, - w_out, - ch_in, - h_in, - w_in, + conv_depthwise_5x5s1_fp32(reinterpret_cast(dout), + reinterpret_cast(din), reinterpret_cast(weights), bias, - pad, flag_bias, flag_relu, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + pad_w, + pad_h, + param, ctx); } else { LOG(FATAL) << "unsupport this type 5x5 dw conv"; @@ -720,8 +739,9 @@ void conv_depthwise_3x3_int8_fp32(const void* din, const operators::ConvParam& param, ARMContext* ctx, const float* scale) { - int pad_h = param.paddings[0]; - int pad_w = param.paddings[1]; + auto paddings = *param.paddings; + int pad_h = paddings[0]; + int pad_w = paddings[2]; int stride = param.strides[1]; bool flag_relu = param.fuse_relu; bool flag_bias = param.bias != nullptr; @@ -778,8 +798,9 @@ void conv_depthwise_3x3_int8_int8(const void* din, const operators::ConvParam& param, ARMContext* ctx, const float* scale) { - int pad_h = param.paddings[0]; - int pad_w = param.paddings[1]; + auto paddings = *param.paddings; + int pad_h = paddings[0]; + int pad_w = paddings[2]; int stride = param.strides[1]; bool flag_relu = param.fuse_relu; bool flag_bias = param.bias != nullptr; @@ -836,8 +857,9 @@ void conv_depthwise_5x5_int8_fp32(const void* din, const operators::ConvParam& param, ARMContext* ctx, const float* scale) { - int pad_h = param.paddings[0]; - int pad_w = param.paddings[1]; + auto paddings = *param.paddings; + int pad_h = paddings[0]; + int pad_w = paddings[2]; int stride = param.strides[1]; bool flag_relu = param.fuse_relu; bool flag_bias = param.bias != nullptr; @@ -858,6 +880,23 @@ void conv_depthwise_5x5_int8_fp32(const void* din, pad_w, pad_h, ctx); + } else if (stride == 2) { + conv_depthwise_5x5s2_int8(reinterpret_cast(dout), + reinterpret_cast(din), + reinterpret_cast(weights), + scale, + bias, + flag_bias, + flag_relu, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + pad_w, + pad_h, + ctx); } else { LOG(FATAL) << "unsupport this type 5x5 dw conv int8"; } @@ -877,8 +916,9 @@ void conv_depthwise_5x5_int8_int8(const void* din, const operators::ConvParam& param, ARMContext* ctx, const float* scale) { - int pad_h = param.paddings[0]; - int pad_w = param.paddings[1]; + auto paddings = *param.paddings; + int pad_h = paddings[0]; + int pad_w = paddings[2]; int stride = param.strides[1]; bool flag_relu = param.fuse_relu; bool flag_bias = param.bias != nullptr; @@ -899,6 +939,23 @@ void conv_depthwise_5x5_int8_int8(const void* din, pad_w, pad_h, ctx); + } else if (stride == 2) { + conv_depthwise_5x5s2_int8(reinterpret_cast(dout), + reinterpret_cast(din), + reinterpret_cast(weights), + scale, + bias, + flag_bias, + flag_relu, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + pad_w, + pad_h, + ctx); } else { LOG(FATAL) << "unsupport this type 5x5 dw conv int8"; } diff --git a/lite/backends/arm/math/conv_impl.h b/lite/backends/arm/math/conv_impl.h index c5baa31e1414c4a7a0c926728e5c150c0fc3e21c..60f74b7feecc91a2fe8262a1fea4dce26430031d 100644 --- a/lite/backends/arm/math/conv_impl.h +++ b/lite/backends/arm/math/conv_impl.h @@ -314,7 +314,51 @@ void fill_bias_int8(int* tensor, const int* bias, int channel, int channel_size); +// new winograd +void weight_trans_c4_8x8( + float* dest, const float* src, int ic, int oc, void* workspace); +void weight_trans_c4_4x4( + float* dest, const float* src, int ic, int oc, void* workspace); +void conv_compute_6x6_3x3(const float* input, + float* output, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const float* weight, + const float* bias, + const operators::ConvParam& param, + ARMContext* ctx); +void conv_compute_2x2_3x3(const float* input, + float* output, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const float* weight, + const float* bias, + const operators::ConvParam& param, + ARMContext* ctx); +void conv_compute_2x2_3x3_small(const float* input, + float* output, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const float* weight, + const float* bias, + const operators::ConvParam& param, + ARMContext* ctx); } // namespace math } // namespace arm } // namespace lite diff --git a/lite/backends/arm/math/conv_winograd_3x3.cc b/lite/backends/arm/math/conv_winograd_3x3.cc index 87b08f63102104b325e95c093fe0fc0aaef243e0..449c9e51db1e67b2a9f0d2d0f6ed0c2c2b2b2772 100644 --- a/lite/backends/arm/math/conv_winograd_3x3.cc +++ b/lite/backends/arm/math/conv_winograd_3x3.cc @@ -37,13 +37,15 @@ void conv_winograd3x3(const float* din, const operators::ConvParam& param, ARMContext* ctx) { int threads = ctx->threads(); - - const int pad_h = param.paddings[0]; - const int pad_w = param.paddings[1]; + auto paddings = *param.paddings; + const int pad_h = paddings[0]; + const int pad_w = paddings[1]; int size_in_channel = win * hin; int size_out_channel = wout * hout; bool flag_relu = param.fuse_relu; bool flag_bias = param.bias != nullptr; + auto act_param = param.activation_param; + act_param.has_active = false; //! transform input int tile_w = (wout + 5) / 6; @@ -127,7 +129,7 @@ void conv_winograd3x3(const float* din, size_tile, nullptr, false, - false, + act_param, ctx); } diff --git a/lite/backends/arm/math/elementwise.cc b/lite/backends/arm/math/elementwise.cc index a4c61f9a9d181924c28cdd009f8412278d44f5bb..186ad19735799dcb91641354af4b4f09692bfce9 100644 --- a/lite/backends/arm/math/elementwise.cc +++ b/lite/backends/arm/math/elementwise.cc @@ -557,6 +557,52 @@ void elementwise_mul(const float* dinx, } } +template <> +void elementwise_mul(const int* dinx, + const int* diny, + int* dout, + int num) { + int cnt = num >> 4; + int remain = num % 16; +#pragma omp parallel for + for (int i = 0; i < cnt; ++i) { + const int* dinx_ptr = dinx + (i << 4); + const int* diny_ptr = diny + (i << 4); + int* dout_ptr = dout + (i << 4); + + int32x4_t dinx0 = vld1q_s32(dinx_ptr); + int32x4_t dinx1 = vld1q_s32(dinx_ptr + 4); + int32x4_t dinx2 = vld1q_s32(dinx_ptr + 8); + int32x4_t dinx3 = vld1q_s32(dinx_ptr + 12); + + int32x4_t diny0 = vld1q_s32(diny_ptr); + int32x4_t diny1 = vld1q_s32(diny_ptr + 4); + int32x4_t diny2 = vld1q_s32(diny_ptr + 8); + int32x4_t diny3 = vld1q_s32(diny_ptr + 12); + + dinx0 = vmulq_s32(dinx0, diny0); + dinx1 = vmulq_s32(dinx1, diny1); + dinx2 = vmulq_s32(dinx2, diny2); + dinx3 = vmulq_s32(dinx3, diny3); + + vst1q_s32(dout_ptr, dinx0); + vst1q_s32(dout_ptr + 4, dinx1); + vst1q_s32(dout_ptr + 8, dinx2); + vst1q_s32(dout_ptr + 12, dinx3); + } + if (remain > 0) { + const int* dinx_ptr = dinx + (cnt << 4); + const int* diny_ptr = diny + (cnt << 4); + int* dout_ptr = dout + (cnt << 4); + for (int i = 0; i < remain; i++) { + *dout_ptr = *dinx_ptr * *diny_ptr; + dout_ptr++; + dinx_ptr++; + diny_ptr++; + } + } +} + template <> void elementwise_mul_relu(const float* dinx, const float* diny, @@ -678,6 +724,73 @@ void elementwise_mul_broadcast(const float* dinx, } } +template <> +void elementwise_mul_broadcast(const int* dinx, + const int* diny, + int* dout, + int batch, + int channels, + int num) { +#pragma omp parallel for collapse(2) + for (int i = 0; i < batch; ++i) { + for (int j = 0; j < channels; ++j) { + int offset = (i * channels + j) * num; + const int* din_ptr = dinx + offset; + const int diny_data = diny[j]; + int* dout_ptr = dout + offset; + + int cnt = num >> 4; + int remain = num % 16; + int32x4_t rb = vdupq_n_s32(diny_data); + for (int k = 0; k < cnt; ++k) { + int32x4_t din0 = vld1q_s32(din_ptr); + int32x4_t din1 = vld1q_s32(din_ptr + 4); + int32x4_t din2 = vld1q_s32(din_ptr + 8); + int32x4_t din3 = vld1q_s32(din_ptr + 12); + + din0 = vmulq_s32(din0, rb); + din1 = vmulq_s32(din1, rb); + din2 = vmulq_s32(din2, rb); + din3 = vmulq_s32(din3, rb); + + vst1q_s32(dout_ptr, din0); + vst1q_s32(dout_ptr + 4, din1); + vst1q_s32(dout_ptr + 8, din2); + vst1q_s32(dout_ptr + 12, din3); + + din_ptr += 16; + dout_ptr += 16; + } + if (remain >= 8) { + int32x4_t din0 = vld1q_s32(din_ptr); + int32x4_t din1 = vld1q_s32(din_ptr + 4); + din0 = vmulq_s32(din0, rb); + din1 = vmulq_s32(din1, rb); + vst1q_s32(dout_ptr, din0); + vst1q_s32(dout_ptr + 4, din1); + din_ptr += 8; + dout_ptr += 8; + remain -= 8; + } + if (remain >= 4) { + int32x4_t din0 = vld1q_s32(din_ptr); + din0 = vmulq_s32(din0, rb); + vst1q_s32(dout_ptr, din0); + din_ptr += 4; + dout_ptr += 4; + remain -= 4; + } + if (remain > 0) { + for (int p = 0; p < remain; ++p) { + *dout_ptr = *din_ptr * diny_data; + dout_ptr++; + din_ptr++; + } + } + } + } +} + template <> void elementwise_mul_relu_broadcast(const float* dinx, const float* diny, diff --git a/lite/backends/arm/math/fill_bias_relu.cc b/lite/backends/arm/math/fill_bias_relu.cc index 7137a0363ba42b9c6416c6f98b0d4a6b5a1687fb..d816c2f549c2c074a35885931a585ff51ae97f6f 100644 --- a/lite/backends/arm/math/fill_bias_relu.cc +++ b/lite/backends/arm/math/fill_bias_relu.cc @@ -115,7 +115,241 @@ void fill_bias_relu(int* tensor, } } } - +#ifdef __aarch64__ +#define FILL_BIAS \ + "1: \n" \ + "ld1 {v0.4s}, [%[din_ptr]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v1.4s}, [%[din_ptr]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v2.4s}, [%[din_ptr]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v3.4s}, [%[din_ptr]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + "add v0.4s, v0.4s, %[vbias].4s \n" \ + "add v1.4s, v1.4s, %[vbias].4s \n" \ + "add v2.4s, v2.4s, %[vbias].4s \n" \ + "add v3.4s, v3.4s, %[vbias].4s \n" +#define FILL_RELU \ + "fmax v0.4s, v0.4s, %[vzero].4s \n" /* vmaxq_f32() */ \ + "fmax v1.4s, v1.4s, %[vzero].4s \n" /* vmaxq_f32() */ \ + "fmax v2.4s, v2.4s, %[vzero].4s \n" /* vmaxq_f32() */ \ + "fmax v3.4s, v3.4s, %[vzero].4s \n" /* vmaxq_f32() */ +#define FILL_RELU6 \ + "fmin v0.4s, v0.4s, %[vsix].4s \n" /* vmaxq_f32() */ \ + "fmin v1.4s, v1.4s, %[vsix].4s \n" /* vmaxq_f32() */ \ + "fmin v2.4s, v2.4s, %[vsix].4s \n" /* vmaxq_f32() */ \ + "fmin v3.4s, v3.4s, %[vsix].4s \n" /* vmaxq_f32() */ +#define FILL_LEAKY_RELU \ + "fcmge v4.4s, v0.4s, %[vzero].4s \n" /* vcgeq_f32 */ \ + "fmul v5.4s, v0.4s, %[vscale].4s \n" /* vmulq_f32 */ \ + "fcmge v6.4s, v1.4s, %[vzero].4s \n" /* vcgeq_f32 */ \ + "fmul v7.4s, v1.4s, %[vscale].4s \n" /* vmulq_f32 */ \ + "fcmge v8.4s, v2.4s, %[vzero].4s \n" /* vcgeq_f32 */ \ + "fmul v9.4s, v2.4s, %[vscale].4s \n" /* vmulq_f32 */ \ + "fcmge v10.4s, v3.4s, %[vzero].4s \n" /* vcgeq_f32 */ \ + "fmul v11.4s, v3.4s, %[vscale].4s \n" /* vmulq_f32 */ \ + "bif v0.16b, v5.16b, v4.16b \n" /* choose*/ \ + "bif v1.16b, v7.16b, v6.16b \n" /* choose*/ \ + "bif v2.16b, v9.16b, v8.16b \n" /* choose*/ \ + "bif v3.16b, v11.16b, v10.16b \n" /* choose*/ +#define FILL_STORE \ + "subs %w[cnt], %w[cnt], #1 \n" \ + "st1 {v0.4s}, [%[dout_ptr]], #16 \n" /* vst1q_f32() */ \ + "st1 {v1.4s}, [%[dout_ptr]], #16 \n" /* vst1q_f32() */ \ + "st1 {v2.4s}, [%[dout_ptr]], #16 \n" /* vst1q_f32() */ \ + "st1 {v3.4s}, [%[dout_ptr]], #16 \n" /* vst1q_f32() */ \ + "bne 1b \n" +#else +#define FILL_BIAS \ + "1: \n" \ + "vld1.32 {d6-d7}, [%[din_ptr]]! @ vld1q_f32(din_ptr) \n" \ + "vld1.32 {d8-d9}, [%[din_ptr]]! @ vld1q_f32(din_ptr) \n" \ + "vld1.32 {d10-d11}, [%[din_ptr]]! @ vld1q_f32(din_ptr) \n" \ + "vld1.32 {d12-d13}, [%[din_ptr]]! @ vld1q_f32(din_ptr) \n" \ + "vadd.f32 q3, q3, %q[vbias] @ add \n" \ + "vadd.f32 q4, q4, %q[vbias] @ add \n" \ + "vadd.f32 q5, q5, %q[vbias] @ add \n" \ + "vadd.f32 q6, q6, %q[vbias] @ add \n" +#define FILL_RELU \ + "vmax.f32 q3, q3, %q[vzero] @ vmaxq_f32() \n" \ + "vmax.f32 q4, q4, %q[vzero] @ vmaxq_f32() \n" \ + "vmax.f32 q5, q5, %q[vzero] @ vmaxq_f32() \n" \ + "vmax.f32 q6, q6, %q[vzero] @ vmaxq_f32() \n" +#define FILL_RELU6 \ + "vmin.f32 q3, q3, %q[vsix] @ vminq_f32() \n" \ + "vmin.f32 q4, q4, %q[vsix] @ vmaxq_f32() \n" \ + "vmin.f32 q5, q5, %q[vsix] @ vmaxq_f32() \n" \ + "vmin.f32 q6, q6, %q[vsix] @ vmaxq_f32() \n" +#define FILL_LEAKY_RELU \ + "vcge.f32 q7, q3, %q[vzero] @ vcgeq_u32 \n" \ + "vmul.f32 q8, q3, %q[vscale] @ vmulq_f32 \n" \ + "vcge.f32 q9, q4, %q[vzero] @ vcgeq_u32 \n" \ + "vmul.f32 q10, q4, %q[vscale] @ vmulq_f32 \n" \ + "vcge.f32 q11, q5, %q[vzero] @ vcgeq_u32 \n" \ + "vmul.f32 q12, q5, %q[vscale] @ vmulq_f32 \n" \ + "vcge.f32 q13, q6, %q[vzero] @ vcgeq_u32 \n" \ + "vmul.f32 q14, q6, %q[vscale] @ vmulq_f32 \n" \ + "vbif q3, q8, q7 @ choose \n" \ + "vbif q4, q10, q9 @ choose \n" \ + "vbif q5, q12, q11 @ choose \n" \ + "vbif q6, q14, q13 @ choose \n" +#define FILL_STORE \ + "subs %[cnt], #1 \n" \ + "vst1.32 {d6-d7}, [%[dout_ptr]]! @ vst1q_f32() \n" \ + "vst1.32 {d8-d9}, [%[dout_ptr]]! @ vst1q_f32() \n" \ + "vst1.32 {d10-d11}, [%[dout_ptr]]! @ vst1q_f32() \n" \ + "vst1.32 {d12-d13}, [%[dout_ptr]]! @ vst1q_f32() \n" \ + "bne 1b \n" +#endif +template <> +void fill_bias_act(float* tensor, + const float* bias, + int channel, + int channel_size, + bool flag_bias, + const operators::ActivationParam* act_param) { + float* data = tensor; + int cnt = channel_size >> 4; + int remain = channel_size % 16; + float32x4_t vzero = vdupq_n_f32(0.f); + if (act_param != nullptr && act_param->has_active) { + float32x4_t vsix = vdupq_n_f32(act_param->Relu_clipped_coef); + float32x4_t vscale = vdupq_n_f32(act_param->Leaky_relu_alpha); + for (int j = 0; j < channel; j++) { + float bias_data = flag_bias ? bias[j] : 0.f; + float* src = data + j * channel_size; + float* dst = data + j * channel_size; + float32x4_t vbias = vdupq_n_f32(bias_data); + if (cnt > 0) { + switch (act_param->active_type) { + case lite_api::ActivationType::kRelu: +#ifdef __aarch64__ + asm volatile( + FILL_BIAS FILL_RELU FILL_STORE + : [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt) + : [vzero] "w"(vzero), [vbias] "w"(vbias) + : "memory", "cc", "v0", "v1", "v2", "v3"); +#else + asm volatile( + FILL_BIAS FILL_RELU FILL_STORE + : [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt) + : [vzero] "w"(vzero), [vbias] "w"(vbias) + : "memory", "cc", "q3", "q4", "q5", "q6"); +#endif + break; + case lite_api::ActivationType::kRelu6: +#ifdef __aarch64__ + asm volatile( + FILL_BIAS FILL_RELU FILL_RELU6 FILL_STORE + : [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt) + : [vzero] "w"(vzero), [vsix] "w"(vsix), [vbias] "w"(vbias) + : "memory", "cc", "v0", "v1", "v2", "v3"); +#else + asm volatile( + FILL_BIAS FILL_RELU FILL_RELU6 FILL_STORE + : [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt) + : [vzero] "w"(vzero), [vsix] "w"(vsix), [vbias] "w"(vbias) + : "memory", "cc", "q3", "q4", "q5", "q6"); +#endif + break; + case lite_api::ActivationType::kLeakyRelu: +#ifdef __aarch64__ + asm volatile( + FILL_BIAS FILL_LEAKY_RELU FILL_STORE + : [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt) + : [vzero] "w"(vzero), [vscale] "w"(vscale), [vbias] "w"(vbias) + : "memory", + "cc", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11"); +#else + asm volatile( + FILL_BIAS FILL_LEAKY_RELU FILL_STORE + : [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt) + : [vzero] "w"(vzero), [vscale] "w"(vscale), [vbias] "w"(vbias) + : "memory", + "cc", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14"); +#endif + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param->active_type) + << " fuse not support"; + } + } + // remain + switch (act_param->active_type) { + case lite_api::ActivationType::kRelu: + for (int i = 0; i < remain; i++) { + *dst = *src >= 0.f ? *src : 0.f; + src++; + dst++; + } + case lite_api::ActivationType::kRelu6: + for (int i = 0; i < remain; i++) { + float tmp = *src >= 0.f ? *src : 0.f; + *dst = tmp <= act_param->Relu_clipped_coef + ? tmp + : act_param->Relu_clipped_coef; + src++; + dst++; + } + case lite_api::ActivationType::kLeakyRelu: + for (int i = 0; i < remain; i++) { + if (*src >= 0.f) { + *dst = *src; + } else { + *dst = *src * act_param->Leaky_relu_alpha; + } + src++; + dst++; + } + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param->active_type) + << " fuse not support"; + } + } + } else { + for (int j = 0; j < channel; ++j) { + float bias_data = flag_bias ? bias[j] : 0.f; + float32x4_t vbias = vdupq_n_f32(bias_data); + float* src = data + j * channel_size; + float* dst = data + j * channel_size; +#ifdef __aarch64__ + asm volatile(FILL_BIAS FILL_STORE + : [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt) + : [vbias] "w"(vbias) + : "memory", "cc", "v0", "v1", "v2", "v3"); +#else + asm volatile(FILL_BIAS FILL_STORE + : [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt) + : [vbias] "w"(vbias) + : "memory", "cc", "q3", "q4", "q5", "q6"); +#endif + } + } +} } // namespace math } // namespace arm } // namespace lite diff --git a/lite/backends/arm/math/fill_bias_relu.h b/lite/backends/arm/math/fill_bias_relu.h index 254d6d43be8aca8b17cb2fb2107977095facba51..ce775a96a13dad7fddac34e211fc19267a9d48fc 100644 --- a/lite/backends/arm/math/fill_bias_relu.h +++ b/lite/backends/arm/math/fill_bias_relu.h @@ -37,7 +37,22 @@ void fill_bias_relu(Dtype* tensor, int channel_size, bool flag_bias, bool flag_relu); - +/** + * * \brief neon implementation to add bias and activation(relu, relu6, + * leakyrelu) + * * @param tensor + * * @param bias + * * @param channel + * * @param channel_size + * + */ +template +void fill_bias_act(Dtype* tensor, + const Dtype* bias, + int channel, + int channel_size, + bool flag_bias, + const operators::ActivationParam* act_param); } // namespace math } // namespace arm } // namespace lite diff --git a/lite/backends/arm/math/funcs.cc b/lite/backends/arm/math/funcs.cc index e4425ade2efebdaad9136f75c39493f2bd3df4ca..8d20e5242e556c86a1283a64ff9ccf51e2efa247 100644 --- a/lite/backends/arm/math/funcs.cc +++ b/lite/backends/arm/math/funcs.cc @@ -21,128 +21,179 @@ namespace arm { namespace math { template <> -void fill_bias_fc(float *out, const float *bias, int num, int channel) { +void fill_bias_fc( + float *out, const float *bias, int num, int channel, bool flag_relu) { int cnt = channel >> 4; int remain = channel & 15; - - for (int j = 0; j < num; ++j) { - const float *ptr_bias = bias; - float *ptr_out = out + j * channel; - - float32x4_t vout1; - float32x4_t vout2; - float32x4_t vout3; - float32x4_t vout4; - - for (int i = 0; i < cnt; ++i) { - float32x4_t vin1 = vld1q_f32(ptr_out); - float32x4_t vb1 = vld1q_f32(ptr_bias); - - float32x4_t vin2 = vld1q_f32(ptr_out + 4); - float32x4_t vb2 = vld1q_f32(ptr_bias + 4); - - float32x4_t vin3 = vld1q_f32(ptr_out + 8); - float32x4_t vb3 = vld1q_f32(ptr_bias + 8); - - float32x4_t vin4 = vld1q_f32(ptr_out + 12); - float32x4_t vb4 = vld1q_f32(ptr_bias + 12); - - vout1 = vaddq_f32(vin1, vb1); - vout2 = vaddq_f32(vin2, vb2); - vout3 = vaddq_f32(vin3, vb3); - vout4 = vaddq_f32(vin4, vb4); - - vst1q_f32(ptr_out, vout1); - vst1q_f32(ptr_out + 4, vout2); - vst1q_f32(ptr_out + 8, vout3); - vst1q_f32(ptr_out + 12, vout4); - - ptr_out += 16; - ptr_bias += 16; + if (flag_relu) { + float32x4_t vzero = vdupq_n_f32(0.f); + for (int j = 0; j < num; ++j) { + const float *ptr_bias = bias; + float *ptr_out = out + j * channel; + + for (int i = 0; i < cnt; ++i) { + float32x4_t vin1 = vld1q_f32(ptr_out); + float32x4_t vb1 = vld1q_f32(ptr_bias); + + float32x4_t vin2 = vld1q_f32(ptr_out + 4); + float32x4_t vb2 = vld1q_f32(ptr_bias + 4); + + float32x4_t vin3 = vld1q_f32(ptr_out + 8); + float32x4_t vb3 = vld1q_f32(ptr_bias + 8); + + float32x4_t vin4 = vld1q_f32(ptr_out + 12); + float32x4_t vb4 = vld1q_f32(ptr_bias + 12); + + float32x4_t vout1 = vaddq_f32(vin1, vb1); + float32x4_t vout2 = vaddq_f32(vin2, vb2); + float32x4_t vout3 = vaddq_f32(vin3, vb3); + float32x4_t vout4 = vaddq_f32(vin4, vb4); + + vout1 = vmaxq_f32(vout1, vzero); + vout2 = vmaxq_f32(vout2, vzero); + vout3 = vmaxq_f32(vout3, vzero); + vout4 = vmaxq_f32(vout4, vzero); + + vst1q_f32(ptr_out, vout1); + vst1q_f32(ptr_out + 4, vout2); + vst1q_f32(ptr_out + 8, vout3); + vst1q_f32(ptr_out + 12, vout4); + + ptr_out += 16; + ptr_bias += 16; + } + for (int i = 0; i < remain; ++i) { + *ptr_out += *(ptr_bias++); + *ptr_out = *ptr_out > 0.f ? *ptr_out : 0.f; + ptr_out++; + } } -#if 0 - if (cnt > 0) { - asm( - "1: \n" - "vld1.32 {d0-d1}, [%[ptr_out]] @ load data\n" - "vld1.32 {d2-d3}, [%[ptr_bias]]! @ load data\n" - "vadd.f32 q2, q0, q1 @ add bias\n" - "vst1.32 {d4-d5}, [%[ptr_out]]! @ store result\n" - "subs %[cnt], #1 @ loop count -1\n" - "bne 1b @ jump to main loop\n" - :[ptr_out] "+r"(ptr_out), [ptr_bias] "+r"(ptr_bias), \ - [cnt] "+r"(cnt) - : - :"q0", "q1", "q2" - ); - } -#endif - for (int i = 0; i < remain; ++i) { - *(ptr_out++) += *(ptr_bias++); + } else { + for (int j = 0; j < num; ++j) { + const float *ptr_bias = bias; + float *ptr_out = out + j * channel; + + for (int i = 0; i < cnt; ++i) { + float32x4_t vin1 = vld1q_f32(ptr_out); + float32x4_t vb1 = vld1q_f32(ptr_bias); + + float32x4_t vin2 = vld1q_f32(ptr_out + 4); + float32x4_t vb2 = vld1q_f32(ptr_bias + 4); + + float32x4_t vin3 = vld1q_f32(ptr_out + 8); + float32x4_t vb3 = vld1q_f32(ptr_bias + 8); + + float32x4_t vin4 = vld1q_f32(ptr_out + 12); + float32x4_t vb4 = vld1q_f32(ptr_bias + 12); + + float32x4_t vout1 = vaddq_f32(vin1, vb1); + float32x4_t vout2 = vaddq_f32(vin2, vb2); + float32x4_t vout3 = vaddq_f32(vin3, vb3); + float32x4_t vout4 = vaddq_f32(vin4, vb4); + + vst1q_f32(ptr_out, vout1); + vst1q_f32(ptr_out + 4, vout2); + vst1q_f32(ptr_out + 8, vout3); + vst1q_f32(ptr_out + 12, vout4); + + ptr_out += 16; + ptr_bias += 16; + } + for (int i = 0; i < remain; ++i) { + *(ptr_out++) += *(ptr_bias++); + } } } } template <> -void fill_bias_fc(int *out, const int *bias, int num, int channel) { +void fill_bias_fc( + int *out, const int *bias, int num, int channel, bool flag_relu) { int cnt = channel >> 4; int remain = channel & 15; - - for (int j = 0; j < num; ++j) { - const int *ptr_bias = bias; - int *ptr_out = out + j * channel; - - int32x4_t vout1; - int32x4_t vout2; - int32x4_t vout3; - int32x4_t vout4; - - for (int i = 0; i < cnt; ++i) { - int32x4_t vin1 = vld1q_s32(ptr_out); - int32x4_t vb1 = vld1q_s32(ptr_bias); - - int32x4_t vin2 = vld1q_s32(ptr_out + 4); - int32x4_t vb2 = vld1q_s32(ptr_bias + 4); - - int32x4_t vin3 = vld1q_s32(ptr_out + 8); - int32x4_t vb3 = vld1q_s32(ptr_bias + 8); - - int32x4_t vin4 = vld1q_s32(ptr_out + 12); - int32x4_t vb4 = vld1q_s32(ptr_bias + 12); - - vout1 = vaddq_s32(vin1, vb1); - vout2 = vaddq_s32(vin2, vb2); - vout3 = vaddq_s32(vin3, vb3); - vout4 = vaddq_s32(vin4, vb4); - - vst1q_s32(ptr_out, vout1); - vst1q_s32(ptr_out + 4, vout2); - vst1q_s32(ptr_out + 8, vout3); - vst1q_s32(ptr_out + 12, vout4); - - ptr_out += 16; - ptr_bias += 16; - } - -#if 0 - if (cnt > 0) { - asm( - "1: \n" - "vld1.32 {d0-d1}, [%[ptr_out]] @ load data\n" - "vld1.32 {d2-d3}, [%[ptr_bias]]! @ load data\n" - "vadd.s32 q2, q0, q1 @ add bias\n" - "vst1.32 {d4-d5}, [%[ptr_out]]! @ store result\n" - "subs %[cnt], #1 @ loop count -1\n" - "bne 1b @ jump to main loop\n" - :[ptr_out] "+r"(ptr_out), [ptr_bias] "+r"(ptr_bias), \ - [cnt] "+r"(cnt) - : - :"q0", "q1", "q2" - ); + if (flag_relu) { + for (int j = 0; j < num; ++j) { + const int *ptr_bias = bias; + int *ptr_out = out + j * channel; + + int32x4_t vzero = vdupq_n_s32(0); + + for (int i = 0; i < cnt; ++i) { + int32x4_t vin1 = vld1q_s32(ptr_out); + int32x4_t vb1 = vld1q_s32(ptr_bias); + + int32x4_t vin2 = vld1q_s32(ptr_out + 4); + int32x4_t vb2 = vld1q_s32(ptr_bias + 4); + + int32x4_t vin3 = vld1q_s32(ptr_out + 8); + int32x4_t vb3 = vld1q_s32(ptr_bias + 8); + + int32x4_t vin4 = vld1q_s32(ptr_out + 12); + int32x4_t vb4 = vld1q_s32(ptr_bias + 12); + + int32x4_t vout1 = vaddq_s32(vin1, vb1); + int32x4_t vout2 = vaddq_s32(vin2, vb2); + int32x4_t vout3 = vaddq_s32(vin3, vb3); + int32x4_t vout4 = vaddq_s32(vin4, vb4); + + vout1 = vmaxq_s32(vout1, vzero); + vout2 = vmaxq_s32(vout2, vzero); + vout3 = vmaxq_s32(vout3, vzero); + vout4 = vmaxq_s32(vout4, vzero); + + vst1q_s32(ptr_out, vout1); + vst1q_s32(ptr_out + 4, vout2); + vst1q_s32(ptr_out + 8, vout3); + vst1q_s32(ptr_out + 12, vout4); + + ptr_out += 16; + ptr_bias += 16; + } + for (int i = 0; i < remain; ++i) { + *ptr_out += *(ptr_bias++); + *ptr_out = *ptr_out > 0 ? *ptr_out : 0; + ptr_out++; + } } -#endif - for (int i = 0; i < remain; ++i) { - *(ptr_out++) += *(ptr_bias++); + } else { + for (int j = 0; j < num; ++j) { + const int *ptr_bias = bias; + int *ptr_out = out + j * channel; + + int32x4_t vout1; + int32x4_t vout2; + int32x4_t vout3; + int32x4_t vout4; + + for (int i = 0; i < cnt; ++i) { + int32x4_t vin1 = vld1q_s32(ptr_out); + int32x4_t vb1 = vld1q_s32(ptr_bias); + + int32x4_t vin2 = vld1q_s32(ptr_out + 4); + int32x4_t vb2 = vld1q_s32(ptr_bias + 4); + + int32x4_t vin3 = vld1q_s32(ptr_out + 8); + int32x4_t vb3 = vld1q_s32(ptr_bias + 8); + + int32x4_t vin4 = vld1q_s32(ptr_out + 12); + int32x4_t vb4 = vld1q_s32(ptr_bias + 12); + + vout1 = vaddq_s32(vin1, vb1); + vout2 = vaddq_s32(vin2, vb2); + vout3 = vaddq_s32(vin3, vb3); + vout4 = vaddq_s32(vin4, vb4); + + vst1q_s32(ptr_out, vout1); + vst1q_s32(ptr_out + 4, vout2); + vst1q_s32(ptr_out + 8, vout3); + vst1q_s32(ptr_out + 12, vout4); + + ptr_out += 16; + ptr_bias += 16; + } + for (int i = 0; i < remain; ++i) { + *(ptr_out++) += *(ptr_bias++); + } } } } diff --git a/lite/backends/arm/math/funcs.h b/lite/backends/arm/math/funcs.h index d8ef6ff47d0392ac15caf2d94b7c53ff63659da2..e975160c97b6e7396ab208805a4d685586ac00c8 100644 --- a/lite/backends/arm/math/funcs.h +++ b/lite/backends/arm/math/funcs.h @@ -39,16 +39,19 @@ #include "lite/backends/arm/math/im2sequence.h" #include "lite/backends/arm/math/increment.h" #include "lite/backends/arm/math/interpolate.h" +#include "lite/backends/arm/math/layout.h" #include "lite/backends/arm/math/lrn.h" #include "lite/backends/arm/math/negative.h" #include "lite/backends/arm/math/norm.h" #include "lite/backends/arm/math/packed_sgemm.h" +#include "lite/backends/arm/math/packed_sgemm_c4.h" #include "lite/backends/arm/math/pad2d.h" #include "lite/backends/arm/math/pooling.h" #include "lite/backends/arm/math/power.h" #include "lite/backends/arm/math/prior_box.h" #include "lite/backends/arm/math/reduce_max.h" #include "lite/backends/arm/math/reduce_mean.h" +#include "lite/backends/arm/math/reduce_prod.h" #include "lite/backends/arm/math/scale.h" #include "lite/backends/arm/math/sequence_expand.h" #include "lite/backends/arm/math/sequence_pool.h" @@ -59,6 +62,7 @@ #include "lite/backends/arm/math/slice.h" #include "lite/backends/arm/math/softmax.h" #include "lite/backends/arm/math/split.h" +#include "lite/backends/arm/math/split_merge_lod_tenosr.h" #include "lite/backends/arm/math/stack.h" #include "lite/backends/arm/math/topk.h" #include "lite/backends/arm/math/yolo_box.h" @@ -352,7 +356,8 @@ inline float32x4_t pow_ps(float32x4_t a, float32x4_t b) { } template -void fill_bias_fc(T* tensor, const T* bias, int num, int channel); +void fill_bias_fc( + T* tensor, const T* bias, int num, int channel, bool flag_relu); template inline float32x4_t vactive_f32(const float32x4_t& x) { diff --git a/lite/backends/arm/math/gru_utils.h b/lite/backends/arm/math/gru_utils.h index 9bef1889b83d1e212c928562f777ba4706c3436a..9d57f81fc584b56ef5552b4fb2e079f3b62390e0 100644 --- a/lite/backends/arm/math/gru_utils.h +++ b/lite/backends/arm/math/gru_utils.h @@ -383,6 +383,8 @@ struct GRUUnitFunctor { const lite_api::ActivationType active_gate, bool origin_mode, ARMContext* ctx) { + operators::ActivationParam act_param; + act_param.has_active = false; if (value.prev_out_value) { sgemm(false, false, @@ -399,7 +401,7 @@ struct GRUUnitFunctor { frame_size * 3, nullptr, false, - false, + act_param, ctx); } gru_unit_reset_act(active_gate, value, frame_size, batch_size); @@ -420,7 +422,7 @@ struct GRUUnitFunctor { frame_size * 3, nullptr, false, - false, + act_param, ctx); } diff --git a/lite/backends/arm/math/interpolate.cc b/lite/backends/arm/math/interpolate.cc index e9e18043dfc09001ebba23f952a59474630e54aa..1c53142fc53bc785efcbf28fa007d403ad99ab70 100644 --- a/lite/backends/arm/math/interpolate.cc +++ b/lite/backends/arm/math/interpolate.cc @@ -477,17 +477,23 @@ void nearest_interp(const float* src, float scale_h_new = (with_align) ? (static_cast(h_in - 1) / (h_out - 1)) : (static_cast(h_in) / (h_out)); - -#pragma omp parallel for collapse(2) schedule(static) - for (int h = 0; h < h_out; ++h) { - for (int w = 0; w < w_out; ++w) { - int near_x = (with_align) ? static_cast(scale_w_new * w + 0.5) - : static_cast(scale_w_new * w); - int near_y = (with_align) ? static_cast(scale_h_new * h + 0.5) - : static_cast(scale_h_new * h); - near_x = near_x < 0 ? 0 : near_x; - near_y = near_y < 0 ? 0 : near_y; - dst[h * w_out + w] = src[near_y * w_in + near_x]; + if (with_align) { + for (int h = 0; h < h_out; ++h) { + float* dst_p = dst + h * w_out; + int near_y = static_cast(scale_h_new * h + 0.5); + for (int w = 0; w < w_out; ++w) { + int near_x = static_cast(scale_w_new * w + 0.5); + *dst_p++ = src[near_y * w_in + near_x]; + } + } + } else { + for (int h = 0; h < h_out; ++h) { + float* dst_p = dst + h * w_out; + int near_y = static_cast(scale_h_new * h); + for (int w = 0; w < w_out; ++w) { + int near_x = static_cast(scale_w_new * w); + *dst_p++ = src[near_y * w_in + near_x]; + } } } } @@ -520,9 +526,9 @@ void interpolate(lite::Tensor* X, } auto out_size = OutSize; if (out_size != nullptr) { - auto out_size_data = get_new_data_from_tensor(out_size); - out_height = static_cast(out_size_data[0]); - out_width = static_cast(out_size_data[1]); + auto out_size_data = get_new_data_from_tensor(out_size); + out_height = out_size_data[0]; + out_width = out_size_data[1]; } } float height_scale = scale; @@ -544,8 +550,10 @@ void interpolate(lite::Tensor* X, int out_w = Out->dims()[3]; int spatial_in = in_h * in_w; int spatial_out = out_h * out_w; - for (int i = 0; i < count; ++i) { - if ("Bilinear" == interpolate_type) { + + if ("Bilinear" == interpolate_type) { +#pragma omp parallel for + for (int i = 0; i < count; ++i) { bilinear_interp(din + spatial_in * i, in_w, in_h, @@ -555,7 +563,10 @@ void interpolate(lite::Tensor* X, 1.f / width_scale, 1.f / height_scale, with_align); - } else if ("Nearest" == interpolate_type) { + } + } else if ("Nearest" == interpolate_type) { +#pragma omp parallel for + for (int i = 0; i < count; ++i) { nearest_interp(din + spatial_in * i, in_w, in_h, diff --git a/lite/backends/arm/math/layout.cc b/lite/backends/arm/math/layout.cc new file mode 100644 index 0000000000000000000000000000000000000000..fd9126ab48c8f829c82d0c78a338074c695f0b9c --- /dev/null +++ b/lite/backends/arm/math/layout.cc @@ -0,0 +1,668 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/backends/arm/math/layout.h" +#include +#include +#include "lite/backends/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { +#ifdef __aarch64__ +#define TRANS_C4 \ + "ld1 {v0.4s}, [%[din0_ptr]] \n" \ + "ld1 {v1.4s}, [%[din1_ptr]] \n" \ + "ld1 {v2.4s}, [%[din2_ptr]] \n" \ + "ld1 {v3.4s}, [%[din3_ptr]] \n" \ + \ + "1: \n" \ + "trn1 v4.4s, v0.4s, v1.4s \n" /*00 10 02 12 */ \ + "trn1 v5.4s, v2.4s, v3.4s \n" /*20 30 22 32 */ \ + "trn2 v6.4s, v0.4s, v1.4s \n" /*01 11 03 13 */ \ + "trn2 v7.4s, v2.4s, v3.4s \n" /*21 31 23 33 */ \ + \ + "add %[din0_ptr], %[din0_ptr], %[stride] \n" /* din+=c*size*/ \ + "add %[din1_ptr], %[din1_ptr], %[stride] \n" /* din+=c*size*/ \ + "add %[din2_ptr], %[din2_ptr], %[stride] \n" /* din+=c*size*/ \ + "add %[din3_ptr], %[din3_ptr], %[stride] \n" /* din+=c*size*/ \ + \ + "trn1 v8.2d, v4.2d, v5.2d \n" /*00 10 20 30 */ \ + "trn1 v9.2d, v6.2d, v7.2d \n" /*01 11 21 31 */ \ + "trn2 v10.2d, v4.2d, v5.2d \n" /*02 12 22 32 */ \ + "trn2 v11.2d, v6.2d, v7.2d \n" /*03 13 23 33 */ \ + \ + "ld1 {v0.4s}, [%[din0_ptr]] \n" \ + "ld1 {v1.4s}, [%[din1_ptr]] \n" \ + "ld1 {v2.4s}, [%[din2_ptr]] \n" \ + "ld1 {v3.4s}, [%[din3_ptr]] \n" \ + \ + "subs %w[cnt], %w[cnt], #1 \n" \ + "str q8, [%[out0_ptr]], #16 \n" \ + "str q9, [%[out1_ptr]], #16 \n" \ + "str q10, [%[out2_ptr]], #16 \n" \ + "str q11, [%[out3_ptr]], #16 \n" \ + "bne 1b \n" + +#define TRANS_C8 \ + "1: \n" \ + "ld1 {v0.8b}, [%[din0_ptr]] \n" \ + "ld1 {v1.8b}, [%[din1_ptr]] \n" \ + "ld1 {v2.8b}, [%[din2_ptr]] \n" \ + "ld1 {v3.8b}, [%[din3_ptr]] \n" \ + \ + "add %[din0_ptr], %[din0_ptr], %[stride_w] \n" /* din+=c*size*/ \ + "add %[din1_ptr], %[din1_ptr], %[stride_w] \n" /* din+=c*size*/ \ + "add %[din2_ptr], %[din2_ptr], %[stride_w] \n" /* din+=c*size*/ \ + "add %[din3_ptr], %[din3_ptr], %[stride_w] \n" /* din+=c*size*/ \ + \ + "trn1 v8.8b, v0.8b, v1.8b \n" /*00 10 02 12 04 14 06 16 */ \ + "trn1 v9.8b, v2.8b, v3.8b \n" /*20 30 22 32 */ \ + "trn2 v12.8b, v0.8b, v1.8b \n" /*01 11 03 13 05 15 07 17 */ \ + "trn2 v13.8b, v2.8b, v3.8b \n" /*21 31 23 33 */ \ + \ + "ld1 {v4.8b}, [%[din0_ptr]] \n" \ + "ld1 {v5.8b}, [%[din1_ptr]] \n" \ + "ld1 {v6.8b}, [%[din2_ptr]] \n" \ + "ld1 {v7.8b}, [%[din3_ptr]] \n" \ + \ + "trn1 v10.8b, v4.8b, v5.8b \n" /*40 50 42 52 */ \ + "trn1 v11.8b, v6.8b, v7.8b \n" /*60 70 62 72 */ \ + "trn2 v14.8b, v4.8b, v5.8b \n" /*41 51 43 53 */ \ + "trn2 v15.8b, v6.8b, v7.8b \n" /*61 71 63 73 */ \ + \ + "trn1 v0.4h, v8.4h, v9.4h \n" /*00 10 20 30 04 14 24 34*/ \ + "trn1 v2.4h, v12.4h, v13.4h \n" /*01 11 21 31 05 15 25 35*/ \ + "trn1 v1.4h, v10.4h, v11.4h \n" /*40 50 60 70 44 54 64 74*/ \ + "trn1 v3.4h, v14.4h, v15.4h \n" /*41 51 61 71 45 55 65 75*/ \ + \ + "trn2 v4.4h, v8.4h, v9.4h \n" /*02 10 20 30 06 14 24 34*/ \ + "trn2 v6.4h, v12.4h, v13.4h \n" /*03 11 21 31 07 15 25 35*/ \ + "trn2 v5.4h, v10.4h, v11.4h \n" /*42 50 60 70 46 54 64 74*/ \ + "trn2 v7.4h, v14.4h, v15.4h \n" /*43 51 61 71 47 55 65 75*/ \ + \ + "trn1 v8.2s, v0.2s, v1.2s \n" /*00 10 20 30 40 50 60 70*/ \ + "trn1 v9.2s, v2.2s, v3.2s \n" /*01 11 21 31 41 51 61 71*/ \ + "trn1 v10.2s, v4.2s, v5.2s \n" /*02 12 22 32 42 50 60 70*/ \ + "trn1 v11.2s, v6.2s, v7.2s \n" /*03 13 23 33 41 51 61 71*/ \ + \ + "trn2 v12.2s, v0.2s, v1.2s \n" /*04 14 24 34 44 54 64 74*/ \ + "trn2 v13.2s, v2.2s, v3.2s \n" /*05 15 25 35 45 55 65 75*/ \ + "trn2 v14.2s, v4.2s, v5.2s \n" /*06 16 22 32 42 50 60 70*/ \ + "trn2 v15.2s, v6.2s, v7.2s \n" /*07 17 23 33 41 51 61 71*/ \ + \ + "add %[din0_ptr], %[din0_ptr], %[stride_w] \n" /* din+=c*size*/ \ + "add %[din1_ptr], %[din1_ptr], %[stride_w] \n" /* din+=c*size*/ \ + "add %[din2_ptr], %[din2_ptr], %[stride_w] \n" /* din+=c*size*/ \ + "add %[din3_ptr], %[din3_ptr], %[stride_w] \n" /* din+=c*size*/ \ + \ + "subs %w[cnt], %w[cnt], #1 \n" \ + "st1 {v8.8b}, [%[out0_ptr]], #8 \n" \ + "st1 {v9.8b}, [%[out1_ptr]], #8 \n" \ + "st1 {v10.8b}, [%[out2_ptr]], #8 \n" \ + "st1 {v11.8b}, [%[out3_ptr]], #8 \n" \ + \ + "st1 {v11.8b}, [%[out4_ptr]], #8 \n" \ + "st1 {v12.8b}, [%[out5_ptr]], #8 \n" \ + "st1 {v13.8b}, [%[out6_ptr]], #8 \n" \ + "st1 {v14.8b}, [%[out7_ptr]], #8 \n" \ + "bne 1b \n" + +#else +#define TRANS_C4 \ + "1: \n" \ + "vld1.32 {d0-d1}, [%[din0_ptr]] \n" \ + "vld1.32 {d2-d3}, [%[din1_ptr]] \n" \ + "vld1.32 {d4-d5}, [%[din2_ptr]] \n" \ + "vld1.32 {d6-d7}, [%[din3_ptr]] \n" \ + \ + "vtrn.32 q0, q1 \n" /*00 10 02 12 01 11 03 13*/ \ + "vtrn.32 q2, q3 \n" /*20 30 22 32 21 31 23 33 */ \ + \ + "add %[din0_ptr], %[din0_ptr], %[stride] \n" /* din+=c*size*/ \ + "add %[din1_ptr], %[din1_ptr], %[stride] \n" /* din+=c*size*/ \ + "add %[din2_ptr], %[din2_ptr], %[stride] \n" /* din+=c*size*/ \ + "add %[din3_ptr], %[din3_ptr], %[stride] \n" /* din+=c*size*/ \ + "vswp d1, d4 \n" \ + "vswp d3, d6 \n" \ + \ + "subs %[cnt], %[cnt], #1 \n" \ + "vst1.32 {d0-d1}, [%[out0_ptr]]! \n" \ + "vst1.32 {d2-d3}, [%[out1_ptr]]! \n" \ + "vst1.32 {d4-d5}, [%[out2_ptr]]! \n" \ + "vst1.32 {d6-d7}, [%[out3_ptr]]! \n" \ + "bne 1b \n" + +#define TRANS_C8 \ + "1: \n" \ + "vld1.8 d0, [%[din0_ptr]] \n" \ + "vld1.8 d1, [%[din1_ptr]] \n" \ + "vld1.8 d2, [%[din2_ptr]] \n" \ + "vld1.8 d3, [%[din3_ptr]] \n" \ + \ + "add %[din0_ptr], %[din0_ptr], %[stride_w] \n" /* din+=c*size*/ \ + "add %[din1_ptr], %[din1_ptr], %[stride_w] \n" /* din+=c*size*/ \ + "add %[din2_ptr], %[din2_ptr], %[stride_w] \n" /* din+=c*size*/ \ + "add %[din3_ptr], %[din3_ptr], %[stride_w] \n" /* din+=c*size*/ \ + \ + "vtrn.8 d0, d1 \n" /*00 10 02 12 04 14 06 16*/ \ + "vtrn.8 d2, d3 \n" /*20 30 22 32 24 34 26 36 */ \ + \ + "vld1.8 d4, [%[din0_ptr]] \n" \ + "vld1.8 d5, [%[din1_ptr]] \n" \ + "vld1.8 d6, [%[din2_ptr]] \n" \ + "vld1.8 d7, [%[din3_ptr]] \n" \ + \ + "vtrn.16 d0, d2 \n" /*00 10 20 30 04 14 24 34*/ \ + "vtrn.16 d1, d3 \n" /* 01 11 21 31 05 15 25 35 */ \ + "vtrn.8 d4, d5 \n" /*40 50 02 12 04 14 06 16*/ \ + "vtrn.8 d6, d7 \n" /*60 70 22 32 24 34 26 36 */ \ + \ + "add %[din0_ptr], %[din0_ptr], %[stride_w] \n" /* din+=c*size*/ \ + "add %[din1_ptr], %[din1_ptr], %[stride_w] \n" /* din+=c*size*/ \ + "add %[din2_ptr], %[din2_ptr], %[stride_w] \n" /* din+=c*size*/ \ + "add %[din3_ptr], %[din3_ptr], %[stride_w] \n" /* din+=c*size*/ \ + \ + "vtrn.16 d4, d6 \n" /*40 50 60 70 04 14 24 34*/ \ + "vtrn.16 d5, d7 \n" /* 41 51 61 71 05 15 25 35 */ \ + \ + "vtrn.32 d0, d4 \n" /*00 10 20 30 40 50 60 70*/ \ + "vtrn.32 d1, d5 \n" /* 01 11 21 31 41 51 61 71 */ \ + "vtrn.32 d2, d6 \n" /*02 12 22 32 42 52 62 72*/ \ + "vtrn.32 d3, d7 \n" /* 03 11 21 33 43 53 63 73 */ \ + \ + "subs %[cnt], %[cnt], #1 \n" \ + "vst1.8 {d0}, [%[out0_ptr]]! \n" \ + "vst1.8 {d1}, [%[out1_ptr]]! \n" \ + "vst1.8 {d2}, [%[out2_ptr]]! \n" \ + "vst1.8 {d3}, [%[out3_ptr]]! \n" \ + "vst1.8 {d4}, [%[out4_ptr]]! \n" \ + "vst1.8 {d5}, [%[out5_ptr]]! \n" \ + "vst1.8 {d6}, [%[out6_ptr]]! \n" \ + "vst1.8 {d7}, [%[out7_ptr]]! \n" \ + "bne 1b \n" + +#endif +template <> +void NCHW2NHWC(int N, int C, int size, const float* X, float* Y) { + int cnt = C >> 2; + int remain = C % 4; + int sum = C * size; + int stride = size << 4; // 4 * size + int stride_w = stride >> 2; + for (int n = 0; n < N; n++) { + const float* din = X + n * sum; + float* dout = Y + n * sum; + int s = 0; +#pragma omp parallel for + for (s = 0; s < size - 3; s += 4) { + const float* din0_ptr = din + s; + const float* din1_ptr = din0_ptr + size; + const float* din2_ptr = din1_ptr + size; + const float* din3_ptr = din2_ptr + size; + float* out0_ptr = dout + s * C; + float* out1_ptr = out0_ptr + C; + float* out2_ptr = out1_ptr + C; + float* out3_ptr = out2_ptr + C; + int cnt_num = cnt; + if (cnt_num > 0) { +#ifdef __aarch64__ + asm volatile(TRANS_C4 + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [din3_ptr] "+r"(din3_ptr), + [out0_ptr] "+r"(out0_ptr), + [out1_ptr] "+r"(out1_ptr), + [out2_ptr] "+r"(out2_ptr), + [out3_ptr] "+r"(out3_ptr), + [cnt] "+r"(cnt_num), + [stride] "+r"(stride) + : + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12"); +#else + asm volatile(TRANS_C4 + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [din3_ptr] "+r"(din3_ptr), + [out0_ptr] "+r"(out0_ptr), + [out1_ptr] "+r"(out1_ptr), + [out2_ptr] "+r"(out2_ptr), + [out3_ptr] "+r"(out3_ptr), + [cnt] "+r"(cnt_num), + [stride] "+r"(stride) + : + : "cc", "memory", "q0", "q1", "q2", "q3"); +#endif + } + for (int i = 0; i < remain; i++) { + const float* ptr = din0_ptr; + *out0_ptr++ = *ptr++; + *out1_ptr++ = *ptr++; + *out2_ptr++ = *ptr++; + *out3_ptr++ = *ptr++; + din0_ptr += size; + } + } + // remain size + for (; s < size; s++) { + const float* din0_ptr = din + s; + const float* din1_ptr = din0_ptr + size; + const float* din2_ptr = din1_ptr + size; + const float* din3_ptr = din2_ptr + size; + float* out0_ptr = dout + s * C; + for (int i = 0; i < cnt; i++) { + *out0_ptr++ = *din0_ptr; + *out0_ptr++ = *din1_ptr; + *out0_ptr++ = *din2_ptr; + *out0_ptr++ = *din3_ptr; + din0_ptr += stride_w; + din1_ptr += stride_w; + din2_ptr += stride_w; + din3_ptr += stride_w; + } + for (int i = 0; i < remain; i++) { + *out0_ptr++ = *din0_ptr; + din0_ptr += size; + } + } + } +} +template <> +void NCHW2NHWC(int N, int C, int size, const int8_t* X, int8_t* Y) { + int cnt = C >> 3; + int remain = C % 8; + int sum = C * size; + int stride = size << 3; // 8 * size + int stride_w = size << 4; // 4 * size * 4 + for (int n = 0; n < N; n++) { + const int8_t* din = X + n * sum; + int8_t* dout = Y + n * sum; + int s = 0; +#pragma omp parallel for + for (s = 0; s < size - 7; s += 8) { + const int8_t* din0_ptr = din + s; + const int8_t* din1_ptr = din0_ptr + size; + const int8_t* din2_ptr = din1_ptr + size; + const int8_t* din3_ptr = din2_ptr + size; + int8_t* out0_ptr = dout + s * C; + int8_t* out1_ptr = out0_ptr + C; + int8_t* out2_ptr = out1_ptr + C; + int8_t* out3_ptr = out2_ptr + C; + int8_t* out4_ptr = out3_ptr + C; + int8_t* out5_ptr = out4_ptr + C; + int8_t* out6_ptr = out5_ptr + C; + int8_t* out7_ptr = out6_ptr + C; + int cnt_num = cnt; + if (cnt_num > 0) { +#ifdef __aarch64__ + asm volatile(TRANS_C8 + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [din3_ptr] "+r"(din3_ptr), + [out0_ptr] "+r"(out0_ptr), + [out1_ptr] "+r"(out1_ptr), + [out2_ptr] "+r"(out2_ptr), + [out3_ptr] "+r"(out3_ptr), + [out4_ptr] "+r"(out4_ptr), + [out5_ptr] "+r"(out5_ptr), + [out6_ptr] "+r"(out6_ptr), + [out7_ptr] "+r"(out7_ptr), + [cnt] "+r"(cnt_num), + [stride_w] "+r"(stride_w) + : + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15"); +#else + asm volatile(TRANS_C8 + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [din3_ptr] "+r"(din3_ptr), + [out0_ptr] "+r"(out0_ptr), + [out1_ptr] "+r"(out1_ptr), + [out2_ptr] "+r"(out2_ptr), + [out3_ptr] "+r"(out3_ptr), + [out4_ptr] "+r"(out4_ptr), + [out5_ptr] "+r"(out5_ptr), + [out6_ptr] "+r"(out6_ptr), + [out7_ptr] "+r"(out7_ptr), + [cnt] "+r"(cnt_num), + [stride_w] "+r"(stride_w) + : + : "cc", "memory", "q0", "q1", "q2", "q3"); +#endif + } + // const int8_t* din_ptr = din + 8 * cnt * size + s; // remain channel + for (int i = 0; i < remain; i++) { + const int8_t* ptr = din0_ptr; + *out0_ptr = *ptr++; + *out1_ptr = *ptr++; + *out2_ptr = *ptr++; + *out3_ptr = *ptr++; + din0_ptr += size; + *out4_ptr = *ptr++; + *out5_ptr = *ptr++; + *out6_ptr = *ptr++; + *out7_ptr = *ptr++; + } + } + // remain size + for (; s < size; s++) { + const int8_t* din0_ptr = din + s; + const int8_t* din1_ptr = din0_ptr + size; + const int8_t* din2_ptr = din1_ptr + size; + const int8_t* din3_ptr = din2_ptr + size; + const int8_t* din4_ptr = din3_ptr + size; + const int8_t* din5_ptr = din4_ptr + size; + const int8_t* din6_ptr = din5_ptr + size; + const int8_t* din7_ptr = din6_ptr + size; + int8_t* out0_ptr = dout + s * C; + for (int i = 0; i < cnt; i++) { + *out0_ptr++ = *din0_ptr; + *out0_ptr++ = *din1_ptr; + *out0_ptr++ = *din2_ptr; + *out0_ptr++ = *din3_ptr; + *out0_ptr++ = *din4_ptr; + *out0_ptr++ = *din5_ptr; + *out0_ptr++ = *din6_ptr; + *out0_ptr++ = *din7_ptr; + din0_ptr += stride; + din1_ptr += stride; + din2_ptr += stride; + din3_ptr += stride; + din4_ptr += stride; + din5_ptr += stride; + din6_ptr += stride; + din7_ptr += stride; + } + for (int i = 0; i < remain; i++) { + *out0_ptr++ = *din0_ptr; + din0_ptr += size; + } + } + } +} +template <> +void NHWC2NCHW(int N, int C, int size, const float* X, float* Y) { + int cnt = size >> 2; + int remain = size % 4; + int sum = C * size; + int stride = C << 4; // 4 * size + int stride_w = C << 2; + for (int n = 0; n < N; n++) { + const float* din = X + n * sum; + float* dout = Y + n * sum; + int s = 0; +#pragma omp parallel for + for (s = 0; s < C - 3; s += 4) { + const float* din0_ptr = din + s; + const float* din1_ptr = din0_ptr + C; + const float* din2_ptr = din1_ptr + C; + const float* din3_ptr = din2_ptr + C; + float* out0_ptr = dout + s * size; + float* out1_ptr = out0_ptr + size; + float* out2_ptr = out1_ptr + size; + float* out3_ptr = out2_ptr + size; + int cnt_num = cnt; + if (cnt_num > 0) { +#ifdef __aarch64__ + asm volatile(TRANS_C4 + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [din3_ptr] "+r"(din3_ptr), + [out0_ptr] "+r"(out0_ptr), + [out1_ptr] "+r"(out1_ptr), + [out2_ptr] "+r"(out2_ptr), + [out3_ptr] "+r"(out3_ptr), + [cnt] "+r"(cnt_num), + [stride] "+r"(stride) + : + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11"); +#else + asm volatile(TRANS_C4 + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [din3_ptr] "+r"(din3_ptr), + [out0_ptr] "+r"(out0_ptr), + [out1_ptr] "+r"(out1_ptr), + [out2_ptr] "+r"(out2_ptr), + [out3_ptr] "+r"(out3_ptr), + [cnt] "+r"(cnt_num), + [stride] "+r"(stride) + : + : "cc", "memory", "q0", "q1", "q2", "q3"); +#endif + } + for (int i = 0; i < remain; i++) { + const float* ptr = din0_ptr; + *out0_ptr++ = *ptr++; + *out1_ptr++ = *ptr++; + *out2_ptr++ = *ptr++; + *out3_ptr++ = *ptr++; + din0_ptr += C; + } + } + // remain size + for (; s < C; s++) { + const float* din0_ptr = din + s; + const float* din1_ptr = din0_ptr + C; + const float* din2_ptr = din1_ptr + C; + const float* din3_ptr = din2_ptr + C; + float* out0_ptr = dout + s * size; + for (int i = 0; i < cnt; i++) { + *out0_ptr++ = *din0_ptr; + *out0_ptr++ = *din1_ptr; + *out0_ptr++ = *din2_ptr; + *out0_ptr++ = *din3_ptr; + din0_ptr += stride_w; + din1_ptr += stride_w; + din2_ptr += stride_w; + din3_ptr += stride_w; + } + for (int i = 0; i < remain; i++) { + *out0_ptr++ = *din0_ptr; + din0_ptr += C; + } + } + } +} +template <> +void NHWC2NCHW(int N, int C, int size, const int8_t* X, int8_t* Y) { + int cnt = size >> 3; + int remain = size % 8; + int sum = C * size; + int stride = C << 3; // 8 * size + int stride_w = C << 4; // 4 * size + for (int n = 0; n < N; n++) { + const int8_t* din = X + n * sum; + int8_t* dout = Y + n * sum; + int s = 0; +#pragma omp parallel for + for (s = 0; s < C - 7; s += 8) { + const int8_t* din0_ptr = din + s; + const int8_t* din1_ptr = din0_ptr + C; + const int8_t* din2_ptr = din1_ptr + C; + const int8_t* din3_ptr = din2_ptr + C; + const int8_t* din4_ptr = din3_ptr + C; + const int8_t* din5_ptr = din4_ptr + C; + const int8_t* din6_ptr = din5_ptr + C; + const int8_t* din7_ptr = din6_ptr + C; + int8_t* out0_ptr = dout + s * size; + int8_t* out1_ptr = out0_ptr + size; + int8_t* out2_ptr = out1_ptr + size; + int8_t* out3_ptr = out2_ptr + size; + int8_t* out4_ptr = out3_ptr + size; + int8_t* out5_ptr = out4_ptr + size; + int8_t* out6_ptr = out5_ptr + size; + int8_t* out7_ptr = out6_ptr + size; + int cnt_num = cnt; + if (cnt_num > 0) { +#ifdef __aarch64__ + asm volatile(TRANS_C8 + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [din3_ptr] "+r"(din3_ptr), + [out0_ptr] "+r"(out0_ptr), + [out1_ptr] "+r"(out1_ptr), + [out2_ptr] "+r"(out2_ptr), + [out3_ptr] "+r"(out3_ptr), + [out4_ptr] "+r"(out4_ptr), + [out5_ptr] "+r"(out5_ptr), + [out6_ptr] "+r"(out6_ptr), + [out7_ptr] "+r"(out7_ptr), + [cnt] "+r"(cnt_num), + [stride_w] "+r"(stride_w) + : + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15"); +#else + asm volatile(TRANS_C8 + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [din3_ptr] "+r"(din3_ptr), + [out0_ptr] "+r"(out0_ptr), + [out1_ptr] "+r"(out1_ptr), + [out2_ptr] "+r"(out2_ptr), + [out3_ptr] "+r"(out3_ptr), + [out4_ptr] "+r"(out4_ptr), + [out5_ptr] "+r"(out5_ptr), + [out6_ptr] "+r"(out6_ptr), + [out7_ptr] "+r"(out7_ptr), + [cnt] "+r"(cnt_num), + [stride_w] "+r"(stride_w) + : + : "cc", "memory", "q0", "q1", "q2", "q3"); +#endif + } + for (int i = 0; i < remain; i++) { + const int8_t* ptr = din0_ptr; + *out0_ptr++ = *ptr++; + *out1_ptr++ = *ptr++; + *out2_ptr++ = *ptr++; + *out3_ptr++ = *ptr++; + *out4_ptr++ = *ptr++; + *out5_ptr++ = *ptr++; + *out6_ptr++ = *ptr++; + *out7_ptr++ = *ptr++; + din0_ptr += C; + } + } + // remain size + for (; s < C; s++) { + const int8_t* din0_ptr = din + s; + const int8_t* din1_ptr = din0_ptr + C; + const int8_t* din2_ptr = din1_ptr + C; + const int8_t* din3_ptr = din2_ptr + C; + const int8_t* din4_ptr = din3_ptr + C; + const int8_t* din5_ptr = din4_ptr + C; + const int8_t* din6_ptr = din5_ptr + C; + const int8_t* din7_ptr = din6_ptr + C; + int8_t* out0_ptr = dout + s * size; + for (int i = 0; i < cnt; i++) { + *out0_ptr++ = *din0_ptr; + *out0_ptr++ = *din1_ptr; + *out0_ptr++ = *din2_ptr; + *out0_ptr++ = *din3_ptr; + *out0_ptr++ = *din4_ptr; + *out0_ptr++ = *din5_ptr; + *out0_ptr++ = *din6_ptr; + *out0_ptr++ = *din7_ptr; + din0_ptr += stride; + din1_ptr += stride; + din2_ptr += stride; + din3_ptr += stride; + din4_ptr += stride; + din5_ptr += stride; + din6_ptr += stride; + din7_ptr += stride; + } + for (int i = 0; i < remain; i++) { + *out0_ptr++ = *din0_ptr; + din0_ptr += C; + } + } + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/backends/arm/math/layout.h b/lite/backends/arm/math/layout.h new file mode 100644 index 0000000000000000000000000000000000000000..ed0e2f8b78a280c513161a02bb3b3b479008145a --- /dev/null +++ b/lite/backends/arm/math/layout.h @@ -0,0 +1,30 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace paddle { +namespace lite { +namespace arm { +namespace math { +template +void NCHW2NHWC(int N, int C, int HxW, const T* X, T* Y); + +template +void NHWC2NCHW(int N, int C, int HxW, const T* X, T* Y); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/backends/arm/math/packed_sgemm.cc b/lite/backends/arm/math/packed_sgemm.cc index 0d6eed9904902aa9539caf95172b0e4109e11f7d..cb9c049d81aee73b65bacd27a64138779d1532cc 100644 --- a/lite/backends/arm/math/packed_sgemm.cc +++ b/lite/backends/arm/math/packed_sgemm.cc @@ -14,6 +14,7 @@ #include "lite/backends/arm/math/packed_sgemm.h" #include +#include "lite/backends/arm/math/conv_block_utils.h" namespace paddle { namespace lite { @@ -51,8 +52,40 @@ void sgemm_prepacked_8x12(bool is_transB, int ldc, const float *bias, bool has_bias, - bool has_relu, + const operators::ActivationParam act_param, ARMContext *ctx); + +void pack_m4(float *out, + const float *in, + float alpha, + int ldin, + int m0, + int mmax, + int k0, + int kmax); + +void pack_trans_m4(float *out, + const float *in, + float alpha, + int ldin, + int m0, + int mmax, + int k0, + int kmax); +void sgemm_prepacked_4x4(bool is_transB, + int M, + int N, + int K, + const float *A_packed, + const float *B, + int ldb, + float beta, + float *C, + int ldc, + const float *bias, + bool has_bias, + const operators::ActivationParam act_param, + ARMContext *ctx); #else // for kA72 void prepackA_6x8(float *out, @@ -104,7 +137,7 @@ void sgemm_prepacked_6x8(bool is_transB, int ldc, const float *bias, bool has_bias, - bool has_relu, + const operators::ActivationParam act_param, ARMContext *ctx); // for kA73, 4x8 void sgemm_prepacked_4x8(bool is_transB, @@ -119,7 +152,7 @@ void sgemm_prepacked_4x8(bool is_transB, int ldc, const float *bias, bool has_bias, - bool has_relu, + const operators::ActivationParam act_param, ARMContext *ctx); #endif // __aarch64__ @@ -139,13 +172,21 @@ void prepackA(float *out, bool is_trans, ARMContext *ctx) { #ifdef __aarch64__ - if (is_trans) { - prepackA_trans_8x12(out, in, alpha, ldin, m0, mmax, k0, kmax); + if (mmax <= 4) { + if (is_trans) { + pack_trans_m4(out, in, alpha, ldin, m0, mmax, k0, kmax); + } else { + pack_m4(out, in, alpha, ldin, m0, mmax, k0, kmax); + } } else { - prepackA_8x12(out, in, alpha, ldin, m0, mmax, k0, kmax); + if (is_trans) { + prepackA_trans_8x12(out, in, alpha, ldin, m0, mmax, k0, kmax); + } else { + prepackA_8x12(out, in, alpha, ldin, m0, mmax, k0, kmax); + } } #else - if (ctx->arch() == kA73) { + if (ctx->arch() == kA73 || mmax <= 4) { if (is_trans) { prepackA_trans_4x8(out, in, alpha, ldin, m0, mmax, k0, kmax); } else { @@ -209,25 +250,42 @@ void sgemm_prepack(bool is_transB, int ldc, const float *bias, bool has_bias, - bool has_relu, + const operators::ActivationParam act_param, ARMContext *ctx) { #ifdef __aarch64__ - sgemm_prepacked_8x12(is_transB, - M, - N, - K, - A_packed, - B, - ldb, - beta, - C, - ldc, - bias, - has_bias, - has_relu, - ctx); + if (M <= 4) { + sgemm_prepacked_4x4(is_transB, + M, + N, + K, + A_packed, + B, + ldb, + beta, + C, + ldc, + bias, + has_bias, + act_param, + ctx); + } else { + sgemm_prepacked_8x12(is_transB, + M, + N, + K, + A_packed, + B, + ldb, + beta, + C, + ldc, + bias, + has_bias, + act_param, + ctx); + } #else // armv7 - if (ctx->arch() == kA73) { + if (ctx->arch() == kA73 || M <= 4) { sgemm_prepacked_4x8(is_transB, M, N, @@ -240,7 +298,7 @@ void sgemm_prepack(bool is_transB, ldc, bias, has_bias, - has_relu, + act_param, ctx); } else { sgemm_prepacked_6x8(is_transB, @@ -255,7 +313,7 @@ void sgemm_prepack(bool is_transB, ldc, bias, has_bias, - has_relu, + act_param, ctx); } #endif // arm64 @@ -522,6 +580,147 @@ void prepackA_8x12(float *dout, } } } +void pack_m4(float *dout, + const float *inptr, + float alpha, + int ldin, + int m0, + int mmax, + int k0, + int kmax) { + int x_len = kmax - k0; + int stride = x_len * 4; + float zerobuff[x_len]; // NOLINT + memset(zerobuff, 0, sizeof(float) * x_len); + bool has_alpha = fabsf(alpha - 1.f) > 1e-8f; + +#pragma omp parallel for + for (int y = m0; y < mmax; y += 4) { + float *outptr = dout + stride * (y - m0) / 4; + + const float *inptr0 = inptr + y * ldin + k0; + const float *inptr1 = inptr0 + ldin; + const float *inptr2 = inptr1 + ldin; + const float *inptr3 = inptr2 + ldin; + + asm volatile( + "prfm pldl1keep, [%[ptr0]] \n" + "prfm pldl1keep, [%[ptr0], #64] \n" + "prfm pldl1keep, [%[ptr1]] \n" + "prfm pldl1keep, [%[ptr1], #64] \n" + "prfm pldl1keep, [%[ptr2]] \n" + "prfm pldl1keep, [%[ptr2], #64] \n" + "prfm pldl1keep, [%[ptr3]] \n" + "prfm pldl1keep, [%[ptr3], #64] \n" + : + : [ptr0] "r"(inptr0), + [ptr1] "r"(inptr1), + [ptr2] "r"(inptr2), + [ptr3] "r"(inptr3) + : "memory"); + + int x = x_len; + //! cope with row index exceed real size, set to zero buffer + if ((y + 3) >= mmax) { + switch ((y + 3) - mmax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + default: + break; + } + } + for (; x > 7; x -= 8) { + asm volatile( + "cbz %w[has_alpha], 0f\n" /* check alpha == 1.f? */ + "dup v31.4s, %w[alpha]\n" /* alpha to vector */ + "ldp q0, q1, [%[inptr0]], #32\n" /* load r0, a0~a7 */ + "ldp q2, q3, [%[inptr1]], #32\n" /* load r1, b0~b7 */ + "fmul v0.4s, v31.4s, v0.4s\n" /* mul alpha */ + "fmul v1.4s, v31.4s, v1.4s\n" /* mul alpha */ + "ldp q4, q5, [%[inptr2]], #32\n" /* load r2, c0~c7 */ + "fmul v2.4s, v31.4s, v2.4s\n" /* mul alpha */ + "fmul v3.4s, v31.4s, v3.4s\n" /* mul alpha */ + "ldp q6, q7, [%[inptr3]], #32\n" /* load r3, d0~d7 */ + "fmul v4.4s, v31.4s, v4.4s\n" /* mul alpha */ + "fmul v5.4s, v31.4s, v5.4s\n" /* mul alpha */ + "fmul v6.4s, v31.4s, v6.4s\n" /* mul alpha */ + "fmul v7.4s, v31.4s, v7.4s\n" /* mul alpha */ + "b 1f\n" /* to main process */ + "0: \n" /* alpha == 1 */ + "ldp q0, q1, [%[inptr0]], #32\n" /* load r0, a0~a7 */ + "ldp q2, q3, [%[inptr1]], #32\n" /* load r1, b0~b7 */ + "ldp q4, q5, [%[inptr2]], #32\n" /* load r2, c0~c7 */ + "ldp q6, q7, [%[inptr3]], #32\n" /* load r3, d0~d7 */ + "1: \n" /* main process */ + "trn1 v8.4s, v0.4s, v2.4s\n" /* a0b0a2b2*/ + "trn2 v9.4s, v0.4s, v2.4s\n" /* a1b1a3b3*/ + "trn1 v10.4s, v1.4s, v3.4s\n" /* a4b4a6b6*/ + "trn2 v11.4s, v1.4s, v3.4s\n" /* a5b5a7b7*/ + + "trn1 v12.4s, v4.4s, v6.4s\n" /* c0d0c2d2*/ + "trn2 v13.4s, v4.4s, v6.4s\n" /* c1d1c3d3*/ + "trn1 v14.4s, v5.4s, v7.4s\n" /* c4d4c6d6*/ + "trn2 v15.4s, v5.4s, v7.4s\n" /* c5d5c7d7*/ + + "trn1 v0.2d, v8.2d, v12.2d\n" /* a0b0c0d0 */ + "trn1 v1.2d, v9.2d, v13.2d\n" /* a1b1c1d1 */ + "trn1 v2.2d, v10.2d, v14.2d\n" /* a4b4c4d4 */ + "trn1 v3.2d, v11.2d, v15.2d\n" /* a5b5c5d5 */ + + "trn2 v4.2d, v8.2d, v12.2d\n" /* a2b2c2d2 */ + "trn2 v5.2d, v9.2d, v13.2d\n" /* a3b3c3d3 */ + "stp q0, q1, [%[outptr]], #32\n" /* save q0, q1, a0~h0*/ + "trn2 v6.2d, v10.2d, v14.2d\n" /* a6b6c6d6 */ + "trn2 v7.2d, v11.2d, v15.2d\n" /* a7b7c7d7 */ + "stp q4, q5, [%[outptr]], #32\n" /* save q2, q3, a1~h1*/ + "stp q2, q3, [%[outptr]], #32\n" /* save q4, q5, a2~h2*/ + "stp q6, q7, [%[outptr]], #32\n" /* save q6, q7, a3~h3*/ + + : [inptr0] "+r"(inptr0), + [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), + [outptr] "+r"(outptr) + : [alpha] "r"(alpha), [has_alpha] "r"(has_alpha) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "cc", + "memory"); + } + + for (; x > 0; x--) { + if (has_alpha) { + *outptr++ = *inptr0++ * alpha; + *outptr++ = *inptr1++ * alpha; + *outptr++ = *inptr2++ * alpha; + *outptr++ = *inptr3++ * alpha; + } else { + *outptr++ = *inptr0++; + *outptr++ = *inptr1++; + *outptr++ = *inptr2++; + *outptr++ = *inptr3++; + } + } + } +} void prepackA_trans_8x12(float *outptr, const float *in, @@ -682,6 +881,128 @@ void prepackA_trans_8x12(float *outptr, } } } +void pack_trans_m4(float *outptr, + const float *in, + float alpha, + int ldin, + int m0, + int mmax, + int k0, + int kmax) { + auto inptr = in + k0 * ldin + m0; + uint32_t mask_buffer[4] = {0, 1, 2, 3}; + int x_len = mmax - m0; + int y_len = kmax - k0; + int right_remain = x_len - 4 * (x_len / 4); + int stride_out = 4 * y_len; + + float32x4_t vzero = vdupq_n_f32(0.f); + uint32x4_t vmask1 = + vcltq_u32(vld1q_u32(mask_buffer), vdupq_n_u32(right_remain)); + + bool has_alpha = fabsf(alpha - 1.f) > 1e-8f; + float32x4_t valpha = vdupq_n_f32(alpha); + +#pragma omp parallel for + for (int y = 0; y < y_len - 3; y += 4) { + const float *ptr0 = inptr + y * ldin; + const float *ptr1 = ptr0 + ldin; + const float *ptr2 = ptr1 + ldin; + const float *ptr3 = ptr2 + ldin; + + asm volatile( + "prfm pldl1keep, [%[ptr0]] \n" + "prfm pldl1keep, [%[ptr0], #64] \n" + "prfm pldl1keep, [%[ptr1]] \n" + "prfm pldl1keep, [%[ptr1], #64] \n" + "prfm pldl1keep, [%[ptr2]] \n" + "prfm pldl1keep, [%[ptr2], #64] \n" + "prfm pldl1keep, [%[ptr3]] \n" + "prfm pldl1keep, [%[ptr3], #64] \n" + : + : [ptr0] "r"(ptr0), [ptr1] "r"(ptr1), [ptr2] "r"(ptr2), [ptr3] "r"(ptr3) + : "memory"); + + float *outptr_row_col = outptr + y * 4; + int i = 0; + for (; i < x_len - 3; i += 4) { + float32x4_t vr00 = vld1q_f32(ptr0); + float32x4_t vr10 = vld1q_f32(ptr1); + float32x4_t vr20 = vld1q_f32(ptr2); + float32x4_t vr30 = vld1q_f32(ptr3); + if (has_alpha) { + vr00 = vmulq_f32(vr00, valpha); + vr10 = vmulq_f32(vr10, valpha); + vr20 = vmulq_f32(vr20, valpha); + vr30 = vmulq_f32(vr30, valpha); + } + + vst1q_f32(outptr_row_col, vr00); + vst1q_f32(outptr_row_col + 4, vr10); + vst1q_f32(outptr_row_col + 8, vr20); + vst1q_f32(outptr_row_col + 12, vr30); + + ptr0 += 4; + ptr1 += 4; + ptr2 += 4; + ptr3 += 4; + + outptr_row_col += stride_out; + } + if (right_remain > 0) { + float32x4_t vr00 = vld1q_f32(ptr0); + float32x4_t vr10 = vld1q_f32(ptr1); + float32x4_t vr20 = vld1q_f32(ptr2); + float32x4_t vr30 = vld1q_f32(ptr3); + + if (has_alpha) { + vr00 = vmulq_f32(vr00, valpha); + vr10 = vmulq_f32(vr10, valpha); + vr20 = vmulq_f32(vr20, valpha); + vr30 = vmulq_f32(vr30, valpha); + } + + float32x4_t vr00_1 = vbslq_f32(vmask1, vr00, vzero); + float32x4_t vr10_1 = vbslq_f32(vmask1, vr10, vzero); + float32x4_t vr20_1 = vbslq_f32(vmask1, vr20, vzero); + float32x4_t vr30_1 = vbslq_f32(vmask1, vr30, vzero); + + vst1q_f32(outptr_row_col, vr00_1); + vst1q_f32(outptr_row_col + 4, vr10_1); + vst1q_f32(outptr_row_col + 8, vr20_1); + vst1q_f32(outptr_row_col + 12, vr30_1); + } + } + +#pragma omp parallel for + for (int y = 4 * (y_len / 4); y < y_len; ++y) { + const float *ptr0 = inptr + y * ldin; + float *outptr_row_col = outptr + y * 4; + int i = 0; + for (; i < x_len - 3; i += 4) { + float32x4_t vr0 = vld1q_f32(ptr0); + if (has_alpha) { + vr0 = vmulq_f32(vr0, valpha); + } + vst1q_f32(outptr_row_col, vr0); + + ptr0 += 4; + + outptr_row_col += stride_out; + } + if (right_remain > 0) { + float32x4_t vr0 = vld1q_f32(ptr0); + + if (has_alpha) { + vr0 = vmulq_f32(vr0, valpha); + } + + float32x4_t vr0_1 = vbslq_f32(vmask1, vr0, vzero); + + vst1q_f32(outptr_row_col, vr0_1); + } + } +} #else // __aarch64__ void prepackA_6x8(float* outptr, @@ -1963,7 +2284,7 @@ void sgemm_prepacked_8x12(bool is_transB, int ldc, const float *bias, bool has_bias, - bool has_relu, + const operators::ActivationParam act_param, ARMContext *ctx) { size_t l2_cache = ctx->llc_size() > 0 ? ctx->llc_size() : 512 * 1024; auto workspace = ctx->workspace_data(); @@ -2517,33 +2838,6 @@ void sgemm_prepacked_8x12(bool is_transB, "fmla v28.4s, v4.4s, v1.s[2]\n" /* out22 = b2 * a10[0], b2 =q7*/ "fmla v31.4s, v4.4s, v1.s[3]\n" /* out23 = b2 * a10[0], b2 =q7*/ "11: \n" /* check if relu */ - "cbz %w[relu], 12f\n" /* skip relu */ - "movi v2.4s, #0\n" /* for relu*/ - "fmax v8.4s, v8.4s, v2.4s\n" /* relu*/ - "fmax v9.4s, v9.4s, v2.4s\n" /* relu*/ - "fmax v10.4s, v10.4s, v2.4s\n" /* relu*/ - "fmax v11.4s, v11.4s, v2.4s\n" /* relu*/ - "fmax v12.4s, v12.4s, v2.4s\n" /* relu*/ - "fmax v13.4s, v13.4s, v2.4s\n" /* relu*/ - "fmax v14.4s, v14.4s, v2.4s\n" /* relu*/ - "fmax v15.4s, v15.4s, v2.4s\n" /* relu*/ - "fmax v16.4s,v16.4s,v2.4s\n" /* relu*/ - "fmax v17.4s,v17.4s,v2.4s\n" /* relu*/ - "fmax v18.4s, v18.4s, v2.4s\n" /* relu*/ - "fmax v19.4s, v19.4s, v2.4s\n" /* relu*/ - "fmax v20.4s, v20.4s, v2.4s\n" /* relu*/ - "fmax v21.4s, v21.4s, v2.4s\n" /* relu*/ - "fmax v22.4s, v22.4s, v2.4s\n" /* relu*/ - "fmax v23.4s, v23.4s, v2.4s\n" /* relu*/ - "fmax v24.4s,v24.4s,v2.4s\n" /* relu*/ - "fmax v25.4s,v25.4s,v2.4s\n" /* relu*/ - "fmax v26.4s, v26.4s, v2.4s\n" /* relu*/ - "fmax v27.4s, v27.4s, v2.4s\n" /* relu*/ - "fmax v28.4s, v28.4s, v2.4s\n" /* relu*/ - "fmax v29.4s, v29.4s, v2.4s\n" /* relu*/ - "fmax v30.4s, v30.4s, v2.4s\n" /* relu*/ - "fmax v31.4s, v31.4s, v2.4s\n" /* relu*/ - "12: \n" "st1 {v8.4s, v9.4s, v10.4s},[%[c_ptr0]], #48\n" /* store r0 */ "st1 {v11.4s, v12.4s, v13.4s},[%[c_ptr1]], #48\n" /* store r1 */ "st1 {v14.4s, v15.4s, v16.4s},[%[c_ptr2]], #48\n" /* store r2 */ @@ -2566,7 +2860,6 @@ void sgemm_prepacked_8x12(bool is_transB, [c_ptr6] "+r"(c_ptr6), [c_ptr7] "+r"(c_ptr7) : [bias_ptr] "r"(bias_local), - [relu] "r"(has_relu), [has_beta] "r"(has_beta), [beta] "r"(beta) : "cc","memory", @@ -2591,6 +2884,298 @@ void sgemm_prepacked_8x12(bool is_transB, } } } + if (act_param.has_active) { +#pragma omp parallel for num_threads(threads) + for (unsigned int x = 0; x < M; x++) { + float *dst = C + x * ldc; + act_switch_process(dst, dst, N, &act_param); + } + } +} + +void sgemm_prepacked_4x4(bool is_transB, + int M, + int N, + int K, + const float *A_packed, + const float *B, + int ldb, + float beta, + float *C, + int ldc, + const float *bias, + bool has_bias, + const operators::ActivationParam act_param, + ARMContext *ctx) { + size_t l2_cache = ctx->llc_size() > 0 ? ctx->llc_size() : 512 * 1024; + auto workspace = ctx->workspace_data(); + int threads = ctx->threads(); + + const int n_block = 4; + const int m_block = 4; + //! MBLOCK * x (result) + MBLOCK * k (A) + x * k (B) = l2 + int x_block = (l2_cache - (m_block * K)) / (sizeof(float) * (K + m_block)); + x_block /= n_block; + x_block *= n_block; + int x_num = (N + (x_block - 1)) / x_block; + x_block = (N + x_num - 1) / x_num; + x_block = (x_block + n_block - 1) / n_block; + x_block *= n_block; + x_block = x_block < n_block ? n_block : x_block; + + // unroll 2 loop + int tail_pre = (K & (KBLOCK - 1)); + int k_pre = ((K + KBLOCK - 1) / KBLOCK) - 1; + if (tail_pre == 0) { + tail_pre = KBLOCK; + } + + bool flag_p_remain = false; + int remain = 0; + + int has_beta = fabsf(beta) > 1e-8f ? 1 : 0; + //! apanel is pre_compute outside gemm + for (unsigned int x0 = 0; x0 < N; x0 += x_block) { + unsigned int xmax = x0 + x_block; + if (xmax > N) { + xmax = N; + } + int bblocks = (xmax - x0 + n_block - 1) / n_block; + remain = xmax - x0 - (bblocks - 1) * n_block; + if (remain > 0) { + flag_p_remain = true; + } + //! load bpanel + float *b_pannel = workspace; + if (is_transB) { + pack_m4(b_pannel, B, 1.0f, ldb, x0, xmax, 0, K); + } else { + pack_trans_m4(b_pannel, B, 1.0f, ldb, x0, xmax, 0, K); + } +#pragma omp parallel for num_threads(threads) + for (unsigned int y = 0; y < M; y += m_block) { + unsigned int ymax = y + m_block; + if (ymax > M) { + ymax = M; + } + + float bias_local[4] = {0}; + if (has_bias) { + bias_local[0] = bias[y]; + bias_local[1] = bias[y + 1]; + bias_local[2] = bias[y + 2]; + bias_local[3] = bias[y + 3]; + } + + float cout0[n_block]; // NOLINT + float cout1[n_block]; // NOLINT + float cout2[n_block]; // NOLINT + float cout3[n_block]; // NOLINT + + float *c_ptr0 = C + y * ldc + x0; + float *c_ptr1 = c_ptr0 + ldc; + float *c_ptr2 = c_ptr1 + ldc; + float *c_ptr3 = c_ptr2 + ldc; + + float *pout0 = c_ptr0; + float *pout1 = c_ptr1; + float *pout2 = c_ptr2; + float *pout3 = c_ptr3; + + const float *a_ptr_l = A_packed + y * K; + const float *b_ptr_l = b_pannel; + for (int xb = 0; xb < bblocks; xb++) { + if ((y + 3) >= ymax) { + switch ((y + 3) - ymax) { + case 2: + c_ptr1 = cout1; + case 1: + c_ptr2 = cout2; + case 0: + c_ptr3 = cout3; + default: + break; + } + } + if (flag_p_remain && (xb == bblocks - 1)) { + pout0 = c_ptr0; + pout1 = c_ptr1; + pout2 = c_ptr2; + pout3 = c_ptr3; + + c_ptr0 = cout0; + c_ptr1 = cout1; + c_ptr2 = cout2; + c_ptr3 = cout3; + if (has_beta) { + for (int i = 0; i < remain; ++i) { + cout0[i] = pout0[i]; + cout1[i] = pout1[i]; + cout2[i] = pout2[i]; + cout3[i] = pout3[i]; + } + } + } + const float *a_ptr = a_ptr_l; + const float *b_ptr = b_ptr_l + xb * K * 4; + int tail = tail_pre; + int k = k_pre; + // clang-format off + asm volatile( + "prfm pldl1keep, [%[a_ptr]]\n" /* preload a*/ + "ld1 {v2.4s}, [%[bias_ptr]]\n" /* load bias to q2, q3*/ + "dup v8.4s, v2.s[0]\n" /* out0 = 0 */ + "prfm pldl1keep, [%[b_ptr]]\n" /* preload b*/ + "dup v9.4s, v2.s[1]\n" /* out1 = 0*/ + "prfm pldl1keep, [%[a_ptr], #64]\n" /* preload a*/ + "dup v10.4s, v2.s[2]\n" /* out2 = 0*/ + "prfm pldl1keep, [%[b_ptr], #64]\n" /* preload b*/ + "dup v11.4s, v2.s[3]\n" /* out3 = 0*/ + "cbz %w[has_beta], 0f\n" /* check beta == 0? */ + /* process beta */ + "dup v7.4s, %w[beta]\n" /* beta to vector */ + "ld1 {v0.4s}, [%[c_ptr0]]\n" /* load output r0 */ + "ld1 {v1.4s}, [%[c_ptr1]]\n" /* load output r1 */ + "fmla v8.4s, v0.4s, v7.4s\n" /* cr00 += beta * c_r00*/ + "fmla v9.4s, v1.4s, v7.4s\n" /* cr10 += beta * c_r10*/ + "ld1 {v2.4s}, [%[c_ptr2]]\n" + "ld1 {v3.4s}, [%[c_ptr3]]\n" + "fmla v10.4s, v2.4s, v7.4s\n" /* cr20 += beta * c_r20*/ + "fmla v11.4s, v3.4s, v7.4s\n" /* cr30 += beta * c_r30*/ + + "0: \n" /* check loop count */ + "ldp q0, q1, [%[a_ptr]], #32\n" /* load a00,a10 to q0, q1*/ + "ldp q4, q5, [%[b_ptr]], #32\n" /* load b0, b1 to q4, q5*/ + "cbz %w[k], 2f\n" /* check loop count > 0 */ + /* main loop */ + /* unrool 0*/ + "1:\n" /* main loop */ + "fmla v8.4s, v4.4s, v0.s[0]\n" /* out0 = b0 * a00[0], b0 =q4 */ + "fmla v9.4s, v4.4s, v0.s[1]\n" /* out1 = b0 * a00[1], b0 =q4 */ + "ldp q6, q7, [%[b_ptr]], #32\n" /* load b2, b3 to q6, q7 */ + "fmla v10.4s, v4.4s, v0.s[2]\n" /* out2 = b0 * a00[2], b0 =q4 */ + "fmla v11.4s, v4.4s, v0.s[3]\n" /* out3 = b0 * a00[3], b0 =q4 */ + + "ldp q2, q3, [%[a_ptr]], #32\n" /* load a20, a30 to q2, q3 */ + "fmla v8.4s, v5.4s, v1.s[0]\n" /* out0 = b1 * a10[0], b1 =q5 */ + "fmla v9.4s, v5.4s, v1.s[1]\n" /* out1 = b1 * a10[1], b1 =q5 */ + "fmla v10.4s, v5.4s, v1.s[2]\n" /* out2 = b1 * a10[2], b1 =q5 */ + "fmla v11.4s, v5.4s, v1.s[3]\n" /* out3 = b1 * a10[3], b1 =q5 */ + "ldp q4, q5, [%[b_ptr]], #32\n" /* load b0, b1 to q4, q5*/ + + "fmla v8.4s, v6.4s, v2.s[0]\n" /* out0 = b2 * a20[0], b2 =q6 */ + "fmla v9.4s, v6.4s, v2.s[1]\n" /* out1 = b2 * a20[1], b2 =q6 */ + "fmla v10.4s, v6.4s, v2.s[2]\n" /* out2 = b2 * a20[2], b2 =q6*/ + "fmla v11.4s, v6.4s, v2.s[3]\n" /* out3 = b2 * a20[3], b2 =q6*/ + "ldp q0, q1, [%[a_ptr]], #32\n" /* load a00, a10 to q0, q1 */ + + "fmla v8.4s, v7.4s, v3.s[0]\n" /* out0 = b3 * a30[0], b3 =q7*/ + "fmla v9.4s, v7.4s, v3.s[1]\n" /* out1 = b3 * a30[1], b3 =q7*/ + "subs %w[k], %w[k], #1\n" /* loop count - 1*/ + "fmla v10.4s, v7.4s, v3.s[2]\n" /* out2 = b3 * a30[2], b3 =q7*/ + "fmla v11.4s, v7.4s, v3.s[3]\n" /* out3 = b3 * a30[3], b3 =q7*/ + + "bne 1b\n" + "2:\n" /* process tail*/ + "subs %w[tail], %w[tail], #1\n" /* tail--*/ + "beq 3f\n" /*jump to tail = 1*/ + /* final unrool 0*/ + /* unrool 0, tail > 1*/ + "fmla v8.4s, v4.4s, v0.s[0]\n" /* out0 = b0 * a00[0], b0 =q4 */ + "fmla v9.4s, v4.4s, v0.s[1]\n" /* out1 = b0 * a00[1], b0 =q4 */ + "subs %w[tail], %w[tail], #1\n" /* tail--*/ + "fmla v10.4s, v4.4s, v0.s[2]\n" /* out2 = b0 * a00[2], b0 =q4 */ + "fmla v11.4s, v4.4s, v0.s[3]\n" /* out3 = b0 * a00[3], b0 =q4 */ + + "beq 4f\n" /*jump to tail = 2*/ + /* unrool 1, tail > 2*/ + "ldp q6, q7, [%[b_ptr]], #32\n" /* load b2, b3 to q6, q7 */ + + "fmla v8.4s, v5.4s, v1.s[0]\n" /* out0 = b1 * a10[0], b1 =q5 */ + "fmla v9.4s, v5.4s, v1.s[1]\n" /* out1 = b1 * a10[1], b1 =q5*/ + "subs %w[tail], %w[tail], #1\n" /* tail--*/ + "fmla v10.4s, v5.4s, v1.s[2]\n" /* out2 = b1 * a10[2], b1 =q5 */ + "fmla v11.4s, v5.4s, v1.s[3]\n" /* out3 = b1 * a10[3], b1 =q5 */ + "ldp q2, q3, [%[a_ptr]], #32\n" /* load a20, a30 to q2, q3 */ + + "beq 5f\n" /*jump to tail = 3*/ + /* unrool 2, tail = 4*/ + "fmla v8.4s, v6.4s, v2.s[0]\n" /* out0 = b2 * a20[0], b1 =q6 */ + "fmla v9.4s, v6.4s, v2.s[1]\n" /* out1 = b2 * a20[1], b1 =q6 */ + "fmla v10.4s, v6.4s, v2.s[2]\n" /* out2 = b2 * a20[2], b1 =q6*/ + "fmla v11.4s, v6.4s, v2.s[3]\n" /* out3 = b2 * a20[3], b1 =q6*/ + + /* unrool 3, tail = 4*/ + + "fmla v8.4s, v7.4s, v3.s[0]\n" /* out0 = b3 * a30[0], b3 =q7*/ + "fmla v9.4s, v7.4s, v3.s[1]\n" /* out1 = b3 * a30[1], b3 =q7*/ + "fmla v10.4s, v7.4s, v3.s[2]\n" /* out2 = b3 * a30[2], b3 =q7*/ + "fmla v11.4s, v7.4s, v3.s[3]\n" /* out3 = b3 * a30[3], b3 =q7*/ + + "b 11f\n" + /* tails==1 final tail*/ + "3: \n" /* tail=1*/ + "fmla v8.4s, v4.4s, v0.s[0]\n" /* out0 = b0 * a00[0], b0 =q4 */ + "fmla v9.4s, v4.4s, v0.s[1]\n" /* out1 = b0 * a00[1], b0 =q4 */ + "fmla v10.4s, v4.4s, v0.s[2]\n" /* out2 = b0 * a00[2], b0 =q4 */ + "fmla v11.4s, v4.4s, v0.s[3]\n" /* out3 = b0 * a00[3], b0 =q4 */ + + "b 11f\n" + /* tails==2 final tail*/ + "4:\n" /* tail = 2*/ + + "fmla v8.4s, v5.4s, v1.s[0]\n" /* out0 = b1 * a10[0], b1 =q5 */ + "fmla v9.4s, v5.4s, v1.s[1]\n" /* out1 = b1 * a10[1], b1 =q5*/ + "fmla v10.4s, v5.4s, v1.s[2]\n" /* out2 = b1 * a10[2], b1 =q5 */ + "fmla v11.4s, v5.4s, v1.s[3]\n" /* out3 = b1 * a10[3], b1 =q5 */ + + "b 11f\n" + /* tails==3 final tail*/ + "5:\n" /* tail = 3*/ + "fmla v8.4s, v6.4s, v2.s[0]\n" /* out0 = b2 * a20[0], b1 =q6 */ + "fmla v9.4s, v6.4s, v2.s[1]\n" /* out1 = b2 * a20[1], b1 =q6 */ + "fmla v10.4s, v6.4s, v2.s[2]\n" /* out2 = b2 * a20[2], b1 =q6*/ + "fmla v11.4s, v6.4s, v2.s[3]\n" /* out3 = b2 * a20[3], b1 =q6*/ + + "11: \n" /* check if relu */ + "st1 {v8.4s}, [%[c_ptr0]], #16\n" /* store r0 */ + "st1 {v9.4s}, [%[c_ptr1]], #16\n" /* store r1 */ + "st1 {v10.4s}, [%[c_ptr2]], #16\n" /* store r2 */ + "st1 {v11.4s}, [%[c_ptr3]], #16\n" /* store r3 */ + + : [a_ptr] "+r"(a_ptr), + [b_ptr] "+r"(b_ptr), + [k] "+r"(k), + [tail] "+r"(tail), + [c_ptr0] "+r"(c_ptr0), + [c_ptr1] "+r"(c_ptr1), + [c_ptr2] "+r"(c_ptr2), + [c_ptr3] "+r"(c_ptr3) + : [bias_ptr] "r"(bias_local), + [has_beta] "r"(has_beta), + [beta] "r"(beta) + : "cc","memory", + "v0","v1","v2","v3","v4","v5","v6","v7", + "v8","v9","v10","v11"); + // clang-format on + if (flag_p_remain && (xb == bblocks - 1)) { + for (int i = 0; i < remain; ++i) { + *pout0++ = cout0[i]; + *pout1++ = cout1[i]; + *pout2++ = cout2[i]; + *pout3++ = cout3[i]; + } + } + } + } + } + if (act_param.has_active) { +#pragma omp parallel for num_threads(threads) + for (unsigned int x = 0; x < M; x++) { + float *dst = C + x * ldc; + act_switch_process(dst, dst, N, &act_param); + } + } } #else // __aarch64__ /** @@ -2616,7 +3201,7 @@ void sgemm_prepacked_6x8(bool is_transB, int ldc, const float* bias, bool has_bias, - bool has_relu, + const operators::ActivationParam act_param, ARMContext* ctx) { size_t l2_cache = ctx->llc_size() > 0 ? ctx->llc_size() : 512 * 1024; auto* workspace = ctx->workspace_data(); @@ -2995,22 +3580,6 @@ void sgemm_prepacked_6x8(bool is_transB, "vmla.f32 q13, q3, d0[0] @ out10 += b2 * a4\n" "vmla.f32 q15, q3, d0[1] @ out11 += b2 * a5\n" "2: @ check relu\n" - "cmp %[relu], #0 @ check if has relu\n" - "ble 6f @ skip relu if relu <= 0\n" - "vmov.u32 q0, #0 @ for relu\n" - "vmax.f32 q4, q4, q0 @ for relu\n" - "vmax.f32 q5, q5, q0 @ for relu\n" - "vmax.f32 q6, q6, q0 @ for relu\n" - "vmax.f32 q7, q7, q0 @ for relu\n" - "vmax.f32 q8, q8, q0 @ for relu\n" - "vmax.f32 q9, q9, q0 @ for relu\n" - "vmax.f32 q10, q10, q0 @ for relu\n" - "vmax.f32 q11, q11, q0 @ for relu\n" - "vmax.f32 q12, q12, q0 @ for relu\n" - "vmax.f32 q13, q13, q0 @ for relu\n" - "vmax.f32 q14, q14, q0 @ for relu\n" - "vmax.f32 q15, q15, q0 @ for relu\n" - "6: @ store result\n" "vst1.32 {d8-d11}, [%[c_ptr0]]! @ store r0\n" "vst1.32 {d12-d15}, [%[c_ptr1]]! @ store r1\n" "vst1.32 {d16-d19}, [%[c_ptr2]]! @ store r2\n" @@ -3028,7 +3597,6 @@ void sgemm_prepacked_6x8(bool is_transB, [k] "+r"(k), [tails] "+r"(tails) : [bias_ptr] "r"(bias_local), - [relu] "r"(has_relu), [beta] "r"(beta) : "q0","q1","q2","q3","q4", "q5","q6","q7","q8","q9","q10","q11", @@ -3048,6 +3616,13 @@ void sgemm_prepacked_6x8(bool is_transB, } } } + if (act_param.has_active) { +#pragma omp parallel for num_threads(threads) + for (unsigned int x = 0; x < M; x++) { + float* dst = C + x * ldc; + act_switch_process(dst, dst, N, &act_param); + } + } } void sgemm_prepacked_4x8(bool is_transB, @@ -3062,7 +3637,7 @@ void sgemm_prepacked_4x8(bool is_transB, int ldc, const float* bias, bool has_bias, - bool has_relu, + const operators::ActivationParam act_param, ARMContext* ctx) { size_t l2_cache = ctx->llc_size() > 0 ? ctx->llc_size() : 512 * 1024; auto* workspace = ctx->workspace_data(); @@ -3347,18 +3922,6 @@ void sgemm_prepacked_4x8(bool is_transB, /*aptr - 16*/ "sub %[a_ptr], %[a_ptr], #16 @ tail--\n" "2: @ check relu\n" - "cmp %[relu], #0 @ check if has relu\n" - "ble 6f @ skip relu if relu <= 0\n" - "vmov.u32 q0, #0 @ for relu\n" - "vmax.f32 q8, q8, q0 @ for relu\n" - "vmax.f32 q9, q9, q0 @ for relu\n" - "vmax.f32 q10, q10, q0 @ for relu\n" - "vmax.f32 q11, q11, q0 @ for relu\n" - "vmax.f32 q12, q12, q0 @ for relu\n" - "vmax.f32 q13, q13, q0 @ for relu\n" - "vmax.f32 q14, q14, q0 @ for relu\n" - "vmax.f32 q15, q15, q0 @ for relu\n" - "6: @ store result\n" "vst1.32 {d16-d19}, [%[c_ptr0]]! @ store r0\n" "vst1.32 {d20-d23}, [%[c_ptr1]]! @ store r1\n" "vst1.32 {d24-d27}, [%[c_ptr2]]! @ store r2\n" @@ -3372,7 +3935,6 @@ void sgemm_prepacked_4x8(bool is_transB, [k] "+r"(k), [tails] "+r"(tails) : [bias_ptr] "r"(bias_local), - [relu] "r"(has_relu), [beta] "r"(beta) : "q0","q1","q2","q3", "q4","q5","q6","q7","q8","q9","q10", @@ -3389,6 +3951,13 @@ void sgemm_prepacked_4x8(bool is_transB, } } } + if (act_param.has_active) { +#pragma omp parallel for num_threads(threads) + for (unsigned int x = 0; x < M; x++) { + float* dst = C + x * ldc; + act_switch_process(dst, dst, N, &act_param); + } + } } #endif // __aarch64__ diff --git a/lite/backends/arm/math/packed_sgemm.h b/lite/backends/arm/math/packed_sgemm.h index 6c14cdb2ef62558a53c765719107d68da678246b..bc23e8eab7b972fef77fda2360ae8f12c2e5d0e3 100644 --- a/lite/backends/arm/math/packed_sgemm.h +++ b/lite/backends/arm/math/packed_sgemm.h @@ -17,6 +17,7 @@ #include #include "lite/core/context.h" #include "lite/core/tensor.h" +#include "lite/operators/op_params.h" namespace paddle { namespace lite { @@ -74,7 +75,7 @@ void sgemm_prepack(bool is_transB, int ldc, const float* bias, bool has_bias, - bool has_relu, + const operators::ActivationParam act_param, ARMContext* ctx); } // namespace math diff --git a/lite/backends/arm/math/packed_sgemm_c4.cc b/lite/backends/arm/math/packed_sgemm_c4.cc new file mode 100644 index 0000000000000000000000000000000000000000..af4934e85756f03ec197520b2b5c130e27bdcad6 --- /dev/null +++ b/lite/backends/arm/math/packed_sgemm_c4.cc @@ -0,0 +1,1704 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/backends/arm/math/packed_sgemm_c4.h" +#include + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +void loadb_c4(float* out, + const float* in, + const int xstart, + const int xend, + const int k_round, + const int n) { + const int xlen = (xend - xstart + NBLOCK_C4 - 1) / NBLOCK_C4 * NBLOCK_C4; + int xloop = xlen / NBLOCK_C4; + const int flag_remain = n < xstart + xlen; + int remain = 0; + int remain4 = 0; + int remain1 = 0; + if (flag_remain) { + remain = (n - xstart) - (xloop - 1) * NBLOCK_C4; + remain4 = remain >> 2; + remain1 = remain & 3; + xloop -= 1; + } + const int ldo = NBLOCK_C4 * k_round; + const int kloop = k_round >> 2; + in += xstart * 4; + if (xloop > 0) { +#pragma omp parallel for + for (int i = 0; i < kloop; ++i) { + float* out_ptr = out + 4 * NBLOCK_C4 * i; + const float* in_ptr = in + i * 4 * n; + for (int j = 0; j < xloop; ++j) { + float* out_p = out_ptr + j * ldo; +#ifdef __aarch64__ + asm volatile( + "ld1 {v0.4s, v1.4s}, [%[in]], #32 \n" + "ld1 {v2.4s, v3.4s}, [%[in]], #32 \n" + "st1 {v0.4s, v1.4s}, [%[out]], #32 \n" + "ld1 {v4.4s, v5.4s}, [%[in]], #32 \n" + "st1 {v2.4s, v3.4s}, [%[out]], #32 \n" + "ld1 {v6.4s, v7.4s}, [%[in]], #32 \n" + "st1 {v4.4s, v5.4s}, [%[out]], #32 \n" + "st1 {v6.4s, v7.4s}, [%[out]], #32 \n" + : [in] "+r"(in_ptr), [out] "+r"(out_p) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"); +#else + asm volatile( + "vld1.32 {d0-d3}, [%[in]]! \n" + "vld1.32 {d4-d7}, [%[in]]! \n" + "vst1.32 {d0-d3}, [%[out]]! \n" + "vld1.32 {d8-d11}, [%[in]]! \n" + "vst1.32 {d4-d7}, [%[out]]! \n" + "vld1.32 {d12-d15}, [%[in]]! \n" + "vst1.32 {d8-d11}, [%[out]]! \n" + "vst1.32 {d12-d15}, [%[out]]! \n" + : [in] "+r"(in_ptr), [out] "+r"(out_p) + : + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7"); +#endif // __aarch674__ + } + } + } + float* out_remain4 = out + xloop * k_round * NBLOCK_C4; + const float* in_remain4 = in + xloop * NBLOCK_C4 * 4; + if (remain4) { +#pragma omp parallel for + for (int i = 0; i < kloop; ++i) { + float* out_ptr = out_remain4 + 4 * 4 * i; + const float* in_ptr = in_remain4 + i * 4 * n; +#ifdef __aarch64__ + asm volatile( + "ld1 {v0.4s, v1.4s}, [%[in]], #32 \n" + "ld1 {v2.4s, v3.4s}, [%[in]], #32 \n" + "st1 {v0.4s, v1.4s}, [%[out]], #32 \n" + "st1 {v2.4s, v3.4s}, [%[out]], #32 \n" + : [in] "+r"(in_ptr), [out] "+r"(out_ptr) + : + : "v0", "v1", "v2", "v3"); +#else + asm volatile( + "vld1.32 {d0-d3}, [%[in]]! \n" + "vld1.32 {d4-d7}, [%[in]]! \n" + "vst1.32 {d0-d3}, [%[out]]! \n" + "vst1.32 {d4-d7}, [%[out]]! \n" + : [in] "+r"(in_ptr), [out] "+r"(out_ptr) + : + : "q0", "q1", "q2", "q3"); +#endif // __aarch64__ + } + } + float* out_remain1 = out_remain4 + remain4 * k_round * 4; + const float* in_remain1 = in_remain4 + remain4 * 4 * 4; + if (remain1) { +#pragma omp parallel for + for (int i = 0; i < kloop; ++i) { + float* out_ptr = out_remain1 + 4 * remain1 * i; + const float* in_ptr = in_remain1 + i * 4 * n; + for (int j = 0; j < remain1; ++j) { + float32x4_t vin = vld1q_f32(in_ptr); + in_ptr += 4; + vst1q_f32(out_ptr, vin); + out_ptr += 4; + } + } + } +} + +void sgemm_prepack_c4_common(int M, + int N, + int K, + const float* A_packed, + const float* B, + float* C, + const float* bias, + bool has_bias, + bool has_relu, + ARMContext* ctx) { + const int m_round = (M + 3) / 4 * 4; + const int k_round = (K + 3) / 4 * 4; + size_t l2_cache = ctx->llc_size() > 0 ? ctx->llc_size() : 512 * 1024; + int threads = ctx->threads(); + auto workspace = ctx->workspace_data(); + // l2 = ablock * K * threads + K * bchunk_w + threads * ablock * bchunk_w; + int bchunk_w = (l2_cache - threads * k_round * sizeof(float)) / + ((k_round + threads * MBLOCK_C4) * sizeof(float)); + bchunk_w = bchunk_w > N ? N : bchunk_w; + bchunk_w = bchunk_w / NBLOCK_C4 * NBLOCK_C4; + bchunk_w = bchunk_w > NBLOCK_C4 ? bchunk_w : NBLOCK_C4; + int bchunk_loop = (N + bchunk_w - 1) / bchunk_w; + + const int h_loop = m_round >> 2; // MBLOCK_C4 == 4; + const int kcnt = (k_round + KBLOCK_C4 - 1) / KBLOCK_C4; + const int ldc = N * 4; + const int lda = k_round * 4; + float bias_buf[m_round]; // NOLINT + if (has_bias) { + memcpy(bias_buf, bias, M * sizeof(float)); + memset(bias_buf + M, 0, (m_round - M) * sizeof(float)); + } else { + memset(bias_buf, 0, m_round * sizeof(float)); + } + // bchunk_loop + float* c = C; + for (int n = 0; n < bchunk_loop; ++n) { + int x_start = n * bchunk_w; + int x_end = x_start + bchunk_w; + int w_loop = bchunk_w / NBLOCK_C4; + int flag_remain = 0; + int w_loop4 = 0; + int remain = 0; + if (x_end > N) { + w_loop = (N - x_start) / NBLOCK_C4; + int w_loop_rem = (N - x_start) - w_loop * NBLOCK_C4; + w_loop4 = w_loop_rem >> 2; + remain = w_loop_rem & 3; + x_end = N; + flag_remain = 1; + } + float* bchunk = workspace; + loadb_c4(bchunk, B, x_start, x_end, k_round, N); + float* cchunk = c + n * bchunk_w * 4; + int has_remain = (n == bchunk_loop - 1) && flag_remain; +#pragma omp parallel for num_threads(threads) + for (int h = 0; h < h_loop; ++h) { + float* bias_h = bias_buf + h * 4; +#ifdef __aarch64__ + float32x4_t vzero = vdupq_n_f32(0.f); + float32x4_t vbias = vld1q_f32(bias_h); +#endif + const float* ablock = A_packed + h * lda; + const float* bblock = bchunk; + float* cblock = cchunk + h * ldc; + for (int w = 0; w < w_loop; ++w) { + int cnt = kcnt; + const float* ablock_ptr = ablock; +// clang-format off +#ifdef __aarch64__ + asm volatile( + "prfm pldl1keep, [%[a]] \n" + "prfm pldl1keep, [%[b]] \n" + "prfm pldl1keep, [%[b], #64] \n" + "mov v9.16b, %[vbias].16b \n" /* mov bias to c0*/ + "mov v10.16b, %[vbias].16b \n" /* mov bias to c1*/ + "mov v11.16b, %[vbias].16b \n" /* mov bias to c2*/ + "mov v12.16b, %[vbias].16b \n" /* mov bias to c3*/ + /* load a0a1 to v1-v2 */ + "ld1 {v1.4s, v2.4s}, [%[a]], #32 \n" + "mov v13.16b, %[vbias].16b \n" /* mov bias to c4*/ + "mov v14.16b, %[vbias].16b \n" /* mov bias to c5*/ + "mov v15.16b, %[vbias].16b \n" /* mov bias to c6*/ + "mov v16.16b, %[vbias].16b \n" /* mov bias to c7*/ + "1:\n" + /* load b0b1b2b3 to v5-v8 */ + "ld1 {v5.4s, v6.4s}, [%[b]], #32 \n" + "ld1 {v7.4s, v8.4s}, [%[b]], #32 \n" + "prfm pldl1keep, [%[b]] \n" + "fmla v9.4s, v1.4s, v5.s[0] \n" + "fmla v10.4s, v1.4s, v6.s[0] \n" + "fmla v11.4s, v1.4s, v7.s[0] \n" + "fmla v12.4s, v1.4s, v8.s[0] \n" + /* load b4b5b6b7 to v25-v28 */ + "ld1 {v25.4s, v26.4s}, [%[b]], #32 \n" + "ld1 {v27.4s, v28.4s}, [%[b]], #32 \n" + "prfm pldl1keep, [%[a], #32] \n" + "fmla v9.4s, v2.4s, v5.s[1] \n" + "fmla v10.4s, v2.4s, v6.s[1] \n" + "fmla v11.4s, v2.4s, v7.s[1] \n" + "fmla v12.4s, v2.4s, v8.s[1] \n" + "prfm pldl1keep, [%[b], #64] \n" + "fmla v13.4s, v1.4s, v25.s[0] \n" + "fmla v14.4s, v1.4s, v26.s[0] \n" + "fmla v15.4s, v1.4s, v27.s[0] \n" + "fmla v16.4s, v1.4s, v28.s[0] \n" + /* load a2a3 to v3-v4 */ + "ld1 {v3.4s, v4.4s}, [%[a]], #32 \n" + "prfm pldl1keep, [%[b], #128] \n" + "fmla v13.4s, v2.4s, v25.s[1] \n" + "fmla v14.4s, v2.4s, v26.s[1] \n" + "fmla v15.4s, v2.4s, v27.s[1] \n" + "fmla v16.4s, v2.4s, v28.s[1] \n" + "subs %w[cnt], %w[cnt], #1 \n" + "fmla v9.4s, v3.4s, v5.s[2] \n" + "fmla v10.4s, v3.4s, v6.s[2] \n" + "fmla v11.4s, v3.4s, v7.s[2] \n" + "fmla v12.4s, v3.4s, v8.s[2] \n" + "fmla v13.4s, v3.4s, v25.s[2] \n" + "fmla v14.4s, v3.4s, v26.s[2] \n" + "fmla v15.4s, v3.4s, v27.s[2] \n" + "fmla v16.4s, v3.4s, v28.s[2] \n" + /* load a0a1 to v1-v2 */ + "ld1 {v1.4s, v2.4s}, [%[a]], #32 \n" + "fmla v9.4s, v4.4s, v5.s[3] \n" + "fmla v10.4s, v4.4s, v6.s[3] \n" + "fmla v11.4s, v4.4s, v7.s[3] \n" + "fmla v12.4s, v4.4s, v8.s[3] \n" + + "fmla v13.4s, v4.4s, v25.s[3] \n" + "fmla v14.4s, v4.4s, v26.s[3] \n" + "fmla v15.4s, v4.4s, v27.s[3] \n" + "fmla v16.4s, v4.4s, v28.s[3] \n" + "bne 1b\n" + "cbz %w[relu], 2f \n" + "fmax v9.4s, v9.4s, %[vzero].4s \n" + "fmax v10.4s, v10.4s, %[vzero].4s \n" + "fmax v11.4s, v11.4s, %[vzero].4s \n" + "fmax v12.4s, v12.4s, %[vzero].4s \n" + "fmax v13.4s, v13.4s, %[vzero].4s \n" + "fmax v14.4s, v14.4s, %[vzero].4s \n" + "fmax v15.4s, v15.4s, %[vzero].4s \n" + "fmax v16.4s, v16.4s, %[vzero].4s \n" + "2:\n" + "st1 {v9.4s, v10.4s, v11.4s, v12.4s}, [%[c]], #64 \n" + "st1 {v13.4s, v14.4s, v15.4s, v16.4s}, [%[c]], #64 \n" + : [a] "+r"(ablock_ptr), + [b] "+r"(bblock), + [c] "+r"(cblock), + [cnt] "+r"(cnt) + : [bias] "r"(bias_h), [relu] "r"(has_relu), + [vbias] "w"(vbias), [vzero] "w" (vzero) + : "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", + "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", + "v25", "v26", "v27", "v28", "cc", "memory"); +#else + asm volatile( + "vld1.32 {d6-d7}, [%[bias]] \n" + "pld [%[a]] \n" + "pld [%[b]] \n" + "pld [%[b], #64] \n" + "vmov.32 q8, q3 \n" /* mov bias to c0*/ + "vmov.32 q9, q3 \n" /* mov bias to c1*/ + "vmov.32 q10, q3 \n" /* mov bias to c2*/ + "vmov.32 q11, q3 \n" /* mov bias to c3*/ + "vld1.32 {d0-d3}, [%[a]]! \n" + "vmov.32 q12, q3 \n" /* mov bias to c4*/ + "vmov.32 q13, q3 \n" /* mov bias to c5*/ + "vmov.32 q14, q3 \n" /* mov bias to c6*/ + "vmov.32 q15, q3 \n" /* mov bias to c7*/ + "1:\n" + /* c0c1c2c3 */ + "vld1.32 {d8-d11}, [%[b]]! \n" + "vld1.32 {d12-d15}, [%[b]]! \n" + "pld [%[b]] \n" + "vmla.f32 q8, q0, d8[0] \n" + "vmla.f32 q9, q0, d10[0] \n" + "vmla.f32 q10, q0, d12[0] \n" + "vmla.f32 q11, q0, d14[0] \n" + "vld1.32 {d4-d7}, [%[a]]! \n" + "vmla.f32 q8, q1, d8[1] \n" + "vmla.f32 q9, q1, d10[1] \n" + "vmla.f32 q10, q1, d12[1] \n" + "vmla.f32 q11, q1, d14[1] \n" + "pld [%[b], #64] \n" + "vmla.f32 q8, q2, d9[0] \n" + "vmla.f32 q9, q2, d11[0] \n" + "vmla.f32 q10, q2, d13[0] \n" + "vmla.f32 q11, q2, d15[0] \n" + "subs %[cnt], %[cnt], #1 \n" + "vmla.f32 q8, q3, d9[1] \n" + "vmla.f32 q9, q3, d11[1] \n" + "vld1.f32 {d8-d11}, [%[b]]! \n" + "vmla.f32 q10, q3, d13[1] \n" + "vmla.f32 q11, q3, d15[1] \n" + "vld1.32 {d12-d15}, [%[b]]! \n" + /* c4c5c6c7 */ + "vmla.f32 q12, q0, d8[0] \n" + "vmla.f32 q13, q0, d10[0] \n" + "vmla.f32 q14, q0, d12[0] \n" + "vmla.f32 q15, q0, d14[0] \n" + "pld [%[a], #32] \n" + "vmla.f32 q12, q1, d8[1] \n" + "vmla.f32 q13, q1, d10[1] \n" + "vmla.f32 q14, q1, d12[1] \n" + "vmla.f32 q15, q1, d14[1] \n" + "vld1.32 {d0-d3}, [%[a]]! \n" + "vmla.f32 q12, q2, d9[0] \n" + "vmla.f32 q13, q2, d11[0] \n" + "vmla.f32 q14, q2, d13[0] \n" + "vmla.f32 q15, q2, d15[0] \n" + "pld [%[b], #64] \n" + "vmla.f32 q12, q3, d9[1] \n" + "vmla.f32 q13, q3, d11[1] \n" + "vmla.f32 q14, q3, d13[1] \n" + "vmla.f32 q15, q3, d15[1] \n" + "bne 1b\n" + "cmp %[relu], #0 \n" + "beq 2f \n" + "vmov.u32 q0, #0 \n" + "vmax.f32 q8, q8, q0 \n" + "vmax.f32 q9, q9, q0 \n" + "vmax.f32 q10, q10, q0 \n" + "vmax.f32 q11, q11, q0 \n" + "vmax.f32 q12, q12, q0 \n" + "vmax.f32 q13, q13, q0 \n" + "vmax.f32 q14, q14, q0 \n" + "vmax.f32 q15, q15, q0 \n" + "2:\n" + "vst1.32 {d16-d19}, [%[c]]! \n" + "vst1.32 {d20-d23}, [%[c]]! \n" + "vst1.32 {d24-d27}, [%[c]]! \n" + "vst1.32 {d28-d31}, [%[c]]! \n" + : [a] "+r"(ablock_ptr), + [b] "+r"(bblock), + [c] "+r"(cblock), + [cnt] "+r"(cnt) + : [bias] "r"(bias_h), + [relu] "r"(has_relu) + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", + "q9", "q10", "q11", "q12", "q13", "q14", "q15", "cc", "memory"); +#endif + // clang-format on + } + if (has_remain) { + if (w_loop4 > 0) { + int cnt = kcnt; + const float* ablock_ptr = ablock; +// clang-format off +#ifdef __aarch64__ + asm volatile( + "prfm pldl1keep, [%[a]] \n" + "prfm pldl1keep, [%[b]] \n" + "mov v9.16b, %[vbias].16b \n" /* mov bias to c0*/ + "mov v10.16b, %[vbias].16b \n" /* mov bias to c1*/ + "mov v11.16b, %[vbias].16b \n" /* mov bias to c2*/ + "mov v12.16b, %[vbias].16b \n" /* mov bias to c3*/ + /* load a0a1 to v1-v2 */ + "ld1 {v1.4s, v2.4s}, [%[a]], #32 \n" + "1:\n" + /* load b0b1b2b3 to v5-v8 */ + "ld1 {v5.4s, v6.4s}, [%[b]], #32 \n" + "ld1 {v7.4s, v8.4s}, [%[b]], #32 \n" + "fmla v9.4s, v1.4s, v5.s[0] \n" + "fmla v10.4s, v1.4s, v6.s[0] \n" + "fmla v11.4s, v1.4s, v7.s[0] \n" + "fmla v12.4s, v1.4s, v8.s[0] \n" + /* load a2a3 to v3-v4 */ + "ld1 {v3.4s, v4.4s}, [%[a]], #32 \n" + "prfm pldl1keep, [%[a]] \n" + "fmla v9.4s, v2.4s, v5.s[1] \n" + "fmla v10.4s, v2.4s, v6.s[1] \n" + "fmla v11.4s, v2.4s, v7.s[1] \n" + "fmla v12.4s, v2.4s, v8.s[1] \n" + "prfm pldl1keep, [%[b]] \n" + "subs %w[cnt], %w[cnt], #1 \n" + "fmla v9.4s, v3.4s, v5.s[2] \n" + "fmla v10.4s, v3.4s, v6.s[2] \n" + "fmla v11.4s, v3.4s, v7.s[2] \n" + "fmla v12.4s, v3.4s, v8.s[2] \n" + /* load a0a1 to v1-v2 */ + "ld1 {v1.4s, v2.4s}, [%[a]], #32 \n" + "fmla v9.4s, v4.4s, v5.s[3] \n" + "fmla v10.4s, v4.4s, v6.s[3] \n" + "fmla v11.4s, v4.4s, v7.s[3] \n" + "fmla v12.4s, v4.4s, v8.s[3] \n" + "bne 1b\n" + "cbz %w[relu], 2f \n" + "fmax v9.4s, v9.4s, %[vzero].4s \n" + "fmax v10.4s, v10.4s, %[vzero].4s \n" + "fmax v11.4s, v11.4s, %[vzero].4s \n" + "fmax v12.4s, v12.4s, %[vzero].4s \n" + "2:\n" + "st1 {v9.4s, v10.4s, v11.4s, v12.4s}, [%[c]], #64 \n" + : [a] "+r"(ablock_ptr), + [b] "+r"(bblock), + [c] "+r"(cblock), + [cnt] "+r"(cnt) + : [bias] "r"(bias_h), + [relu] "r"(has_relu), + [vbias] "w"(vbias), + [vzero] "w" (vzero) + : "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v9", "v10", "v11", "v12", "cc", "memory"); +#else + asm volatile( + "pld [%[a]] \n" + "pld [%[b]] \n" + "vld1.32 {d6-d7}, [%[bias]] \n" + "vld1.32 {d0-d3}, [%[a]]! \n" /* load a0 a1 */ + "vmov.32 q8, q3 \n" /* mov bias to c0 */ + "vmov.32 q9, q3 \n" /* mov bias to c1 */ + "vmov.32 q10, q3 \n" /* mov bias to c2 */ + "vmov.32 q11, q3 \n" /* mov bias to c3 */ + "1:\n" + /* c0c1c2c3 */ + "vld1.32 {d8-d11}, [%[b]]! \n" + "vld1.32 {d12-d15}, [%[b]]! \n" + "pld [%[b]] \n" + "vmla.f32 q8, q0, d8[0] \n" + "vmla.f32 q9, q0, d10[0] \n" + "vmla.f32 q10, q0, d12[0] \n" + "vmla.f32 q11, q0, d14[0] \n" + "vld1.32 {d4-d7}, [%[a]]! \n" + "pld [%[a]] \n" + "vmla.f32 q8, q1, d8[1] \n" + "vmla.f32 q9, q1, d10[1] \n" + "vmla.f32 q10, q1, d12[1] \n" + "vmla.f32 q11, q1, d14[1] \n" + "subs %[cnt], %[cnt], #1 \n" + "vmla.f32 q8, q2, d9[0] \n" + "vmla.f32 q9, q2, d11[0] \n" + "vmla.f32 q10, q2, d13[0] \n" + "vmla.f32 q11, q2, d15[0] \n" + "vld1.32 {d0-d3}, [%[a]]! \n" + "vmla.f32 q8, q3, d9[1] \n" + "vmla.f32 q9, q3, d11[1] \n" + "vmla.f32 q10, q3, d13[1] \n" + "vmla.f32 q11, q3, d15[1] \n" + "bne 1b\n" + "cmp %[relu], #0 \n" + "beq 2f \n" + "vmov.u32 q0, #0 \n" + "vmax.f32 q8, q8, q0 \n" + "vmax.f32 q9, q9, q0 \n" + "vmax.f32 q10, q10, q0 \n" + "vmax.f32 q11, q11, q0 \n" + "2:\n" + "vst1.32 {d16-d19}, [%[c]]! \n" + "vst1.32 {d20-d23}, [%[c]]! \n" + : [a] "+r"(ablock_ptr), + [b] "+r"(bblock), + [c] "+r"(cblock), + [cnt] "+r"(cnt) + : [bias] "r"(bias_h), [relu] "r"(has_relu) + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", + "q9", "q10", "q11", "cc", "memory"); +#endif + // clang-format on + } + if (remain > 0) { + int cnt = kcnt; + const float* ablock_ptr = ablock; +// clang-format off +#ifdef __aarch64__ + asm volatile( + "prfm pldl1keep, [%[a]] \n" + "prfm pldl1keep, [%[b]] \n" + "ld1 {v1.4s, v2.4s}, [%[a]], #32 \n" + "cmp %w[remain], #3 \n" + "beq 1f \n" + "cmp %w[remain], #2 \n" + "beq 2f \n" + /* remain 1 */ + "mov v9.16b, %[vbias].16b \n" /* mov bias to c0*/ + "mov v10.16b, %[vzero].16b \n" /* mov zero to c1*/ + "3: \n" + "ld1 {v5.4s}, [%[b]], #16 \n" + "ld1 {v3.4s, v4.4s}, [%[a]], #32 \n" + "fmla v9.4s, v1.4s, v5.s[0] \n" + "fmla v10.4s, v2.4s, v5.s[1] \n" + "subs %w[cnt], %w[cnt], #1 \n" + "ld1 {v1.4s, v2.4s}, [%[a]], #32 \n" + "fmla v9.4s, v3.4s, v5.s[2] \n" + "fmla v10.4s, v4.4s, v5.s[3] \n" + "bne 3b \n" + "fadd v9.4s, v9.4s, v10.4s \n" + "cbz %w[relu], 6f \n" + "fmax v9.4s, v9.4s, %[vzero].4s \n" + "6: \n" + "st1 {v9.4s}, [%[c]], #16 \n" + "b 9f \n" + /* remain 2 */ + "2: \n" + "mov v9.16b, %[vbias].16b \n" /* mov bias to c0*/ + "mov v10.16b, %[vbias].16b \n" /* mov bias to c1*/ + "mov v11.16b, %[vzero].16b \n" /* mov zero to c2*/ + "mov v12.16b, %[vzero].16b \n" /* mov zero to c3*/ + "4: \n" + "ld1 {v5.4s, v6.4s}, [%[b]], #32 \n" + "ld1 {v3.4s, v4.4s}, [%[a]], #32 \n" + "fmla v9.4s, v1.4s, v5.s[0] \n" + "fmla v10.4s, v1.4s, v6.s[0] \n" + "fmla v11.4s, v2.4s, v5.s[1] \n" + "fmla v12.4s, v2.4s, v6.s[1] \n" + "subs %w[cnt], %w[cnt], #1 \n" + "fmla v9.4s, v3.4s, v5.s[2] \n" + "fmla v10.4s, v3.4s, v6.s[2] \n" + "fmla v11.4s, v4.4s, v5.s[3] \n" + "fmla v12.4s, v4.4s, v6.s[3] \n" + "ld1 {v1.4s, v2.4s}, [%[a]], #32 \n" + "bne 4b \n" + "fadd v9.4s, v9.4s, v11.4s \n" + "fadd v10.4s, v10.4s, v12.4s \n" + "cbz %w[relu], 7f \n" + "fmax v9.4s, v9.4s, %[vzero].4s \n" + "fmax v10.4s, v10.4s, %[vzero].4s \n" + "7: \n" + "st1 {v9.4s, v10.4s}, [%[c]], #32 \n" + "b 9f \n" + /* remain 3 */ + "1: \n" + "mov v9.16b, %[vbias].16b \n" /* mov bias to c0*/ + "mov v10.16b, %[vbias].16b \n" /* mov bias to c1*/ + "mov v11.16b, %[vbias].16b \n" /* mov bias to c2*/ + "5: \n" + "ld1 {v5.4s, v6.4s}, [%[b]], #32 \n" + "ld1 {v7.4s}, [%[b]], #16 \n" + "fmla v9.4s, v1.4s, v5.s[0] \n" + "fmla v10.4s, v1.4s, v6.s[0] \n" + "fmla v11.4s, v1.4s, v7.s[0] \n" + "ld1 {v3.4s, v4.4s}, [%[a]], #32 \n" + "fmla v9.4s, v2.4s, v5.s[1] \n" + "fmla v10.4s, v2.4s, v6.s[1] \n" + "fmla v11.4s, v2.4s, v7.s[1] \n" + "subs %w[cnt], %w[cnt], #1 \n" + "fmla v9.4s, v3.4s, v5.s[2] \n" + "fmla v10.4s, v3.4s, v6.s[2] \n" + "fmla v11.4s, v3.4s, v7.s[2] \n" + "prfm pldl1keep, [%[a]] \n" + "fmla v9.4s, v4.4s, v5.s[3] \n" + "fmla v10.4s, v4.4s, v6.s[3] \n" + "fmla v11.4s, v4.4s, v7.s[3] \n" + "ld1 {v1.4s, v2.4s}, [%[a]], #32 \n" + "bne 5b \n" + "cbz %w[relu], 8f \n" + "fmax v9.4s, v9.4s, %[vzero].4s \n" + "fmax v10.4s, v10.4s, %[vzero].4s \n" + "fmax v11.4s, v11.4s, %[vzero].4s \n" + "8: \n" + "st1 {v9.4s, v10.4s}, [%[c]], #32 \n" + "st1 {v11.4s}, [%[c]], #16 \n" + "9:\n" + : [a] "+r"(ablock_ptr), + [b] "+r"(bblock), + [c] "+r"(cblock), + [cnt] "+r"(cnt) + : [bias] "r"(bias_h), [relu] "r"(has_relu), + [remain] "r"(remain), [vbias] "w"(vbias), + [vzero] "w" (vzero) + : "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v9", + "v10", "v11", "v12", "cc","memory"); +#else + asm volatile( + "pld [%[a]] \n" + "pld [%[b]] \n" + "vld1.32 {d0-d1}, [%[bias]] \n" + "vld1.32 {d2-d5}, [%[a]]! \n" + "vmov.u32 q15, #0 \n" + "cmp %[remain], #3 \n" + "beq 1f \n" + "cmp %[remain], #2 \n" + "beq 2f \n" + /* remain 1 */ + "vmov.32 q9, q0 \n" /* mov bias to c0*/ + "vmov.32 q10, q15 \n" /* mov zero to c1*/ + "3: \n" + "vld1.32 {d10-d11}, [%[b]]! \n" + "vld1.32 {d6-d9}, [%[a]]! \n" + "vmla.f32 q9, q1, d10[0] \n" + "vmla.f32 q10, q2, d10[1] \n" + "subs %[cnt], %[cnt], #1 \n" + "vld1.32 {d2-d5}, [%[a]]! \n" + "vmla.f32 q9, q3, d11[0] \n" + "vmla.f32 q10, q4, d11[1] \n" + "bne 3b \n" + "vadd.f32 q9, q9, q10 \n" + "cmp %[relu], #0 \n" + "beq 6f \n" + "vmax.f32 q9, q9, q15 \n" + "6: \n" + "vst1.32 {d18-d19}, [%[c]]! \n" + "b 9f \n" + /* remain 2 */ + "2: \n" + "vmov.u32 q9, q0 \n" /* mov bias to c0*/ + "vmov.u32 q10, q0 \n" /* mov bias to c1*/ + "vmov.u32 q11, q15 \n" /* mov zero to c2*/ + "vmov.u32 q12, q15 \n" /* mov zero to c3*/ + "4: \n" + "vld1.32 {d10-d13}, [%[b]]! \n" + "vld1.32 {d6-d9}, [%[a]]! \n" + "vmla.f32 q9, q1, d10[0] \n" + "vmla.f32 q10, q1, d12[0] \n" + "vmla.f32 q11, q2, d10[1] \n" + "vmla.f32 q12, q2, d12[1] \n" + "subs %[cnt], %[cnt], #1 \n" + "vmla.f32 q9, q3, d11[0] \n" + "vmla.f32 q10, q3, d13[0] \n" + "vmla.f32 q11, q4, d11[1] \n" + "vmla.f32 q12, q4, d13[1] \n" + "vld1.32 {d2-d5}, [%[a]]! \n" + "bne 4b \n" + "vadd.f32 q9, q9, q11 \n" + "vadd.f32 q10, q10, q12 \n" + "cmp %[relu], #0 \n" + "beq 7f \n" + "vmax.f32 q9, q9, q15 \n" + "vmax.f32 q10, q10, q15 \n" + "7: \n" + "vst1.32 {d18-d21}, [%[c]]! \n" + "b 9f \n" + /* remain 3 */ + "1: \n" + "vmov.u32 q9, q0 \n" /* mov bias to c0*/ + "vmov.u32 q10, q0 \n" /* mov bias to c1*/ + "vmov.u32 q11, q0 \n" /* mov bias to c2*/ + "5: \n" + "vld1.32 {d10-d13}, [%[b]]! \n" + "vld1.32 {d14-d15}, [%[b]]! \n" + "vmla.f32 q9, q1, d10[0] \n" + "vmla.f32 q10, q1, d12[0] \n" + "vmla.f32 q11, q1, d14[0] \n" + "vld1.32 {d6-d9}, [%[a]]! \n" + "vmla.f32 q9, q2, d10[1] \n" + "vmla.f32 q10, q2, d12[1] \n" + "vmla.f32 q11, q2, d14[1] \n" + "subs %[cnt], %[cnt], #1 \n" + "vmla.f32 q9, q3, d11[0] \n" + "vmla.f32 q10, q3, d13[0] \n" + "vmla.f32 q11, q3, d15[0] \n" + "pld [%[a]] \n" + "vmla.f32 q9, q4, d11[1] \n" + "vmla.f32 q10, q4, d13[1] \n" + "vmla.f32 q11, q4, d15[1] \n" + "vld1.32 {d2-d5}, [%[a]]! \n" + "bne 5b \n" + "cmp %[relu], #0 \n" + "beq 8f \n" + "vmax.f32 q9, q9, q15 \n" + "vmax.f32 q10, q10, q15 \n" + "vmax.f32 q11, q11, q15 \n" + "8: \n" + "vst1.32 {d18-d21}, [%[c]]! \n" + "vst1.32 {d22-d23}, [%[c]]! \n" + "9:\n" + : [a] "+r"(ablock_ptr), + [b] "+r"(bblock), + [c] "+r"(cblock), + [cnt] "+r"(cnt) + : [bias] "r"(bias_h), + [relu] "r"(has_relu), + [remain] "r"(remain) + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q9", + "q10", "q11", "q12", "q15", "cc","memory"); +#endif + // clang-format on + } + } + } + } +} +void sgemm_prepack_c4_small(int M, + int N, + int K, + const float* A_packed, + const float* B, + float* C, + const float* bias, + bool has_bias, + bool has_relu, + ARMContext* ctx) { + const int m_round = (M + 3) / 4 * 4; + const int k_round = (K + 3) / 4 * 4; + const int mloop = m_round >> 2; + const int lda = 4 * k_round; + const int ldb_byte = 4 * N * sizeof(float); + const int kcnt = k_round >> 2; + float bias_buf[m_round]; // NOLINT + if (has_bias) { + memcpy(bias_buf, bias, M * sizeof(float)); + memset(bias_buf + M, 0, (m_round - M) * sizeof(float)); + } else { + memset(bias_buf, 0, m_round * sizeof(float)); + } +#ifdef __aarch64__ + float32x4_t vzero = vdupq_n_f32(0.f); +#endif + const float* bias_ptr = bias_buf; + for (int m = 0; m < mloop; ++m) { +#ifdef __aarch64__ + float32x4_t vbias = vld1q_f32(bias_ptr); +#endif + const float* b = B; + int n = N; +#ifdef __aarch64__ + for (; n > 7; n -= 8) { + int cnt = kcnt; + const float* a_ptr = A_packed; + const float* b_ptr = b; + // clang-format off + asm volatile( + /* load a0, a1 */ + "ld1 {v16.4s, v17.4s}, [%[a]], #32 \n" + /* mov bias to c0-c7*/ + "mov v8.16b, %[vbias].16b \n" + "mov v9.16b, %[vbias].16b \n" + "mov v10.16b, %[vbias].16b \n" + "mov v11.16b, %[vbias].16b \n" + /* load b0, b1 */ + "ld1 {v0.4s, v1.4s}, [%[b]], #32 \n" + "mov v12.16b, %[vbias].16b \n" + "mov v13.16b, %[vbias].16b \n" + "mov v14.16b, %[vbias].16b \n" + "mov v15.16b, %[vbias].16b \n" + "1:\n" + /* load b2, b3 */ + "ld1 {v2.4s, v3.4s}, [%[b]], #32 \n" + /* load a2, a3 */ + "ld1 {v18.4s, v19.4s}, [%[a]], #32 \n" + "fmla v8.4s, v16.4s, v0.s[0] \n" + "fmla v9.4s, v16.4s, v1.s[0] \n" + "fmla v10.4s, v16.4s, v2.s[0] \n" + "fmla v11.4s, v16.4s, v3.s[0] \n" + "prfm pldl1keep, [%[b]] \n" + "fmla v8.4s, v17.4s, v0.s[1] \n" + "fmla v9.4s, v17.4s, v1.s[1] \n" + "fmla v10.4s, v17.4s, v2.s[1] \n" + "fmla v11.4s, v17.4s, v3.s[1] \n" + /* load b4, b5 */ + "ld1 {v4.4s, v5.4s}, [%[b]], #32 \n" + "fmla v8.4s, v18.4s, v0.s[2] \n" + "fmla v9.4s, v18.4s, v1.s[2] \n" + "fmla v10.4s, v18.4s, v2.s[2] \n" + "fmla v11.4s, v18.4s, v3.s[2] \n" + /* load b6, b7 */ + "ld1 {v6.4s, v7.4s}, [%[b]], #32 \n" + "fmla v8.4s, v19.4s, v0.s[3] \n" + "fmla v9.4s, v19.4s, v1.s[3] \n" + "fmla v10.4s, v19.4s, v2.s[3] \n" + "fmla v11.4s, v19.4s, v3.s[3] \n" + "sub %[b], %[b], #128 \n" + "fmla v12.4s, v16.4s, v4.s[0] \n" + "fmla v13.4s, v16.4s, v5.s[0] \n" + "fmla v14.4s, v16.4s, v6.s[0] \n" + "fmla v15.4s, v16.4s, v7.s[0] \n" + "add %[b], %[b], %[ldb] \n" + "fmla v12.4s, v17.4s, v4.s[1] \n" + "fmla v13.4s, v17.4s, v5.s[1] \n" + "fmla v14.4s, v17.4s, v6.s[1] \n" + "fmla v15.4s, v17.4s, v7.s[1] \n" + /* load a0, a1 */ + "ld1 {v16.4s, v17.4s}, [%[a]], #32 \n" + "fmla v12.4s, v18.4s, v4.s[2] \n" + "fmla v13.4s, v18.4s, v5.s[2] \n" + "fmla v14.4s, v18.4s, v6.s[2] \n" + "fmla v15.4s, v18.4s, v7.s[2] \n" + /* load b0, b1 */ + "ld1 {v0.4s, v1.4s}, [%[b]], #32 \n" + "fmla v12.4s, v19.4s, v4.s[3] \n" + "fmla v13.4s, v19.4s, v5.s[3] \n" + "fmla v14.4s, v19.4s, v6.s[3] \n" + "fmla v15.4s, v19.4s, v7.s[3] \n" + "subs %w[cnt], %w[cnt], #1 \n" + "bne 1b \n" + "cbz %w[relu], 2f \n" + "fmax v8.4s, v8.4s, %[vzero].4s \n" + "fmax v9.4s, v9.4s, %[vzero].4s \n" + "fmax v10.4s, v10.4s, %[vzero].4s \n" + "fmax v11.4s, v11.4s, %[vzero].4s \n" + "fmax v12.4s, v12.4s, %[vzero].4s \n" + "fmax v13.4s, v13.4s, %[vzero].4s \n" + "fmax v14.4s, v14.4s, %[vzero].4s \n" + "fmax v15.4s, v15.4s, %[vzero].4s \n" + "2:\n" + "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[c]], #64 \n" + "st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[c]], #64 \n" + : [a] "+r" (a_ptr), + [b] "+r" (b_ptr), + [c] "+r" (C), + [cnt] "+r" (cnt) + : [relu] "r" (has_relu), + [ldb] "r" (ldb_byte), + [vbias] "w" (vbias), + [vzero] "w" (vzero) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", + "v19", "cc", "memory" + ); + b += 4 * 8; + } + for (; n > 3; n -= 4) { + int cnt = kcnt; + const float* a_ptr = A_packed; + const float* b_ptr = b; + asm volatile( + /* load a0, a1 */ + "ld1 {v16.4s, v17.4s}, [%[a]], #32 \n" + /* mov bias to c0-c3*/ + "mov v8.16b, %[vbias].16b \n" + "mov v9.16b, %[vbias].16b \n" + "mov v10.16b, %[vbias].16b \n" + "mov v11.16b, %[vbias].16b \n" + "1:\n" + /* load b0-b3 */ + "ld1 {v0.4s, v1.4s}, [%[b]], #32 \n" + "ld1 {v2.4s, v3.4s}, [%[b]], #32 \n" + /* load a2, a3 */ + "ld1 {v18.4s, v19.4s}, [%[a]], #32 \n" + "fmla v8.4s, v16.4s, v0.s[0] \n" + "fmla v9.4s, v16.4s, v1.s[0] \n" + "fmla v10.4s, v16.4s, v2.s[0] \n" + "fmla v11.4s, v16.4s, v3.s[0] \n" + "sub %[b], %[b], #64 \n" + "fmla v8.4s, v17.4s, v0.s[1] \n" + "fmla v9.4s, v17.4s, v1.s[1] \n" + "fmla v10.4s, v17.4s, v2.s[1] \n" + "fmla v11.4s, v17.4s, v3.s[1] \n" + "add %[b], %[b], %[ldb] \n" + "fmla v8.4s, v18.4s, v0.s[2] \n" + "fmla v9.4s, v18.4s, v1.s[2] \n" + "fmla v10.4s, v18.4s, v2.s[2] \n" + "fmla v11.4s, v18.4s, v3.s[2] \n" + /* load a0, a1 */ + "ld1 {v16.4s, v17.4s}, [%[a]], #32 \n" + "fmla v8.4s, v19.4s, v0.s[3] \n" + "fmla v9.4s, v19.4s, v1.s[3] \n" + "fmla v10.4s, v19.4s, v2.s[3] \n" + "fmla v11.4s, v19.4s, v3.s[3] \n" + "subs %w[cnt], %w[cnt], #1 \n" + "bne 1b \n" + "cbz %w[relu], 2f \n" + "fmax v8.4s, v8.4s, %[vzero].4s \n" + "fmax v9.4s, v9.4s, %[vzero].4s \n" + "fmax v10.4s, v10.4s, %[vzero].4s \n" + "fmax v11.4s, v11.4s, %[vzero].4s \n" + "2:\n" + "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[c]], #64 \n" + : [a] "+r" (a_ptr), + [b] "+r" (b_ptr), + [c] "+r" (C), + [cnt] "+r" (cnt) + : [relu] "r" (has_relu), + [ldb] "r" (ldb_byte), + [vbias] "w" (vbias), + [vzero] "w" (vzero) + : "v0", "v1", "v2", "v3", "v8", "v9", + "v10", "v11", "v16", "v17", "v18", + "v19", "cc", "memory" + ); + b += 4 * 4; + } + for (; n > 0; n--) { + int cnt = kcnt; + const float* a_ptr = A_packed; + const float* b_ptr = b; + asm volatile( + /* load a0, a1 */ + "ld1 {v16.4s, v17.4s}, [%[a]], #32 \n" + /* mov bias to c0 */ + "mov v8.16b, %[vbias].16b \n" + "mov v9.16b, %[vzero].16b \n" + "1:\n" + /* load b0 */ + "ld1 {v0.4s}, [%[b]], #16 \n" + /* load a2, a3 */ + "ld1 {v18.4s, v19.4s}, [%[a]], #32 \n" + "fmla v8.4s, v16.4s, v0.s[0] \n" + "fmla v9.4s, v17.4s, v0.s[1] \n" + "sub %[b], %[b], #16 \n" + "subs %w[cnt], %w[cnt], #1 \n" + "add %[b], %[b], %[ldb] \n" + "fmla v8.4s, v18.4s, v0.s[2] \n" + "fmla v9.4s, v19.4s, v0.s[3] \n" + /* load a0, a1 */ + "ld1 {v16.4s, v17.4s}, [%[a]], #32 \n" + "bne 1b \n" + "fadd v8.4s, v8.4s, v9.4s \n" + "cbz %w[relu], 2f \n" + "fmax v8.4s, v8.4s, %[vzero].4s \n" + "2:\n" + "st1 {v8.4s}, [%[c]], #16 \n" + : [a] "+r" (a_ptr), + [b] "+r" (b_ptr), + [c] "+r" (C), + [cnt] "+r" (cnt) + : [relu] "r" (has_relu), + [ldb] "r" (ldb_byte), + [vbias] "w" (vbias), + [vzero] "w" (vzero) + : "v0", "v8", "v9", "v16", "v17", + "v18", "v19", "cc", "memory" + ); + b += 4; + } +#else + for (; n > 7; n -= 8) { + int cnt = kcnt; + const float* a_ptr = A_packed; + const float* b_ptr = b; + // clang-format off + asm volatile( + "vld1.32 {d6-d7}, [%[bias]] \n" + /* load a0, a1 */ + "vld1.32 {d8-d11}, [%[a]]! \n" + /* mov bias to c0-c7*/ + "vmov.u32 q8, q3 \n" + "vmov.u32 q9, q3 \n" + "vmov.u32 q10, q3 \n" + "vmov.u32 q11, q3 \n" + /* load b0, b1 */ + "vld1.32 {d0-d3}, [%[b]]! \n" + "vmov.u32 q12, q3 \n" + "vmov.u32 q13, q3 \n" + "vmov.u32 q14, q3 \n" + "vmov.u32 q15, q3 \n" + "1:\n" + /* load b2, b3 */ + "vld1.32 {d4-d7}, [%[b]]! \n" + /* load a2, a3 */ + "vld1.32 {d12-d15}, [%[a]]! \n" + "vmla.f32 q8, q4, d0[0] \n" + "vmla.f32 q9, q4, d2[0] \n" + "vmla.f32 q10, q4, d4[0] \n" + "vmla.f32 q11, q4, d6[0] \n" + "pld [%[b]] \n" + "vmla.f32 q8, q5, d0[1] \n" + "vmla.f32 q9, q5, d2[1] \n" + "vmla.f32 q10, q5, d4[1] \n" + "vmla.f32 q11, q5, d6[1] \n" + "subs %[cnt], %[cnt], #1 \n" + "vmla.f32 q8, q6, d1[0] \n" + "vmla.f32 q9, q6, d3[0] \n" + "vmla.f32 q10, q6, d5[0] \n" + "vmla.f32 q11, q6, d7[0] \n" + "pld [%[b], #64] \n" + "vmla.f32 q8, q7, d1[1] \n" + "vmla.f32 q9, q7, d3[1] \n" + /* load b4, b5 */ + "vld1.32 {d0-d3}, [%[b]]! \n" + "vmla.f32 q10, q7, d5[1] \n" + "vmla.f32 q11, q7, d7[1] \n" + /* load b6, b7 */ + "vld1.32 {d4-d7}, [%[b]]! \n" + "vmla.f32 q12, q4, d0[0] \n" + "vmla.f32 q13, q4, d2[0] \n" + "vmla.f32 q14, q4, d4[0] \n" + "vmla.f32 q15, q4, d6[0] \n" + "sub %[b], %[b], #128 \n" + "vmla.f32 q12, q5, d0[1] \n" + "vmla.f32 q13, q5, d2[1] \n" + "vmla.f32 q14, q5, d4[1] \n" + "vmla.f32 q15, q5, d6[1] \n" + "add %[b], %[b], %[ldb] \n" + "vmla.f32 q12, q6, d1[0] \n" + "vmla.f32 q13, q6, d3[0] \n" + "vmla.f32 q14, q6, d5[0] \n" + "vmla.f32 q15, q6, d7[0] \n" + /* load a0, a1 */ + "vld1.32 {d8-d11}, [%[a]]! \n" + "vmla.f32 q12, q7, d1[1] \n" + "vmla.f32 q13, q7, d3[1] \n" + /* load b0, b1 */ + "vld1.32 {d0-d3}, [%[b]]! \n" + "vmla.f32 q14, q7, d5[1] \n" + "vmla.f32 q15, q7, d7[1] \n" + "bne 1b \n" + "cmp %[relu], #0 \n" + "beq 2f \n" + "vmov.u32 q0, #0 \n" + "vmax.f32 q8, q8, q0 \n" + "vmax.f32 q9, q9, q0 \n" + "vmax.f32 q10, q10, q0 \n" + "vmax.f32 q11, q11, q0 \n" + "vmax.f32 q12, q12, q0 \n" + "vmax.f32 q13, q13, q0 \n" + "vmax.f32 q14, q14, q0 \n" + "vmax.f32 q15, q15, q0 \n" + "2:\n" + "vst1.32 {d16-d19}, [%[c]]! \n" + "vst1.32 {d20-d23}, [%[c]]! \n" + "vst1.32 {d24-d27}, [%[c]]! \n" + "vst1.32 {d28-d31}, [%[c]]! \n" + : [a] "+r" (a_ptr), + [b] "+r" (b_ptr), + [c] "+r" (C), + [cnt] "+r" (cnt) + : [relu] "r" (has_relu), + [ldb] "r" (ldb_byte), + [bias] "r" (bias_ptr) + : "q0", "q1", "q2", "q3", "q4", "q5", + "q6", "q7", "q8", "q9", "q10", "q11", + "q12", "q13", "q14", "q15", "cc", "memory" + ); + b += 4 * 8; + } + for (; n > 3; n -= 4) { + int cnt = kcnt; + const float* a_ptr = A_packed; + const float* b_ptr = b; + asm volatile( + "vld1.32 {d24-d25}, [%[bias]] \n" + /* load a0, a1 */ + "vld1.32 {d8-d11}, [%[a]]! \n" + /* mov bias to c0-c3*/ + "vmov.u32 q8, q12 \n" + "vmov.u32 q9, q12 \n" + "vmov.u32 q10, q12 \n" + "vmov.u32 q11, q12 \n" + "vmov.u32 q13, #0 \n" + "1:\n" + /* load b0-b3 */ + "vld1.32 {d0-d3}, [%[b]]! \n" + "vld1.32 {d4-d7}, [%[b]]! \n" + /* load a2, a3 */ + "vld1.32 {d12-d15}, [%[a]]!\n" + "vmla.f32 q8, q4, d0[0] \n" + "vmla.f32 q9, q4, d2[0] \n" + "vmla.f32 q10, q4, d4[0] \n" + "vmla.f32 q11, q4, d6[0] \n" + "sub %[b], %[b], #64 \n" + "vmla.f32 q8, q5, d0[1] \n" + "vmla.f32 q9, q5, d2[1] \n" + "vmla.f32 q10, q5, d4[1] \n" + "vmla.f32 q11, q5, d6[1] \n" + "add %[b], %[b], %[ldb] \n" + "vmla.f32 q8, q6, d1[0] \n" + "vmla.f32 q9, q6, d3[0] \n" + "vmla.f32 q10, q6, d5[0] \n" + "vmla.f32 q11, q6, d7[0] \n" + /* load a0, a1 */ + "vld1.32 {d8-d11}, [%[a]]! \n" + "vmla.f32 q8, q7, d1[1] \n" + "vmla.f32 q9, q7, d3[1] \n" + "vmla.f32 q10, q7, d5[1] \n" + "vmla.f32 q11, q7, d7[1] \n" + "subs %[cnt], %[cnt], #1 \n" + "bne 1b \n" + "cmp %[relu], #0 \n" + "beq 2f \n" + "vmax.f32 q8, q8, q13 \n" + "vmax.f32 q9, q9, q13 \n" + "vmax.f32 q10, q10, q13 \n" + "vmax.f32 q11, q11, q13 \n" + "2:\n" + "vst1.32 {d16-d19}, [%[c]]!\n" + "vst1.32 {d20-d23}, [%[c]]!\n" + : [a] "+r" (a_ptr), + [b] "+r" (b_ptr), + [c] "+r" (C), + [cnt] "+r" (cnt) + : [relu] "r" (has_relu), + [ldb] "r" (ldb_byte), + [bias] "r" (bias_ptr) + : "q0", "q1", "q2", "q3", "q4", "q5", + "q6", "q7", "q8", "q9", "q10", "q11", + "q12", "q13", "cc", "memory" + ); + b += 4 * 4; + } + for (; n > 0; n--) { + int cnt = kcnt; + const float* a_ptr = A_packed; + const float* b_ptr = b; + asm volatile( + "vld1.32 {d14-d15}, [%[bias]] \n" + "vmov.u32 q8, #0 \n" + /* load a0, a1 */ + "vld1.32 {d2-d5}, [%[a]]! \n" + /* mov bias to c0 */ + "vmov.u32 q5, q7 \n" + "vmov.u32 q6, q8 \n" + "1:\n" + /* load b0 */ + "vld1.32 {d0-d1}, [%[b]]! \n" + /* load a2, a3 */ + "vld1.32 {d6-d9}, [%[a]]! \n" + "vmla.f32 q5, q1, d0[0] \n" + "vmla.f32 q6, q2, d0[1] \n" + "sub %[b], %[b], #16 \n" + "subs %[cnt], %[cnt], #1 \n" + "add %[b], %[b], %[ldb] \n" + "vmla.f32 q5, q3, d1[0] \n" + "vmla.f32 q6, q4, d1[1] \n" + /* load a0, a1 */ + "vld1.32 {d2-d5}, [%[a]]! \n" + "bne 1b \n" + "vadd.f32 q5, q5, q6 \n" + "cmp %[relu], #0 \n" + "beq 2f \n" + "vmax.f32 q5, q5, q8 \n" + "2:\n" + "vst1.32 {d10-d11}, [%[c]]!\n" + : [a] "+r" (a_ptr), + [b] "+r" (b_ptr), + [c] "+r" (C), + [cnt] "+r" (cnt) + : [relu] "r" (has_relu), + [ldb] "r" (ldb_byte), + [bias] "r" (bias_ptr) + : "q0", "q1", "q2", "q3", "q4", + "q5", "q6", "q7", "q8", "cc", "memory" + ); + // clang-format on + b += 4; + } +#endif + bias_ptr += 4; + A_packed += lda; + } +} + +void sgemm_prepack_c4_small(int M, + int N, + int K, + const float* A_packed, + const float* B, + float* C, + ARMContext* ctx) { + const int m_round = (M + 3) / 4 * 4; + const int k_round = (K + 3) / 4 * 4; + const int mloop = m_round >> 2; + const int lda = 4 * k_round; + const int ldb_byte = 4 * N * sizeof(float); + const int kcnt = k_round >> 2; +#ifdef __aarch64__ + float32x4_t vzero = vdupq_n_f32(0.f); +#endif + for (int m = 0; m < mloop; ++m) { + const float* b = B; + int n = N; +#ifdef __aarch64__ + for (; n > 7; n -= 8) { + int cnt = kcnt; + const float* a_ptr = A_packed; + const float* b_ptr = b; + // clang-format off + asm volatile( + "0:\n" + /* load a0, a1 */ + "ld1 {v16.4s, v17.4s}, [%[a]], #32 \n" + /* load b0, b1 */ + "ld1 {v0.4s, v1.4s}, [%[b]], #32 \n" + /* load b2, b3 */ + "ld1 {v2.4s, v3.4s}, [%[b]], #32 \n" + /* load a2, a3 */ + "fmul v8.4s, v16.4s, v0.s[0] \n" + "fmul v9.4s, v16.4s, v1.s[0] \n" + "fmul v10.4s, v16.4s, v2.s[0] \n" + "fmul v11.4s, v16.4s, v3.s[0] \n" + "ld1 {v18.4s, v19.4s}, [%[a]], #32 \n" + "prfm pldl1keep, [%[b]] \n" + "fmla v8.4s, v17.4s, v0.s[1] \n" + "fmla v9.4s, v17.4s, v1.s[1] \n" + "fmla v10.4s, v17.4s, v2.s[1] \n" + "fmla v11.4s, v17.4s, v3.s[1] \n" + /* load b4, b5 */ + "ld1 {v4.4s, v5.4s}, [%[b]], #32 \n" + "fmla v8.4s, v18.4s, v0.s[2] \n" + "fmla v9.4s, v18.4s, v1.s[2] \n" + "fmla v10.4s, v18.4s, v2.s[2] \n" + "fmla v11.4s, v18.4s, v3.s[2] \n" + /* load b6, b7 */ + "ld1 {v6.4s, v7.4s}, [%[b]], #32 \n" + "fmla v8.4s, v19.4s, v0.s[3] \n" + "fmla v9.4s, v19.4s, v1.s[3] \n" + "fmla v10.4s, v19.4s, v2.s[3] \n" + "fmla v11.4s, v19.4s, v3.s[3] \n" + "sub %[b], %[b], #128 \n" + "fmul v12.4s, v16.4s, v4.s[0] \n" + "fmul v13.4s, v16.4s, v5.s[0] \n" + "fmul v14.4s, v16.4s, v6.s[0] \n" + "fmul v15.4s, v16.4s, v7.s[0] \n" + "add %[b], %[b], %[ldb] \n" + "fmla v12.4s, v17.4s, v4.s[1] \n" + "fmla v13.4s, v17.4s, v5.s[1] \n" + "fmla v14.4s, v17.4s, v6.s[1] \n" + "fmla v15.4s, v17.4s, v7.s[1] \n" + /* load a0, a1 */ + "ld1 {v16.4s, v17.4s}, [%[a]], #32 \n" + "fmla v12.4s, v18.4s, v4.s[2] \n" + "fmla v13.4s, v18.4s, v5.s[2] \n" + "fmla v14.4s, v18.4s, v6.s[2] \n" + "fmla v15.4s, v18.4s, v7.s[2] \n" + /* load b0, b1 */ + "ld1 {v0.4s, v1.4s}, [%[b]], #32 \n" + "fmla v12.4s, v19.4s, v4.s[3] \n" + "fmla v13.4s, v19.4s, v5.s[3] \n" + "fmla v14.4s, v19.4s, v6.s[3] \n" + "fmla v15.4s, v19.4s, v7.s[3] \n" + "subs %w[cnt], %w[cnt], #1 \n" + "beq 2f \n" + "1:\n" + /* load b2, b3 */ + "ld1 {v2.4s, v3.4s}, [%[b]], #32 \n" + "fmla v8.4s, v16.4s, v0.s[0] \n" + "fmla v9.4s, v16.4s, v1.s[0] \n" + "fmla v10.4s, v16.4s, v2.s[0] \n" + "fmla v11.4s, v16.4s, v3.s[0] \n" + /* load a2, a3 */ + "ld1 {v18.4s, v19.4s}, [%[a]], #32 \n" + "prfm pldl1keep, [%[b]] \n" + "fmla v8.4s, v17.4s, v0.s[1] \n" + "fmla v9.4s, v17.4s, v1.s[1] \n" + "fmla v10.4s, v17.4s, v2.s[1] \n" + "fmla v11.4s, v17.4s, v3.s[1] \n" + /* load b4, b5 */ + "ld1 {v4.4s, v5.4s}, [%[b]], #32 \n" + "fmla v8.4s, v18.4s, v0.s[2] \n" + "fmla v9.4s, v18.4s, v1.s[2] \n" + "fmla v10.4s, v18.4s, v2.s[2] \n" + "fmla v11.4s, v18.4s, v3.s[2] \n" + /* load b6, b7 */ + "ld1 {v6.4s, v7.4s}, [%[b]], #32 \n" + "fmla v8.4s, v19.4s, v0.s[3] \n" + "fmla v9.4s, v19.4s, v1.s[3] \n" + "fmla v10.4s, v19.4s, v2.s[3] \n" + "fmla v11.4s, v19.4s, v3.s[3] \n" + "sub %[b], %[b], #128 \n" + "fmla v12.4s, v16.4s, v4.s[0] \n" + "fmla v13.4s, v16.4s, v5.s[0] \n" + "fmla v14.4s, v16.4s, v6.s[0] \n" + "fmla v15.4s, v16.4s, v7.s[0] \n" + "add %[b], %[b], %[ldb] \n" + "fmla v12.4s, v17.4s, v4.s[1] \n" + "fmla v13.4s, v17.4s, v5.s[1] \n" + "fmla v14.4s, v17.4s, v6.s[1] \n" + "fmla v15.4s, v17.4s, v7.s[1] \n" + /* load a0, a1 */ + "ld1 {v16.4s, v17.4s}, [%[a]], #32 \n" + "fmla v12.4s, v18.4s, v4.s[2] \n" + "fmla v13.4s, v18.4s, v5.s[2] \n" + "fmla v14.4s, v18.4s, v6.s[2] \n" + "fmla v15.4s, v18.4s, v7.s[2] \n" + /* load b0, b1 */ + "ld1 {v0.4s, v1.4s}, [%[b]], #32 \n" + "fmla v12.4s, v19.4s, v4.s[3] \n" + "fmla v13.4s, v19.4s, v5.s[3] \n" + "fmla v14.4s, v19.4s, v6.s[3] \n" + "fmla v15.4s, v19.4s, v7.s[3] \n" + "subs %w[cnt], %w[cnt], #1 \n" + "bne 1b \n" + "2:\n" + "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[c]], #64 \n" + "st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[c]], #64 \n" + : [a] "+r" (a_ptr), + [b] "+r" (b_ptr), + [c] "+r" (C), + [cnt] "+r" (cnt) + : [ldb] "r" (ldb_byte), + [vzero] "w" (vzero) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", + "v19", "cc", "memory" + ); + b += 4 * 8; + } + for (; n > 3; n -= 4) { + int cnt = kcnt; + const float* a_ptr = A_packed; + const float* b_ptr = b; + asm volatile( + "0:\n" + /* load a0, a1 */ + "ld1 {v16.4s, v17.4s}, [%[a]], #32 \n" + /* load b0-b3 */ + "ld1 {v0.4s, v1.4s}, [%[b]], #32 \n" + "ld1 {v2.4s, v3.4s}, [%[b]], #32 \n" + "fmul v8.4s, v16.4s, v0.s[0] \n" + "fmul v9.4s, v16.4s, v1.s[0] \n" + "fmul v10.4s, v16.4s, v2.s[0] \n" + "fmul v11.4s, v16.4s, v3.s[0] \n" + /* load a2, a3 */ + "ld1 {v18.4s, v19.4s}, [%[a]], #32 \n" + "sub %[b], %[b], #64 \n" + "fmla v8.4s, v17.4s, v0.s[1] \n" + "fmla v9.4s, v17.4s, v1.s[1] \n" + "fmla v10.4s, v17.4s, v2.s[1] \n" + "fmla v11.4s, v17.4s, v3.s[1] \n" + "add %[b], %[b], %[ldb] \n" + "fmla v8.4s, v18.4s, v0.s[2] \n" + "fmla v9.4s, v18.4s, v1.s[2] \n" + "fmla v10.4s, v18.4s, v2.s[2] \n" + "fmla v11.4s, v18.4s, v3.s[2] \n" + /* load a0, a1 */ + "ld1 {v16.4s, v17.4s}, [%[a]], #32 \n" + "fmla v8.4s, v19.4s, v0.s[3] \n" + "fmla v9.4s, v19.4s, v1.s[3] \n" + "fmla v10.4s, v19.4s, v2.s[3] \n" + "fmla v11.4s, v19.4s, v3.s[3] \n" + "subs %w[cnt], %w[cnt], #1 \n" + "beq 2f \n" + "1:\n" + /* load b0-b3 */ + "ld1 {v0.4s, v1.4s}, [%[b]], #32 \n" + "ld1 {v2.4s, v3.4s}, [%[b]], #32 \n" + "fmla v8.4s, v16.4s, v0.s[0] \n" + "fmla v9.4s, v16.4s, v1.s[0] \n" + "fmla v10.4s, v16.4s, v2.s[0] \n" + "fmla v11.4s, v16.4s, v3.s[0] \n" + /* load a2, a3 */ + "ld1 {v18.4s, v19.4s}, [%[a]], #32 \n" + "sub %[b], %[b], #64 \n" + "fmla v8.4s, v17.4s, v0.s[1] \n" + "fmla v9.4s, v17.4s, v1.s[1] \n" + "fmla v10.4s, v17.4s, v2.s[1] \n" + "fmla v11.4s, v17.4s, v3.s[1] \n" + "add %[b], %[b], %[ldb] \n" + "fmla v8.4s, v18.4s, v0.s[2] \n" + "fmla v9.4s, v18.4s, v1.s[2] \n" + "fmla v10.4s, v18.4s, v2.s[2] \n" + "fmla v11.4s, v18.4s, v3.s[2] \n" + /* load a0, a1 */ + "ld1 {v16.4s, v17.4s}, [%[a]], #32 \n" + "fmla v8.4s, v19.4s, v0.s[3] \n" + "fmla v9.4s, v19.4s, v1.s[3] \n" + "fmla v10.4s, v19.4s, v2.s[3] \n" + "fmla v11.4s, v19.4s, v3.s[3] \n" + "subs %w[cnt], %w[cnt], #1 \n" + "bne 1b \n" + "2:\n" + "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[c]], #64 \n" + : [a] "+r" (a_ptr), + [b] "+r" (b_ptr), + [c] "+r" (C), + [cnt] "+r" (cnt) + : [ldb] "r" (ldb_byte), + [vzero] "w" (vzero) + : "v0", "v1", "v2", "v3", "v8", "v9", + "v10", "v11", "v16", "v17", "v18", + "v19", "cc", "memory" + ); + b += 4 * 4; + } + for (; n > 0; n--) { + int cnt = kcnt; + const float* a_ptr = A_packed; + const float* b_ptr = b; + asm volatile( + "0:\n" + /* load a0, a1 */ + "ld1 {v16.4s, v17.4s}, [%[a]], #32 \n" + /* load b0 */ + "ld1 {v0.4s}, [%[b]], #16 \n" + "fmul v8.4s, v16.4s, v0.s[0] \n" + "fmul v9.4s, v17.4s, v0.s[1] \n" + /* load a2, a3 */ + "ld1 {v18.4s, v19.4s}, [%[a]], #32 \n" + "sub %[b], %[b], #16 \n" + "subs %w[cnt], %w[cnt], #1 \n" + "add %[b], %[b], %[ldb] \n" + "fmla v8.4s, v18.4s, v0.s[2] \n" + "fmla v9.4s, v19.4s, v0.s[3] \n" + /* load a0, a1 */ + "ld1 {v16.4s, v17.4s}, [%[a]], #32 \n" + "beq 2f \n" + "1:\n" + /* load b0 */ + "ld1 {v0.4s}, [%[b]], #16 \n" + "fmla v8.4s, v16.4s, v0.s[0] \n" + "fmla v9.4s, v17.4s, v0.s[1] \n" + /* load a2, a3 */ + "ld1 {v18.4s, v19.4s}, [%[a]], #32 \n" + "sub %[b], %[b], #16 \n" + "subs %w[cnt], %w[cnt], #1 \n" + "add %[b], %[b], %[ldb] \n" + "fmla v8.4s, v18.4s, v0.s[2] \n" + "fmla v9.4s, v19.4s, v0.s[3] \n" + /* load a0, a1 */ + "ld1 {v16.4s, v17.4s}, [%[a]], #32 \n" + "bne 1b \n" + "2:\n" + "fadd v8.4s, v8.4s, v9.4s \n" + "st1 {v8.4s}, [%[c]], #16 \n" + : [a] "+r" (a_ptr), + [b] "+r" (b_ptr), + [c] "+r" (C), + [cnt] "+r" (cnt) + : [ldb] "r" (ldb_byte), + [vzero] "w" (vzero) + : "v0", "v8", "v9", "v16", "v17", + "v18", "v19", "cc", "memory" + ); + b += 4; + } +#else + for (; n > 7; n -= 8) { + int cnt = kcnt; + const float* a_ptr = A_packed; + const float* b_ptr = b; + // clang-format off + asm volatile( + "0:\n" + /* load a0, a1 */ + "vld1.32 {d8-d11}, [%[a]]! \n" + "vld1.32 {d0-d3}, [%[b]]! \n" + /* load b2, b3 */ + "vld1.32 {d4-d7}, [%[b]]! \n" + "vmul.f32 q8, q4, d0[0] \n" + "vmul.f32 q9, q4, d2[0] \n" + "vmul.f32 q10, q4, d4[0] \n" + "vmul.f32 q11, q4, d6[0] \n" + /* load a2, a3 */ + "vld1.32 {d12-d15}, [%[a]]! \n" + "pld [%[b]] \n" + "vmla.f32 q8, q5, d0[1] \n" + "vmla.f32 q9, q5, d2[1] \n" + "vmla.f32 q10, q5, d4[1] \n" + "vmla.f32 q11, q5, d6[1] \n" + "subs %[cnt], %[cnt], #1 \n" + "vmla.f32 q8, q6, d1[0] \n" + "vmla.f32 q9, q6, d3[0] \n" + "vmla.f32 q10, q6, d5[0] \n" + "vmla.f32 q11, q6, d7[0] \n" + "pld [%[b], #64] \n" + "vmla.f32 q8, q7, d1[1] \n" + "vmla.f32 q9, q7, d3[1] \n" + /* load b4, b5 */ + "vld1.32 {d0-d3}, [%[b]]! \n" + "vmla.f32 q10, q7, d5[1] \n" + "vmla.f32 q11, q7, d7[1] \n" + /* load b6, b7 */ + "vld1.32 {d4-d7}, [%[b]]! \n" + "vmul.f32 q12, q4, d0[0] \n" + "vmul.f32 q13, q4, d2[0] \n" + "vmul.f32 q14, q4, d4[0] \n" + "vmul.f32 q15, q4, d6[0] \n" + "sub %[b], %[b], #128 \n" + "vmla.f32 q12, q5, d0[1] \n" + "vmla.f32 q13, q5, d2[1] \n" + "vmla.f32 q14, q5, d4[1] \n" + "vmla.f32 q15, q5, d6[1] \n" + "add %[b], %[b], %[ldb] \n" + "vmla.f32 q12, q6, d1[0] \n" + "vmla.f32 q13, q6, d3[0] \n" + "vmla.f32 q14, q6, d5[0] \n" + "vmla.f32 q15, q6, d7[0] \n" + /* load a0, a1 */ + "vld1.32 {d8-d11}, [%[a]]! \n" + "vmla.f32 q12, q7, d1[1] \n" + "vmla.f32 q13, q7, d3[1] \n" + /* load b0, b1 */ + "vld1.32 {d0-d3}, [%[b]]! \n" + "vmla.f32 q14, q7, d5[1] \n" + "vmla.f32 q15, q7, d7[1] \n" + "beq 2f \n" + "1:\n" + /* load b2, b3 */ + "vld1.32 {d4-d7}, [%[b]]! \n" + "vmla.f32 q8, q4, d0[0] \n" + "vmla.f32 q9, q4, d2[0] \n" + "vmla.f32 q10, q4, d4[0] \n" + "vmla.f32 q11, q4, d6[0] \n" + /* load a2, a3 */ + "vld1.32 {d12-d15}, [%[a]]! \n" + "pld [%[b]] \n" + "vmla.f32 q8, q5, d0[1] \n" + "vmla.f32 q9, q5, d2[1] \n" + "vmla.f32 q10, q5, d4[1] \n" + "vmla.f32 q11, q5, d6[1] \n" + "subs %[cnt], %[cnt], #1 \n" + "vmla.f32 q8, q6, d1[0] \n" + "vmla.f32 q9, q6, d3[0] \n" + "vmla.f32 q10, q6, d5[0] \n" + "vmla.f32 q11, q6, d7[0] \n" + "pld [%[b], #64] \n" + "vmla.f32 q8, q7, d1[1] \n" + "vmla.f32 q9, q7, d3[1] \n" + /* load b4, b5 */ + "vld1.32 {d0-d3}, [%[b]]! \n" + "vmla.f32 q10, q7, d5[1] \n" + "vmla.f32 q11, q7, d7[1] \n" + /* load b6, b7 */ + "vld1.32 {d4-d7}, [%[b]]! \n" + "vmla.f32 q12, q4, d0[0] \n" + "vmla.f32 q13, q4, d2[0] \n" + "vmla.f32 q14, q4, d4[0] \n" + "vmla.f32 q15, q4, d6[0] \n" + "sub %[b], %[b], #128 \n" + "vmla.f32 q12, q5, d0[1] \n" + "vmla.f32 q13, q5, d2[1] \n" + "vmla.f32 q14, q5, d4[1] \n" + "vmla.f32 q15, q5, d6[1] \n" + "add %[b], %[b], %[ldb] \n" + "vmla.f32 q12, q6, d1[0] \n" + "vmla.f32 q13, q6, d3[0] \n" + "vmla.f32 q14, q6, d5[0] \n" + "vmla.f32 q15, q6, d7[0] \n" + /* load a0, a1 */ + "vld1.32 {d8-d11}, [%[a]]! \n" + "vmla.f32 q12, q7, d1[1] \n" + "vmla.f32 q13, q7, d3[1] \n" + /* load b0, b1 */ + "vld1.32 {d0-d3}, [%[b]]! \n" + "vmla.f32 q14, q7, d5[1] \n" + "vmla.f32 q15, q7, d7[1] \n" + "bne 1b \n" + "2:\n" + "vst1.32 {d16-d19}, [%[c]]! \n" + "vst1.32 {d20-d23}, [%[c]]! \n" + "vst1.32 {d24-d27}, [%[c]]! \n" + "vst1.32 {d28-d31}, [%[c]]! \n" + : [a] "+r" (a_ptr), + [b] "+r" (b_ptr), + [c] "+r" (C), + [cnt] "+r" (cnt) + : [ldb] "r" (ldb_byte) + : "q0", "q1", "q2", "q3", "q4", "q5", + "q6", "q7", "q8", "q9", "q10", "q11", + "q12", "q13", "q14", "q15", "cc", "memory" + ); + b += 4 * 8; + } + for (; n > 3; n -= 4) { + int cnt = kcnt; + const float* a_ptr = A_packed; + const float* b_ptr = b; + asm volatile( + "0:\n" + /* load a0, a1 */ + "vld1.32 {d8-d11}, [%[a]]! \n" + /* load b0-b3 */ + "vld1.32 {d0-d3}, [%[b]]! \n" + "vld1.32 {d4-d7}, [%[b]]! \n" + "vmul.f32 q8, q4, d0[0] \n" + "vmul.f32 q9, q4, d2[0] \n" + "vmul.f32 q10, q4, d4[0] \n" + "vmul.f32 q11, q4, d6[0] \n" + /* load a2, a3 */ + "vld1.32 {d12-d15}, [%[a]]!\n" + "sub %[b], %[b], #64 \n" + "vmla.f32 q8, q5, d0[1] \n" + "vmla.f32 q9, q5, d2[1] \n" + "vmla.f32 q10, q5, d4[1] \n" + "vmla.f32 q11, q5, d6[1] \n" + "add %[b], %[b], %[ldb] \n" + "vmla.f32 q8, q6, d1[0] \n" + "vmla.f32 q9, q6, d3[0] \n" + "vmla.f32 q10, q6, d5[0] \n" + "vmla.f32 q11, q6, d7[0] \n" + /* load a0, a1 */ + "vld1.32 {d8-d11}, [%[a]]! \n" + "vmla.f32 q8, q7, d1[1] \n" + "vmla.f32 q9, q7, d3[1] \n" + "vmla.f32 q10, q7, d5[1] \n" + "vmla.f32 q11, q7, d7[1] \n" + "subs %[cnt], %[cnt], #1 \n" + "beq 2f \n" + "1:\n" + /* load b0-b3 */ + "vld1.32 {d0-d3}, [%[b]]! \n" + "vld1.32 {d4-d7}, [%[b]]! \n" + "vmla.f32 q8, q4, d0[0] \n" + "vmla.f32 q9, q4, d2[0] \n" + "vmla.f32 q10, q4, d4[0] \n" + "vmla.f32 q11, q4, d6[0] \n" + /* load a2, a3 */ + "vld1.32 {d12-d15}, [%[a]]!\n" + "sub %[b], %[b], #64 \n" + "vmla.f32 q8, q5, d0[1] \n" + "vmla.f32 q9, q5, d2[1] \n" + "vmla.f32 q10, q5, d4[1] \n" + "vmla.f32 q11, q5, d6[1] \n" + "add %[b], %[b], %[ldb] \n" + "vmla.f32 q8, q6, d1[0] \n" + "vmla.f32 q9, q6, d3[0] \n" + "vmla.f32 q10, q6, d5[0] \n" + "vmla.f32 q11, q6, d7[0] \n" + /* load a0, a1 */ + "vld1.32 {d8-d11}, [%[a]]! \n" + "vmla.f32 q8, q7, d1[1] \n" + "vmla.f32 q9, q7, d3[1] \n" + "vmla.f32 q10, q7, d5[1] \n" + "vmla.f32 q11, q7, d7[1] \n" + "subs %[cnt], %[cnt], #1 \n" + "bne 1b \n" + "2:\n" + "vst1.32 {d16-d19}, [%[c]]!\n" + "vst1.32 {d20-d23}, [%[c]]!\n" + : [a] "+r" (a_ptr), + [b] "+r" (b_ptr), + [c] "+r" (C), + [cnt] "+r" (cnt) + : [ldb] "r" (ldb_byte) + : "q0", "q1", "q2", "q3", "q4", "q5", + "q6", "q7", "q8", "q9", "q10", "q11", + "q12", "q13", "cc", "memory" + ); + b += 4 * 4; + } + for (; n > 0; n--) { + int cnt = kcnt; + const float* a_ptr = A_packed; + const float* b_ptr = b; + asm volatile( + "0:\n" + /* load a0, a1 */ + "vld1.32 {d2-d5}, [%[a]]! \n" + /* load b0 */ + "vld1.32 {d0-d1}, [%[b]]! \n" + "vmul.f32 q5, q1, d0[0] \n" + "vmul.f32 q6, q2, d0[1] \n" + /* load a2, a3 */ + "vld1.32 {d6-d9}, [%[a]]! \n" + "sub %[b], %[b], #16 \n" + "subs %[cnt], %[cnt], #1 \n" + "add %[b], %[b], %[ldb] \n" + "vmla.f32 q5, q3, d1[0] \n" + "vmla.f32 q6, q4, d1[1] \n" + /* load a0, a1 */ + "vld1.32 {d2-d5}, [%[a]]! \n" + "beq 2f \n" + "1:\n" + /* load b0 */ + "vld1.32 {d0-d1}, [%[b]]! \n" + "vmla.f32 q5, q1, d0[0] \n" + "vmla.f32 q6, q2, d0[1] \n" + /* load a2, a3 */ + "vld1.32 {d6-d9}, [%[a]]! \n" + "sub %[b], %[b], #16 \n" + "subs %[cnt], %[cnt], #1 \n" + "add %[b], %[b], %[ldb] \n" + "vmla.f32 q5, q3, d1[0] \n" + "vmla.f32 q6, q4, d1[1] \n" + /* load a0, a1 */ + "vld1.32 {d2-d5}, [%[a]]! \n" + "bne 1b \n" + "2:\n" + "vadd.f32 q5, q5, q6 \n" + "vst1.32 {d10-d11}, [%[c]]!\n" + : [a] "+r" (a_ptr), + [b] "+r" (b_ptr), + [c] "+r" (C), + [cnt] "+r" (cnt) + : [ldb] "r" (ldb_byte) + : "q0", "q1", "q2", "q3", "q4", + "q5", "q6", "q7", "q8", "cc", "memory" + ); + // clang-format on + b += 4; + } +#endif + A_packed += lda; + } +} + +void sgemm_prepack_c4(int M, + int N, + int K, + const float* A_packed, + const float* B, + float* C, + const float* bias, + bool has_bias, + bool has_relu, + ARMContext* ctx) { + if (N > 16) { + sgemm_prepack_c4_common( + M, N, K, A_packed, B, C, bias, has_bias, has_relu, ctx); + } else { + sgemm_prepack_c4_small( + M, N, K, A_packed, B, C, bias, has_bias, has_relu, ctx); + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/backends/arm/math/packed_sgemm_c4.h b/lite/backends/arm/math/packed_sgemm_c4.h new file mode 100644 index 0000000000000000000000000000000000000000..3229ff3e0774ce8bff02b12d79d7ec50ed873cea --- /dev/null +++ b/lite/backends/arm/math/packed_sgemm_c4.h @@ -0,0 +1,60 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "lite/core/context.h" +#include "lite/core/tensor.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +constexpr int MBLOCK_C4 = 4; +constexpr int NBLOCK_C4 = 8; +constexpr int KBLOCK_C4 = 4; + +void sgemm_prepack_c4(int M, + int N, + int K, + const float* A_packed, + const float* B, + float* C, + const float* bias, + bool has_bias, + bool has_relu, + ARMContext* ctx); +void sgemm_prepack_c4_small(int M, + int N, + int K, + const float* A_packed, + const float* B, + float* C, + const float* bias, + bool has_bias, + bool has_relu, + ARMContext* ctx); +void sgemm_prepack_c4_small(int M, + int N, + int K, + const float* A_packed, + const float* B, + float* C, + ARMContext* ctx); +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/backends/arm/math/pooling.cc b/lite/backends/arm/math/pooling.cc index a857e9830c54b568c93afa4c1aa119ed2baffa1e..07cbd00378c082e311e194c7b22b6d3cb195a63a 100644 --- a/lite/backends/arm/math/pooling.cc +++ b/lite/backends/arm/math/pooling.cc @@ -46,7 +46,7 @@ void pooling_basic(const float* din, int stride_h = strides[0]; int stride_w = strides[1]; int pad_h = paddings[0]; - int pad_w = paddings[1]; + int pad_w = paddings[2]; int size_channel_in = win * hin; int size_channel_out = wout * hout; if (global_pooling) { @@ -125,18 +125,22 @@ void pooling_basic(const float* din, int bh = kernel_h; int bw = kernel_w; if (ew == win) { - bw = sw + kernel_w >= win + pad_w ? win + pad_w - : sw + kernel_w; + bw = (sw + kernel_w) >= (win + paddings[3]) + ? (win + paddings[3]) + : (sw + kernel_w); bw -= sw; - if (sw - pad_w < 0 && sw + kernel_w > win + pad_w) { + if ((sw - pad_w) < 0 && + (sw + kernel_w) > (win + paddings[3])) { bw += pad_w; } } if (eh == hin) { - bh = sh + kernel_h >= hin + pad_h ? hin + pad_h - : sh + kernel_h; + bh = (sh + kernel_h) >= (hin + paddings[1]) + ? (hin + paddings[1]) + : (sh + kernel_h); bh -= sh; - if (sh - pad_h < 0 && sh + kernel_h > hin + pad_h) { + if ((sh - pad_h) < 0 && + (sh + kernel_h) > (hin + paddings[1])) { bh += pad_h; } } @@ -163,7 +167,7 @@ void pooling_basic(const float* din, "ld1 {v2.4s-v3.4s}, [%[data_in_channel]], #32 \n" \ "fmax v6.4s, v4.4s, v5.4s \n" \ "subs %w[cnt], %w[cnt], #1 \n" \ - "fmax %w[vmax].4s, %w[vmax].4s, v6.4s \n" \ + "fmax %[vmax].4s, %[vmax].4s, v6.4s \n" \ "bne 1b \n" #define GLOBAL_AVG \ "1: \n" \ @@ -172,7 +176,7 @@ void pooling_basic(const float* din, "ld1 {v0.4s-v1.4s}, [%[data_in_channel]], #32 \n" \ "fadd %[vsum].4s, %[vsum].4s, v3.4s \n" \ "subs %w[cnt], %w[cnt], #1 \n" \ - "fadd %w[vsum].4s, %w[vsum].4s, v4.4s \n" \ + "fadd %[vsum].4s, %[vsum].4s, v4.4s \n" \ "ld1 {v2.4s-v3.4s}, [%[data_in_channel]], #32 \n" \ "bne 1b \n" @@ -894,6 +898,121 @@ void pooling_global_avg(const float* din, } } +void pooling1x1s2p0_max(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win) { + int size_channel_out = wout * hout; + int size_channel_in = win * hin; + auto data_out = static_cast(dout); + auto data_in = static_cast(din); + + int w_unroll_size = wout / 4; + int w_unroll_remian = wout - w_unroll_size * 4; + int win_ext = w_unroll_size * 8; + auto zero_ptr = + static_cast(TargetMalloc(TARGET(kARM), win * sizeof(float))); + memset(zero_ptr, 0, win * sizeof(float)); + auto write_ptr = + static_cast(TargetMalloc(TARGET(kARM), wout * sizeof(float))); + + for (int n = 0; n < num; ++n) { + float* data_out_batch = data_out + n * chout * size_channel_out; + const float* data_in_batch = data_in + n * chin * size_channel_in; +#pragma omp parallel for + for (int c = 0; c < chout; c++) { + float* data_out_channel = data_out_batch + c * size_channel_out; + const float* data_in_channel = data_in_batch + c * size_channel_in; + for (int h = 0; h < hout; h += 4) { + const float* din0_ptr = data_in_channel + h * 2 * win; + const float* din1_ptr = din0_ptr + 2 * win; + const float* din2_ptr = din1_ptr + 2 * win; + const float* din3_ptr = din2_ptr + 2 * win; + + float* doutr0 = data_out_channel + h * wout; + float* doutr1 = doutr0 + wout; + float* doutr2 = doutr1 + wout; + float* doutr3 = doutr2 + wout; + if (h + 4 > hout) { + switch (h + 4 - hout) { + case 3: + doutr1 = write_ptr; + case 2: + doutr2 = write_ptr; + case 1: + doutr3 = write_ptr; + default: + break; + } + } + if (h * 2 + 7 > hin) { + switch (h * 2 + 7 - hin) { + case 7: + din0_ptr = zero_ptr; + case 6: + case 5: + din1_ptr = zero_ptr; + case 4: + case 3: + din2_ptr = zero_ptr; + case 2: + case 1: + din3_ptr = zero_ptr; + default: + break; + } + } + for (int i = 0; i < w_unroll_size; i++) { + float32x4x2_t din0 = vld2q_f32(din0_ptr); + float32x4x2_t din1 = vld2q_f32(din1_ptr); + float32x4x2_t din2 = vld2q_f32(din2_ptr); + float32x4x2_t din3 = vld2q_f32(din3_ptr); + din0_ptr += 8; + din1_ptr += 8; + din2_ptr += 8; + din3_ptr += 8; + + vst1q_f32(doutr0, din0.val[0]); + vst1q_f32(doutr1, din1.val[0]); + vst1q_f32(doutr2, din2.val[0]); + vst1q_f32(doutr3, din3.val[0]); + + doutr0 += 4; + doutr1 += 4; + doutr2 += 4; + doutr3 += 4; + } + int j = win_ext; + for (int i = 0; i < w_unroll_remian; i++) { + if (j >= win) { + *doutr0++ = 0.f; + *doutr1++ = 0.f; + *doutr2++ = 0.f; + *doutr3++ = 0.f; + } else { + *doutr0++ = *din0_ptr; + *doutr1++ = *din1_ptr; + *doutr2++ = *din2_ptr; + *doutr3++ = *din3_ptr; + din0_ptr += 2; + din1_ptr += 2; + din2_ptr += 2; + din3_ptr += 2; + } + j += 2; + } + } + } + } + TargetFree(TARGET(kARM), zero_ptr); + TargetFree(TARGET(kARM), write_ptr); +} + void pooling2x2s2_max(const float* din, float* dout, int num, diff --git a/lite/backends/arm/math/pooling.h b/lite/backends/arm/math/pooling.h index 9288f27bbc7519f1b06bfa1f119a21a33611f74c..701732cb453bfc9f2e970c83c8d713e70a205434 100644 --- a/lite/backends/arm/math/pooling.h +++ b/lite/backends/arm/math/pooling.h @@ -64,6 +64,16 @@ void pooling_global_avg(const float* din, int hin, int win); +void pooling1x1s2p0_max(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win); + void pooling2x2s2_max(const float* din, float* dout, int num, diff --git a/lite/backends/arm/math/reduce_prod.cc b/lite/backends/arm/math/reduce_prod.cc new file mode 100644 index 0000000000000000000000000000000000000000..e7b3f7095f2087af365d0765f49df7902df42bb9 --- /dev/null +++ b/lite/backends/arm/math/reduce_prod.cc @@ -0,0 +1,23 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "lite/backends/arm/math/reduce_prod.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math {} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/backends/arm/math/reduce_prod.h b/lite/backends/arm/math/reduce_prod.h new file mode 100644 index 0000000000000000000000000000000000000000..6c8898288fa498a6f97709a27306e6975dffc975 --- /dev/null +++ b/lite/backends/arm/math/reduce_prod.h @@ -0,0 +1,185 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "lite/core/tensor.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +void reduce_prod_n(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + int hw_size = height_in * width_in; + int chw_size = channel_in * hw_size; + int data_index, src_index, src_index0; + for (int c = 0; c < channel_in; ++c) { + for (int h = 0; h < height_in; ++h) { + for (int w = 0; w < width_in; ++w) { + data_index = c * hw_size + h * width_in + w; + dst[data_index] = static_cast(1); + for (int n = 0; n < num_in; ++n) { + src_index = n * chw_size + data_index; + dst[data_index] *= src[src_index]; + } + } + } + } +} + +template +void reduce_prod_c(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + int hw_size = height_in * width_in; + int chw_size = hw_size * channel_in; + int data_index, src_index0, src_index; + for (int n = 0; n < num_in; ++n) { + for (int h = 0; h < height_in; ++h) { + for (int w = 0; w < width_in; ++w) { + data_index = n * hw_size + h * width_in + w; + src_index0 = n * chw_size + h * width_in + w; + dst[data_index] = static_cast(1); + for (int c = 0; c < channel_in; ++c) { + src_index = src_index0 + c * hw_size; + dst[data_index] *= src[src_index]; + } + } + } + } +} + +template +void reduce_prod_h(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + int cw_size = channel_in * width_in; + int chw_size = cw_size * height_in; + int hw_size = height_in * width_in; + int data_index, src_index, src_index0; + for (int n = 0; n < num_in; ++n) { + for (int c = 0; c < channel_in; ++c) { + for (int w = 0; w < width_in; ++w) { + data_index = n * cw_size + c * width_in + w; + src_index0 = n * chw_size + c * hw_size + w; + dst[data_index] = static_cast(1); + for (int h = 0; h < height_in; ++h) { + src_index = src_index0 + h * width_in; + dst[data_index] *= src[src_index]; + } + } + } + } +} + +template +void reduce_prod_w(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + int ch_size = channel_in * height_in; + int hw_size = height_in * width_in; + int chw_size = ch_size * width_in; + int data_index = 0; + int src_index0 = 0; + int src_index = 0; + for (int n = 0; n < num_in; ++n) { + for (int c = 0; c < channel_in; ++c) { + for (int h = 0; h < height_in; ++h) { + data_index = n * ch_size + c * height_in + h; + src_index0 = n * chw_size + c * hw_size + h * width_in; + dst[data_index] = static_cast(1); + for (int w = 0; w < width_in; ++w) { + src_index = src_index0 + w; + dst[data_index] *= src[src_index]; + } + } + } + } +} + +template +void reduce_prod_nc(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + // reduce n first. + DDimLite ddimA({1, channel_in, height_in, width_in}); + lite::Tensor tensor_tmp; + tensor_tmp.Resize(ddimA); + auto* tmp_out = tensor_tmp.mutable_data(); + reduce_prod_n(src, tmp_out, num_in, channel_in, height_in, width_in); + reduce_prod_c(tmp_out, dst, 1, channel_in, height_in, width_in); +} + +template +void reduce_prod_ch(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + // reduce c first + DDimLite ddimA({num_in, 1, height_in, width_in}); + lite::Tensor tensor_tmp; + tensor_tmp.Resize(ddimA); + auto* tmp_out = tensor_tmp.mutable_data(); + reduce_prod_c(src, tmp_out, num_in, channel_in, height_in, width_in); + reduce_prod_h(tmp_out, dst, num_in, 1, height_in, width_in); +} + +template +void reduce_prod_hw(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + // reduce h first + DDimLite ddimA({num_in, channel_in, 1, width_in}); + lite::Tensor tensor_tmp; + tensor_tmp.Resize(ddimA); + auto* tmp_out = tensor_tmp.mutable_data(); + reduce_prod_h(src, tmp_out, num_in, channel_in, height_in, width_in); + reduce_prod_w(tmp_out, dst, num_in, channel_in, 1, width_in); +} + +template +void reduce_prod_all(const T* src, T* dst, int64_t total_num) { + dst[0] = static_cast(1); + for (int n = 0; n < total_num; ++n) { + dst[0] *= src[n]; + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/backends/arm/math/sgemm.cc b/lite/backends/arm/math/sgemm.cc index f3123ddd718ee61b6430d2b7f14480b79435291a..f2ba090222d491f4032aaf4cf3dbb29b4c53708d 100644 --- a/lite/backends/arm/math/sgemm.cc +++ b/lite/backends/arm/math/sgemm.cc @@ -34,7 +34,7 @@ void sgemm(bool is_transA, int ldc, const float* bias, bool is_bias, - bool is_relu, + const operators::ActivationParam act_param, ARMContext* ctx) { int hblock = get_hblock(ctx); int m_roundup = hblock * ((M + hblock - 1) / hblock); @@ -56,7 +56,7 @@ void sgemm(bool is_transA, ldc, bias, is_bias, - is_relu, + act_param, ctx); TargetFree(TargetType::kARM, packed_A); } diff --git a/lite/backends/arm/math/sgemm.h b/lite/backends/arm/math/sgemm.h index 08f68fb3d41e5d0a837f57a8d28acd82dd3f8cb4..b48080855fa8eedad9d619c1fbc84c9fd0040504 100644 --- a/lite/backends/arm/math/sgemm.h +++ b/lite/backends/arm/math/sgemm.h @@ -39,7 +39,7 @@ void sgemm(bool is_transA, int ldc, const float* bias, bool is_bias, - bool is_relu, + const operators::ActivationParam act_param, ARMContext* ctx); } // namespace math diff --git a/lite/backends/arm/math/sgemv.cc b/lite/backends/arm/math/sgemv.cc index 1830423136cc883d30d4eecad0eb9fcfc9ded6ba..98404fe60fdb1384d390458e10dac8c967fd2b21 100644 --- a/lite/backends/arm/math/sgemv.cc +++ b/lite/backends/arm/math/sgemv.cc @@ -22,35 +22,87 @@ namespace lite { namespace arm { namespace math { -void sgemv(const bool transA, - const int M, +void sgemv(const int M, const int N, const float *A, const float *x, - float *y); - -void sgemv_relu(const bool transA, - const int M, - const int N, - const float *A, - const float *x, - float *y); + float *y, + bool flag_bias, + const float *bias); -void sgemv_bias(const bool transA, - const int M, +void sgemv_relu(const int M, const int N, const float *A, const float *x, float *y, + bool flag_bias, const float *bias); -void sgemv_bias_relu(const bool transA, - const int M, - const int N, - const float *A, - const float *x, - float *y, - const float *bias); +void sgemv_relu6(const int M, + const int N, + const float *A, + const float *x, + float *y, + bool flag_bias, + const float *bias, + const float six); + +void sgemv_leakey_relu(const int M, + const int N, + const float *A, + const float *x, + float *y, + bool flag_bias, + const float *bias, + const float alpha); + +void sgemv_trans(const int M, + const int N, + const float *A, + const float *x, + float *y, + bool flag_bias, + const float *bias, + bool flag_act, + lite_api::ActivationType act, + const ARMContext *ctx, + float six, + float alpha); + +bool sgemv(const float *A, + const float *x, + float *y, + bool transA, + int M, + int N, + bool is_bias, + const float *bias, + bool flag_act, + lite_api::ActivationType act, + const ARMContext *ctx, + float six, + float alpha) { + if (transA) { + sgemv_trans(M, N, A, x, y, is_bias, bias, flag_act, act, ctx, six, alpha); + } else { + if (flag_act) { + if (act == lite_api::ActivationType::kRelu) { + sgemv_relu(M, N, A, x, y, is_bias, bias); + } else if (act == lite_api::ActivationType::kRelu6) { + sgemv_relu6(M, N, A, x, y, is_bias, bias, six); + } else if (act == lite_api::ActivationType::kLeakyRelu) { + sgemv_leakey_relu(M, N, A, x, y, is_bias, bias, alpha); + } else { + LOG(FATAL) + << "sgemv no transA only support relu, relu6, leakey relu fusion"; + } + } else { + sgemv(M, N, A, x, y, is_bias, bias); + } + } + return true; +} + #ifdef __aarch64__ void sgemv_trans(const int M, const int N, @@ -59,8 +111,11 @@ void sgemv_trans(const int M, float *y, bool flag_bias, const float *bias, - bool flag_relu, - const ARMContext *ctx) { + bool flag_act, + lite_api::ActivationType act, + const ARMContext *ctx, + float six, + float alpha) { int m_cnt16 = M >> 4; int m_cnt8 = (M & 15) >> 3; int m_cnt4 = (M & 15 & 7) >> 2; @@ -281,26 +336,70 @@ void sgemv_trans(const int M, valid_ths = rdc_ths; rdc_ths = rdc_ths >> 1; } - if (flag_relu) { + if (flag_act) { float *in_y = y_buf; float32x4_t vzero = vdupq_n_f32(0.f); - if (cnt4 > 0) { - int cnt = cnt4; - asm volatile( - "ld1 {v0.4s}, [%[in_y]], #16 \n" /* load y to v0 */ - "1:\n" - "fmax v1.4s, v0.4s, %[vzero].4s \n" /* v0 relu */ - "ld1 {v0.4s}, [%[in_y]], #16 \n" /* load y to v0 */ - "subs %w[cnt], %w[cnt], #1 \n" /* sub cnt */ - "st1 {v1.4s}, [%[out_y]], #16 \n" /* store v1 to y */ - "bne 1b \n" /* branch to label 1*/ - "sub %[in_y], %[in_y], #16 \n" /* restore in_y */ - : [cnt] "+r"(cnt), [in_y] "+r"(in_y), [out_y] "+r"(y) - : [vzero] "w"(vzero) - : "v0", "v1", "cc", "memory"); - } - for (int r = 0; r < remain; ++r) { - y[r] = in_y[r] > 0.f ? in_y[r] : 0.f; + if (act == lite_api::ActivationType::kRelu) { + if (cnt4 > 0) { + int cnt = cnt4; + asm volatile( + "ld1 {v0.4s}, [%[in_y]], #16 \n" /* load y to v0 */ + "1:\n" + "fmax v1.4s, v0.4s, %[vzero].4s \n" /* v0 relu */ + "ld1 {v0.4s}, [%[in_y]], #16 \n" /* load y to v0 */ + "subs %w[cnt], %w[cnt], #1 \n" /* sub cnt */ + "st1 {v1.4s}, [%[out_y]], #16 \n" /* store v1 to y */ + "bne 1b \n" /* branch to label 1*/ + "sub %[in_y], %[in_y], #16 \n" /* restore in_y */ + : [cnt] "+r"(cnt), [in_y] "+r"(in_y), [out_y] "+r"(y) + : [vzero] "w"(vzero) + : "v0", "v1", "cc", "memory"); + } + for (int r = 0; r < remain; ++r) { + y[r] = in_y[r] > 0.f ? in_y[r] : 0.f; + } + } else if (act == lite_api::ActivationType::kRelu6) { + float32x4_t vsix = vdupq_n_f32(six); + if (cnt4 > 0) { + int cnt = cnt4; + asm volatile( + "ld1 {v0.4s}, [%[in_y]], #16 \n" /* load y to v0 */ + "1:\n" + "fmax v1.4s, v0.4s, %[vzero].4s \n" /* v0 relu6 */ + "fmin v1.4s, v1.4s, %[vsix].4s \n" /* v1 relu6 */ + "ld1 {v0.4s}, [%[in_y]], #16 \n" /* load y to v0 */ + "subs %w[cnt], %w[cnt], #1 \n" /* sub cnt */ + "st1 {v1.4s}, [%[out_y]], #16 \n" /* store v1 to y */ + "bne 1b \n" /* branch to label 1*/ + "sub %[in_y], %[in_y], #16 \n" /* restore in_y */ + : [cnt] "+r"(cnt), [in_y] "+r"(in_y), [out_y] "+r"(y) + : [vzero] "w"(vzero), [vsix] "w"(vsix) + : "v0", "v1", "cc", "memory"); + } + for (int r = 0; r < remain; ++r) { + y[r] = in_y[r] > 0.f ? in_y[r] : 0.f; + y[r] = y[r] > six ? six : y[r]; + } + } else if (act == lite_api::ActivationType::kLeakyRelu) { + float32x4_t valpha = vdupq_n_f32(alpha); + if (cnt4 > 0) { + int cnt = cnt4; + asm volatile( + "1:\n" + "ld1 {v0.4s}, [%[in_y]], #16 \n" /* load y to v0 */ + "fcmge v4.4s, v0.4s, %[vzero].4s \n" /* vcgeq_f32 */ + "fmul v5.4s, v0.4s, %[valpha].4s \n" /* vmulq_f32 */ + "bif v0.16b, v5.16b, v4.16b \n" /* choose */ + "subs %w[cnt], %w[cnt], #1 \n" /* sub cnt */ + "st1 {v0.4s}, [%[out_y]], #16 \n" /* store v0 to y */ + "bne 1b \n" /* branch to label 1*/ + : [cnt] "+r"(cnt), [in_y] "+r"(in_y), [out_y] "+r"(y) + : [vzero] "w"(vzero), [valpha] "w"(valpha) + : "v0", "v4", "v5", "cc", "memory"); + } + for (int r = 0; r < remain; ++r) { + y[r] = in_y[r] < 0.f ? alpha * in_y[r] : in_y[r]; + } } } else { memcpy(y, y_buf, M * sizeof(float)); @@ -314,8 +413,11 @@ void sgemv_trans(const int M, float *y, bool flag_bias, const float *bias, - bool flag_relu, - const ARMContext *ctx) { + bool flag_act, + lite_api::ActivationType act, + const ARMContext *ctx, + float six, + float alpha) { int m_cnt8 = M >> 3; int m_cnt4 = (M & 7) >> 2; int m_remain = M & 7 & 3; @@ -497,43 +599,73 @@ void sgemv_trans(const int M, valid_ths = rdc_ths; rdc_ths = rdc_ths >> 1; } - if (flag_relu) { + // do activation + if (flag_act) { float *in_y = y_buf; float32x4_t vzero = vdupq_n_f32(0.f); - if (m_cnt8 > 0) { - int cnt8 = m_cnt8; - asm volatile( - "vld1.32 {d0-d3}, [%[in_y]]! \n" /* load y to q0, q1 */ - "1:\n" - "vmax.f32 q2, q0, %q[vzero] \n" /* q0 relu */ - "vld1.32 {d0-d1}, [%[in_y]]! \n" /* load y to q0 */ - "vmax.f32 q3, q1, %q[vzero] \n" /* q1 relu */ - "subs %[cnt], %[cnt], #1 \n" /* sub cnt */ - "vst1.32 {d4-d7}, [%[out_y]]! \n" /* store q0, q1 to y*/ - "vld1.32 {d2-d3}, [%[in_y]]! \n" /* load y to q0 */ - "bne 1b \n" /* branch to label 1*/ - "sub %[in_y], %[in_y], #32 \n" /* restore in_y */ - : [cnt] "+r"(cnt8), [in_y] "+r"(in_y), [out_y] "+r"(y) - : [vzero] "w"(vzero) - : "q0", "q1", "q2", "q3", "cc", "memory"); - } - if (m_cnt4 > 0) { - int cnt4 = m_cnt4; - asm volatile( - "vld1.32 {d0-d1}, [%[in_y]]! \n" /* load y to q0 */ - "1:\n" - "vmax.f32 q1, q0, %q[vzero] \n" /* q0 relu */ - "vld1.32 {d0-d1}, [%[in_y]]! \n" /* load y to q0 */ - "subs %[cnt], %[cnt], #1 \n" /* sub cnt */ - "vst1.32 {d2-d3}, [%[out_y]]! \n" /* store q1 to y */ - "bne 1b \n" /* branch to label 1*/ - "sub %[in_y], %[in_y], #16 \n" /* restore in_y */ - : [cnt] "+r"(cnt4), [in_y] "+r"(in_y), [out_y] "+r"(y) - : [vzero] "w"(vzero) - : "q0", "q1", "cc", "memory"); - } - for (int r = 0; r < m_remain; ++r) { - y[r] = in_y[r] > 0.f ? in_y[r] : 0.f; + m_cnt4 = M >> 2; + m_remain = M & 3; + if (act == lite_api::ActivationType::kRelu) { + if (m_cnt4 > 0) { + int cnt4 = m_cnt4; + asm volatile( + "vld1.32 {d0-d1}, [%[in_y]]! \n" /* load y to q0 */ + "1:\n" + "vmax.f32 q1, q0, %q[vzero] \n" /* q0 relu */ + "vld1.32 {d0-d1}, [%[in_y]]! \n" /* load y to q0 */ + "subs %[cnt], %[cnt], #1 \n" /* sub cnt */ + "vst1.32 {d2-d3}, [%[out_y]]! \n" /* store q1 to y */ + "bne 1b \n" /* branch to label 1*/ + "sub %[in_y], %[in_y], #16 \n" /* restore in_y */ + : [cnt] "+r"(cnt4), [in_y] "+r"(in_y), [out_y] "+r"(y) + : [vzero] "w"(vzero) + : "q0", "q1", "cc", "memory"); + } + for (int r = 0; r < m_remain; ++r) { + y[r] = in_y[r] > 0.f ? in_y[r] : 0.f; + } + } else if (act == lite_api::ActivationType::kRelu6) { + float32x4_t vsix = vdupq_n_f32(six); + if (m_cnt4 > 0) { + int cnt4 = m_cnt4; + asm volatile( + "vld1.32 {d0-d1}, [%[in_y]]! \n" /* load y to q0 */ + "1:\n" + "vmax.f32 q1, q0, %q[vzero] \n" /* q0 relu6 */ + "vld1.32 {d0-d1}, [%[in_y]]! \n" /* load y to q0 */ + "vmin.f32 q1, q1, %q[vsix] \n" /* q0 relu6 */ + "subs %[cnt], %[cnt], #1 \n" /* sub cnt */ + "vst1.32 {d2-d3}, [%[out_y]]! \n" /* store q1 to y */ + "bne 1b \n" /* branch to label 1*/ + "sub %[in_y], %[in_y], #16 \n" /* restore in_y */ + : [cnt] "+r"(cnt4), [in_y] "+r"(in_y), [out_y] "+r"(y) + : [vzero] "w"(vzero), [vsix] "w"(vsix) + : "q0", "q1", "cc", "memory"); + } + for (int r = 0; r < m_remain; ++r) { + y[r] = in_y[r] > 0.f ? in_y[r] : 0.f; + y[r] = y[r] > six ? six : y[r]; + } + } else if (act == lite_api::ActivationType::kLeakyRelu) { + float32x4_t valpha = vdupq_n_f32(alpha); + if (m_cnt4 > 0) { + int cnt4 = m_cnt4; + asm volatile( + "1:\n" + "vld1.32 {d0-d1}, [%[in_y]]! \n" /* load y to q0 */ + "vcge.f32 q3, q0, %q[vzero] \n" /* vcgeq_f32 */ + "vmul.f32 q4, q0, %q[valpha] \n" /* vmulq_f32 */ + "vbif q0, q4, q3 \n" /* choose */ + "subs %[cnt], %[cnt], #1 \n" /* sub cnt */ + "vst1.32 {d0-d1}, [%[out_y]]! \n" /* store q0 to y */ + "bne 1b \n" /* branch to label 1*/ + : [cnt] "+r"(cnt4), [in_y] "+r"(in_y), [out_y] "+r"(y) + : [vzero] "w"(vzero), [valpha] "w"(valpha) + : "q0", "q3", "q4", "cc", "memory"); + } + for (int r = 0; r < m_remain; ++r) { + y[r] = in_y[r] < 0.f ? alpha * in_y[r] : in_y[r]; + } } } else { memcpy(y, y_buf, M * sizeof(float)); @@ -541,41 +673,6 @@ void sgemv_trans(const int M, } #endif // __aarch64__ -bool sgemv(const float *A, - const float *x, - float *y, - bool transA, - int M, - int N, - bool is_bias, - const float *bias, - bool is_relu, - const ARMContext *ctx) { - if (transA) { - sgemv_trans(M, N, A, x, y, is_bias, bias, is_relu, ctx); - } else { - if (is_bias) { - //! with bias - if (is_relu) { - //! with relu - sgemv_bias_relu(transA, M, N, A, x, y, bias); - } else { - //! without relu - sgemv_bias(transA, M, N, A, x, y, bias); - } - } else { - //! without bias - if (is_relu) { - //! with relu - sgemv_relu(transA, M, N, A, x, y); - } else { - //! without relu - sgemv(transA, M, N, A, x, y); - } - } - } - return true; -} // clang-format off //! define compute kernel #ifdef __aarch64__ @@ -715,19 +812,19 @@ bool sgemv(const float *A, #define SGEMV_KERNEL_1 \ /* check main loop */ \ "cmp %w[cnt], #1 \n" /* check whether has main loop */ \ - "blt 2f \n" /* jump to tail */ /* main loop */ \ - "1: \n" /* main loop */ \ - "ldp q8, q9, [%[in]], #32 \n" /* load input 8 float */ \ - "ldp q10, q11, [%[w0]], #32 \n" /* load w0 8 float */ \ - "fmla v0.4s, v8.4s, v10.4s \n" /* mul + add*/ \ - "subs %w[cnt], %w[cnt], #1 \n" /* sub main loop count */ \ - "fmla v1.4s, v9.4s, v11.4s \n" /* mul + add*/ \ + "blt 2f \n" /* jump to tail */ \ + "1: \n" /* main loop */ \ + "ldp q8, q9, [%[in]], #32 \n" /* load input 8 float */ \ + "ldp q10, q11, [%[w0]], #32 \n" /* load w0 8 float */ \ + "fmla v0.4s, v8.4s, v10.4s \n" /* mul + add*/ \ + "subs %w[cnt], %w[cnt], #1 \n" /* sub main loop count */ \ + "fmla v1.4s, v9.4s, v11.4s \n" /* mul + add*/ \ "bne 1b \n" /* jump to main loop */ \ /* pair add to final result */ \ "2: \n" /* reduce to scale */ \ "fadd v9.4s, v0.4s, v1.4s \n" /* add 2 vector */ \ "faddp v10.4s, v9.4s, v9.4s\n" /* pair add to vector */ \ - "faddp s8, v10.2s \n" /* pair add to scale */ /* check tails */ \ + "faddp s8, v10.2s \n" /* pair add to scale */ \ "cmp %w[tail], #1 \n" /* check whether has tail */ \ "blt 4f \n" /* jump to end */ \ "3: \n" /* tail loop */ \ @@ -737,43 +834,100 @@ bool sgemv(const float *A, "subs %w[tail], %w[tail], #1\n" /* sub tail loop count */ \ "bne 3b \n" /* jump to tail loop */ -#define SGEMV_OUT_8 \ - /* end */ \ - "4: \n" /* end */ \ - "stp s8, s9, [%[out]] \n" /* save result */ \ - "stp s10, s11, [%[out], #8] \n" /* save result */ \ - "stp s12, s13, [%[out], #16]\n" /* save result */ \ - "stp s14, s15, [%[out], #24]\n" /* save result */ +#define SGEMV_OUT_8 \ + /* end */ \ + "4: \n" /* end */ \ + "mov v8.s[1], v9.s[0] \n" /* ins s9 to v8[1]*/ \ + "mov v8.s[2], v10.s[0] \n" /* ins s10 to v8[2]*/ \ + "mov v8.s[3], v11.s[0] \n" /* ins s11 to v8[3]*/ \ + "mov v9.s[0], v12.s[0] \n" /* ins s12 to v9[0]*/ \ + "mov v9.s[1], v13.s[0] \n" /* ins s13 to v9[1]*/ \ + "mov v9.s[2], v14.s[0] \n" /* ins s14 to v9[2]*/ \ + "mov v9.s[3], v15.s[0] \n" /* ins s15 to v9[3]*/ \ + "stp q8, q9, [%[out]] \n" /* save result */ #define SGEMV_OUT_8_RELU \ /* end */ \ - "4: \n" /* end */ \ - "movi d0, #0 \n" /* zero data for relu */ \ - "fmax s8, s8, s0 \n" /* relu */ \ - "fmax s9, s9, s0 \n" /* relu */ \ - "fmax s10, s10, s0 \n" /* relu */ \ - "fmax s11, s11, s0 \n" /* relu */ \ - "fmax s12, s12, s0 \n" /* relu */ \ - "fmax s13, s13, s0 \n" /* relu */ \ - "fmax s14, s14, s0 \n" /* relu */ \ - "fmax s15, s15, s0 \n" /* relu */ \ - "stp s8, s9, [%[out]] \n" /* save result */ \ - "stp s10, s11, [%[out], #8] \n" /* save result */ \ - "stp s12, s13, [%[out], #16]\n" /* save result */ \ - "stp s14, s15, [%[out], #24]\n" /* save result */ + "4: \n" /* end */ \ + "mov v8.s[1], v9.s[0] \n" /* ins s9 to v8[1]*/ \ + "mov v8.s[2], v10.s[0] \n" /* ins s10 to v8[2]*/ \ + "mov v8.s[3], v11.s[0] \n" /* ins s11 to v8[3]*/ \ + "mov v9.s[0], v12.s[0] \n" /* ins s12 to v9[0]*/ \ + "mov v9.s[1], v13.s[0] \n" /* ins s13 to v9[1]*/ \ + "mov v9.s[2], v14.s[0] \n" /* ins s14 to v9[2]*/ \ + "mov v9.s[3], v15.s[0] \n" /* ins s15 to v9[3]*/ \ + "movi v2.4s, #0 \n" /* zero data for relu */\ + "fmax v8.4s, v8.4s, v2.4s \n" /* relu */ \ + "fmax v9.4s, v9.4s, v2.4s \n" /* relu */ \ + "stp q8, q9, [%[out]] \n" /* save result */ -#define SGEMV_OUT_1 \ - /* end */ \ - "4: \n" /* end */ \ +#define SGEMV_OUT_8_RELU6 \ + /* end */ \ + "4: \n" /* end */ \ + "mov v8.s[1], v9.s[0] \n" /* ins s9 to v8[1]*/ \ + "mov v8.s[2], v10.s[0] \n" /* ins s10 to v8[2]*/ \ + "mov v8.s[3], v11.s[0] \n" /* ins s11 to v8[3]*/ \ + "mov v9.s[0], v12.s[0] \n" /* ins s12 to v9[0]*/ \ + "mov v9.s[1], v13.s[0] \n" /* ins s13 to v9[1]*/ \ + "mov v9.s[2], v14.s[0] \n" /* ins s14 to v9[2]*/ \ + "mov v9.s[3], v15.s[0] \n" /* ins s15 to v9[3]*/ \ + "movi v2.4s, #0 \n" /* zero data for relu6 */\ + "fmax v8.4s, v8.4s, v2.4s \n" /* relu6 */ \ + "fmax v9.4s, v9.4s, v2.4s \n" /* relu6 */ \ + "fmin v8.4s, v8.4s, %[vsix].4s \n" /* relu */ \ + "fmin v9.4s, v9.4s, %[vsix].4s \n" /* relu */ \ + "stp q8, q9, [%[out]] \n" /* save result */ + +#define SGEMV_OUT_8_LEAKEY_RELU \ + /* end */ \ + "4: \n" /* end */ \ + "mov v8.s[1], v9.s[0] \n" /* ins s9 to v8[1]*/ \ + "mov v8.s[2], v10.s[0] \n" /* ins s10 to v8[2]*/ \ + "mov v8.s[3], v11.s[0] \n" /* ins s11 to v8[3]*/ \ + "mov v9.s[0], v12.s[0] \n" /* ins s12 to v9[0]*/ \ + "mov v9.s[1], v13.s[0] \n" /* ins s13 to v9[1]*/ \ + "mov v9.s[2], v14.s[0] \n" /* ins s14 to v9[2]*/ \ + "mov v9.s[3], v15.s[0] \n" /* ins s15 to v9[3]*/ \ + "movi v2.4s, #0 \n" /* zero data for leakey relu */ \ + "fcmge v4.4s, v8.4s, v2.4s \n" /* vcgeq_f32 */ \ + "fmul v5.4s, v8.4s, %[valpha].4s \n" /* vmulq_f32 */ \ + "fcmge v6.4s, v9.4s, v2.4s \n" /* vcgeq_f32 */ \ + "fmul v7.4s, v9.4s, %[valpha].4s \n" /* vmulq_f32 */ \ + "bif v8.16b, v5.16b, v4.16b \n" /* choose*/ \ + "bif v9.16b, v7.16b, v6.16b \n" /* choose*/ \ + "stp q8, q9, [%[out]] \n" /* save result */ + +#define SGEMV_OUT_1 \ + /* end */ \ + "4: \n" /* end */ \ "str s8, [%[out]] \n" /* save result */ #define SGEMV_OUT_1_RELU \ /* end */ \ "4: \n" /* end */ \ - "movi d0, #0 \n" /* zero data for relu */ \ - "fmax s8, s8, s0 \n" /* relu */ \ + "movi d1, #0 \n" /* zero data for relu */ \ + "fmax s8, s8, s1 \n" /* relu */ \ + "str s8, [%[out]] \n" /* save result */ + +#define SGEMV_OUT_1_RELU6 \ + /* end */ \ + "4: \n" /* end */ \ + "movi d1, #0 \n" /* zero data for relu6 */ \ + "fmov s2, %w[six] \n" /* mov six to s2 */ \ + "fmax s8, s8, s1 \n" /* relu6 */ \ + "fmin s8, s8, s2 \n" /* relu6 */ \ "str s8, [%[out]] \n" /* save result */ +#define SGEMV_OUT_1_LEAKEY_RELU \ + /* end */ \ + "4: \n" /* end */ \ + "fmov s1, %w[alpha] \n" /* mov alpha to s1 */ \ + "fcmp s8, #0 \n" /* cmp with zero*/ \ + "bge 5f \n" /* if ge zero */ \ + "fmul s8, s8, s1 \n" /* out * alpha */ \ + "5: \n" /* leakey relu label */ \ + "str s8, [%[out]] \n" /* save result */ + #else // __aarch64__ #define SGEMV_IN_4 \ @@ -841,14 +995,13 @@ bool sgemv(const float *A, "vmla.f32 q2, q5, q11 @ mul add\n" \ "vmla.f32 q3, q5, q13 @ mul add\n" \ "bne 1b @ jump to main loop\n" \ - /* pair add to final result */ \ "2: @ pair add \n" \ "vpadd.f32 d8, d0, d1 @ pair add, first step\n" \ "vpadd.f32 d9, d2, d3 @ pair add, first step\n" \ "vpadd.f32 d10, d4, d5 @ pair add, first step\n" \ "vpadd.f32 d11, d6, d7 @ pair add, first step\n" \ "vpadd.f32 d0, d8, d9 @ pair add, second step\n" \ - "vpadd.f32 d1, d10, d11 @ pair add, second step\n" /* check tails */ \ + "vpadd.f32 d1, d10, d11 @ pair add, second step\n" \ "cmp %[tail], #1 @ check whether has tail\n" \ "blt 4f @ jump to end\n" \ "3: @ tail loop\n" \ @@ -876,7 +1029,7 @@ bool sgemv(const float *A, "bne 1b @ jump to main loop\n" \ "2: @ end processing\n" \ "vpadd.f32 d2, d0, d1 @ pair add, first step\n" \ - "vpadd.f32 d0, d2, d2 @ pair add, final step\n"/*check tails*/ \ + "vpadd.f32 d0, d2, d2 @ pair add, final step\n" \ "cmp %[tail], #1 @ check whether has mid cols\n" \ "blt 4f @ jump to end\n" \ "3: @ tail loop\n" \ @@ -898,6 +1051,25 @@ bool sgemv(const float *A, "vmax.f32 q0, q0, q1 @ relu\n" \ "vst1.32 {d0-d1}, [%[out]] @ save result\n" +#define SGEMV_OUT_4_RELU6 \ + /* end */ \ + "4: @ end\n" \ + "vmov.i32 q1, #0 @ zero for relu6\n" \ + "vdup.f32 q2, %[six] @ six for relu6\n" \ + "vmax.f32 q0, q0, q1 @ relu6\n" \ + "vmin.f32 q0, q0, q2 @ relu6\n" \ + "vst1.32 {d0-d1}, [%[out]] @ save result\n" + +#define SGEMV_OUT_4_LEAKEY_RELU \ + /* end */ \ + "4: @ end\n" \ + "vmov.i32 q1, #0 @ zero for leakey relu\n" \ + "vdup.f32 q2, %[alpha] @ alpha for leakey relu\n" \ + "vcge.f32 q3, q0, q1 @ vcgeq_f32 \n" \ + "vmul.f32 q4, q0, q2 @ vmulq_f32 \n" \ + "vbif q0, q4, q3 @ choose \n" \ + "vst1.32 {d0-d1}, [%[out]] @ save result\n" + #define SGEMV_OUT_1 \ /* end */ \ "4: @ end\n" \ @@ -909,14 +1081,36 @@ bool sgemv(const float *A, "vmov.i32 d1, #0 @ zero for relu\n" \ "vmax.f32 d0, d0, d1 @ relu\n" \ "vst1.32 {d0[0]}, [%[out]] @ save result\n" + +#define SGEMV_OUT_1_RELU6 \ + /* end */ \ + "4: @ end\n" \ + "vmov.i32 d1, #0 @ zero for relu6\n" \ + "vdup.f32 d4, %[six] @ six for relu6\n" \ + "vmax.f32 d0, d0, d1 @ relu6\n" \ + "vmin.f32 d0, d0, d4 @ relu6\n" \ + "vst1.32 {d0[0]}, [%[out]] @ save result\n" + +#define SGEMV_OUT_1_LEAKEY_RELU \ + /* end */ \ + "4: @ end\n" \ + "vmov.i32 d2, #0 @ zero for leakey relu\n" \ + "vdup.f32 d3, %[alpha] @ alpha for leakey relu\n" \ + "vcge.f32 d6, d0, d2 @ vcgeq_f32 \n" \ + "vmul.f32 d8, d0, d3 @ vmulq_f32 \n" \ + "vbif d0, d8, d6 @ choose \n" \ + "vst1.32 {d0[0]}, [%[out]] @ save result\n" + #endif // clang-format on -void sgemv(const bool transA, - const int M, + +void sgemv(const int M, const int N, const float *A, const float *x, - float *y) { + float *y, + bool flag_bias, + const float *bias) { float *data_out = y; const float *data_in = x; const float *weights_ptr = A; @@ -926,7 +1120,6 @@ void sgemv(const bool transA, #ifdef __aarch64__ int out_cnt = M >> 3; - #pragma omp parallel for for (int j = 0; j < out_cnt; j++) { int out_idx = j * 8; @@ -940,9 +1133,22 @@ void sgemv(const bool transA, const float *ptr_w5 = ptr_w4 + N; const float *ptr_w6 = ptr_w5 + N; const float *ptr_w7 = ptr_w6 + N; + const float *bias_ptr = bias + out_idx; + float bias_local[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; + if (flag_bias) { + bias_local[0] = bias_ptr[0]; + bias_local[1] = bias_ptr[1]; + bias_local[2] = bias_ptr[2]; + bias_local[3] = bias_ptr[3]; + bias_local[4] = bias_ptr[4]; + bias_local[5] = bias_ptr[5]; + bias_local[6] = bias_ptr[6]; + bias_local[7] = bias_ptr[7]; + } int cnt_loop = cnt; int tail_loop = tail; - asm volatile(SGEMV_IN_8 SGEMV_KERNEL_8 SGEMV_OUT_8 + // clang-format off + asm volatile(SGEMV_IN_8_BIAS SGEMV_KERNEL_8 SGEMV_OUT_8 : [in] "+r"(ptr_in), [w0] "+r"(ptr_w0), [w1] "+r"(ptr_w1), @@ -954,35 +1160,12 @@ void sgemv(const bool transA, [w7] "+r"(ptr_w7), [cnt] "+r"(cnt_loop), [tail] "+r"(tail_loop) - : [out] "r"(ptr_out) - : "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25", - "cc", - "memory"); + : [out] "r"(ptr_out), [bias_ptr] "r"(bias_local) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", + "v24", "v25", "cc", "memory"); + // clang-format on } //! deal with remains #pragma omp parallel for @@ -992,24 +1175,17 @@ void sgemv(const bool transA, const float *ptr_w0 = weights_ptr + (N * j); int cnt_loop = cnt; int tail_loop = tail; - float tmp[4]; - float tmp1[4]; - float tmp2[4]; - float tmp3[4]; - float tmp4[4]; - asm volatile( - SGEMV_IN_1 SGEMV_KERNEL_1 SGEMV_OUT_1 - : [in] "+r"(ptr_in), - [w0] "+r"(ptr_w0), - [cnt] "+r"(cnt_loop), - [tail] "+r"(tail_loop) - : [out] "r"(ptr_out), - [tmp] "r"(tmp), - [tmp1] "r"(tmp1), - [tmp2] "r"(tmp2), - [tmp3] "r"(tmp3), - [tmp4] "r"(tmp4) - : "v0", "v1", "v8", "v9", "v10", "v11", "v16", "v17", "cc", "memory"); + float bias0 = 0.f; + if (flag_bias) { + bias0 = bias[j]; + } + asm volatile(SGEMV_IN_1_BIAS SGEMV_KERNEL_1 SGEMV_OUT_1 + : [in] "+r"(ptr_in), + [w0] "+r"(ptr_w0), + [cnt] "+r"(cnt_loop), + [tail] "+r"(tail_loop) + : [out] "r"(ptr_out), [bias0] "r"(bias0) + : "v0", "v1", "v8", "v9", "v10", "v11", "v16", "v17", "cc"); } #else // __aarch64__ int out_cnt = M >> 2; @@ -1022,10 +1198,20 @@ void sgemv(const bool transA, const float *ptr_w1 = ptr_w0 + N; const float *ptr_w2 = ptr_w1 + N; const float *ptr_w3 = ptr_w2 + N; - + float bias0 = 0.f; + float bias1 = 0.f; + float bias2 = 0.f; + float bias3 = 0.f; + if (flag_bias) { + bias0 = bias[out_idx]; + bias1 = bias[out_idx + 1]; + bias2 = bias[out_idx + 2]; + bias3 = bias[out_idx + 3]; + } int cnt_loop = cnt; int tail_loop = tail; - asm volatile(SGEMV_IN_4 SGEMV_KERNEL_4 SGEMV_OUT_4 + // clang-format off + asm volatile(SGEMV_IN_4_BIAS SGEMV_KERNEL_4 SGEMV_OUT_4 : [in] "+r"(ptr_in), [w0] "+r"(ptr_w0), [w1] "+r"(ptr_w1), @@ -1033,23 +1219,16 @@ void sgemv(const bool transA, [w3] "+r"(ptr_w3), [cnt] "+r"(cnt_loop), [tail] "+r"(tail_loop) - : [out] "r"(ptr_out) - : "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "cc", + : [out] "r"(ptr_out), + [bias0] "r"(bias0), + [bias1] "r"(bias1), + [bias2] "r"(bias2), + [bias3] "r"(bias3) + : "q0", "q1", "q2", "q3", "q4", + "q5", "q6", "q7", "q8", "q9", + "q10", "q11", "q12", "q13", "cc", "memory"); + // clang-format on } //! deal with remains #pragma omp parallel for @@ -1059,23 +1238,28 @@ void sgemv(const bool transA, const float *ptr_w0 = weights_ptr + (N * j); int cnt_loop = cnt; int tail_loop = tail; - asm volatile(SGEMV_IN_1 SGEMV_KERNEL_1 SGEMV_OUT_1 + float bias0 = 0.f; + if (flag_bias) { + bias0 = bias[j]; + } + asm volatile(SGEMV_IN_1_BIAS SGEMV_KERNEL_1 SGEMV_OUT_1 : [in] "+r"(ptr_in), [w0] "+r"(ptr_w0), [cnt] "+r"(cnt_loop), [tail] "+r"(tail_loop) - : [out] "r"(ptr_out) + : [out] "r"(ptr_out), [bias0] "r"(bias0) : "q0", "q1", "q12", "q13", "q14", "q15", "cc", "memory"); } #endif // __aarch64__ } -void sgemv_relu(const bool transA, - const int M, +void sgemv_relu(const int M, const int N, const float *A, const float *x, - float *y) { + float *y, + bool flag_bias, + const float *bias) { float *data_out = y; const float *data_in = x; const float *weights_ptr = A; @@ -1098,9 +1282,22 @@ void sgemv_relu(const bool transA, const float *ptr_w5 = ptr_w4 + N; const float *ptr_w6 = ptr_w5 + N; const float *ptr_w7 = ptr_w6 + N; + const float *bias_ptr = bias + out_idx; + float bias_local[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; + if (flag_bias) { + bias_local[0] = bias_ptr[0]; + bias_local[1] = bias_ptr[1]; + bias_local[2] = bias_ptr[2]; + bias_local[3] = bias_ptr[3]; + bias_local[4] = bias_ptr[4]; + bias_local[5] = bias_ptr[5]; + bias_local[6] = bias_ptr[6]; + bias_local[7] = bias_ptr[7]; + } int cnt_loop = cnt; int tail_loop = tail; - asm volatile(SGEMV_IN_8 SGEMV_KERNEL_8 SGEMV_OUT_8_RELU + // clang-format off + asm volatile(SGEMV_IN_8_BIAS SGEMV_KERNEL_8 SGEMV_OUT_8_RELU : [in] "+r"(ptr_in), [w0] "+r"(ptr_w0), [w1] "+r"(ptr_w1), @@ -1112,35 +1309,12 @@ void sgemv_relu(const bool transA, [w7] "+r"(ptr_w7), [cnt] "+r"(cnt_loop), [tail] "+r"(tail_loop) - : [out] "r"(ptr_out) - : "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25", - "cc", - "memory"); + : [out] "r"(ptr_out), [bias_ptr] "r"(bias_local) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", + "v24", "v25", "cc", "memory"); + // clang-format on } //! deal with remains #pragma omp parallel for @@ -1150,13 +1324,17 @@ void sgemv_relu(const bool transA, const float *ptr_w0 = weights_ptr + (N * j); int cnt_loop = cnt; int tail_loop = tail; + float bias0 = 0.f; + if (flag_bias) { + bias0 = bias[j]; + } asm volatile( - SGEMV_IN_1 SGEMV_KERNEL_1 SGEMV_OUT_1_RELU + SGEMV_IN_1_BIAS SGEMV_KERNEL_1 SGEMV_OUT_1_RELU : [in] "+r"(ptr_in), [w0] "+r"(ptr_w0), [cnt] "+r"(cnt_loop), [tail] "+r"(tail_loop) - : [out] "r"(ptr_out) + : [out] "r"(ptr_out), [bias0] "r"(bias0) : "v0", "v1", "v8", "v9", "v10", "v11", "v16", "v17", "cc", "memory"); } #else // __aarch64__ @@ -1170,10 +1348,20 @@ void sgemv_relu(const bool transA, const float *ptr_w1 = ptr_w0 + N; const float *ptr_w2 = ptr_w1 + N; const float *ptr_w3 = ptr_w2 + N; - + float bias0 = 0.f; + float bias1 = 0.f; + float bias2 = 0.f; + float bias3 = 0.f; + if (flag_bias) { + bias0 = bias[out_idx]; + bias1 = bias[out_idx + 1]; + bias2 = bias[out_idx + 2]; + bias3 = bias[out_idx + 3]; + } int cnt_loop = cnt; int tail_loop = tail; - asm volatile(SGEMV_IN_4 SGEMV_KERNEL_4 SGEMV_OUT_4_RELU + // clang-format off + asm volatile(SGEMV_IN_4_BIAS SGEMV_KERNEL_4 SGEMV_OUT_4_RELU : [in] "+r"(ptr_in), [w0] "+r"(ptr_w0), [w1] "+r"(ptr_w1), @@ -1181,23 +1369,16 @@ void sgemv_relu(const bool transA, [w3] "+r"(ptr_w3), [cnt] "+r"(cnt_loop), [tail] "+r"(tail_loop) - : [out] "r"(ptr_out) - : "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "cc", + : [out] "r"(ptr_out), + [bias0] "r"(bias0), + [bias1] "r"(bias1), + [bias2] "r"(bias2), + [bias3] "r"(bias3) + : "q0", "q1", "q2", "q3", "q4", + "q5", "q6", "q7", "q8", "q9", + "q10", "q11", "q12", "q13", "cc", "memory"); + // clang-format on } //! deal with remains #pragma omp parallel for @@ -1207,31 +1388,36 @@ void sgemv_relu(const bool transA, const float *ptr_w0 = weights_ptr + (N * j); int cnt_loop = cnt; int tail_loop = tail; - asm volatile(SGEMV_IN_1 SGEMV_KERNEL_1 SGEMV_OUT_1_RELU + float bias0 = 0.f; + if (flag_bias) { + bias0 = bias[j]; + } + asm volatile(SGEMV_IN_1_BIAS SGEMV_KERNEL_1 SGEMV_OUT_1_RELU : [in] "+r"(ptr_in), [w0] "+r"(ptr_w0), [cnt] "+r"(cnt_loop), [tail] "+r"(tail_loop) - : [out] "r"(ptr_out) + : [out] "r"(ptr_out), [bias0] "r"(bias0) : "q0", "q1", "q12", "q13", "q14", "q15", "cc", "memory"); } #endif // __aarch64__ } -void sgemv_bias(const bool transA, - const int M, - const int N, - const float *A, - const float *x, - float *y, - const float *bias) { +void sgemv_relu6(const int M, + const int N, + const float *A, + const float *x, + float *y, + bool flag_bias, + const float *bias, + const float six) { float *data_out = y; const float *data_in = x; const float *weights_ptr = A; int cnt = N >> 3; int tail = N & 7; - + float32x4_t vsix = vdupq_n_f32(six); #ifdef __aarch64__ int out_cnt = M >> 3; #pragma omp parallel for @@ -1248,9 +1434,21 @@ void sgemv_bias(const bool transA, const float *ptr_w6 = ptr_w5 + N; const float *ptr_w7 = ptr_w6 + N; const float *bias_ptr = bias + out_idx; + float bias_local[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; + if (flag_bias) { + bias_local[0] = bias_ptr[0]; + bias_local[1] = bias_ptr[1]; + bias_local[2] = bias_ptr[2]; + bias_local[3] = bias_ptr[3]; + bias_local[4] = bias_ptr[4]; + bias_local[5] = bias_ptr[5]; + bias_local[6] = bias_ptr[6]; + bias_local[7] = bias_ptr[7]; + } int cnt_loop = cnt; int tail_loop = tail; - asm volatile(SGEMV_IN_8_BIAS SGEMV_KERNEL_8 SGEMV_OUT_8 + // clang-format off + asm volatile(SGEMV_IN_8_BIAS SGEMV_KERNEL_8 SGEMV_OUT_8_RELU6 : [in] "+r"(ptr_in), [w0] "+r"(ptr_w0), [w1] "+r"(ptr_w1), @@ -1262,35 +1460,13 @@ void sgemv_bias(const bool transA, [w7] "+r"(ptr_w7), [cnt] "+r"(cnt_loop), [tail] "+r"(tail_loop) - : [out] "r"(ptr_out), [bias_ptr] "r"(bias_ptr) - : "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25", - "cc", - "memory"); + : [out] "r"(ptr_out), [bias_ptr] "r"(bias_local), + [vsix] "w" (vsix) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", + "v24", "v25", "cc", "memory"); + // clang-format on } //! deal with remains #pragma omp parallel for @@ -1300,14 +1476,17 @@ void sgemv_bias(const bool transA, const float *ptr_w0 = weights_ptr + (N * j); int cnt_loop = cnt; int tail_loop = tail; - float bias0 = bias[j]; + float bias0 = 0.f; + if (flag_bias) { + bias0 = bias[j]; + } asm volatile( - SGEMV_IN_1_BIAS SGEMV_KERNEL_1 SGEMV_OUT_1 + SGEMV_IN_1_BIAS SGEMV_KERNEL_1 SGEMV_OUT_1_RELU6 : [in] "+r"(ptr_in), [w0] "+r"(ptr_w0), [cnt] "+r"(cnt_loop), [tail] "+r"(tail_loop) - : [out] "r"(ptr_out), [bias0] "r"(bias0) + : [out] "r"(ptr_out), [bias0] "r"(bias0), [six] "r"(six) : "v0", "v1", "v8", "v9", "v10", "v11", "v16", "v17", "cc", "memory"); } #else // __aarch64__ @@ -1321,14 +1500,20 @@ void sgemv_bias(const bool transA, const float *ptr_w1 = ptr_w0 + N; const float *ptr_w2 = ptr_w1 + N; const float *ptr_w3 = ptr_w2 + N; - float bias0 = bias[out_idx]; - float bias1 = bias[out_idx + 1]; - float bias2 = bias[out_idx + 2]; - float bias3 = bias[out_idx + 3]; - + float bias0 = 0.f; + float bias1 = 0.f; + float bias2 = 0.f; + float bias3 = 0.f; + if (flag_bias) { + bias0 = bias[out_idx]; + bias1 = bias[out_idx + 1]; + bias2 = bias[out_idx + 2]; + bias3 = bias[out_idx + 3]; + } int cnt_loop = cnt; int tail_loop = tail; - asm volatile(SGEMV_IN_4_BIAS SGEMV_KERNEL_4 SGEMV_OUT_4 + // clang-format off + asm volatile(SGEMV_IN_4_BIAS SGEMV_KERNEL_4 SGEMV_OUT_4_RELU6 : [in] "+r"(ptr_in), [w0] "+r"(ptr_w0), [w1] "+r"(ptr_w1), @@ -1340,23 +1525,13 @@ void sgemv_bias(const bool transA, [bias0] "r"(bias0), [bias1] "r"(bias1), [bias2] "r"(bias2), - [bias3] "r"(bias3) - : "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "cc", + [bias3] "r"(bias3), + [six] "r" (six) + : "q0", "q1", "q2", "q3", "q4", + "q5", "q6", "q7", "q8", "q9", + "q10", "q11", "q12", "q13", "cc", "memory"); + // clang-format on } //! deal with remains #pragma omp parallel for @@ -1366,30 +1541,35 @@ void sgemv_bias(const bool transA, const float *ptr_w0 = weights_ptr + (N * j); int cnt_loop = cnt; int tail_loop = tail; - float bias0 = bias[j]; - asm volatile(SGEMV_IN_1_BIAS SGEMV_KERNEL_1 SGEMV_OUT_1 + float bias0 = 0.f; + if (flag_bias) { + bias0 = bias[j]; + } + asm volatile(SGEMV_IN_1_BIAS SGEMV_KERNEL_1 SGEMV_OUT_1_RELU6 : [in] "+r"(ptr_in), [w0] "+r"(ptr_w0), [cnt] "+r"(cnt_loop), [tail] "+r"(tail_loop) - : [out] "r"(ptr_out), [bias0] "r"(bias0) + : [out] "r"(ptr_out), [bias0] "r"(bias0), [six] "r"(six) : "q0", "q1", "q12", "q13", "q14", "q15", "cc", "memory"); } #endif // __aarch64__ } -void sgemv_bias_relu(const bool transA, - const int M, - const int N, - const float *A, - const float *x, - float *y, - const float *bias) { +void sgemv_leakey_relu(const int M, + const int N, + const float *A, + const float *x, + float *y, + bool flag_bias, + const float *bias, + const float alpha) { float *data_out = y; const float *data_in = x; const float *weights_ptr = A; int cnt = N >> 3; int tail = N & 7; + float32x4_t valpha = vdupq_n_f32(alpha); #ifdef __aarch64__ int out_cnt = M >> 3; #pragma omp parallel for @@ -1406,9 +1586,21 @@ void sgemv_bias_relu(const bool transA, const float *ptr_w6 = ptr_w5 + N; const float *ptr_w7 = ptr_w6 + N; const float *bias_ptr = bias + out_idx; + float bias_local[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; + if (flag_bias) { + bias_local[0] = bias_ptr[0]; + bias_local[1] = bias_ptr[1]; + bias_local[2] = bias_ptr[2]; + bias_local[3] = bias_ptr[3]; + bias_local[4] = bias_ptr[4]; + bias_local[5] = bias_ptr[5]; + bias_local[6] = bias_ptr[6]; + bias_local[7] = bias_ptr[7]; + } int cnt_loop = cnt; int tail_loop = tail; - asm volatile(SGEMV_IN_8_BIAS SGEMV_KERNEL_8 SGEMV_OUT_8_RELU + // clang-format off + asm volatile(SGEMV_IN_8_BIAS SGEMV_KERNEL_8 SGEMV_OUT_8_LEAKEY_RELU : [in] "+r"(ptr_in), [w0] "+r"(ptr_w0), [w1] "+r"(ptr_w1), @@ -1420,35 +1612,13 @@ void sgemv_bias_relu(const bool transA, [w7] "+r"(ptr_w7), [cnt] "+r"(cnt_loop), [tail] "+r"(tail_loop) - : [out] "r"(ptr_out), [bias_ptr] "r"(bias_ptr) - : "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25", - "cc", - "memory"); + : [out] "r"(ptr_out), [bias_ptr] "r"(bias_local), + [valpha] "w" (valpha) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", + "v24", "v25", "cc", "memory"); + // clang-format on } //! deal with remains #pragma omp parallel for @@ -1458,14 +1628,17 @@ void sgemv_bias_relu(const bool transA, const float *ptr_w0 = weights_ptr + (N * j); int cnt_loop = cnt; int tail_loop = tail; - float bias0 = bias[j]; + float bias0 = 0.f; + if (flag_bias) { + bias0 = bias[j]; + } asm volatile( - SGEMV_IN_1_BIAS SGEMV_KERNEL_1 SGEMV_OUT_1_RELU + SGEMV_IN_1_BIAS SGEMV_KERNEL_1 SGEMV_OUT_1_LEAKEY_RELU : [in] "+r"(ptr_in), [w0] "+r"(ptr_w0), [cnt] "+r"(cnt_loop), [tail] "+r"(tail_loop) - : [out] "r"(ptr_out), [bias0] "r"(bias0) + : [out] "r"(ptr_out), [bias0] "r"(bias0), [alpha] "r"(alpha) : "v0", "v1", "v8", "v9", "v10", "v11", "v16", "v17", "cc", "memory"); } #else // __aarch64__ @@ -1479,14 +1652,20 @@ void sgemv_bias_relu(const bool transA, const float *ptr_w1 = ptr_w0 + N; const float *ptr_w2 = ptr_w1 + N; const float *ptr_w3 = ptr_w2 + N; - float bias0 = bias[out_idx]; - float bias1 = bias[out_idx + 1]; - float bias2 = bias[out_idx + 2]; - float bias3 = bias[out_idx + 3]; - + float bias0 = 0.f; + float bias1 = 0.f; + float bias2 = 0.f; + float bias3 = 0.f; + if (flag_bias) { + bias0 = bias[out_idx]; + bias1 = bias[out_idx + 1]; + bias2 = bias[out_idx + 2]; + bias3 = bias[out_idx + 3]; + } int cnt_loop = cnt; int tail_loop = tail; - asm volatile(SGEMV_IN_4_BIAS SGEMV_KERNEL_4 SGEMV_OUT_4_RELU + // clang-format off + asm volatile(SGEMV_IN_4_BIAS SGEMV_KERNEL_4 SGEMV_OUT_4_LEAKEY_RELU : [in] "+r"(ptr_in), [w0] "+r"(ptr_w0), [w1] "+r"(ptr_w1), @@ -1498,23 +1677,13 @@ void sgemv_bias_relu(const bool transA, [bias0] "r"(bias0), [bias1] "r"(bias1), [bias2] "r"(bias2), - [bias3] "r"(bias3) - : "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "cc", + [bias3] "r"(bias3), + [alpha] "r" (alpha) + : "q0", "q1", "q2", "q3", "q4", + "q5", "q6", "q7", "q8", "q9", + "q10", "q11", "q12", "q13", "cc", "memory"); + // clang-format on } //! deal with remains #pragma omp parallel for @@ -1524,14 +1693,18 @@ void sgemv_bias_relu(const bool transA, const float *ptr_w0 = weights_ptr + (N * j); int cnt_loop = cnt; int tail_loop = tail; - float bias0 = bias[j]; - asm volatile(SGEMV_IN_1_BIAS SGEMV_KERNEL_1 SGEMV_OUT_1_RELU - : [in] "+r"(ptr_in), - [w0] "+r"(ptr_w0), - [cnt] "+r"(cnt_loop), - [tail] "+r"(tail_loop) - : [out] "r"(ptr_out), [bias0] "r"(bias0) - : "q0", "q1", "q12", "q13", "q14", "q15", "cc", "memory"); + float bias0 = 0.f; + if (flag_bias) { + bias0 = bias[j]; + } + asm volatile( + SGEMV_IN_1_BIAS SGEMV_KERNEL_1 SGEMV_OUT_1_LEAKEY_RELU + : [in] "+r"(ptr_in), + [w0] "+r"(ptr_w0), + [cnt] "+r"(cnt_loop), + [tail] "+r"(tail_loop) + : [out] "r"(ptr_out), [bias0] "r"(bias0), [alpha] "r"(alpha) + : "q0", "q1", "q3", "q4", "q12", "q13", "q14", "q15", "cc", "memory"); } #endif // __aarch64__ } diff --git a/lite/backends/arm/math/sgemv.h b/lite/backends/arm/math/sgemv.h index aa17349c99e61f7135090318be829149ecd6bb57..53b2c2ab55a2cee51f8535683c5cf34340fd6dab 100644 --- a/lite/backends/arm/math/sgemv.h +++ b/lite/backends/arm/math/sgemv.h @@ -17,23 +17,26 @@ #include #include "lite/core/context.h" #include "lite/core/device_info.h" +#include "lite/operators/op_params.h" namespace paddle { namespace lite { namespace arm { namespace math { -// TODO(xxx): fixme now only support transA = false -bool sgemv(const float* A, - const float* x, - float* y, +bool sgemv(const float *A, + const float *x, + float *y, bool transA, int M, int N, bool is_bias, - const float* bias, - bool is_relu, - const ARMContext* ctx); + const float *bias, + bool flag_act, + lite_api::ActivationType act, + const ARMContext *ctx, + float six = 6.f, + float alpha = 1.f); } // namespace math } // namespace arm diff --git a/lite/backends/arm/math/slice.cc b/lite/backends/arm/math/slice.cc index 8b9a7690509260ed4c6c0e14750d849f657d2fa8..67ca567fea988acfc9e20e2bfc929e9c3a0bbcb8 100644 --- a/lite/backends/arm/math/slice.cc +++ b/lite/backends/arm/math/slice.cc @@ -86,6 +86,13 @@ template void slice(const int* input, std::vector ends, int* out, Context* ctx); +template void slice(const float* input, + std::vector dims, + std::vector axes, + std::vector starts, + std::vector ends, + float* out, + Context* ctx); } // namespace math } // namespace arm diff --git a/lite/backends/arm/math/split.cc b/lite/backends/arm/math/split.cc index 54ea7e62c2567cf2fe490351572968366fda483e..bff29af93b525dc18e19bded03b0770f7f7a33c8 100644 --- a/lite/backends/arm/math/split.cc +++ b/lite/backends/arm/math/split.cc @@ -70,10 +70,12 @@ void split(const float* din, int in_after = in_strides[axis]; int out_after = out_strides[axis]; + const float* din_ptr = din + input_offset; + for (int i = 0; i < before; ++i) { - split_cpy(din + input_offset + i * in_after, - out_data + i * out_after, - out_after); + std::memcpy(out_data, din_ptr, sizeof(float) * out_after); + din_ptr += in_after; + out_data += out_after; } input_offset += out_strides[axis]; } diff --git a/lite/backends/arm/math/split_merge_lod_tenosr.cc b/lite/backends/arm/math/split_merge_lod_tenosr.cc new file mode 100644 index 0000000000000000000000000000000000000000..35dc4a455b7c51e0aab1a45c48460ccc513b9a08 --- /dev/null +++ b/lite/backends/arm/math/split_merge_lod_tenosr.cc @@ -0,0 +1,62 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/backends/arm/math/split_merge_lod_tenosr.h" +#include +#include + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +using LoDAndOffset = std::pair>; +LoDAndOffset GetSubLoDAndAbsoluteOffset(const LoD &lod, + size_t start_idx, + size_t end_idx, + size_t start_level) { + LoD sub_lod; + for (size_t level_idx = start_level; level_idx < lod.size(); ++level_idx) { + CHECK(start_idx <= end_idx); + CHECK(end_idx < lod[level_idx].size()); + std::vector level_lens; + for (size_t i = start_idx; i < end_idx; ++i) { + level_lens.push_back(lod[level_idx][i + 1] - lod[level_idx][i]); + } + sub_lod.emplace_back(level_lens); + start_idx = lod[level_idx][start_idx]; + end_idx = lod[level_idx][end_idx]; + } + return LoDAndOffset{sub_lod, {start_idx, end_idx}}; +} + +void AppendLoD(LoD *lod, const LoD &lod_length) { + CHECK(lod->empty() || lod->size() == lod_length.size()); + if (lod->empty()) { + for (size_t i = 0; i < lod_length.size(); ++i) { + lod->emplace_back(std::vector({0})); + } + } + for (size_t i = 0; i < lod->size(); ++i) { + auto &level = (*lod)[i]; + for (auto len : lod_length[i]) { + level.push_back(level.back() + len); + } + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/backends/arm/math/split_merge_lod_tenosr.h b/lite/backends/arm/math/split_merge_lod_tenosr.h new file mode 100644 index 0000000000000000000000000000000000000000..47c484aa4a203ed1819a7e810f71858f4ef0b4dd --- /dev/null +++ b/lite/backends/arm/math/split_merge_lod_tenosr.h @@ -0,0 +1,33 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "lite/core/tensor.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +std::pair> GetSubLoDAndAbsoluteOffset( + const LoD &lod, size_t start_idx, size_t end_idx, size_t start_level); + +void AppendLoD(LoD *lod, const LoD &lod_length); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/backends/arm/math/type_trans.cc b/lite/backends/arm/math/type_trans.cc index 6ded50e75294ad5145b3b88c4c341d4cce09c812..c50abb741ded487efa03d7d46baf2c6f13a8791d 100644 --- a/lite/backends/arm/math/type_trans.cc +++ b/lite/backends/arm/math/type_trans.cc @@ -46,6 +46,7 @@ void fp32_to_int8(const float* din, float inv_scale = 1.f / scale[j % axis_size]; float32x4_t vzero = vdupq_n_f32(0.f); float32x4_t vscale = vdupq_n_f32(inv_scale); + float32x4_t vmax = vdupq_n_f32(-127.f); float32x4_t vpoff = vdupq_n_f32(0.5f); float32x4_t vnoff = vdupq_n_f32(-0.5f); const float* din_c = din + j * inner_size; @@ -63,6 +64,14 @@ void fp32_to_int8(const float* din, "fmul v5.4s, v1.4s, %[scale].4s \n" "fmul v6.4s, v2.4s, %[scale].4s \n" "fmul v7.4s, v3.4s, %[scale].4s \n" + "fcmge v8.4s, v4.4s, %[vmax].4s \n" + "fcmge v9.4s, v5.4s, %[vmax].4s \n" + "fcmge v10.4s, v6.4s, %[vmax].4s \n" + "fcmge v11.4s, v7.4s, %[vmax].4s \n" + "bif v4.16b, %[vmax].16b, v8.16b \n" + "bif v5.16b, %[vmax].16b, v9.16b \n" + "bif v6.16b, %[vmax].16b, v10.16b \n" + "bif v7.16b, %[vmax].16b, v11.16b \n" "ldp q0, q1, [%[in]], #32 \n" "subs %[cnt], %[cnt], #1 \n" "FCVTAS v8.4s, v4.4s \n" @@ -79,7 +88,7 @@ void fp32_to_int8(const float* din, "str q8, [%[out]], #16 \n" "bne 0b \n" : [in] "+r"(din_ptr), [out] "+r"(dout_ptr), [cnt] "+r"(cnt_loop) - : [scale] "w"(vscale) + : [scale] "w"(vscale), [vmax] "w"(vmax) : "v0", "v1", "v2", @@ -104,15 +113,23 @@ void fp32_to_int8(const float* din, "vcgt.f32 q8, q0, %q[vzero] @ get mask > 0, in0\n" "vcgt.f32 q9, q1, %q[vzero] @ get mask > 0, in1\n" "vcgt.f32 q10, q2, %q[vzero] @ get mask > 0, in2\n" - "vcgt.f32 q11, q3, %q[vzero] @ get mask > 0, in3\n" "vbif.f32 q4, %q[vnoff], q8 @ get right offset\n" + "vcgt.f32 q8, q3, %q[vzero] @ get mask > 0, in3\n" "vbif.f32 q5, %q[vnoff], q9 @ get right offset\n" "vbif.f32 q6, %q[vnoff], q10 @ get right offset\n" - "vbif.f32 q7, %q[vnoff], q11 @ get right offset\n" + "vbif.f32 q7, %q[vnoff], q8 @ get right offset\n" "vmla.f32 q4, q0, %q[vscale] @ mul scale\n" "vmla.f32 q5, q1, %q[vscale] @ mul scale\n" "vmla.f32 q6, q2, %q[vscale] @ mul scale\n" "vmla.f32 q7, q3, %q[vscale] @ mul scale\n" + "vcge.f32 q8, q4, %q[vmax] @ q4 >= vmax \n" + "vcge.f32 q9, q5, %q[vmax] @ q4 >= vmax \n" + "vcge.f32 q10, q6, %q[vmax] @ q4 >= vmax \n" + "vbif q4, %q[vmax], q8 @ choose \n" + "vcge.f32 q8, q7, %q[vmax] @ q4 >= vmax \n" + "vbif q5, %q[vmax], q9 @ choose \n" + "vbif q6, %q[vmax], q10 @ choose \n" + "vbif q7, %q[vmax], q8 @ choose \n" "vcvt.s32.f32 q0, q4 @ cvt to int32\n" "vcvt.s32.f32 q1, q5 @ cvt to int32\n" "vcvt.s32.f32 q2, q6 @ cvt to int32\n" @@ -133,25 +150,16 @@ void fp32_to_int8(const float* din, : [vscale] "w"(vscale), [vpoff] "w"(vpoff), [vnoff] "w"(vnoff), - [vzero] "w"(vzero) - : "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11"); + [vzero] "w"(vzero), + [vmax] "w"(vmax) + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10"); #endif } const float* din_r = din_c + 16 * cnt; signed char* dout_r = dout_c + 16 * cnt; for (int i = 0; i < remain; ++i) { dout_r[i] = saturate_cast(roundf(inv_scale * din_r[i])); + dout_r[i] = dout_r[i] < -127 ? -127 : dout_r[i]; } } } diff --git a/lite/backends/bm/CMakeLists.txt b/lite/backends/bm/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..9e15b9836b875cec8b5e129ad0f6aceb85ff9d33 --- /dev/null +++ b/lite/backends/bm/CMakeLists.txt @@ -0,0 +1,5 @@ +if (NOT LITE_WITH_BM) + return() +endif() + +lite_cc_library(target_wrapper_bm SRCS target_wrapper.cc DEPS ${bm_runtime_libs}) diff --git a/lite/backends/bm/target_wrapper.cc b/lite/backends/bm/target_wrapper.cc new file mode 100644 index 0000000000000000000000000000000000000000..c75c71452269167064c248418098bcb285d09055 --- /dev/null +++ b/lite/backends/bm/target_wrapper.cc @@ -0,0 +1,111 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "lite/backends/bm/target_wrapper.h" +#include +#include +#include + +namespace paddle { +namespace lite { + +int TargetWrapperBM::device_id_ = 0; +std::map TargetWrapperBM::bm_hds_; + +size_t TargetWrapperBM::num_devices() { + int count = 0; + bm_dev_getcount(&count); + return count; +} + +void TargetWrapperBM::SetDevice(int id) { + /* + if (id < 0 || (size_t)id >= num_devices()) { + LOG(FATAL) << "Failed with invalid device id " << id; + } + */ + device_id_ = id; + if (bm_hds_.find(id) == bm_hds_.end()) { + bm_handle_t bm_handle; + bm_status_t ret = bm_dev_request(&bm_handle, id); + CHECK_EQ(ret, BM_SUCCESS) << "Failed with error code: " + << static_cast(ret); + bm_hds_.insert(std::pair(id, bm_handle)); + } + return; +} + +void* TargetWrapperBM::GetHandle() { + if (bm_hds_.find(device_id_) == bm_hds_.end()) { + LOG(FATAL) << "device not initialized " << device_id_; + } + return bm_hds_.at(device_id_); +} + +void* TargetWrapperBM::Malloc(size_t size) { + void* ptr{}; + + if (bm_hds_.find(device_id_) == bm_hds_.end()) { + SetDevice(device_id_); + } + + bm_handle_t bm_handle = static_cast(bm_hds_.at(device_id_)); + bm_device_mem_t* p_mem = + reinterpret_cast(malloc(sizeof(bm_device_mem_t))); + bm_malloc_device_byte(bm_handle, p_mem, size); + ptr = reinterpret_cast(p_mem); + return ptr; +} + +void TargetWrapperBM::Free(void* ptr) { + if (ptr != NULL) { + bm_handle_t bm_handle = static_cast(bm_hds_.at(device_id_)); + bm_device_mem_t* mem = static_cast(ptr); + bm_free_device(bm_handle, *mem); + free(ptr); + } + return; +} + +void TargetWrapperBM::MemcpySync(void* dst, + const void* src, + size_t size, + IoDirection dir) { + if (bm_hds_.find(device_id_) == bm_hds_.end()) { + return; + } + + bm_handle_t bm_handle = static_cast(bm_hds_.at(device_id_)); + bm_device_mem_t* pmem{}; + const bm_device_mem_t* pcst_mem{}; + + switch (dir) { + case IoDirection::HtoD: + pmem = static_cast(dst); + bm_memcpy_s2d_partial_offset( + bm_handle, *pmem, const_cast(src), size, 0); + break; + case IoDirection::DtoH: + pcst_mem = static_cast(src); + bm_memcpy_d2s_partial_offset( + bm_handle, reinterpret_cast(dst), *pcst_mem, size, 0); + break; + default: + LOG(FATAL) << "Unsupported IoDirection " << static_cast(dir); + break; + } + return; +} + +} // namespace lite +} // namespace paddle diff --git a/lite/backends/bm/target_wrapper.h b/lite/backends/bm/target_wrapper.h new file mode 100644 index 0000000000000000000000000000000000000000..2674ffe161582fbd2fe0dfcabbe8e349d13f847f --- /dev/null +++ b/lite/backends/bm/target_wrapper.h @@ -0,0 +1,73 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/target_wrapper.h" + +namespace paddle { +namespace lite { + +using TargetWrapperBM = TargetWrapper; + +template <> +class TargetWrapper { + public: + using stream_t = int; + using event_t = int; + + static size_t num_devices(); + static size_t maximum_stream() { return 0; } + + static void SetDevice(int id); + static void CreateStream(stream_t* stream) {} + static void DestroyStream(const stream_t& stream) {} + + static void CreateEvent(event_t* event) {} + static void DestroyEvent(const event_t& event) {} + + static void RecordEvent(const event_t& event) {} + static void SyncEvent(const event_t& event) {} + + static void StreamSync(const stream_t& stream) {} + + static void* Malloc(size_t size); + static void Free(void* ptr); + + static void* GetHandle(); + + static void MemcpySync(void* dst, + const void* src, + size_t size, + IoDirection dir); + + static void MemcpyAsync(void* dst, + const void* src, + size_t size, + IoDirection dir, + const stream_t& stream) {} + + static void MemsetSync(void* devPtr, int value, size_t count) {} + + static void MemsetAsync(void* devPtr, + int value, + size_t count, + const stream_t& stream) {} + + private: + static int device_id_; + static std::map bm_hds_; +}; +} // namespace lite +} // namespace paddle diff --git a/lite/backends/cuda/CMakeLists.txt b/lite/backends/cuda/CMakeLists.txt index a6c3fcc66a789f159cd3a756ed893627b393e1fe..35f5f0ce2d93db59cbb856d8008e6f3138633e42 100644 --- a/lite/backends/cuda/CMakeLists.txt +++ b/lite/backends/cuda/CMakeLists.txt @@ -1,10 +1,9 @@ if(NOT LITE_WITH_CUDA) return() endif() -set(cuda_static_deps cudnn_static cublas_static curand_static - culibos_static cudart_static) +get_property(cuda_deps GLOBAL PROPERTY CUDA_MODULES) -nv_library(target_wrapper_cuda SRCS target_wrapper.cc DEPS ${cuda_static_deps}) -nv_library(cuda_blas SRCS blas.cc DEPS ${cuda_static_deps}) +nv_library(target_wrapper_cuda SRCS target_wrapper.cc DEPS ${cuda_deps}) +nv_library(cuda_blas SRCS blas.cc DEPS ${cuda_deps}) add_subdirectory(math) diff --git a/lite/backends/cuda/math/CMakeLists.txt b/lite/backends/cuda/math/CMakeLists.txt index 1829bcf330aba31708ac97c97d093afbda197908..fafd74ae7a43d1a769456edfe408c71593d21201 100644 --- a/lite/backends/cuda/math/CMakeLists.txt +++ b/lite/backends/cuda/math/CMakeLists.txt @@ -2,8 +2,7 @@ if(NOT LITE_WITH_CUDA) return() endif() -set(cuda_static_deps cudnn_static cublas_static curand_static - culibos_static cudart_static) +get_property(cuda_static_deps GLOBAL PROPERTY CUDA_STATIC_MODULES) nv_library(cuda_activation SRCS activation.cu DEPS ${cuda_static_deps}) nv_library(cuda_scale SRCS scale.cu DEPS ${cuda_static_deps}) @@ -12,6 +11,7 @@ nv_library(cuda_transpose SRCS transpose.cu DEPS ${cuda_static_deps}) nv_library(cudnn_conv SRCS cudnn_conv.cc DEPS cuda_activation cuda_scale cuda_type_trans ${cuda_static_deps}) nv_library(cuda_elementwise SRCS elementwise.cu DEPS ${cuda_static_deps}) +nv_library(cudnn_pool SRCS cudnn_pool.cc DEPS ${cuda_static_deps}) nv_library(cuda_gemm SRCS gemm.cc DEPS ${cuda_static_deps}) nv_library(cuda_batched_gemm SRCS batched_gemm.cc DEPS ${cuda_static_deps}) @@ -23,6 +23,7 @@ set ( cuda_type_trans cuda_transpose cuda_elementwise + cudnn_pool cuda_gemm cuda_batched_gemm ) diff --git a/lite/backends/cuda/math/cudnn_conv.cc b/lite/backends/cuda/math/cudnn_conv.cc index 72ed3951f6b9b22a5ae1ee6caef8c69708102885..5dd53084f4079ae68c6fda0530fb5de8cf1d3717 100644 --- a/lite/backends/cuda/math/cudnn_conv.cc +++ b/lite/backends/cuda/math/cudnn_conv.cc @@ -31,6 +31,9 @@ bool CudnnConv2D::create(const operators::ConvParam& param, auto o_dims = param.output->dims(); int batch = x_dims[0]; + auto paddings = *param.paddings; + auto dilations = *param.dilations; + int iw = x_dims[3]; // nchw int ih = x_dims[2]; int ic = x_dims[1]; @@ -41,10 +44,10 @@ bool CudnnConv2D::create(const operators::ConvParam& param, int kh = w_dims[2]; int sw = param.strides[1]; int sh = param.strides[0]; - int pw = param.paddings[1]; - int ph = param.paddings[0]; - int dw = param.dilations[1]; - int dh = param.dilations[0]; + int pw = paddings[2]; + int ph = paddings[0]; + int dw = dilations[1]; + int dh = dilations[0]; CHECK(ic % param.groups == 0) << "The conv input channel shoud be divide group number."; @@ -86,9 +89,15 @@ bool CudnnConv2D::create(const operators::ConvParam& param, this->act_desc_, CUDNN_ACTIVATION_RELU, CUDNN_NOT_PROPAGATE_NAN, 0.0)); } +#if CUDNN_VERSION_MIN(7, 0, 0) + cudnnMathType_t math_type = + use_tensor_core_ ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH; + CUDNN_CHECK(cudnnSetConvolutionMathType(this->conv_desc_, math_type)); +#endif + if (ic == param.groups && ic == oc && ic != 1) { this->fwd_algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM; - } else if (1) { + } else if (!param.var_length) { const auto* i_data = param.x->data(); const auto* w_data = param.filter->data(); auto* o_data = param.output->mutable_data(TARGET(kCUDA)); @@ -133,8 +142,8 @@ bool CudnnConv2D::create(const operators::ConvParam& param, this->fwd_algo_ = algo_cache.GetAlgorithm(x_dims.Vectorize(), w_dims.Vectorize(), param.strides, - param.paddings, - param.dilations, + *param.paddings, + *param.dilations, 0, search_func); @@ -311,12 +320,15 @@ bool CudnnConv2DInt8::create(const operators::ConvParam& param, int kw = w_dims[2]; int kh = w_dims[1]; + auto paddings = *param.paddings; + auto dilations = *param.dilations; + int sw = param.strides[1]; int sh = param.strides[0]; - int pw = param.paddings[1]; - int ph = param.paddings[0]; - int dw = param.dilations[1]; - int dh = param.dilations[0]; + int pw = paddings[2]; + int ph = paddings[0]; + int dw = dilations[1]; + int dh = dilations[0]; std::vector weight_scale = param.weight_scale; float input_scale = param.input_scale; diff --git a/lite/backends/cuda/math/cudnn_pool.cc b/lite/backends/cuda/math/cudnn_pool.cc new file mode 100644 index 0000000000000000000000000000000000000000..f970fc326b29c4c226e7dc9643e416a3cf24f0eb --- /dev/null +++ b/lite/backends/cuda/math/cudnn_pool.cc @@ -0,0 +1,159 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/backends/cuda/math/cudnn_pool.h" +#include "lite/backends/cuda/math/activation.h" +#include "lite/backends/cuda/math/scale.h" +#include "lite/backends/cuda/math/type_trans.h" + +namespace paddle { +namespace lite { +namespace cuda { +namespace math { + +inline void UpdatePadding(std::vector* paddings, + const bool global_pooling, + const bool adaptive, + const std::vector& data_dims, + const std::vector& strides, + const std::vector& ksize) { + if (paddings->size() == data_dims.size()) { + for (size_t i = 0; i < data_dims.size(); ++i) { + int copy_pad = *(paddings->begin() + 2 * i); + paddings->insert(paddings->begin() + 2 * i + 1, copy_pad); + } + } else { + CHECK(data_dims.size() * 2 == paddings->size()) + << "Paddings size should be the same or twice as the pooling size."; + } + if (global_pooling || adaptive) { + for (auto it = paddings->begin(); it != paddings->end(); it++) { + *it = 0; + } + } +} + +inline void UpdateKsize(std::vector* ksize, + const std::vector& data_dims) { + ksize->resize(static_cast(data_dims.size())); + for (size_t i = 0; i < ksize->size(); ++i) { + *(ksize->begin() + i) = static_cast(data_dims[i]); + } +} + +template <> +bool CudnnPool2DNHWC::create( + const operators::PoolParam& param, Context* ctx) { + return true; +} + +template <> +bool CudnnPool2DNHWC::init(const operators::PoolParam& param, + Context* ctx) { + this->stream_ = ctx->exec_stream(); + CUDNN_CHECK(cudnnCreate(&this->handle_)); + CUDNN_CHECK(cudnnSetStream(this->handle_, this->stream_)); + + cudnnCreateTensorDescriptor(&this->input_desc_); + cudnnCreateTensorDescriptor(&this->output_desc_); + cudnnCreatePoolingDescriptor(&this->pooling_desc_); + + return create(param, ctx); +} + +template <> +bool CudnnPool2DNHWC::run( + const operators::PoolParam& param) { + auto x_dims = param.x->dims(); + auto o_dims = param.output->dims(); + int batch = x_dims[0]; + const float* in_data = param.x->data(); + float* out_data = param.output->mutable_data(TARGET(kCUDA)); + + int ih = x_dims[1]; + int iw = x_dims[2]; // nchw + int ic = x_dims[3]; + + int oh = o_dims[1]; + int ow = o_dims[2]; + int oc = o_dims[3]; + + std::vector ksize = param.ksize; + std::vector strides = param.strides; + std::vector paddings = *(param.paddings.get()); + + std::string pooling_type = param.pooling_type; + bool global_pooling = param.global_pooling; + bool exclusive = param.exclusive; + bool adaptive = param.adaptive; + + std::vector data_dims = {ih, iw}; + UpdatePadding(&paddings, global_pooling, adaptive, data_dims, strides, ksize); + + if (data_dims.size() * 2 == paddings.size()) { + for (size_t i = 0; i < data_dims.size(); ++i) { + paddings.erase(paddings.begin() + i + 1); + } + } + + if (global_pooling) { + UpdateKsize(&ksize, data_dims); + } + CUDNN_CHECK(cudnnSetTensor4dDescriptor(this->input_desc_, + CUDNN_TENSOR_NHWC, + CUDNN_DATA_FLOAT, + batch, + ic, + ih, + iw)); + + CUDNN_CHECK(cudnnSetTensor4dDescriptor(this->output_desc_, + CUDNN_TENSOR_NHWC, + CUDNN_DATA_FLOAT, + batch, + oc, + oh, + ow)); + cudnnPoolingMode_t mode; + if (pooling_type == "max") { + mode = CUDNN_POOLING_MAX; + } else { + mode = exclusive ? CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING + : CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; + } + CUDNN_CHECK(cudnnSetPoolingNdDescriptor(this->pooling_desc_, + mode, + CUDNN_NOT_PROPAGATE_NAN, + ksize.size(), + ksize.data(), + paddings.data(), + strides.data())); + float alpha = 1.0f; + float beta = 0.0f; + CUDNN_CHECK(cudnnPoolingForward(this->handle_, + this->pooling_desc_, + &alpha, + this->input_desc_, + in_data, + &beta, + this->output_desc_, + out_data)); + + return true; +} + +} // namespace math +} // namespace cuda +} // namespace lite +} // namespace paddle diff --git a/lite/backends/cuda/math/cudnn_pool.h b/lite/backends/cuda/math/cudnn_pool.h new file mode 100644 index 0000000000000000000000000000000000000000..acdc695b500ab41d615cb98c9501efd729c2fe6a --- /dev/null +++ b/lite/backends/cuda/math/cudnn_pool.h @@ -0,0 +1,79 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include "lite/api/paddle_place.h" +#include "lite/backends/cuda/cuda_utils.h" +#include "lite/core/context.h" +#include "lite/core/target_wrapper.h" +#include "lite/operators/op_params.h" + +namespace paddle { +namespace lite { +namespace cuda { +namespace math { + +template +class CudnnPool2DBase { + public: + CudnnPool2DBase() + : handle_(NULL), + input_desc_(NULL), + output_desc_(NULL), + pooling_desc_(NULL) {} + + ~CudnnPool2DBase() { + if (handle_ != NULL) { + CUDNN_CHECK(cudnnDestroy(handle_)); + } + if (input_desc_) { + CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc_)); + } + if (output_desc_) { + CUDNN_CHECK(cudnnDestroyTensorDescriptor(output_desc_)); + } + if (pooling_desc_) { + cudnnDestroyPoolingDescriptor(pooling_desc_); + } + } + + protected: + cudaStream_t stream_; + cudnnHandle_t handle_; + cudnnTensorDescriptor_t input_desc_; + cudnnTensorDescriptor_t output_desc_; + cudnnPoolingDescriptor_t pooling_desc_; +}; + +template +class CudnnPool2DNHWC : public CudnnPool2DBase { + public: + CudnnPool2DNHWC() : CudnnPool2DBase() {} + virtual ~CudnnPool2DNHWC() = default; + virtual bool init(const operators::PoolParam& param, + Context* ctx); + + virtual bool create(const operators::PoolParam& param, + Context* ctx); + + virtual bool run(const operators::PoolParam& param); +}; + +} // namespace math +} // namespace cuda +} // namespace lite +} // namespace paddle diff --git a/lite/backends/cuda/math/elementwise.cu b/lite/backends/cuda/math/elementwise.cu index 57c9ec022a6e49551fd2d56a9b2036de13bf5a2c..8f0ebd1f97a03f03b568de694b986e9540f07c55 100644 --- a/lite/backends/cuda/math/elementwise.cu +++ b/lite/backends/cuda/math/elementwise.cu @@ -13,13 +13,55 @@ // limitations under the License. #include "lite/backends/cuda/math/elementwise.h" -#include "lite/backends/cuda/math/utils.h" namespace paddle { namespace lite { namespace cuda { namespace math { +template +__global__ void elementwise_kernel(const size_t total, + const Dtype* x_data, + const Dtype* y_data, + Dtype* out_data, + int pre, + int n, + int post, + BinaryOperation type) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < total) { + int idx = tid / post % n; +#if __CUDA_ARCH__ >= 350 + out_data[tid] = binary_calc(__ldg(x_data + tid), __ldg(y_data + idx), type); +#else + out_data[tid] = binary_calc(x_data[tid], y_data[idx], type); +#endif + } +} + +template +__global__ void elementwise_relu_kernel(const size_t total, + const Dtype* x_data, + const Dtype* y_data, + Dtype* out_data, + int pre, + int n, + int post, + BinaryOperation type) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < total) { + int idx = tid / post % n; + Dtype temp; +#if __CUDA_ARCH__ >= 350 + temp = binary_calc(__ldg(x_data + tid), __ldg(y_data + idx), type); + +#else + temp = binary_calc(x_data[tid], y_data[idx], type); +#endif + out_data[tid] = temp > 0 ? temp : 0; + } +} + template __global__ void elementwise_add_kernel(const size_t total, const Dtype* x_data, @@ -76,6 +118,56 @@ __global__ void elementwise_add_nhwc4_int8_kernel(const size_t total, } } +template +void elementwise(const Dtype* x_data, + const Dtype* y_data, + Dtype* out_data, + int pre, + int n, + int post, + BinaryOperation type, + cudaStream_t stream) { + int num = pre * n * post; + int thread = 256; + int block = (num + thread - 1) / thread; + elementwise_kernel<<>>( + num, x_data, y_data, out_data, pre, n, post, type); +} + +template +void elementwise_relu(const Dtype* x_data, + const Dtype* y_data, + Dtype* out_data, + int pre, + int n, + int post, + BinaryOperation type, + cudaStream_t stream) { + int num = pre * n * post; + int thread = 256; + int block = (num + thread - 1) / thread; + elementwise_relu_kernel<<>>( + num, x_data, y_data, out_data, pre, n, post, type); +} + +template void elementwise(const float*, + const float*, + float*, + int, + int, + int, + BinaryOperation, + cudaStream_t); + +template void elementwise_relu(const float*, + const float*, + float*, + int, + int, + int, + BinaryOperation, + cudaStream_t); + template void elementwise_add(int num, const Dtype* x_data, diff --git a/lite/backends/cuda/math/elementwise.h b/lite/backends/cuda/math/elementwise.h index 7fcdf95021ff21379bf94298ed06328dd6d2db09..ce45d0544e5a55a9cdc34bdfacc2b48157f5a198 100644 --- a/lite/backends/cuda/math/elementwise.h +++ b/lite/backends/cuda/math/elementwise.h @@ -15,12 +15,33 @@ #pragma once #include #include +#include "lite/backends/cuda/math/utils.h" namespace paddle { namespace lite { namespace cuda { namespace math { +template +void elementwise(const Dtype* x_data, + const Dtype* y_data, + Dtype* out_data, + int pre, + int n, + int post, + BinaryOperation type, + cudaStream_t stream); + +template +void elementwise_relu(const Dtype* x_data, + const Dtype* y_data, + Dtype* out_data, + int pre, + int n, + int post, + BinaryOperation type, + cudaStream_t stream); + template void elementwise_add(int num, const Dtype* x_data, diff --git a/lite/backends/cuda/math/gemm.h b/lite/backends/cuda/math/gemm.h index 12194d54b08a533a3812e10b5d2f78134c19da24..85576e65018a0e1bdec6f2bd2fdc590bd35e9656 100644 --- a/lite/backends/cuda/math/gemm.h +++ b/lite/backends/cuda/math/gemm.h @@ -55,6 +55,8 @@ class Gemm { PtypeOut* c, Context* ctx); + cublasHandle_t get_handle() const { return cu_handle_; } + private: cudaStream_t exe_stream_; cublasHandle_t cu_handle_; diff --git a/lite/backends/cuda/math/transpose.cu b/lite/backends/cuda/math/transpose.cu index cebcece812dc584d0921edea2fef8f129e430b56..c50840fe269657965db8c58b171fce6819009775 100644 --- a/lite/backends/cuda/math/transpose.cu +++ b/lite/backends/cuda/math/transpose.cu @@ -69,44 +69,16 @@ void BatchTranspose2DCUDAImpl(const int N, const int W, const T* input, T* out, - CUDAContext* ctx) { + cudaStream_t* stream) { const int dh = (H + kTileDim - 1) / kTileDim; const int dw = (W + kTileDim - 1) / kTileDim; BatchTranspose2DCUDAKernel< - T><<exec_stream()>>>( + T><<>>( N, H, W, dh, dw, input, out); cudaError_t error = cudaGetLastError(); if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); } -#define TYPE_SPECIALIZED_CUDA_NCHW2NHWC(T) \ - template <> \ - void NCHW2NHWC(const int N, \ - const int C, \ - const int HxW, \ - const T* X, \ - T* Y, \ - CUDAContext* ctx) { \ - BatchTranspose2DCUDAImpl(N, C, HxW, X, Y, ctx); \ - } -TYPE_SPECIALIZED_CUDA_NCHW2NHWC(float) -TYPE_SPECIALIZED_CUDA_NCHW2NHWC(int8_t) -#undef TYPE_SPECIALIZED_CUDA_NCHW2NHWC - -#define TYPE_SPECIALIZED_CUDA_NHWC2NCHW(T) \ - template <> \ - void NHWC2NCHW(const int N, \ - const int C, \ - const int HxW, \ - const T* X, \ - T* Y, \ - CUDAContext* ctx) { \ - BatchTranspose2DCUDAImpl(N, HxW, C, X, Y, ctx); \ - } -TYPE_SPECIALIZED_CUDA_NHWC2NCHW(float) -TYPE_SPECIALIZED_CUDA_NHWC2NCHW(int8_t) -#undef TYPE_SPECIALIZED_CUDA_NHWC2NCHW - template __global__ void TransposeCUDAKernel(const int size, const int ndim, @@ -136,7 +108,9 @@ void TransposeCUDAImpl(const std::vector& X_dims, const std::vector& axes, const T* X, T* Y, - CUDAContext* ctx) { + lite::Tensor* Y_dims_, + lite::Tensor* strides_, + cudaStream_t* stream) { CHECK_EQ(X_dims.size(), axes.size()) << "dimension size should be equal"; int ndim = X_dims.size(); std::vector strides(ndim, 0); @@ -156,37 +130,68 @@ void TransposeCUDAImpl(const std::vector& X_dims, size *= X_dims[i]; } - lite::Tensor Y_dims_, strides_; - Y_dims_.Resize(std::vector({ndim})); - int* d_y_dims = Y_dims_.mutable_data(TARGET(kCUDA)); - CopySync( - d_y_dims, Y_dims.data(), sizeof(int) * Y_dims.size(), IoDirection::HtoD); + Y_dims_->Resize(std::vector({ndim})); + int* d_y_dims = Y_dims_->mutable_data(TARGET(kCUDA)); + TargetWrapperCuda::MemcpyAsync(d_y_dims, + Y_dims.data(), + sizeof(int) * Y_dims.size(), + IoDirection::HtoD, + *stream); - strides_.Resize(std::vector({ndim})); - int* d_strides = strides_.mutable_data(TARGET(kCUDA)); - CopySync(d_strides, - strides.data(), - sizeof(int) * strides.size(), - IoDirection::HtoD); + strides_->Resize(std::vector({ndim})); + int* d_strides = strides_->mutable_data(TARGET(kCUDA)); + TargetWrapperCuda::MemcpyAsync(d_strides, + strides.data(), + sizeof(int) * strides.size(), + IoDirection::HtoD, + *stream); const int M = (size + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; - TransposeCUDAKernel<<exec_stream()>>>( + TransposeCUDAKernel<<>>( size, ndim, d_strides, d_y_dims, X, Y); auto e = cudaGetLastError(); CHECK_EQ(e, cudaSuccess) << " CUDA: " << cudaGetErrorString(e); } -#define TYPE_SPECIALIZED_CUDA_TRANSPOSE(T) \ - template <> \ - void Transpose(const std::vector& X_dims, \ - const std::vector& axes, \ - const T* X, \ - T* Y, \ - CUDAContext* ctx) { \ - TransposeCUDAImpl(X_dims, axes, X, Y, ctx); \ - } -TYPE_SPECIALIZED_CUDA_TRANSPOSE(float) -#undef TYPE_SPECIALIZED_CUDA_TRANSPOSEF +template +void Transpose::NCHW2NHWC( + int N, int C, int HxW, const T* X, T* Y, cudaStream_t* stream) { + BatchTranspose2DCUDAImpl(N, C, HxW, X, Y, stream); +} + +template +void Transpose::NHWC2NCHW( + int N, int C, int HxW, const T* X, T* Y, cudaStream_t* stream) { + BatchTranspose2DCUDAImpl(N, HxW, C, X, Y, stream); +} + +template +void Transpose::transpose(T* dst, + const T* src, + const std::vector& src_dims, + const std::vector& axes, + cudaStream_t* stream) { + TransposeCUDAImpl(src_dims, axes, src, dst, &Y_dims_, &strides_, stream); +} + +// template +// void Transpose::transpose(T* dst, +// const T* src, +// const std::vector& src_dims, +// const std::vector& axes, +// cudaStream_t* stream) { +// std::vector _src_dims(src_dims.size(), 0); +// std::transform( +// src_dims.begin(), +// src_dims.end(), +// _src_dims.begin(), +// [](int data) -> int64_t { return static_cast(data); }); +// TransposeCUDAImpl(_src_dims, axes, src, dst, &Y_dims_, &strides_, +// stream); +//} + +template class Transpose; +template class Transpose; } // namespace math } // namespace cuda diff --git a/lite/backends/cuda/math/transpose.h b/lite/backends/cuda/math/transpose.h index ba2464547b587f44cd9b0ce287a0d40d37d46411..ed52ba3b5590ab631c3c57a0472e16cb0ed51a91 100644 --- a/lite/backends/cuda/math/transpose.h +++ b/lite/backends/cuda/math/transpose.h @@ -26,17 +26,27 @@ namespace cuda { namespace math { template -void NCHW2NHWC(int N, int C, int HxW, const T* X, T* Y, CUDAContext* context); +class Transpose { + public: + void NCHW2NHWC(int N, int C, int HxW, const T* X, T* Y, cudaStream_t* stream); -template -void NHWC2NCHW(int N, int C, int HxW, const T* X, T* Y, CUDAContext* context); + void NHWC2NCHW(int N, int C, int HxW, const T* X, T* Y, cudaStream_t* stream); -template -void Transpose(const std::vector& X_dims, - const std::vector& axes, - const T* X, - T* Y, - CUDAContext* ctx); + void transpose(T* dst, + const T* src, + const std::vector& src_dims, + const std::vector& axes, + cudaStream_t* stream); + + // void transpose(T* dst, + // const T* src, + // const std::vector& src_dims, + // const std::vector& axes, + // cudaStream_t* stream); + + private: + lite::Tensor Y_dims_, strides_; // for transpose. +}; } // namespace math } // namespace cuda diff --git a/lite/backends/cuda/math/utils.h b/lite/backends/cuda/math/utils.h index b4cd82fd8df6df063d92df709311f3c90e7cf4b6..b6aa9c7d160ad6c8b60b132e4a2bbd7ae1e0b9ff 100644 --- a/lite/backends/cuda/math/utils.h +++ b/lite/backends/cuda/math/utils.h @@ -25,6 +25,24 @@ namespace lite { namespace cuda { namespace math { +enum class BinaryOperation { + kADD = 0, + kMUL = 1, + kDIV = 2, +}; + +template +__device__ T binary_calc(T x, T y, BinaryOperation type); + +template <> +__device__ __forceinline__ float binary_calc(float x, + float y, + BinaryOperation type) { + if (type == BinaryOperation::kADD) return x + y; + if (type == BinaryOperation::kMUL) return x * y; + if (type == BinaryOperation::kDIV) return x / y; +} + template __device__ T from_float(float x); diff --git a/lite/backends/fpga/CMakeLists.txt b/lite/backends/fpga/CMakeLists.txt index b12fd85caf7e0c79de830b45569e02ba916c34e6..a5207c01a4d5e7b8d05490bd7c9be0dcc01f365e 100644 --- a/lite/backends/fpga/CMakeLists.txt +++ b/lite/backends/fpga/CMakeLists.txt @@ -3,13 +3,35 @@ if (NOT LITE_WITH_FPGA) endif() set(LITE_FPGA_KD_PATH "${PADDLE_SOURCE_DIR}/lite/backends/fpga/KD") +set(LITE_FPGA_KD_LLAPI_PATH "${PADDLE_SOURCE_DIR}/lite/backends/fpga/KD/llapi") +set(LITE_FPGA_KD_PE_PATH "${PADDLE_SOURCE_DIR}/lite/backends/fpga/KD/pes") set(LITE_FPGA_PATH "${PADDLE_SOURCE_DIR}/lite/backends/fpga") message("fpga_kd_path ${LITE_FPGA_KD_PATH}") message("fpga_path ${LITE_FPGA_PATH}") -file(GLOB_RECURSE KD_CPP *.cpp *.cc) +file(GLOB KD_CPP "${LITE_FPGA_KD_PATH}/*.cpp") +file(GLOB PE_CPP "${LITE_FPGA_KD_PE_PATH}/*.cpp") +file(GLOB LLAPI_CPP "${LITE_FPGA_KD_LLAPI_PATH}/*.cpp") file(GLOB FPGA_CPP "${LITE_FPGA_PATH}/*.cc") - -cc_library(kernel_fpga SRCS ${KD_CPP} ${FPGA_CPP}) +set(FPGA_ALL_CPP "") +FOREACH(FILE_PATH ${KD_CPP}) + STRING(REGEX REPLACE ".+/(.+\\..*)" "\\1" FILE_NAME ${FILE_PATH}) + list(APPEND FPGA_ALL_CPP KD/${FILE_NAME}) +ENDFOREACH(FILE_PATH) +FOREACH(FILE_PATH ${PE_CPP}) + STRING(REGEX REPLACE ".+/(.+\\..*)" "\\1" FILE_NAME ${FILE_PATH}) + list(APPEND FPGA_ALL_CPP KD/pes/${FILE_NAME}) +ENDFOREACH(FILE_PATH) +FOREACH(FILE_PATH ${LLAPI_CPP}) + STRING(REGEX REPLACE ".+/(.+\\..*)" "\\1" FILE_NAME ${FILE_PATH}) + list(APPEND FPGA_ALL_CPP KD/llapi/${FILE_NAME}) +ENDFOREACH(FILE_PATH) +FOREACH(FILE_PATH ${FPGA_CPP}) + STRING(REGEX REPLACE ".+/(.+\\..*)" "\\1" FILE_NAME ${FILE_PATH}) + list( APPEND FPGA_ALL_CPP ${FILE_NAME}) +ENDFOREACH(FILE_PATH) +message("fpga kd: ${FPGA_ALL_CPP}") +cc_library(kernel_fpga SRCS ${FPGA_ALL_CPP}) +#cc_library(kernel_fpga SRCS ${KD_CPP} ${FPGA_CPP}) cc_library(lite_tensor_fpga SRCS lite_tensor.cc DEPS memory) -cc_library(fpga_target_wrapper SRCS ${LITE_FPGA_PATH}/target_wrapper.cc DEPS kernel_fpga) +cc_library(fpga_target_wrapper SRCS target_wrapper.cc DEPS kernel_fpga) diff --git a/lite/backends/fpga/KD/debugger.hpp b/lite/backends/fpga/KD/debugger.hpp new file mode 100755 index 0000000000000000000000000000000000000000..9b1189c407d6d601bb3e5ba8172b1455f04710fd --- /dev/null +++ b/lite/backends/fpga/KD/debugger.hpp @@ -0,0 +1,152 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "lite/core/tensor.h" + +namespace paddle { +namespace lite { + +#define FPGA_PRINT_TENSOR + +class Debugger { + public: + static Debugger& get_instance() { + static Debugger s_instance; + return s_instance; + } + + void registerOutput(std::string op_type, zynqmp::Tensor* tensor) { + if (op_config[op_type]) { + tensor->saveToFile(op_type, true); + } + } + + private: + std::unordered_map op_config; + Debugger() { + op_config["concat"] = true; + op_config["pooling"] = true; + op_config["conv"] = true; + op_config["dwconv"] = true; + op_config["ew_add"] = true; + op_config["crop"] = true; + op_config["feed"] = true; + op_config["mul"] = true; + op_config["fetch"] = true; + op_config["boxes"] = true; + op_config["scores"] = true; + op_config["nms"] = true; + op_config["pb_boxes"] = true; + op_config["pb_variances"] = true; + // op_config["fc"] = true; + op_config["softmax"] = true; + } +}; + +inline void chw_to_hwc(Tensor* t, float* dst) { + int num = t->dims()[0]; + int channel = t->dims()[1]; + + int height = 1; + int width = 1; + if (t->dims().size() > 2) { + height = t->dims()[2]; + } + if (t->dims().size() > 3) { + width = t->dims()[3]; + } + const float* chw_data = t->data(); + float* hwc_data = dst; + + int chw = channel * height * width; + int wc = width * channel; + int index = 0; + for (int n = 0; n < num; n++) { + for (int c = 0; c < channel; c++) { + for (int h = 0; h < height; h++) { + for (int w = 0; w < width; w++) { + hwc_data[n * chw + h * wc + w * channel + c] = chw_data[index]; + index++; + } + } + } + } +} + +inline void read_from_file(lite::Tensor* t, const std::string& path) { + std::ifstream file_stream; + file_stream.open(path); + if (!file_stream) { + return; + } + float* data = t->mutable_data(); + int num = t->numel(); + for (int i = 0; i < num; ++i) { + float value = 0; + file_stream >> value; + data[i] = value; + } +} + +inline void save_float(float* data, const std::string& name, int len) { + static int counter = 0; + std::string old_string = std::to_string(counter); + std::string new_string = + std::string(3 - old_string.length(), '0') + old_string; + + std::string file = "arm_" + new_string + name; + counter++; + + std::ofstream ofs; + ofs.open(file); + for (int i = 0; i < len; i++) { + float value = data[i]; + ofs << value << std::endl; + } + ofs.close(); +} + +inline void save_tensor(lite::Tensor* t, + const std::string& name, + bool convert = true) { + float* data = const_cast(t->data()); + float* dst = new float[t->numel()]; + if (convert) { + chw_to_hwc(t, dst); + data = dst; + } + + save_float(data, name, t->numel()); + delete[] dst; +} + +inline void save_tensor(const lite::Tensor* t, + const std::string& name, + bool convert = true) { + float* data = const_cast(t->data()); + float* dst = new float[t->numel()]; + if (convert) { + chw_to_hwc(const_cast(t), dst); + data = dst; + } + save_float(data, name, t->numel()); + delete[] dst; +} +} // namespace lite +} // namespace paddle diff --git a/lite/backends/fpga/KD/dl_engine.cpp b/lite/backends/fpga/KD/dl_engine.cpp old mode 100644 new mode 100755 index 9849e4275b5d0f59346b9684530610853f1a560c..ea503518a0f39671e77157f14788a1cadb4579f3 --- a/lite/backends/fpga/KD/dl_engine.cpp +++ b/lite/backends/fpga/KD/dl_engine.cpp @@ -13,14 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "lite/backends/fpga/KD/dl_engine.hpp" + namespace paddle { namespace zynqmp { DLEngine::DLEngine() { open_device(); - struct DeviceInfo info; - int ret = get_device_info(info); - filter::set_filter_capacity(info.filter_cap); + int ret = get_device_info(info_); + filter::set_filter_capacity(info_.filter_cap); + filter::set_colunm(info_.colunm); } } // namespace zynqmp diff --git a/lite/backends/fpga/KD/dl_engine.hpp b/lite/backends/fpga/KD/dl_engine.hpp old mode 100644 new mode 100755 index 829f41dfebfabfe5642bd4cf107fc6c54f3ffd86..eddf5ca454cdc9e91f87d6e4f2c8dfc13f35fdc6 --- a/lite/backends/fpga/KD/dl_engine.hpp +++ b/lite/backends/fpga/KD/dl_engine.hpp @@ -15,7 +15,6 @@ limitations under the License. */ #pragma once #include - #include "lite/backends/fpga/KD/llapi/filter.h" #include "lite/backends/fpga/KD/llapi/zynqmp_api.h" @@ -29,8 +28,15 @@ class DLEngine { return s_instance; } + DeviceInfo& deviceInfo(); + + bool isZU3() { return info_.device_type / 100 == 3; } + + float* out_data = nullptr; + private: DLEngine(); + DeviceInfo info_; }; } // namespace zynqmp } // namespace paddle diff --git a/lite/backends/fpga/KD/layout.hpp b/lite/backends/fpga/KD/layout.hpp index 74819cd2120630def0114422b04efe076e1d6cb2..c6b5c911872b6b22633a4319ea708ed23c7e7e36 100644 --- a/lite/backends/fpga/KD/layout.hpp +++ b/lite/backends/fpga/KD/layout.hpp @@ -22,6 +22,7 @@ namespace paddle { namespace zynqmp { enum LayoutType { + None, N, NC, NCHW, @@ -39,6 +40,15 @@ class Layout { virtual int elementCount(const std::vector& dims) = 0; }; +struct None : Layout { + int numIndex() { return -1; } + int channelIndex() { return -1; } + int heightIndex() { return -1; } + int widthIndex() { return -1; } + int alignedElementCount(const std::vector& dims) { return 16; } + virtual int elementCount(const std::vector& dims) { return 1; } +}; + struct NCHW : Layout { int numIndex() { return 0; } int channelIndex() { return 1; } diff --git a/lite/backends/fpga/KD/llapi/bias_scale.cpp b/lite/backends/fpga/KD/llapi/bias_scale.cpp index cd60f27f9896e857f8ad566d285a9b9aea1d4721..339a442207e811be31161ff25f60a080572efe8d 100644 --- a/lite/backends/fpga/KD/llapi/bias_scale.cpp +++ b/lite/backends/fpga/KD/llapi/bias_scale.cpp @@ -14,6 +14,7 @@ limitations under the License. */ #include +#include "lite/backends/fpga/KD/float16.hpp" #include "lite/backends/fpga/KD/llapi/bias_scale.h" #include "lite/backends/fpga/KD/llapi/zynqmp_api.h" @@ -54,7 +55,7 @@ void align_element(float **data_in, int num_per_div_before_alignment, int num) { *data_in = ptr_aligned; } -void interleave(float **data_in, int num_after_alignment) { +size_t interleave(float **data_in, int num_after_alignment) { float *ptr_uninterleaved = *data_in; float *ptr_interleaved = (float *)fpga_malloc(2 * num_after_alignment * sizeof(float)); // NOLINT @@ -69,6 +70,7 @@ void interleave(float **data_in, int num_after_alignment) { fpga_free(ptr_uninterleaved); *data_in = ptr_interleaved; + return 2 * num_after_alignment * sizeof(float); } void format_bias_scale_array(float **bias_scale_array, @@ -78,8 +80,9 @@ void format_bias_scale_array(float **bias_scale_array, int div_num = (num + element_num_per_division - 1) / element_num_per_division; int element_num_after_division = align_to_x(element_num_per_division, BS_NUM_ALIGNMENT); - interleave(bias_scale_array, div_num * element_num_after_division); - fpga_flush(*bias_scale_array, 2 * element_num_after_division * sizeof(float)); + size_t mem = + interleave(bias_scale_array, div_num * element_num_after_division); + fpga_flush(*bias_scale_array, mem); } void format_bias_array(float **bias_array, int num) { float *ptr_unaligned = *bias_array; diff --git a/lite/backends/fpga/KD/llapi/bias_scale.h b/lite/backends/fpga/KD/llapi/bias_scale.h index 83f30df18fc7e5967d727ed8ce275d63e1cb29e0..d47d082ccdc6b41cf43860495e43076c17b13ac3 100644 --- a/lite/backends/fpga/KD/llapi/bias_scale.h +++ b/lite/backends/fpga/KD/llapi/bias_scale.h @@ -19,7 +19,7 @@ namespace zynqmp { namespace bias_scale { void align_element(float** data_in, int num_per_div_before_alignment, int num); -void interleave(float** data_in, int num_after_alignment); +size_t interleave(float** data_in, int num_after_alignment); void format_bias_scale_array(float** bias_scale_array, int element_num_per_division, int num); diff --git a/lite/backends/fpga/KD/llapi/filter.cpp b/lite/backends/fpga/KD/llapi/filter.cpp old mode 100644 new mode 100755 index 0e41a204a854b0b57e1a8c98fb3cc8d5224c807c..da81565cf5ca152a54b6cc1514cb660589428439 --- a/lite/backends/fpga/KD/llapi/filter.cpp +++ b/lite/backends/fpga/KD/llapi/filter.cpp @@ -15,6 +15,8 @@ limitations under the License. */ #include "lite/backends/fpga/KD/llapi/filter.h" #include #include +#include +#include #include "lite/backends/fpga/KD/float16.hpp" #include "lite/backends/fpga/KD/llapi/zynqmp_api.h" @@ -23,11 +25,41 @@ namespace zynqmp { namespace filter { static int FILTER_SIZE = 2048; +static int COLUMN = 4; + +void saveToFile(std::string name, void* data_in, int size) { + std::ofstream ofs; + ofs.open(name); + + int8_t* data = static_cast(data_in); + for (int i = 0; i < size; i++) { + float value = data[i]; + ofs << value << std::endl; + } + ofs.close(); +} + +void saveFloatToFile(std::string name, float* data_in, int size) { + std::ofstream ofs; + ofs.open(name); + + for (int i = 0; i < size; i++) { + float value = data_in[i]; + ofs << value << std::endl; + } + ofs.close(); +} void set_filter_capacity(uint32_t cap) { FILTER_SIZE = cap; } +void set_colunm(uint32_t column) { COLUMN = column; } + +// replace zynqmp_api.h #define FILTER_NUM_ALIGNMENT +int get_filter_num_alignment() { return COLUMN * 4; } + int calc_division_capacity(int chw) { - int n = FILTER_SIZE / ((chw + 15) / 16) * 32; + int filter_num_alignment = get_filter_num_alignment(); + int n = FILTER_SIZE / ((chw + 15) / 16) * filter_num_alignment; return n < FILTER_SIZE ? n : FILTER_SIZE; } @@ -52,28 +84,36 @@ int calc_num_per_div(int num, int group_num, int division_capacity) { } } -void convert_to_hwc( - char **data_in, int num, int channel, int height, int width) { - char *tmp = *data_in; +int calc_pack_num(int num_per_group, int group, int division_capacity) { + auto n = 1; + while ((num_per_group * (group + n - 1) / n) > division_capacity) { + n++; + } + return (n); +} + +void convert_to_hwc(int8_t* chw_data, + int8_t* hwc_data, + int num, + int channel, + int height, + int width) { int chw = channel * height * width; - char *data_tmp = (char *)fpga_malloc(chw * num * sizeof(char)); // NOLINT + int wc = width * channel; + int index = 0; for (int n = 0; n < num; n++) { - int64_t amount_per_row = width * channel; for (int c = 0; c < channel; c++) { for (int h = 0; h < height; h++) { - int64_t offset_height = h * amount_per_row; for (int w = 0; w < width; w++) { - *(data_tmp + n * chw + offset_height + w * channel + c) = - *((*data_in)++); + hwc_data[n * chw + h * wc + w * channel + c] = chw_data[index]; + index++; } } } } - *data_in = data_tmp; - fpga_free(tmp); } -float find_max(float *data_in, int data_size) { +float find_max(float* data_in, int data_size) { float max = 0.0; for (int i = 0; i < data_size; ++i) { float value = data_in[i]; @@ -83,166 +123,178 @@ float find_max(float *data_in, int data_size) { return max; } -signed char float_to_int8(float fdata) { +int8_t float_to_int8(float fdata) { if (fdata < 0.0) { fdata -= 0.5; } else { fdata += 0.5; } - return (signed char)fdata; + return (int8_t)fdata; } -void quantize(float **data_in, int data_size, float max) { - float *tmp = *data_in; +void quantize(float* src, int8_t* dst, int len, float max) { float fix_range = 127; float scale = fix_range / max; - - signed char *tmp_data = (signed char *)fpga_malloc(data_size * sizeof(char)); - for (int i = 0; i < data_size; i++) { - tmp_data[i] = float_to_int8( - (*data_in)[i] * scale); // (signed char)((*data_in)[i] * scale); + for (size_t i = 0; i < len; i++) { + dst[i] = float_to_int8(src[i] * scale); } - *data_in = (float *)tmp_data; // NOLINT - fpga_free(tmp); } -void align_element(char **data_in, int num, int chw) { - int j = 0; +bool should_align_chw(int chw) { int align_chw = align_to_x(chw, FILTER_ELEMENT_ALIGNMENT); - if (align_chw != chw) { - char *tmp = *data_in; - char *data_tmp = - (char *)fpga_malloc(num * align_chw * sizeof(char)); // NOLINT - - memset(data_tmp, 0, num * align_chw); - for (j = 0; j < num; j++) { - memcpy(data_tmp + j * align_chw, (*data_in) + j * chw, chw); - } - *data_in = data_tmp; - fpga_free(tmp); + return align_chw != chw; +} + +void align_chw(int8_t* src, int8_t* dst, int num, int chw) { + int aligned_chw = align_to_x(chw, FILTER_ELEMENT_ALIGNMENT); + memset(dst, 0, num * aligned_chw); + for (int j = 0; j < num; j++) { + memcpy((dst + j * aligned_chw), (src + j * chw), chw); } } -void align_num(char **data_in, +void align_num(int8_t* src, + int8_t* dst, int num_per_div_before_alignment, int num, - int chw) { - int i = 0; - int align_chw = align_to_x(chw, FILTER_ELEMENT_ALIGNMENT); + int align_chw) { + int filter_num_alignment = get_filter_num_alignment(); int num_per_div_after_alignment = - align_to_x(num_per_div_before_alignment, FILTER_NUM_ALIGNMENT); + align_to_x(num_per_div_before_alignment, filter_num_alignment); - char *tmp = *data_in; int div_num = (num + num_per_div_before_alignment - 1) / num_per_div_before_alignment; int num_element = div_num * num_per_div_after_alignment * align_chw; - char *data_tmp = (char *)fpga_malloc(num_element * sizeof(char)); // NOLINT - - memset(data_tmp, 0, num_element * sizeof(char)); + memset(dst, 0, num_element * sizeof(int8_t)); + int i = 0; for (i = 0; i < div_num - 1; i++) { - memcpy(data_tmp + num_per_div_after_alignment * align_chw * i, - *data_in + num_per_div_before_alignment * align_chw * i, + memcpy(dst + num_per_div_after_alignment * align_chw * i, + src + num_per_div_before_alignment * align_chw * i, num_per_div_before_alignment * align_chw); } - memcpy(data_tmp + num_per_div_after_alignment * align_chw * i, - *data_in + num_per_div_before_alignment * align_chw * i, + memcpy(dst + num_per_div_after_alignment * align_chw * i, + src + num_per_div_before_alignment * align_chw * i, (num - (div_num - 1) * num_per_div_before_alignment) * align_chw); - - *data_in = data_tmp; - fpga_free(tmp); } -void reorder(char **data_in, int num_after_alignment, int chw) { +void reorder(int8_t* src, int8_t* dst, int num_after_alignment, int chw) { int index = 0; int new_index = 0; - + int filter_num_alignment = get_filter_num_alignment(); int chw_align = align_to_x(chw, FILTER_ELEMENT_ALIGNMENT); - - char *data_tmp = - (char *)fpga_malloc(chw_align * num_after_alignment * // NOLINT - sizeof(char)); - char *tmp = *data_in; for (index = 0; index < num_after_alignment; index++) { - new_index = index / 32 * 32 + (index % 16 / 4 * 8) + (index % 16 % 4) + - (index / 16 % 2 * 4); - memcpy(data_tmp + index * chw_align, - *data_in + new_index * chw_align, - chw_align); + new_index = index / filter_num_alignment * filter_num_alignment + + (index % (filter_num_alignment / 2) / 4 * 8) + + (index % (filter_num_alignment / 2) % 4) + + (index / (filter_num_alignment / 2) % 2 * 4); + memcpy((dst + index * chw_align), (src + new_index * chw_align), chw_align); } - *data_in = data_tmp; - fpga_free(tmp); } -size_t interleave(char **data_in, int num_after_alignment, int chw) { - int i = 0; - int j = 0; - int k = 0; +void interleave(int8_t* src, int8_t* dst, int num_after_alignment, int chw) { int interleave_per_num = 16; - int chw_align = align_to_x(chw, FILTER_ELEMENT_ALIGNMENT); - char *data_tmp = - (char *)fpga_malloc(chw_align * num_after_alignment * // NOLINT - sizeof(char)); - char *tmp = *data_in; int interleave_num = chw_align * 2 / interleave_per_num; - for (i = 0; i < num_after_alignment; i += 2) { - for (j = 0, k = 0; j < interleave_num; j += 2, k++) { - memcpy(data_tmp + i * chw_align + interleave_per_num * j, - *data_in + i * chw_align + interleave_per_num * k, + for (int i = 0; i < num_after_alignment; i += 2) { + for (int j = 0, k = 0; j < interleave_num; j += 2, k++) { + memcpy(dst + i * chw_align + interleave_per_num * j, + src + i * chw_align + interleave_per_num * k, interleave_per_num); - memcpy(data_tmp + i * chw_align + interleave_per_num * (j + 1), - *data_in + (i + 1) * chw_align + interleave_per_num * k, + memcpy(dst + i * chw_align + interleave_per_num * (j + 1), + src + (i + 1) * chw_align + interleave_per_num * k, interleave_per_num); } } - *data_in = data_tmp; - fpga_free(tmp); - return chw_align * num_after_alignment; } -size_t format_filter(float **data_in, - int num, - int channel, - int height, - int width, - int group_num, - float max) { +int8_t* format_filter(float* data_in, + int& mem_size_a, // NOLINT + int num, + int channel, + int height, + int width, + int group_num, + float max, + std::vector& filter_max) { // NOLINT int data_size = channel * height * width * num; int chw = channel * height * width; int division_capacity = calc_division_capacity(chw); + int filter_num_alignment = get_filter_num_alignment(); int num_per_div_before_alignment = calc_num_per_div(num, group_num, division_capacity); int num_per_div_after_alignment = - align_to_x(num_per_div_before_alignment, FILTER_NUM_ALIGNMENT); + align_to_x(num_per_div_before_alignment, filter_num_alignment); int div_num = (num + num_per_div_before_alignment - 1) / num_per_div_before_alignment; int residual = num % num_per_div_before_alignment; int num_after_alignment = num_per_div_after_alignment * ((residual == 0) ? div_num : (div_num - 1)) + - align_to_x(residual, FILTER_NUM_ALIGNMENT); - quantize(data_in, data_size, max); - char **quantize_data = (char **)data_in; // NOLINT - convert_to_hwc(quantize_data, num, channel, height, width); - align_element(quantize_data, num, chw); - if (num_after_alignment != num) { - align_num(quantize_data, num_per_div_before_alignment, num, chw); + align_to_x(residual, filter_num_alignment); + + int8_t* quantized_data = + reinterpret_cast(fpga_malloc(data_size * sizeof(int8_t))); + + for (int n = 0; n < num; n++) { + float* filter_start = data_in + n * chw; + int8_t* quantized_start = quantized_data + n * chw; + quantize(filter_start, quantized_start, chw, max); + filter_max.push_back(1); } - reorder(quantize_data, num_after_alignment, chw); - size_t mem_size = interleave(quantize_data, num_after_alignment, chw); - fpga_flush(*quantize_data, + int8_t* hwc_data = + reinterpret_cast(fpga_malloc(data_size * sizeof(int8_t))); + convert_to_hwc(quantized_data, hwc_data, num, channel, height, width); + fpga_free(quantized_data); + + int8_t* temp_data = hwc_data; // NOLINT + int chw_aligned = align_to_x(chw, FILTER_ELEMENT_ALIGNMENT); + if (should_align_chw(chw)) { + int8_t* hwc_aligned_data = reinterpret_cast( + fpga_malloc(num * chw_aligned * sizeof(int8_t))); + align_chw(hwc_data, hwc_aligned_data, num, chw); + + temp_data = hwc_aligned_data; + fpga_free(hwc_data); + } + if (num_after_alignment != num) { + int filter_num_alignment = get_filter_num_alignment(); + int num_per_div_after_alignment = + align_to_x(num_per_div_before_alignment, filter_num_alignment); + + int num_element = div_num * num_per_div_after_alignment * chw_aligned; + int8_t* num_aligned_data = + reinterpret_cast(fpga_malloc(num_element * sizeof(int8_t))); + align_num(temp_data, + num_aligned_data, + num_per_div_before_alignment, + num, + chw_aligned); + + fpga_free(temp_data); + temp_data = num_aligned_data; + } + int8_t* aligned_data = + reinterpret_cast(fpga_malloc(num_after_alignment * chw_aligned)); + reorder(temp_data, aligned_data, num_after_alignment, chw); + fpga_free(temp_data); + int8_t* interleaved_data = + reinterpret_cast(fpga_malloc(num_after_alignment * chw_aligned)); + interleave(aligned_data, interleaved_data, num_after_alignment, chw); + fpga_free(aligned_data); + fpga_flush(interleaved_data, align_to_x(chw, FILTER_ELEMENT_ALIGNMENT) * num_after_alignment * sizeof(char)); - return mem_size; + mem_size_a = num_after_alignment * chw_aligned; + return interleaved_data; } -void convert_to_hwn(int16_t **data_in, int num, int height, int width) { - int16_t *tmp = *data_in; - int16_t *data_tmp = - (int16_t *)fpga_malloc(height * width * num * sizeof(int16_t)); // NOLINT +void convert_to_hwn(int16_t** data_in, int num, int height, int width) { + int16_t* tmp = *data_in; + int16_t* data_tmp = + (int16_t*)fpga_malloc(height * width * num * sizeof(int16_t)); // NOLINT for (int n = 0; n < num; n++) { for (int h = 0; h < height; h++) { for (int w = 0; w < width; w++) { @@ -254,16 +306,16 @@ void convert_to_hwn(int16_t **data_in, int num, int height, int width) { fpga_free(tmp); } -size_t align_element_n(int16_t **data_in, int num, int height, int width) { +size_t align_element_n(int16_t** data_in, int num, int height, int width) { int unalign_n = num; int align_n = align_to_x(num, FILTER_ELEMENT_ALIGNMENT); int num_element = height * width * align_n; if (unalign_n != align_n) { - int16_t *tmp = *data_in; + int16_t* tmp = *data_in; int num_element = height * width * align_n; - int16_t *data_tmp = - (int16_t *)fpga_malloc(num_element * sizeof(int16_t)); // NOLINT + int16_t* data_tmp = + (int16_t*)fpga_malloc(num_element * sizeof(int16_t)); // NOLINT memset(data_tmp, 0, num_element * sizeof(int16_t)); for (int h = 0; h < height; h++) { @@ -276,17 +328,37 @@ size_t align_element_n(int16_t **data_in, int num, int height, int width) { } } *data_in = data_tmp; - free(tmp); + fpga_free(tmp); } return num_element * sizeof(int16_t); } +void to_fp16(float* src, + float16* dst, + int num, + int height, + int width, + float* scale_ptr) { + int size = num * height * width; + for (int n = 0; n < num; n++) { + float scale_val = scale_ptr[n]; + for (int h = 0; h < height; h++) { + for (int w = 0; w < width; w++) { + int index = n * height * width + h * width + w; + float value = src[index] * scale_val; + dst[index] = float_to_half(value); + } + } + } + fpga_flush(dst, size * sizeof(int16_t)); +} + void quantize_to_fp16( - float **data_in, int num, int height, int width, float *scale_ptr) { - float *tmp = *data_in; + float** data_in, int num, int height, int width, float* scale_ptr) { + float* tmp = *data_in; int size = num * height * width; - float16 *tmp_data = (float16 *)fpga_malloc(size * sizeof(float16)); // NOLINT + float16* tmp_data = (float16*)fpga_malloc(size * sizeof(float16)); // NOLINT for (int n = 0; n < num; n++) { float scale_val = scale_ptr[n]; for (int h = 0; h < height; h++) { @@ -298,13 +370,14 @@ void quantize_to_fp16( } } fpga_flush(tmp_data, size * sizeof(int16_t)); - *data_in = (float *)tmp_data; // NOLINT + *data_in = (float*)tmp_data; // NOLINT fpga_free(tmp); } size_t format_dwconv_filter( - float **data_in, int num, int height, int width, float *scale_ptr) { + float** data_in, int num, int height, int width, float* scale_ptr) { quantize_to_fp16(data_in, num, height, width, scale_ptr); - int16_t **quantize_data = (int16_t **)data_in; // NOLINT + int16_t** quantize_data = reinterpret_cast(data_in); + convert_to_hwn(quantize_data, num, height, width); size_t size = align_element_n(quantize_data, num, height, width); fpga_flush(*quantize_data, diff --git a/lite/backends/fpga/KD/llapi/filter.h b/lite/backends/fpga/KD/llapi/filter.h index 7d9c6c2e015250cbcba2d1dba71b7c1f3554d9f0..42d98e74923e116240b145c87b3dc5cfa0210f8d 100644 --- a/lite/backends/fpga/KD/llapi/filter.h +++ b/lite/backends/fpga/KD/llapi/filter.h @@ -18,38 +18,36 @@ limitations under the License. */ #include #include +#include + namespace paddle { namespace zynqmp { namespace filter { void set_filter_capacity(uint32_t cap); +void set_colunm(uint32_t column); +int get_filter_num_alignment(); int calc_division_capacity(int chw); int calc_split_num(int num, int division_capacity); int calc_division_number(int num, int group_num, int division_capacity); int calc_num_per_div(int num, int group_num, int division_capacity); -void convert_to_hwc( - char** data_in, int num, int channel, int height, int width); +int calc_pack_num(int num_per_group, int group, int division_capacity); + float find_max(float* data_in, int data_size); -void quantize(float** data_in, int data_size, float max); -void align_element(char** data_in, int num, int chw); -void align_num(char** data_in, - int num_per_div_before_alignment, - int num, - int chw); -void reorder(char** data_in, int num_after_alignment, int chw); -size_t interleave(char** data_in, int num_after_alignment, int chw); -size_t format_filter(float** data_in, - int num, - int channel, - int height, - int width, - int group_num, - float max); +int8_t* format_filter(float* data_in, + int& mem_size, // NOLINT + int num, + int channel, + int height, + int width, + int group_num, + float max, + std::vector& filter_max); // NOLINT void convert_to_hwn(int16_t** data_in, int num, int height, int width); size_t align_element_n(int16_t** data_in, int num, int height, int width); -void quantize_to_fp16( - float** data_in, int num, int height, int width, float* scale_ptr); +// void quantize_to_fp16(float** data_in, int num, int height, int width, +// float* scale_ptr); size_t format_dwconv_filter( float** data_in, int num, int height, int width, float* scale_ptr); diff --git a/lite/backends/fpga/KD/llapi/zynqmp_api.cpp b/lite/backends/fpga/KD/llapi/zynqmp_api.cpp index 1f1226ead3d4e9b50100f4de574104a5d6f777b2..bcbf2b98f487aea3c6516fa6369e70d11be97ffc 100644 --- a/lite/backends/fpga/KD/llapi/zynqmp_api.cpp +++ b/lite/backends/fpga/KD/llapi/zynqmp_api.cpp @@ -23,13 +23,12 @@ limitations under the License. */ #include #include -#include "lite/backends/fpga/KD/llapi/config.h" #include "lite/backends/fpga/KD/llapi/zynqmp_api.h" namespace paddle { namespace zynqmp { -#define PADDLE_LITE_OS_LINUX +#define PADDLE_MOBILE_OS_LINUX static int fd = -1; static const char *device_path = "/dev/fpgadrv0"; @@ -39,14 +38,10 @@ static size_t memory_size_max = 0; static size_t memory_size = 0; static inline int do_ioctl(uint64_t req, const void *arg) { - int ret = -1; -#ifdef PADDLE_LITE_OS_LINUX - ret = ioctl(fd, req, arg); - if (ret != 0) { - throw - 1; - } +#ifdef PADDLE_MOBILE_OS_LINUX + return ioctl(fd, req, arg); #else - return ret; + return -1; #endif } @@ -66,15 +61,33 @@ void reset_device() { // memory management; void *fpga_malloc(size_t size) { -#ifdef PADDLE_LITE_OS_LINUX +#ifdef PADDLE_MOBILE_OS_LINUX + void *ptr = reinterpret_cast( mmap64(NULL, size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0)); - if (ptr == NULL) { + if (ptr == MAP_FAILED) { std::cout << "not enough memory !"; exit(-1); } + if (errno == ENOMEM) { + std::cout << "mmap failed with not enough memory !"; + exit(-1); + } + if (errno == EINVAL) { + std::cout << "mmap failed with invalid arguments ! (size=" << size << ")" + << std::endl; + exit(-1); + } + if (ptr == NULL) { + std::cout << "NULL returned, errno=" << errno + << ", mmap failed with other errors other than memory usage !" + << std::endl; + exit(-1); + } + memory_map.insert(std::make_pair(ptr, size)); memory_size += size; + if (memory_size > memory_size_max) { memory_size_max = memory_size; } @@ -90,7 +103,7 @@ size_t fpga_get_memory_size_max() { return memory_size_max; } size_t fpga_diagnose_memory(int detailed) { size_t total = 0; - auto iter = memory_map.begin(); // std::map::iterator + auto iter = memory_map.begin(); while (iter != memory_map.end()) { total += iter->second; iter++; @@ -100,7 +113,7 @@ size_t fpga_diagnose_memory(int detailed) { void fpga_free(void *ptr) { size_t size = 0; - auto iter = memory_map.find(ptr); // std::map::iterator + auto iter = memory_map.find(ptr); if (iter != memory_map.end()) { size = iter->second; memory_map.erase(iter); @@ -108,8 +121,7 @@ void fpga_free(void *ptr) { memory_size -= size; -#ifdef PADDLE_LITE_OS_LINUX - +#ifdef PADDLE_MOBILE_OS_LINUX munmap(ptr, size); #else free(ptr); @@ -150,6 +162,11 @@ void fpga_copy(void *dest, const void *src, size_t num) { memcpy(dest, src, num); } +int fpga_reset() { + struct FpgaResetArgs args; + return do_ioctl(IOCTL_FPGA_RESET, &args); +} + int ioctl_conv(const struct ConvArgs &args) { return do_ioctl(IOCTL_CONFIG_CONV, &args); } @@ -166,7 +183,6 @@ int compute_fpga_conv(const struct SplitConvArgs &args) { } if (split_num > 1) { - std::cout << "Split num > 1 !!!!!!!!!!!!!!!!!!" << std::endl; exit(-1); } return ret; @@ -186,6 +202,7 @@ int get_device_info(const struct DeviceInfo &args) { } int perform_bypass(const struct BypassArgs &args) { + int ret = -1; int size = args.image.channels * args.image.width * args.image.height; int max_size = 1 << 21; @@ -213,7 +230,7 @@ int perform_bypass(const struct BypassArgs &args) { reinterpret_cast(input_address + i * max_size * type_size); bypassArgs.output.address = reinterpret_cast(output_address + i * max_size * out_type_size); - int ret = do_ioctl(IOCTL_CONFIG_BYPASS, &bypassArgs); + ret = do_ioctl(IOCTL_CONFIG_BYPASS, &bypassArgs); scale = std::max(scale, scales[0]); if (ret != 0) { @@ -222,13 +239,16 @@ int perform_bypass(const struct BypassArgs &args) { } int remainder = size - max_size * count; - bypassArgs.image.channels = remainder; - bypassArgs.image.address = - reinterpret_cast(input_address + count * max_size * type_size); - bypassArgs.output.address = reinterpret_cast( - output_address + count * max_size * out_type_size); - int ret = do_ioctl(IOCTL_CONFIG_BYPASS, &bypassArgs); - scale = std::max(scale, scales[0]); + if (remainder > 0) { + bypassArgs.image.channels = remainder; + bypassArgs.image.address = + reinterpret_cast(input_address + count * max_size * type_size); + bypassArgs.output.address = reinterpret_cast( + output_address + count * max_size * out_type_size); + ret = do_ioctl(IOCTL_CONFIG_BYPASS, &bypassArgs); + scale = std::max(scale, scales[0]); + } + args.output.scale_address[0] = scale; args.output.scale_address[1] = 1.0f / scale; return ret; @@ -261,28 +281,13 @@ int compute_fpga_scale(const struct ScaleArgs &args) { } int compute_fpga_dwconv(const struct DWconvArgs &args) { -#ifdef ENABLE_DEBUG - std::cout << "======Compute Basic Conv======"; - std::cout << " relu_enabled:" << args.relu_enabled - << " filter_address:" << args.filter_address; - std::cout << " image_address:" << args.image.address - << " image_scale_address:" << args.image.scale_address - << " image_channels:" << args.image.channels - << " image_height:" << args.image.height - << " image_width:" << args.image.width - << " pad_height:" << args.image.pad_height - << " pad_width:" << args.image.pad_width; - std::cout << " kernel_height:" << args.kernel.height - << " kernel_width:" << args.kernel.width - << " stride_h:" << args.kernel.stride_h - << " stride_w:" << args.kernel.stride_w; - std::cout << " out_address:" << args.output.address - << " out_scale_address:" << args.output.scale_address; - -#endif return do_ioctl(IOCTL_CONFIG_DWCONV, &args); } +int config_activation(const struct ActiveParamterArgs &args) { + return do_ioctl(IOCTL_CONFIG_ACTIVATION_PARAMETER, &args); +} + int config_inplace(const struct InplaceArgs &args) { return do_ioctl(IOCTL_CONFIG_INPLACE, &args); } diff --git a/lite/backends/fpga/KD/llapi/zynqmp_api.h b/lite/backends/fpga/KD/llapi/zynqmp_api.h index 7d22de95a2272862c6fe781295bdaab7177a92fe..55c2fde079a1ca0ec368870e2bb8f727d870a8f3 100644 --- a/lite/backends/fpga/KD/llapi/zynqmp_api.h +++ b/lite/backends/fpga/KD/llapi/zynqmp_api.h @@ -14,6 +14,9 @@ limitations under the License. */ #pragma once +#ifndef PADDLE_LITE_SRC_FPGA_KD_ZYNQMP_API_H +#define PADDLE_LITE_SRC_FPGA_KD_ZYNQMP_API_H + #include #include #include @@ -25,7 +28,6 @@ namespace zynqmp { typedef int16_t half; #define IMAGE_ALIGNMENT 16 // Aligned to 16 -#define FILTER_NUM_ALIGNMENT 32 // Filter number aligned to 32 #define FILTER_ELEMENT_ALIGNMENT 16 // Filter element number aligned to 16 #define BS_NUM_ALIGNMENT 8 #define BIAS_NUM_ALIGNMENT 16 @@ -40,15 +42,19 @@ enum DLayoutType { LAYOUT_HWC = 0, }; -struct VersionArgs { - void* buffer; +enum ActiveType { + TYPE_NONE = 0, + TYPE_RELU = 1, + TYPE_RELU6 = 2, + TYPE_LEAKY_RELU = 3, + TYPE_SIGMOID = 4, }; struct DeviceInfo { uint32_t filter_cap; uint32_t version; uint16_t device_type; - uint32_t reserved0; + uint32_t colunm; uint32_t reserved1; uint32_t reserved2; uint32_t reserved3; @@ -57,6 +63,11 @@ struct DeviceInfo { uint32_t reserved6; }; +struct VersionArgs { + void* buffer; + size_t size; +}; + struct MemoryCopyArgs { void* src; void* dest; @@ -68,7 +79,9 @@ struct MemoryCacheArgs { size_t size; }; -struct MemoryBarrierArgs {}; +struct MemoryBarrierArgs { + uint16_t dummy; +}; struct BNArgs { bool enabled; @@ -108,6 +121,7 @@ struct ConvArgs { void* filter_scale_address; uint32_t filter_num; uint32_t group_num; + uint32_t dilation; struct KernelArgs kernel; struct ImageInputArgs image; // input image; @@ -199,9 +213,16 @@ struct NormalizeParameterArgs { uint32_t hight_width; }; +struct ActiveParamterArgs { + ActiveType type; + uint16_t leaky_relu_factor; +}; + struct InplaceArgs { bool leaky_relu_enable; bool relu_enable; + bool sigmoid_enable; + bool relu6_enable; bool power_enable; bool normalize_enable; }; @@ -216,7 +237,9 @@ struct FpgaRegReadArgs { uint64_t value; }; -struct FpgaResetArgs {}; +struct FpgaResetArgs { + uint32_t val; +}; #define IOCTL_FPGA_MAGIC (('F' + 'P' + 'G' + 'A') / 4) @@ -248,6 +271,8 @@ struct FpgaResetArgs {}; _IOW(IOCTL_FPGA_MAGIC, 41, struct PowerParameterArgs) #define IOCTL_CONFIG_NORMALIZE_PARAMETER \ _IOW(IOCTL_FPGA_MAGIC, 42, struct NormalizeParameterArgs) +#define IOCTL_CONFIG_ACTIVATION_PARAMETER \ + _IOW(IOCTL_FPGA_MAGIC, 43, struct ActiveParamterArgs) #define IOCTL_FPGA_REG_READ _IOW(IOCTL_FPGA_MAGIC, 50, struct FpgaRegReadArgs) #define IOCTL_FPGA_REG_WRITE _IOW(IOCTL_FPGA_MAGIC, 51, struct FpgaRegWriteArgs) #define IOCTL_FPGA_RESET _IOW(IOCTL_FPGA_MAGIC, 52, struct FpgaResetArgs) @@ -331,6 +356,7 @@ int compute_fpga_scale(const struct ScaleArgs& args); int compute_fpga_concat(const struct ConcatArgs& args); int compute_fpga_resize(const struct ResizeArgs& args); +int config_activation(const struct ActiveParamterArgs& args); int config_power(const struct PowerArgs& args); int compute_fpga_dwconv(const struct DWconvArgs& args); int config_norm_param(const struct NormalizeParameterArgs& args); @@ -341,7 +367,11 @@ int config_inplace(const struct InplaceArgs& args); int flush_cache(void* addr, int size); int invalidate_cache(void* addr, int size); +int fpga_reset(); + int16_t fp32_2_fp16(float fp32_num); float fp16_2_fp32(int16_t fp16_num); } // namespace zynqmp } // namespace paddle + +#endif // PADDLE_LITE_SRC_FPGA_KD_ZYNQMP_API_H diff --git a/lite/backends/fpga/KD/pe.hpp b/lite/backends/fpga/KD/pe.hpp index d1dc3c4caa18cbfeba74fac26cca9e19230e2c21..2796124341012574dc719ae9f30633d1d9524680 100644 --- a/lite/backends/fpga/KD/pe.hpp +++ b/lite/backends/fpga/KD/pe.hpp @@ -32,6 +32,5 @@ class PE { virtual ~PE() {} }; - } // namespace zynqmp } // namespace paddle diff --git a/lite/backends/fpga/KD/pe_params.hpp b/lite/backends/fpga/KD/pe_params.hpp index 709f04d399793c6f21c34fc1265f7ed8b5818314..42ec32957e5884aaae3cc96f46060de114b44ead 100644 --- a/lite/backends/fpga/KD/pe_params.hpp +++ b/lite/backends/fpga/KD/pe_params.hpp @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include +#include #include #include "lite/backends/fpga/KD/llapi/zynqmp_api.h" @@ -26,10 +27,16 @@ namespace zynqmp { struct ReLUParam { public: bool enabled = false; + float leaky_relu_factor = 0.0f; +}; + +struct ActiveParam { + enum ActiveType type = TYPE_NONE; + float leaky_relu_factor; }; struct PEParam { - ReLUParam relu; + ActiveParam activeParam; }; struct InputParam : PEParam { @@ -133,6 +140,13 @@ struct ElementwiseAddParam : PEParam { EWAddArgs ewargs; }; +struct ElementwiseMulParam : PEParam { + public: + Tensor* input_x; + Tensor* input_y = nullptr; + Tensor* output = nullptr; +}; + struct FullyConnectedParam : PEParam { public: Tensor* input = nullptr; @@ -197,6 +211,17 @@ struct PriorBoxParam : PEParam { float offset; }; +struct YoloBoxParam : PEParam { + Tensor* input; + Tensor* imgSize; + Tensor* outputBoxes; + Tensor* outputScores; + int downsampleRatio; + std::vector anchors; + int classNum; + float confThresh; +}; + struct ScaleParam : PEParam { public: Tensor* input = nullptr; @@ -229,5 +254,24 @@ struct CropParam : PEParam { std::vector offsets; std::vector shape; }; + +struct GRUParam : PEParam { + public: + Tensor* input = nullptr; + Tensor* h0 = nullptr; + Tensor* weight = nullptr; + Tensor* bias = nullptr; + + Tensor* batch_gate = nullptr; + Tensor* batch_reset_hidden_prev = nullptr; + Tensor* batch_hidden = nullptr; + Tensor* hidden = nullptr; + + std::string gate_activation = "sigmoid"; + std::string activation = "tanh"; + bool is_reverse = false; + bool origin_mode = false; +}; + } // namespace zynqmp } // namespace paddle diff --git a/lite/backends/fpga/KD/pes/conv_pe.hpp b/lite/backends/fpga/KD/pes/conv_pe.hpp index e897f82280fa57f904bd7c749e371d8ec9219b51..b4eac2c41e138cab19197ccb8ab89681a69ec6fe 100644 --- a/lite/backends/fpga/KD/pes/conv_pe.hpp +++ b/lite/backends/fpga/KD/pes/conv_pe.hpp @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include +#include #include #include "lite/backends/fpga/KD/pe.hpp" @@ -24,6 +25,7 @@ limitations under the License. */ #include "lite/backends/fpga/KD/pes/conv_process.hpp" #include "lite/backends/fpga/KD/pes/elementwise_add_pe.hpp" #include "lite/backends/fpga/KD/pes/scale_pe.hpp" +#include "lite/backends/fpga/KD/pes/split_pe.hpp" namespace paddle { namespace zynqmp { @@ -40,6 +42,8 @@ class ConvPE : public PE { void apply() { split_axis = fill_split_arg(param_); + split_channel = param_.groups != 1 && param_.splitParams().size() > 1; + if (split_axis == 0 && param_.splitParams().size() > 1) { ConcatParam& concat_param = concatPE_.param(); for (auto conv_param : param_.splitParams()) { @@ -49,6 +53,28 @@ class ConvPE : public PE { concatPE_.init(); concatPE_.apply(); } + + if (split_channel) { + SplitParam& split_param = splitPE_.param(); + split_param.input = param_.input; + for (auto conv_param : param_.splitParams()) { + split_param.outputs.push_back(&conv_param->input); + } + splitPE_.init(); + splitPE_.apply(); + } + + if (DLEngine::get_instance().isZU3() && + param_.input->shape().dimSize() == 4 && + param_.input->shape().width() == 1 && + param_.input->shape().channel() >= 2048) { + use_cpu_ = true; + } + if (!use_cpu_) { + // param_.filter->releaseData(); + } + + // exit(-1); } void cpu_compute() { Tensor* input = param_.input; @@ -59,6 +85,7 @@ class ConvPE : public PE { Tensor float_output; float* image_addr = float_input.mutableData(FP32, input->shape()); float_input.copyFrom(input); + // float16* data_out = output->data(); float* out = float_output.mutableData(FP32, output->shape()); int out_channel = output->shape().channel(); @@ -66,13 +93,21 @@ class ConvPE : public PE { float* filter_data = param_.filter->data(); float* mi = new float[in_channel]; - for (int i = 0; i < out_channel; i++) { float* image = image_addr; float* filter_ptr = filter_data + i * in_channel; float* out_ptr = mi; #pragma omp parallel for for (int j = 0; j < in_channel; j++) { + // float32x4_t x0 = vld1q_f32(image); + // float32x4_t x1 = vld1q_f32(filter_ptr); + + // float32x4_t r = vmulq_f32(x0, x1); + + // vst1q_f32(out_ptr, r); + // image += 4; + // filter_ptr += 4; + // out_ptr += 4; float value = image_addr[j] * filter_ptr[j]; mi[j] = value; } @@ -89,49 +124,104 @@ class ConvPE : public PE { } bool dispatch() { - inplace_.relu_enable = param_.relu.enabled; - inplace_.power_enable = false; - inplace_.normalize_enable = false; + fpga_reset(); + if (use_cpu_) { + cpu_compute(); + return true; + } - if (param_.relu.enabled) { - inplace_.relu_enable = param_.relu.enabled; + if (param_.activeParam.type == TYPE_RELU) { + inplace_.relu_enable = true; + } else if (param_.activeParam.type == TYPE_RELU6) { + inplace_.relu6_enable = true; + } else if (param_.activeParam.type == TYPE_SIGMOID) { + inplace_.sigmoid_enable = true; + } else if (param_.activeParam.type == TYPE_LEAKY_RELU) { + inplace_.leaky_relu_enable = true; + } + + if (inplace_.relu_enable || inplace_.leaky_relu_enable || + inplace_.relu6_enable || inplace_.sigmoid_enable) { config_inplace(inplace_); + if (inplace_.leaky_relu_enable) { + activeParamterArgs.type = TYPE_LEAKY_RELU; + activeParamterArgs.leaky_relu_factor = + fp32_2_fp16(param_.activeParam.leaky_relu_factor); + config_activation(activeParamterArgs); + } } std::vector& params = param_.splitParams(); + + if (split_channel) { + // splitPE_.param().input->saveToFile("input_image",true); + splitPE_.dispatch(); + } + int ret = 0; for (auto conv_param : params) { + // conv_param->input.printScale(); + // if (split_channel) { + // conv_param->input.saveToFile("pack_image",true); + // } ret |= compute_fpga_conv_basic(conv_param->args); } - if (param_.relu.enabled) { + if (inplace_.relu_enable || inplace_.leaky_relu_enable || + inplace_.relu6_enable || inplace_.sigmoid_enable) { inplace_.relu_enable = false; + inplace_.leaky_relu_enable = false; + inplace_.relu6_enable = false; + inplace_.sigmoid_enable = false; config_inplace(inplace_); + + if (inplace_.leaky_relu_enable) { + activeParamterArgs.type = TYPE_LEAKY_RELU; + activeParamterArgs.leaky_relu_factor = fp32_2_fp16(0); + config_activation(activeParamterArgs); + } } size_t size = params.size(); if (split_axis == 0 && ret == 0 && size > 1) { + // std::cout << "concat size:" << size << std::endl; concatPE_.dispatch(); } if (split_axis == 1 && ret == 0 && size > 1) { + // for (int n = 0; n < size - 1; n++) { ElementwiseAddParam& add_param = addPE_.param(); add_param.inputs = {¶ms[0]->output, ¶ms[1]->output}; add_param.output = param_.output; addPE_.init(); addPE_.apply(); addPE_.dispatch(); + + // param_.output->printScale(); + + // params[0]->input.saveToFile("conv_1.txt"); + // params[1]->input.saveToFile("conv_2.txt"); + + // params[0]->output.saveToFile("ew_o1.txt"); + // params[1]->output.saveToFile("ew_o2.txt"); + // std::cout << "\n ================== EW ================== \n"; + // } } + return ret == 0; } ConvParam& param() { return param_; } private: + bool use_cpu_ = false; + bool split_channel = false; ConvParam param_; ConcatPE concatPE_; + SplitPE splitPE_; ElementwiseAddPE addPE_; int split_axis = 0; InplaceArgs inplace_ = {0}; + ActiveParamterArgs activeParamterArgs; }; } // namespace zynqmp diff --git a/lite/backends/fpga/KD/pes/conv_process.hpp b/lite/backends/fpga/KD/pes/conv_process.hpp old mode 100644 new mode 100755 index fd17218d06f050df3dc935bdde0a320e52b56a40..cea22e0edc647b3bf4f0ac15e43121b5d8926154 --- a/lite/backends/fpga/KD/pes/conv_process.hpp +++ b/lite/backends/fpga/KD/pes/conv_process.hpp @@ -14,6 +14,9 @@ limitations under the License. */ #pragma once +#ifndef conv_process_hpp +#define conv_process_hpp + #include #include #include @@ -45,7 +48,19 @@ inline int get_split_num(Tensor* filter) { filter->shape().width(); auto num = filter->shape().num(); int div_capacity = filter::calc_division_capacity(chw); - return filter::calc_split_num(num, div_capacity); + int filter_num_alignment = filter::get_filter_num_alignment(); + int aligned_num = align_to_x(num, filter_num_alignment); + return filter::calc_split_num(aligned_num, div_capacity); +} + +inline int get_pack_num(Tensor* filter, int group_num) { + auto chw = filter->shape().channel() * filter->shape().height() * + filter->shape().width(); + auto num = filter->shape().num(); + int div_capacity = filter::calc_division_capacity(chw); + int filter_num_alignment = filter::get_filter_num_alignment(); + int aligned_num_per_group = align_to_x(num / group_num, filter_num_alignment); + return filter::calc_pack_num(aligned_num_per_group, group_num, div_capacity); } inline void fill_scale_bias_const(ConvParam* param_) { @@ -112,6 +127,50 @@ inline void combine_add_bn_params(BatchnormParam* bn, param_->bias()->setDataLocation(CPU); } +inline int gcd_(int a, int b) { + while (b) { + int temp = a; + a = b; + b = temp % b; + } + return a; +} + +inline int lcm_(int a, int b) { return a * b / gcd_(a, b); } + +inline void format_bias_scale_new(Tensor* bias, + Tensor* scale, + Tensor* scale_bias) { + Shape& bias_shape = bias->shape(); + int channel = bias_shape.channel(); + int repeat = 1; + int alignment = 16; + int length = channel; + + if (channel % alignment != 0 || channel < alignment) { + int c_lcm = lcm_(channel, alignment); + repeat = c_lcm / (channel); + } + Shape shape(N, {2 * channel * repeat}); + float16* scale_bias_data = scale_bias->mutableData(FP16, shape); + + float* bias_data_float = bias->data(); + float* scale_data_float = scale->data(); + + for (int i = 0; i < repeat; i++) { + for (int j = 0; j < length; j++) { + float16 value_bias = float_to_half(bias_data_float[j]); + scale_bias_data[i * length + j] = value_bias; + } + } + for (int i = 0; i < repeat; i++) { + for (int j = 0; j < length; j++) { + float16 value_scale = float_to_half(scale_data_float[j]); + scale_bias_data[i * length + j + length * repeat] = value_scale; + } + } +} + inline void format_scale_bias(Tensor* scale, Tensor* bias, Tensor* filter, @@ -126,41 +185,99 @@ inline void format_scale_bias(Tensor* scale, bias_data = bias->data(); } int channel = filter->shape().num(); - Shape bias_scale_shape(N, {2 * channel}); + int scale_bias_len = align_to_x(channel / group, BS_NUM_ALIGNMENT) * group; + + int c_per_group = channel / group; + int aligned_c_per_group = align_to_x(channel / group, BS_NUM_ALIGNMENT); + + Shape bias_scale_shape(N, {2 * scale_bias_len}); float* bs_data = scale_bias->mutableData(FP32, bias_scale_shape); - for (int i = 0; i < channel; i++) { - float scale_value = scale_data == nullptr ? 1 : scale_data[i]; - float bias_value = bias_data == nullptr ? 0 : bias_data[i]; - bs_data[i + channel] = scale_value; - bs_data[i] = bias_value; + float* temp_data = + reinterpret_cast(fpga_malloc(2 * scale_bias_len * sizeof(float))); + memset(temp_data, 0, 2 * scale_bias_len * sizeof(float)); + + std::vector scales; + if (scale_data != nullptr) { + for (int i = 0; i < channel; ++i) { + scales.push_back(scale_data[i]); + } + for (int i = 0; i < scale_bias_len - channel; i++) { + scales.push_back(1); + } + } else { + for (int i = 0; i < scale_bias_len; i++) { + scales.push_back(1); + } + } + + for (int i = 0; i < scale_bias_len; ++i) { + temp_data[i + scale_bias_len] = 1; + temp_data[i] = 0; } - int element_num_per_div = get_filter_num_per_div(filter, group); - bias_scale::format_bias_scale_array(&bs_data, element_num_per_div, channel); + for (int g = 0; g < group; g++) { + for (int c = 0; c < c_per_group; c++) { + int src_index = g * c_per_group + c; + int dst_index = g * aligned_c_per_group + c; + float scale_value = scales[src_index]; + float bias_value = bias_data == nullptr ? 0 : bias_data[src_index]; + temp_data[dst_index + scale_bias_len] = scale_value; + temp_data[dst_index] = bias_value; + } + } + + bias_scale::format_bias_scale_array( + &temp_data, scale_bias_len / group, scale_bias_len); + memcpy(bs_data, temp_data, 2 * scale_bias_len * sizeof(float)); } -inline void format_filter(Tensor* filter, Tensor* quantized_filter, int group) { +inline void format_filter(Tensor* filter, + Tensor* quantized_filter, + int group, + std::vector& scales, // NOLINT + float max) { float max_value = find_max(*filter); + // max_value = max; //TODO: global quantization for filter Shape& filter_shape = filter->shape(); + + int mem_size; + std::vector max_values; + int8_t* quantized_data = filter::format_filter(filter->data(), + mem_size, + filter_shape.num(), + filter_shape.channel(), + filter_shape.height(), + filter_shape.width(), + group, + max_value, + max_values); + + float mem_factor = mem_size * 1.0f / filter->shape().numel(); + quantized_filter->setMemScale(mem_factor); + quantized_filter->setAligned(true); - quantized_filter->mutableData(INT8, filter->shape()); + int8_t* src = quantized_filter->mutableData(INT8, filter->shape()); quantized_filter->scale()[0] = max_value / 127.0f; quantized_filter->scale()[1] = 127.0f / max_value; - auto memory_size = filter->shape().memorySize(sizeof(float)); - auto new_data = reinterpret_cast(fpga_malloc(memory_size)); - memcpy(new_data, filter->data(), memory_size); - size_t mem_size = filter::format_filter(&new_data, - filter_shape.num(), - filter_shape.channel(), - filter_shape.height(), - filter_shape.width(), - group, - max_value); - int8_t* src = quantized_filter->mutableData(INT8, filter->shape()); - memcpy(src, new_data, mem_size); - fpga_free(new_data); + memcpy(src, quantized_data, mem_size); quantized_filter->flush(); + fpga_free(quantized_data); + + // for (size_t i = 0; i < max_values.size(); i++) { + // // scales.push_back(max_values[i] / max_value); + // scales.push_back(1.0f); + // } + + // filter->saveToFile("filter.txt"); + // std::ofstream ofs; + // ofs.open("quant.txt"); + // for (int i = 0; i < mem_size; i++) { + // float value = quantized_data[i]; + // ofs << value << std::endl; + // } + // ofs.close(); + // exit(-1); } inline void format_dw_filter(Tensor* filter, @@ -207,10 +324,11 @@ inline void split_filter_num(const ConvParam& c_param) { Tensor* out = param.output; Tensor* filter = param.filter; auto channel = out->shape().channel(); - - int split_num = param.groups == 1 ? get_split_num(param.filter) : 1; + int split_num = get_split_num(param.filter); int filter_num_per_div = get_filter_num_per_div(filter, param.groups); + float max = find_max(*filter); + Shape& out_shape = out->shape(); for (int i = 0; i < split_num; i++) { BasicConvParam* conv_param = new BasicConvParam(); @@ -251,17 +369,18 @@ inline void split_filter_num(const ConvParam& c_param) { filter->data() + i * filter_num_per_div * filter_hwc, filter_num * filter_hwc * sizeof(float)); new_filter.flush(); - conv_param->filter.mutableData(FP32, f_shape); - format_filter(&new_filter, &(conv_param->filter), param.groups); - int sb_num = 2 * align_to_x(filter_num, BS_NUM_ALIGNMENT); + std::vector v; // TODO(chonwhite) change variable name; + format_filter(&new_filter, &(conv_param->filter), param.groups, v, max); + conv_param->filter.setDataType(INT8); + Tensor scale; Tensor bias; int chnnnel_start = i * filter_num_per_div; - Shape s_shape(N, {filter_num}); + Shape s_shape(NC, {1, filter_num}); float* scale_data = scale.mutableData(FP32, s_shape); float* bias_data = bias.mutableData(FP32, s_shape); for (int n = 0; n < filter_num; n++) { @@ -270,17 +389,11 @@ inline void split_filter_num(const ConvParam& c_param) { for (int n = 0; n < filter_num; n++) { bias_data[n] = param.bias()->data()[n + chnnnel_start]; } - Shape sb_shape(N, {sb_num}); - format_scale_bias(&scale, - &bias, - &conv_param->filter, - &conv_param->scaleBias, - param.groups); + format_bias_scale_new(&bias, &scale, &conv_param->scaleBias); conv_param->scaleBias.flush(); args.group_num = param.groups; - args.relu_enabled = param.relu.enabled; - args.sb_address = conv_param->scaleBias.data(); + args.sb_address = conv_param->scaleBias.data(); args.kernel.stride_h = param.strides[1]; args.kernel.stride_w = param.strides[0]; args.kernel.height = new_filter.shape().height(); @@ -296,6 +409,137 @@ inline void split_filter_num(const ConvParam& c_param) { args.image.height = input->shape().height(); args.image.pad_width = param.paddings[1]; args.image.pad_height = param.paddings[0]; + args.dilation = param.dilations[0]; + + args.output.address = out_address; + args.output.scale_address = out_scale_address; + param.splitParams().push_back(conv_param); + } +} + +inline void pack_channel_filter(const ConvParam& c_param) { + ConvParam& param = const_cast(c_param); + Tensor* input = param.input; + Tensor* out = param.output; + Tensor* filter = param.filter; + int filter_num_alignment = filter::get_filter_num_alignment(); + auto filter_num = filter->shape().num(); + int pack_num = get_pack_num(param.filter, param.groups); + int group_per_pack = (param.groups + pack_num - 1) / pack_num; + int filter_per_group = filter_num / param.groups; + int filter_per_pack = filter_per_group * group_per_pack; + int channel_per_pack = filter->shape().channel() * group_per_pack; + + float max = find_max(*filter); + + Shape& out_shape = out->shape(); + + for (int i = 0; i < pack_num; i++) { + BasicConvParam* conv_param = new BasicConvParam(); + + conv_param->output.setDataLocation(Device); + conv_param->output.setAligned(true); + + float16* out_address = nullptr; + float* out_scale_address = nullptr; + + float16* input_address = nullptr; + + ConvArgs& args = conv_param->args; + + if (pack_num == 1) { + out_address = out->data(); + out_scale_address = out->scale(); + } + + int new_group = param.groups; + int filter_current_pack = filter->shape().num(); + int channel_current_pack = input->shape().channel(); + + new_group = i == pack_num - 1 + ? param.groups - (pack_num - 1) * group_per_pack + : group_per_pack; + filter_current_pack = new_group * filter_per_group; + channel_current_pack = new_group * filter->shape().channel(); + + if (pack_num == 1) { + input_address = input->data(); + } else { + Shape in_shape(NCHW, + {1, + channel_current_pack, + input->shape().height(), + input->shape().width()}); + input_address = conv_param->input.mutableData(FP16, in_shape); + } + + if (pack_num != 1) { + Shape shape( + NHWC, + {1, out_shape.height(), out_shape.width(), filter_current_pack}); + out_address = conv_param->output.mutableData(FP16, shape); + out_scale_address = conv_param->output.scale(); + } + Shape f_shape(NCHW, + {filter_current_pack, + filter->shape().channel(), + filter->shape().height(), + filter->shape().width()}); + + Tensor new_filter; + float* new_filter_data = new_filter.mutableData(FP32, f_shape); + int filter_hwc = filter->shape().height() * filter->shape().width() * + filter->shape().channel(); + + memcpy(new_filter_data, + filter->data() + i * filter_per_pack * filter_hwc, + filter_current_pack * filter_hwc * sizeof(float)); + new_filter.flush(); + conv_param->filter.mutableData(FP32, f_shape); + + float mem_factor = filter_num_alignment / filter_per_pack; + conv_param->filter.setMemScale(mem_factor); + + std::vector v; // TODO(chonwhite) change variable name + format_filter(&new_filter, &(conv_param->filter), new_group, v, max); + conv_param->filter.setDataType(INT8); + + Tensor scale; + Tensor bias; + + int chnnnel_start = i * filter_per_pack; + + Shape s_shape(NC, {1, filter_current_pack}); + float* scale_data = scale.mutableData(FP32, s_shape); + float* bias_data = bias.mutableData(FP32, s_shape); + for (int n = 0; n < filter_current_pack; n++) { + scale_data[n] = param.scale()->data()[n + chnnnel_start]; + } + for (int n = 0; n < filter_current_pack; n++) { + bias_data[n] = param.bias()->data()[n + chnnnel_start]; + } + format_bias_scale_new(&bias, &scale, &conv_param->scaleBias); + conv_param->scaleBias.flush(); + + args.group_num = new_group; + args.sb_address = conv_param->scaleBias.data(); + args.kernel.stride_h = param.strides[1]; + args.kernel.stride_w = param.strides[0]; + args.kernel.height = new_filter.shape().height(); + args.kernel.width = new_filter.shape().width(); + + args.filter_address = conv_param->filter.data(); + args.filter_num = filter_current_pack; + args.filter_scale_address = conv_param->filter.scale(); + args.image.address = input_address; + args.image.scale_address = input->scale(); + args.image.channels = channel_current_pack; + args.image.width = input->shape().width(); + args.image.height = input->shape().height(); + args.image.pad_width = param.paddings[1]; + args.image.pad_height = param.paddings[0]; + args.dilation = param.dilations[0]; + args.output.address = out_address; args.output.scale_address = out_scale_address; param.splitParams().push_back(conv_param); @@ -310,9 +554,11 @@ inline void split_channel(const ConvParam& c_param) { int num = ceil(input->shape().channel() * 1.0f / 2047); int channel = input->shape().channel() / num; - std::cout << "channel::" << channel << "num::" << num << std::endl; + Shape bs_shape(N, {channel}); + float max = 1.0f; + for (int i = 0; i < num; i++) { BasicConvParam* conv_param = new BasicConvParam(); @@ -324,6 +570,7 @@ inline void split_channel(const ConvParam& c_param) { // filter transformation; Shape f_shape(NCHW, {param.filter->shape().num(), channel, 1, 1}); + Tensor new_filter; float* dst = new_filter.mutableData(FP32, f_shape); @@ -334,7 +581,9 @@ inline void split_channel(const ConvParam& c_param) { src += param.filter->shape().channel(); } new_filter.flush(); - format_filter(&new_filter, &(conv_param->filter), param.groups); + std::vector scales; + format_filter( + &new_filter, &(conv_param->filter), param.groups, scales, max); Tensor bias; Tensor scale; @@ -356,7 +605,6 @@ inline void split_channel(const ConvParam& c_param) { ConvArgs& args = conv_param->args; args.group_num = param.groups; - args.relu_enabled = param.relu.enabled; args.sb_address = conv_param->scaleBias.data(); args.kernel.stride_h = param.strides[1]; args.kernel.stride_w = param.strides[0]; @@ -374,6 +622,7 @@ inline void split_channel(const ConvParam& c_param) { args.image.height = conv_param->input.shape().height(); args.image.pad_width = param.paddings[1]; args.image.pad_height = param.paddings[0]; + args.dilation = param.dilations[0]; args.output.address = conv_param->output.mutableData(); args.output.scale_address = conv_param->output.scale(); param.splitParams().push_back(conv_param); @@ -384,13 +633,17 @@ inline int fill_split_arg(const ConvParam& c_param) { ConvParam& param = const_cast(c_param); Tensor* input = param.input; Tensor* output = param.output; + if (output->shape().dimSize() == 4 && input->shape().channel() > 2047 && input->shape().width() == 1) { split_channel(c_param); return 1; - } else { + } else if (param.groups == 1) { split_filter_num(c_param); return 0; + } else { + pack_channel_filter(c_param); + return 0; } } @@ -407,7 +660,6 @@ inline bool compute_conv(const ConvParam& c_conv_params) { for (int i = 0; i < 1; i++) { for (int i = 0; i < img.shape().numel(); i++) { float value = half_to_float(img.data()[i]); - std::cout << "value:" << value << std::endl; } } } @@ -416,3 +668,5 @@ inline bool compute_conv(const ConvParam& c_conv_params) { } // namespace zynqmp } // namespace paddle + +#endif /* conv_process_hpp */ diff --git a/lite/backends/fpga/KD/pes/crop_pe.cpp b/lite/backends/fpga/KD/pes/crop_pe.cpp old mode 100644 new mode 100755 index c29df623aa610d329a46ee337cdcb1abd801881c..1438aaba6565cefa72f863d5fc3af0a389fc95e0 --- a/lite/backends/fpga/KD/pes/crop_pe.cpp +++ b/lite/backends/fpga/KD/pes/crop_pe.cpp @@ -14,8 +14,6 @@ limitations under the License. */ #include "lite/backends/fpga/KD/pes/crop_pe.hpp" -#include - namespace paddle { namespace zynqmp { diff --git a/lite/backends/fpga/KD/pes/crop_pe.hpp b/lite/backends/fpga/KD/pes/crop_pe.hpp index 6ebbcdb31f1afb7939c75a2ba9254c0b31f67d31..ccd1e0c98968375ebd840c7e8b15aedd6ad7ef77 100755 --- a/lite/backends/fpga/KD/pes/crop_pe.hpp +++ b/lite/backends/fpga/KD/pes/crop_pe.hpp @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include diff --git a/lite/backends/fpga/KD/pes/depthwise_conv_pe.hpp b/lite/backends/fpga/KD/pes/depthwise_conv_pe.hpp index 9d7b9b544bff953662bab86f095823c5c7b3075b..9958990af6eb237d2122a63e1b7ed947ca329d31 100755 --- a/lite/backends/fpga/KD/pes/depthwise_conv_pe.hpp +++ b/lite/backends/fpga/KD/pes/depthwise_conv_pe.hpp @@ -24,6 +24,17 @@ namespace zynqmp { class DepthwiseConvPE : public PE { public: + inline int gcd_(int a, int b) { + while (b) { + int temp = a; + a = b; + b = temp % b; + } + return a; + } + + inline int lcm_(int a, int b) { return a * b / gcd_(a, b); } + bool init() { Tensor* output = param_.output; output->setAligned(true); @@ -37,18 +48,61 @@ class DepthwiseConvPE : public PE { Tensor* output = param.output; int channel = output->shape().channel(); - float* new_scale_data = param_.scale()->data(); - float* new_bias_data = param_.bias()->data(); + int repeat = 1; + int alignment = 16; + int length = channel; - float16* b_data = bias_.mutableData(FP16, param_.bias()->shape()); - for (int i = 0; i < channel; i++) { - b_data[i] = float_to_half(new_bias_data[i]); + if (channel % alignment != 0 || channel < alignment) { + int c_lcm = lcm_(channel, alignment); + repeat = c_lcm / (channel); + } + Shape shape(N, {channel * repeat}); + + float16* b_data = bias_.mutableData(FP16, shape); + if (param_.bias()->dataType() == FP32) { + float* new_bias_data = param_.bias()->data(); + // bias从float转换成float16 + // for (int i = 0; i < channel; i++) { + // b_data[i] = float_to_half(new_bias_data[i]); + // } + // bias 按16对齐填充hw + for (int i = 0; i < repeat; i++) { + for (int j = 0; j < length; j++) { + float16 value = float_to_half(new_bias_data[j]); + b_data[i * length + j] = value; + } + } + bias_.flush(); + } else { + float16* new_bias_data = param_.bias()->data(); + // memcpy(b_data, new_bias_data, channel * sizeof(float16)); + for (int i = 0; i < repeat; i++) { + for (int j = 0; j < length; j++) { + // float16 value = float_to_half(bias_data_float[j]); + b_data[i * length + j] = new_bias_data[j]; + } + } + bias_.flush(); } - bias_.flush(); - Tensor* quantized_filter = param.quantizedFilter(); - quantized_filter->mutableData(FP16, param.filter->shape()); - format_dw_filter(param.filter, param.quantizedFilter(), new_scale_data); + if (param_.scale()->dataType() == FP32) { + float* new_scale_data = param_.scale()->data(); + Tensor* quantized_filter = param.quantizedFilter(); + quantized_filter->mutableData(FP16, param.filter->shape()); + format_dw_filter(param.filter, param.quantizedFilter(), new_scale_data); + + } else { + // filter 全为1时,且channal为对齐时 + float16* scale_data = param_.scale()->data(); + float16* filter_data = param.quantizedFilter()->mutableData( + FP16, param.filter->shape()); + + // memcpy(filter_data, scale_data, channel * sizeof(float16)); + memcpy(filter_data, + scale_data, + param.filter->shape().numel() * sizeof(float16)); + param.quantizedFilter()->flush(); + } DWconvArgs args = {0}; args.bias_address = b_data; @@ -71,20 +125,33 @@ class DepthwiseConvPE : public PE { args.sub_conv_num = 1; param.args = args; - inplace_.relu_enable = param_.relu.enabled; inplace_.power_enable = false; inplace_.normalize_enable = false; } bool dispatch() { param_.input->syncToDevice(); - if (param_.relu.enabled) { - inplace_.relu_enable = param_.relu.enabled; + if (param_.activeParam.type == TYPE_RELU) { + inplace_.relu_enable = true; + } else if (param_.activeParam.type == TYPE_RELU6) { + inplace_.relu6_enable = true; + } else if (param_.activeParam.type == TYPE_SIGMOID) { + inplace_.sigmoid_enable = true; + } else if (param_.activeParam.type == TYPE_LEAKY_RELU) { + inplace_.leaky_relu_enable = true; + } + + if (inplace_.relu_enable || inplace_.leaky_relu_enable || + inplace_.relu6_enable || inplace_.sigmoid_enable) { config_inplace(inplace_); } bool ret = compute_fpga_dwconv(param_.args) == 0; - if (param_.relu.enabled) { + if (inplace_.relu_enable || inplace_.leaky_relu_enable || + inplace_.relu6_enable || inplace_.sigmoid_enable) { inplace_.relu_enable = false; + inplace_.leaky_relu_enable = false; + inplace_.relu6_enable = false; + inplace_.sigmoid_enable = false; config_inplace(inplace_); } return ret; diff --git a/lite/backends/fpga/KD/pes/elementwise_add_pe.hpp b/lite/backends/fpga/KD/pes/elementwise_add_pe.hpp index a498a2bde9a3656cf8b7006b867eec088d87b425..6f76ae3d4a1d9d054339d929515f24989f1c15b0 100755 --- a/lite/backends/fpga/KD/pes/elementwise_add_pe.hpp +++ b/lite/backends/fpga/KD/pes/elementwise_add_pe.hpp @@ -58,15 +58,29 @@ class ElementwiseAddPE : public PE { bool dispatch() { param_.inputs[0]->syncToDevice(); param_.inputs[1]->syncToDevice(); - InplaceArgs inplace_args = {0}; - if (param_.relu.enabled) { - inplace_args.relu_enable = true; - config_inplace(inplace_args); + // InplaceArgs inplace_ = {0}; + + if (param_.activeParam.type == TYPE_RELU) { + inplace_.relu_enable = true; + } else if (param_.activeParam.type == TYPE_RELU6) { + inplace_.relu6_enable = true; + } else if (param_.activeParam.type == TYPE_SIGMOID) { + inplace_.sigmoid_enable = true; + } else if (param_.activeParam.type == TYPE_LEAKY_RELU) { + inplace_.leaky_relu_enable = true; + } + if (inplace_.relu_enable || inplace_.leaky_relu_enable || + inplace_.relu6_enable || inplace_.sigmoid_enable) { + config_inplace(inplace_); } compute_fpga_ewadd(param_.ewargs); - if (param_.relu.enabled) { - inplace_args.relu_enable = false; - config_inplace(inplace_args); + if (inplace_.relu_enable || inplace_.leaky_relu_enable || + inplace_.relu6_enable || inplace_.sigmoid_enable) { + inplace_.relu_enable = false; + inplace_.relu6_enable = false; + inplace_.sigmoid_enable = false; + inplace_.leaky_relu_enable = false; + config_inplace(inplace_); } return true; } @@ -75,6 +89,7 @@ class ElementwiseAddPE : public PE { private: ElementwiseAddParam param_; + InplaceArgs inplace_ = {0}; }; } // namespace zynqmp diff --git a/lite/backends/fpga/KD/pes/elementwise_mul_pe.hpp b/lite/backends/fpga/KD/pes/elementwise_mul_pe.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7730c598b0e8745e47df7d5c456e2b5420fbe6c0 --- /dev/null +++ b/lite/backends/fpga/KD/pes/elementwise_mul_pe.hpp @@ -0,0 +1,77 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "lite/backends/fpga/KD/pe.hpp" +#include "lite/backends/fpga/KD/pe_params.hpp" +namespace paddle { +namespace zynqmp { + +class ElementwiseMulPE : public PE { + public: + bool init() { + Tensor* output = param_.output; + output->setAligned(true); + output->setDataLocation(Device); + return true; + } + + void apply() { + Tensor* input = param_.input_x; + Tensor* output = param_.output; + + int wc_aligned = align_to_x(param_.input_x->shape().numel(), 32); + + Shape s(N, {wc_aligned}); + float16* bias_data = bias_tensor.mutableData(FP16, s); + memset(bias_data, 0, wc_aligned * sizeof(float16)); + + ScaleArgs& args = args_; + args.scale_address = param_.input_y->data(); + args.bias_address = bias_tensor.data(); + args.wc_alignment = wc_aligned; + args.channel_alignment = wc_aligned; + args.image.address = input->data(); + args.image.scale_address = input->scale(); + args.image.channels = wc_aligned; + args.image.height = 1; + args.image.width = 1; + args.image.pad_width = 0; + args.image.pad_height = 0; + args.output.address = output->data(); + args.output.scale_address = output->scale(); + } + + void updateInput(Tensor* t, int index) { + if (index == 0) { + args_.scale_address = t->data(); // replace inputs? + } + } + + bool dispatch() { + compute_fpga_scale(args_) == 0; + return true; + } + + ElementwiseMulParam& param() { return param_; } + + private: + ElementwiseMulParam param_; + ScaleArgs args_ = {0}; + Tensor bias_tensor; +}; + +} // namespace zynqmp +} // namespace paddle diff --git a/lite/backends/fpga/KD/pes/fully_connected_pe.hpp b/lite/backends/fpga/KD/pes/fully_connected_pe.hpp index 2179a142ad3b3a990512b3ea1cd202bc5ce502f1..a2b184e383aa600b1279197a115c58309e204a95 100644 --- a/lite/backends/fpga/KD/pes/fully_connected_pe.hpp +++ b/lite/backends/fpga/KD/pes/fully_connected_pe.hpp @@ -38,6 +38,8 @@ class FullyConnectedPE : public PE { Tensor* input = param_.input; convParam_.input = param_.input; convParam_.output = param_.output; + // convParam_.relu = param_.relu; + convParam_.activeParam.type = param_.activeParam.type; convParam_.groups = 1; convParam_.strides = {1, 1}; convParam_.paddings = {0, 0}; @@ -46,6 +48,9 @@ class FullyConnectedPE : public PE { int num = param_.filter->shape().channel(); int chw = param_.filter->shape().num(); + // if (num == 2) { + // return; + // } int height = param_.input->shape().height(); int width = param_.input->shape().width(); @@ -82,7 +87,45 @@ class FullyConnectedPE : public PE { convPE_.apply(); } - bool dispatch() { return convPE_.dispatch(); } + void cpu_compute() { + int num = param_.filter->shape().channel(); + int chw = param_.filter->shape().num(); + + float* filter_data = param_.filter->data(); + float max = 0.0f; + Tensor* input = param_.input; + Tensor* output = param_.output; + float16* input_data = input->data(); + float16* output_data = output->data(); + + for (int i = 0; i < num; i++) { + float sum = 0; + float bias = param_.bias->data()[i]; + for (int j = 0; j < chw; j++) { + float scale = filter_data[j * num + i]; + float data = half_to_float(input_data[j]); + sum += scale * data; + } + output_data[i] = float_to_half(sum + bias); + if (max < output_data[i]) { + max = output_data[i]; + } + } + + output->flush(); + output->scale()[0] = max / 127.0f; + output->scale()[1] = 127.0f / max; + } + + bool dispatch() { + // int num = param_.filter->shape().channel(); + // if (num == 2) { + // cpu_compute(); + // return 1; + // } else { + return convPE_.dispatch(); + // } + } FullyConnectedParam& param() { return param_; } diff --git a/lite/backends/fpga/KD/pes/gru_pe.hpp b/lite/backends/fpga/KD/pes/gru_pe.hpp new file mode 100755 index 0000000000000000000000000000000000000000..299ffb872b4620fc409eb8e66760a6308a814efb --- /dev/null +++ b/lite/backends/fpga/KD/pes/gru_pe.hpp @@ -0,0 +1,192 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "lite/backends/arm/math/sgemm.h" +#include "lite/backends/fpga/KD/pe.hpp" +#include "lite/backends/fpga/KD/pe_params.hpp" +#include "lite/backends/fpga/KD/pes/elementwise_add_pe.hpp" +#include "lite/backends/fpga/KD/pes/elementwise_mul_pe.hpp" +#include "lite/backends/fpga/KD/pes/fully_connected_pe.hpp" +#include "lite/backends/fpga/KD/pes/relu_pe.hpp" + +#include "lite/api/paddle_place.h" +#include "lite/backends/arm/math/funcs.h" +#include "lite/core/type_system.h" + +namespace paddle { +namespace zynqmp { + +struct GRUTensors { + Tensor* gate; + Tensor* pre_output; + Tensor* output; + Tensor* reset_output; +}; + +class GRUPE : public PE { + public: + bool init() { return true; } + + void apply() { + auto hidden = param_.hidden; + int frame_size = hidden->shape().channel(); + + zynqmp::Shape hidden_shape{zynqmp::NCHW, {1, frame_size, 1, 1}}; + float16* prev_hidden_data = + prev_hidden_.mutableData(zynqmp::FP16, hidden_shape); + // set previous hidden data to 0; + memset(prev_hidden_data, 0, hidden_shape.numel() * sizeof(float16)); + + // copy 2/3 weight from param.weight; + zynqmp::Shape weight_shape{zynqmp::NC, {frame_size, frame_size * 2}}; + float* weight_data = weight_.mutableData(zynqmp::FP32, weight_shape); + memset(weight_data, 0, weight_shape.numel() * sizeof(float)); + weight_data = weight_.mutableData(zynqmp::FP32, weight_shape); + memcpy(weight_data, + param_.weight->data(), + weight_shape.numel() * sizeof(float)); + + Shape gate_shape(zynqmp::NC, {1, frame_size * 2}); + gate_ping_.mutableData(FP32, gate_shape); + gate_pong_.mutableData(FP16, gate_shape); + + zynqmp::FullyConnectedParam& pre_out_param = pre_out_pe_.param(); + pre_out_param.input = &prev_hidden_; + pre_out_param.output = &gate_pong_; + pre_out_param.filter = &weight_; + pre_out_param.bias = &gate_ping_; + pre_out_pe_.init(); + pre_out_pe_.apply(); + + reset_gate_.mutableData(FP16, hidden_shape); + prev_hidden_.mutableData(FP16, hidden_shape); + reset_hidden_.mutableData(FP16, hidden_shape); + + ElementwiseMulParam& mul_param = mul_pe_.param(); + // mul_param.inputs = {&reset_gate_, &prev_hidden_}; + mul_param.input_x = &reset_gate_; + mul_param.input_y = &prev_hidden_; + mul_param.output = &reset_hidden_; + mul_pe_.init(); + mul_pe_.apply(); + } + + bool dispatch() { return true; } + + void gru_unit_reset_act(const lite_api::ActivationType active_gate, + GRUTensors& value, // NOLINT + int frame_size, + int batch_size) { + int stride_update = 3 * frame_size; + int stride_cell_state = 3 * frame_size; + int stride_hidden_prev = frame_size; + int stride_hidden = frame_size; + + float* update_gate_data = gate_ping_.data(); + float* reset_gate_data = update_gate_data + frame_size; + + for (int b = 0; b < batch_size; b++) { + Tensor tmp; + Shape s(NC, {1, frame_size}); + float* tmp_data = tmp.mutableData(FP32, s); + + for (int i = 0; i < frame_size; i++) { + update_gate_data[i] = + lite::arm::math::active_f32( + update_gate_data[i]); + reset_gate_data[i] = + lite::arm::math::active_f32( + reset_gate_data[i]); + } + memcpy(tmp_data, reset_gate_data, frame_size * sizeof(float)); + tmp.flush(); + reset_gate_.copyFrom(&tmp); + + Tensor* hidden_prev = value.pre_output; + if (hidden_prev) { + // TODO(chonwhite): change to pre_out; + prev_hidden_.copyFrom(value.pre_output); + } + mul_pe_.dispatch(); + update_gate_data += stride_update; + reset_gate_data += stride_update; + + // reset_hidden_prev += stride_hidden;// TODO + } + } + + void gru_unit_out_act(const lite_api::ActivationType active_node, + bool origin_mode, + GRUTensors& value, // NOLINT + int frame_size, + int batch_size) {} + + void copy_input(GRUTensors& value) { // NOLINT + float max = find_max(*(value.gate)); + gate_ping_.mutableData(FP32, value.gate->shape()); + gate_ping_.copyFrom(value.gate); + // update input pointer? + } + + void GRUCOmpute(GRUTensors& value, // NOLINT + int frame_size, + int batch_size, + const lite_api::ActivationType active_node, + const lite_api::ActivationType active_gate, + bool origin_mode) { + copy_input(value); + + if (value.pre_output) { + // copy by batch; + pre_out_pe_.dispatch(); + gate_ping_.copyFrom(&gate_pong_); + } + + gru_unit_reset_act(active_gate, value, frame_size, batch_size); + } + + GRUParam& param() { return param_; } + + Tensor* updateGate() { return &update_gate_; } + + Tensor* resetGate() { return &reset_gate_; } + + private: + GRUParam param_; + zynqmp::Tensor gate_ping_; + zynqmp::Tensor gate_pong_; + zynqmp::Tensor bias_; + zynqmp::Tensor weight_; + zynqmp::Tensor state_weight_; + zynqmp::Tensor update_gate_; + zynqmp::Tensor reset_gate_; + zynqmp::Tensor cell_state_; + zynqmp::Tensor prev_hidden_; + zynqmp::Tensor reset_hidden_; + + Tensor tempTensor; + + ReluPE update_relu_pe_; + ReluPE reset_relu_pe_; + zynqmp::ElementwiseMulPE mul_pe_; + zynqmp::FullyConnectedPE pre_out_pe_; + zynqmp::FullyConnectedPE reset_out_pe_; + + zynqmp::ElementwiseAddPE bias_ew_pe_; +}; + +} // namespace zynqmp +} // namespace paddle diff --git a/lite/kernels/xpu/bridges/paddle_use_xpu_bridges.h b/lite/backends/fpga/KD/pes/gru_util.hpp similarity index 71% rename from lite/kernels/xpu/bridges/paddle_use_xpu_bridges.h rename to lite/backends/fpga/KD/pes/gru_util.hpp index 3c76e0e8b5cf0842cb8d5a613cef7aee3cd13bdb..d49169846f4f18e4d8e30f3658c2173157678f81 100644 --- a/lite/kernels/xpu/bridges/paddle_use_xpu_bridges.h +++ b/lite/backends/fpga/KD/pes/gru_util.hpp @@ -14,13 +14,10 @@ #pragma once -#include "lite/kernels/xpu/bridges/registry.h" +#include "lite/backends/arm/math/gru_utils.h" -USE_XPU_BRIDGE(relu); -USE_XPU_BRIDGE(conv2d); -USE_XPU_BRIDGE(depthwise_conv2d); -USE_XPU_BRIDGE(elementwise_add); -USE_XPU_BRIDGE(pool2d); -USE_XPU_BRIDGE(softmax); -USE_XPU_BRIDGE(mul); -USE_XPU_BRIDGE(batch_norm); +namespace paddle { +namespace lite { +namespace fpga {} +} +} diff --git a/lite/backends/fpga/KD/pes/norm_pe.hpp b/lite/backends/fpga/KD/pes/norm_pe.hpp index 3e2fd8062766c84282233b91fcaecf5e0a26fd72..0537df27e212014ed309245b0e86b8d8f077489e 100644 --- a/lite/backends/fpga/KD/pes/norm_pe.hpp +++ b/lite/backends/fpga/KD/pes/norm_pe.hpp @@ -72,8 +72,10 @@ class NormPE : public PE { input_float.mutableData(FP32, param_.input->shape()); float_out.mutableData(FP32, param_.output->shape()); + // param_.input->syncToDevice(); input_float.copyFrom(param_.input); input_float.syncToCPU(); + // input_float.saveToFile("normalize_", true); int channel = input_float.shape().channel(); int height = input_float.shape().height(); @@ -85,6 +87,7 @@ class NormPE : public PE { float* out_ptr = float_out.data(); int loop = height * width; +#pragma omp parallel for for (int i = 0; i < loop; i++) { float sum = param_.epsilon; for (int c = 0; c < channel; c++) { @@ -98,11 +101,26 @@ class NormPE : public PE { } } float_out.flush(); + // float_out.saveToFile("normalize_", true); param_.output->copyFrom(&float_out); } bool dispatch() { cpuCompute(); + // std::cout << "CPU normalize ---------------------" << std::endl; + + // param_.input->syncToDevice(); + // // param_.input->saveToFile("normalize_fpga_", true); + // config_norm_param(norm_param_args_); + // inplace_args_.normalize_enable = true; + // config_inplace(inplace_args_); + + // perform_bypass(bypass_args_); + // inplace_args_.normalize_enable = false; + // config_inplace(inplace_args_); + // compute_norm(norm_args_); + // param_.output->saveToFile("normalize_fpga_", true); + // std::cout << "FPGA normalize ---------------------" << std::endl; return true; } diff --git a/lite/backends/fpga/KD/pes/output_pe.hpp b/lite/backends/fpga/KD/pes/output_pe.hpp old mode 100644 new mode 100755 index 1c99386ab19f485c07723c7fcc8501bdf5556f6c..2944691693b135a2d2df7b91ecbe0ef249b015d8 --- a/lite/backends/fpga/KD/pes/output_pe.hpp +++ b/lite/backends/fpga/KD/pes/output_pe.hpp @@ -25,6 +25,8 @@ class OutputPE : public PE { bool init() { Tensor* output = param_.output; output->setAligned(false); + DLEngine::get_instance().out_data = reinterpret_cast( + fpga_malloc(output->shape().numel() * sizeof(float))); return true; } @@ -41,6 +43,15 @@ class OutputPE : public PE { } else { output->copyFrom(input); } + // + output->syncToCPU(); + if (DLEngine::get_instance().out_data == nullptr) { + DLEngine::get_instance().out_data = reinterpret_cast( + fpga_malloc(output->shape().numel() * sizeof(float))); + } + memcpy(DLEngine::get_instance().out_data, + output->data(), + output->shape().numel() * sizeof(float)); return true; } diff --git a/lite/backends/fpga/KD/pes/pooling_pe.hpp b/lite/backends/fpga/KD/pes/pooling_pe.hpp index fd3be1f463d3bfce925cc4ce5444d119c33e5692..60755ee1dbf81512bde618389cbf3a88cf93d1ce 100644 --- a/lite/backends/fpga/KD/pes/pooling_pe.hpp +++ b/lite/backends/fpga/KD/pes/pooling_pe.hpp @@ -35,12 +35,17 @@ class PoolingPE : public PE { Tensor* input = param_.input; Tensor* output = param_.output; - uint32_t k_width = param_.kernelSize[0]; - uint32_t k_height = param_.kernelSize[1]; + uint32_t k_height = 1; + uint32_t k_width = 1; if (param_.globalPooling) { k_width = input->shape().width(); k_height = input->shape().height(); + param_.kernelSize[0] = k_height; + param_.kernelSize[1] = k_width; + } else { + k_height = param_.kernelSize[0]; + k_width = param_.kernelSize[1]; } PoolingArgs args = {0}; @@ -63,8 +68,12 @@ class PoolingPE : public PE { args.out_width = output->shape().width(); param_.poolingArgs = args; + // use_cpu_ = output->shape().width() == 1 && output->shape().height() == 1 + // && + // (k_width > 7 || k_height > 7); use_cpu_ = output->shape().width() == 1 && output->shape().height() == 1 && - (k_width > 7 || k_height > 7); + (k_width > 255 || k_height > 255); + // use_cpu_ = param_.type == AVERAGE; } void compute() { @@ -73,6 +82,7 @@ class PoolingPE : public PE { input->syncToCPU(); Tensor float_input; + // Tensor float_output; float* image_addr = float_input.mutableData(FP32, input->shape()); float_input.copyFrom(input); float16* data_out = output->data(); @@ -107,6 +117,8 @@ class PoolingPE : public PE { for (int c = 0; c < image_channels; ++c) { const int pool_index = (ph * pooled_width_ + pw) * image_channels + c; float sum = 0; + // const int index = + // (hstart * image_width + wstart) * image_channels + c; for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { const int index = (h * image_width + w) * image_channels + c; @@ -127,7 +139,7 @@ class PoolingPE : public PE { output->flush(); } - void cpu_compute() { + void cpu_compute1() { Tensor* input = param_.input; Tensor* output = param_.output; input->syncToCPU(); @@ -135,6 +147,7 @@ class PoolingPE : public PE { Tensor float_input; float_input.mutableData(FP32, input->shape()); float_input.copyFrom(input); + // float_input.saveToFile("pool_float.txt"); float16* data_out = output->data(); int kernel_hw = param_.kernelSize[0] * param_.kernelSize[1]; @@ -152,13 +165,45 @@ class PoolingPE : public PE { } output->scale()[0] = scale_max / 127.0f; output->scale()[1] = 127.0f / scale_max; - std::cout << "pool scale:" << scale_max / 127.0f << std::endl; output->flush(); + // exit(-1); + } + + void cpu_compute() { + Tensor* input = param_.input; + Tensor* output = param_.output; + input->syncToCPU(); + + Tensor float_input; + float* float_input_data = + float_input.mutableData(FP32, input->shape()); + float_input.copyFrom(input); + + float16* data_out = output->data(); + + int kernel_hw = param_.kernelSize[0] * param_.kernelSize[1]; + + float scale_max = 0; + for (int i = 0; i < output->shape().channel(); i++) { + float sum = 0; + for (int j = 0; j < kernel_hw; j++) { + sum += float_input_data[i * kernel_hw + j]; + } + float value = sum / kernel_hw; + data_out[i] = float_to_half(value); + scale_max = std::max(scale_max, std::abs(value)); + } + output->scale()[0] = scale_max / 127.0f; + output->scale()[1] = 127.0f / scale_max; + output->flush(); + // exit(-1); } bool dispatch() { if (use_cpu_) { + // cpu_compute(); compute(); + // exit(-1); return true; } param_.input->syncToDevice(); diff --git a/lite/backends/fpga/KD/pes/prior_box_pe.cpp b/lite/backends/fpga/KD/pes/prior_box_pe.cpp index d6a503a31d4e0736724740ce1875c916969d93e0..00dfe1830f6f44cbf6a30708fa5783563470c686 100644 --- a/lite/backends/fpga/KD/pes/prior_box_pe.cpp +++ b/lite/backends/fpga/KD/pes/prior_box_pe.cpp @@ -253,9 +253,8 @@ bool PriorBoxPE::dispatch() { if (cachedBoxes_ == nullptr) { cachedBoxes_ = new Tensor(); cachedVariances_ = new Tensor(); - cachedBoxes_->mutableData(FP16, param_.outputBoxes->shape()); - cachedVariances_->mutableData(FP16, - param_.outputVariances->shape()); + cachedBoxes_->mutableData(FP32, param_.outputBoxes->shape()); + cachedVariances_->mutableData(FP32, param_.outputVariances->shape()); cachedBoxes_->setDataLocation(CPU); cachedVariances_->setDataLocation(CPU); compute_prior_box(); diff --git a/lite/backends/fpga/KD/pes/scale_pe.hpp b/lite/backends/fpga/KD/pes/scale_pe.hpp index d5e16615d9943a1771dfabe916433768ecf16319..09755c65a322da8ccab0d57dd2e877712b112361 100755 --- a/lite/backends/fpga/KD/pes/scale_pe.hpp +++ b/lite/backends/fpga/KD/pes/scale_pe.hpp @@ -14,11 +14,16 @@ limitations under the License. */ #pragma once +#include + #include "lite/backends/fpga/KD/pe.hpp" #include "lite/backends/fpga/KD/pe_params.hpp" +#include "lite/backends/fpga/KD/pes/depthwise_conv_pe.hpp" +#include "lite/backends/fpga/KD/tensor.hpp" namespace paddle { namespace zynqmp { + class ScalePE : public PE { public: inline int gcd(int a, int b) { @@ -42,6 +47,8 @@ class ScalePE : public PE { Tensor* input = param_.input; Tensor* output = param_.output; Shape& input_shape = input->shape(); + DepthwiseConvParam& dw_param = dw_pe_.param(); + int channel = input_shape.channel(); int repeat = 1; int alignment = 16; @@ -51,70 +58,141 @@ class ScalePE : public PE { int c_lcm = lcm(channel, alignment); repeat = c_lcm / (channel); } + + // FPGA限制 H >2047, W >1023 , WC> 65536 ,需要使用CPU实现 Shape shape(N, {channel * repeat}); - param_.alignedBias()->mutableData(FP16, shape); - param_.alignedScale()->mutableData(FP16, shape); - float16* bias_data = param_.alignedBias()->data(); - float16* scale_data = param_.alignedScale()->data(); + float* filter_data = filter.mutableData(FP32, shape); + std::fill_n(filter_data, input->shape().channel(), 1.0f); - if (param_.bias != nullptr) { - float* bias_data_float = param_.bias->data(); + Tensor* scale = dw_param.scale(); + float16* scale_data = scale->mutableData(FP16, shape); + + Tensor* bias = dw_param.bias(); + float16* bias_data = bias->mutableData(FP16, shape); + std::fill_n(bias_data, input->shape().channel(), 0); + + if (param_.scale->dataType() == FP32) { + if (param_.bias != nullptr) { + float* bias_data_float = param_.bias->data(); + for (int i = 0; i < repeat; i++) { + for (int j = 0; j < length; j++) { + float16 value = float_to_half(bias_data_float[j]); + bias_data[i * length + j] = value; + } + } + } else { + float16 zero = float_to_half(0.0f); + for (int i = 0; i < repeat; i++) { + for (int j = 0; j < length; j++) { + bias_data[i * length + j] = zero; + } + } + } + float* scale_data_float = param_.scale->data(); for (int i = 0; i < repeat; i++) { for (int j = 0; j < length; j++) { - float16 value = float_to_half(bias_data_float[j]); - bias_data[i * length + j] = value; + float16 value = float_to_half(scale_data_float[j]); + scale_data[i * length + j] = value; } } } else { - float16 zero = float_to_half(0.0f); + if (param_.bias != nullptr) { + float16* bias_data_float = param_.bias->data(); + for (int i = 0; i < repeat; i++) { + for (int j = 0; j < length; j++) { + float16 value = bias_data_float[j]; + bias_data[i * length + j] = value; + } + } + } else { + float16 zero = float_to_half(0.0f); + for (int i = 0; i < repeat; i++) { + for (int j = 0; j < length; j++) { + bias_data[i * length + j] = zero; + } + } + } + + float16* scale_data_float = param_.scale->data(); for (int i = 0; i < repeat; i++) { for (int j = 0; j < length; j++) { - bias_data[i * length + j] = zero; + float16 value = scale_data_float[j]; + scale_data[i * length + j] = value; } } } - float* scale_data_float = param_.scale->data(); - for (int i = 0; i < repeat; i++) { - for (int j = 0; j < length; j++) { - float16 value = float_to_half(scale_data_float[j]); - scale_data[i * length + j] = value; + dw_param.input = param_.input; + dw_param.output = param_.output; + dw_param.filter = &filter; + + dw_param.strides = {1, 1}; + dw_param.paddings = {0, 0}; + dw_param.kernelSize = {1, 1}; + dw_param.dilations = {1, 1}; + + dw_pe_.init(); + dw_pe_.apply(); + } + + void cpu_compute() { + Tensor* input = param_.input; + Tensor* output = param_.output; + Tensor float_input; + float* image_addr = float_input.mutableData(FP32, input->shape()); + input->syncToCPU(); + float_input.copyFrom(input); + float16* data_out = output->data(); + + float* scale_data = param_.scale->data(); + + int wh = input->shape().width() * input->shape().height(); + + float16* in_data = input->data(); + + float max = 0; + + for (int i = 0; i < wh; i++) { + for (int c = 0; c < input->shape().channel(); c++) { + int index = i * input->shape().channel() + c; + float value = half_to_float(in_data[index]) * scale_data[c]; + data_out[index] = float_to_half(value); + + if (value < 0) { + value = -value; + } + if (value > max) { + max = value; + } } } - - param_.alignedScale()->flush(); - param_.alignedBias()->flush(); - - int wc = input_shape.width() * input_shape.channel(); - int wc_aligned = align_image(wc); - - ScaleArgs& args = param_.args; - args.scale_address = param_.alignedScale()->data(); - args.bias_address = param_.alignedBias()->data(); - args.wc_alignment = wc_aligned; - args.channel_alignment = channel * repeat; - - args.image.address = input->data(); - args.image.scale_address = input->scale(); - args.image.channels = channel; - args.image.height = input_shape.height(); - args.image.width = input_shape.width(); - args.image.pad_width = 0; - args.image.pad_height = 0; - args.output.address = output->data(); - args.output.scale_address = output->scale(); + output->flush(); + output->scale()[0] = max / 127.0f; + output->scale()[1] = 127.0f / max; } bool dispatch() { + if (param_.scale->dataType() == FP16) { + DepthwiseConvParam& dw_param = dw_pe_.param(); + memcpy(dw_param.quantizedFilter()->mutableData(), + param_.scale->data(), + param_.scale->shape().numel() * sizeof(float16)); + dw_param.quantizedFilter()->scale()[0] = param_.scale->scale()[0]; + dw_param.quantizedFilter()->scale()[1] = param_.scale->scale()[1]; + + dw_param.quantizedFilter()->flush(); + } param_.input->syncToDevice(); - return compute_fpga_scale(param_.args) == 0; + return dw_pe_.dispatch(); } ScaleParam& param() { return param_; } private: ScaleParam param_; + Tensor filter; + DepthwiseConvPE dw_pe_; }; } // namespace zynqmp } // namespace paddle diff --git a/lite/backends/fpga/KD/pes/split_pe.hpp b/lite/backends/fpga/KD/pes/split_pe.hpp index 26598a4c87f0b88882b3fe76de64ddfa5c6cd6a8..01a036787441c596bf74858aa9bf6a6613864cc1 100644 --- a/lite/backends/fpga/KD/pes/split_pe.hpp +++ b/lite/backends/fpga/KD/pes/split_pe.hpp @@ -53,20 +53,37 @@ class SplitPE : public PE { int64_t src_after = src_stride_numel[axis]; int64_t dst_after = dst_stride_numel[axis]; + // PADDLE_MOBILE_ENFORCE(src_stride_numel.size() == dst_stride_numel.size(), + // "src and dst tensor should have the same dims + // size."); + for (int64_t i = 0; i < axis; ++i) { if (i < axis) { + // PADDLE_MOBILE_ENFORCE(src_stride_numel[i] / src_stride_numel[axis] == + // dst_stride_numel[i] / + // dst_stride_numel[axis], + // "src and dst should have the same elements " + // "except the specified axis."); } else if (i == axis) { continue; } else { + // PADDLE_MOBILE_ENFORCE(src_stride_numel[i] == dst_stride_numel[i], + // "src and dst should have the same elements " + // "except the specified axis."); } } for (int64_t i = 0; i < before; ++i) { - memory::Copy(dst + i * dst_after, src + i * src_after, sizeof(T) * size); + memcpy(dst + i * dst_after, src + i * src_after, sizeof(T) * size); } } - void split3D() { int axis = param_.axis; } + void split3D() { + int axis = param_.axis; + // float16* dst = param_.output->data(); + // std::vector& dst_dims = ; + // StridedNumelCopyWithAxis(); + } bool dispatch() { Tensor* input = param_.input; @@ -88,6 +105,7 @@ class SplitPE : public PE { in_stride, out_stride[axis]); input_offset += out_stride[axis]; + // out->flush(); } return true; } @@ -95,21 +113,26 @@ class SplitPE : public PE { std::vector outputs = param_.outputs; int in_channel = input->shape().channel(); - int split_channel = input->shape().channel() / param_.num; + // int split_channel = input->shape().channel() / param_.num; int hw = input->shape().height() * input->shape().width(); float16* in_data = input->data(); + for (int i = 0; i < hw; i++) { + int channel_stride = 0; for (int n = 0; n < outputs.size(); n++) { Tensor* out = outputs[n]; float16* out_data = out->data(); - memcpy(out_data + i * split_channel, - in_data + i * in_channel + n * split_channel, - split_channel * sizeof(float16)); + memcpy(out_data + i * out->shape().channel(), + in_data + i * in_channel + channel_stride, + out->shape().channel() * sizeof(float16)); + channel_stride += out->shape().channel(); } } + for (int n = 0; n < outputs.size(); n++) { Tensor* out = outputs[n]; + out->flush(); out->copyScaleFrom(input); } return true; @@ -120,5 +143,6 @@ class SplitPE : public PE { private: SplitParam param_; }; + } // namespace zynqmp } // namespace paddle diff --git a/lite/backends/fpga/KD/shape.hpp b/lite/backends/fpga/KD/shape.hpp index 566ad8e6ff2eff32301e83b6cdb5b1addd0117fe..c25c3315145137a147928a164fcabd2923b09e87 100755 --- a/lite/backends/fpga/KD/shape.hpp +++ b/lite/backends/fpga/KD/shape.hpp @@ -23,6 +23,7 @@ limitations under the License. */ namespace paddle { namespace zynqmp { +static struct None none_; static struct NCHW nchw_; static struct NHWC nhwc_; static struct NC nc_; @@ -82,6 +83,9 @@ class Shape { void setLayoutType(LayoutType layout) { this->layoutType_ = layout; switch (layout) { + case None: + layout_ = &none_; + break; case NCHW: layout_ = &nchw_; break; diff --git a/lite/backends/fpga/KD/tensor.hpp b/lite/backends/fpga/KD/tensor.hpp index f003ded33eb51136ae0ae0a2c21988460232f89a..988bc1bb507036de8f13a6c6549c549718bd1256 100644 --- a/lite/backends/fpga/KD/tensor.hpp +++ b/lite/backends/fpga/KD/tensor.hpp @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include +#include #include #include #include @@ -24,13 +25,10 @@ limitations under the License. */ #include #include -// #include "lite/core/tensor.h" - #include "lite/backends/fpga/KD/dl_engine.hpp" #include "lite/backends/fpga/KD/float16.hpp" #include "lite/backends/fpga/KD/llapi/zynqmp_api.h" #include "lite/backends/fpga/KD/shape.hpp" -// #include "lite/backends/fpga/KD/types.hpp" namespace paddle { namespace zynqmp { @@ -117,7 +115,8 @@ class Tensor { template Dtype* mutableData() { - size_t memorySize = shape_->memorySize(CellSize(dataType_)); + size_t memorySize = + shape_->memorySize(CellSize(dataType_)) * mem_scale_factor_; if (placeHolder_ != nullptr) { if (memorySize > placeHolder_->memorySize()) { placeHolder_.reset(new PlaceHolder(memorySize)); @@ -241,6 +240,10 @@ class Tensor { } } + void setMemScale(float scale_factor) { + this->mem_scale_factor_ = scale_factor; + } + void shareDataWith(Tensor* src) { shareDataWith(src, src->shape()); } void shareDataWith(Tensor* src, const Shape& shape, int offset = 0) { @@ -276,9 +279,11 @@ class Tensor { .height = 1, .pad_width = 0u, .pad_height = 0u}; - args.output = { + + ImageOutputArgs output = { .address = data(), .scale_address = scale(), }; + args.output = output; src->syncToDevice(); size_t aligned_remainder = src->shape().numel() % 16; if (aligned_remainder > 0) { @@ -294,10 +299,14 @@ class Tensor { this->invalidate(); } - void flush() { fpga_flush(placeHolder_->data(), placeHolder_->memorySize()); } + void flush() { + size_t memorySize = placeHolder_->memorySize(); + fpga_flush(placeHolder_->data(), memorySize); + } void invalidate() { - fpga_invalidate(placeHolder_->data(), placeHolder_->memorySize()); + size_t memorySize = placeHolder_->memorySize(); + fpga_invalidate(placeHolder_->data(), memorySize); } void sync() { @@ -339,6 +348,8 @@ class Tensor { } } + void printScale(std::string type) { printScale(); } + std::string dimsFileName() { return std::to_string(shape_->num()) + "_" + std::to_string(shape_->channel()) + "_" + @@ -358,29 +369,9 @@ class Tensor { saveToFile(path); } - friend std::ostream& operator<<(std::ostream& os, Tensor& tensor) { - os << "tensor:" - << "\n"; - os << "dims: {"; - for (int i = 0; i < tensor.shape().dimSize(); ++i) { - os << tensor.shape()[i] << " "; - } - os << "}\n"; - for (int i = 0; i < tensor.shape().numel(); i++) { - float value = 0; - if (tensor.dataType() == FP32) { - value = tensor.data()[i]; - } else { - value = half_to_float(tensor.data()[i]); - } - os << value << " "; - } - os << "\n"; - return os; - } - void saveToFile(std::string path) { syncToCPU(); + invalidate(); std::ofstream ofs; static int counter = 0; std::string npath = std::to_string(counter) + "_" + path; @@ -389,17 +380,19 @@ class Tensor { } void save_file_with_name(std::string path) { - // return; invalidate(); std::ofstream ofs; - ofs.open(path); + ofs << scale()[0] << " / " << scale()[1] << std::endl; + for (int i = 0; i < shape_->numel(); i++) { float value = 0; if (dataType_ == FP32) { value = data()[i]; - } else { + } else if (dataType_ == FP16) { value = half_to_float(data()[i]); + } else { + value = data()[i]; } ofs << value << std::endl; } @@ -415,18 +408,49 @@ class Tensor { int num = shape_->numel(); invalidate(); float max = 0.0f; - float16* data = mutableData(); - for (int i = 0; i < num; ++i) { - float value = 0; - file_stream >> value; - max = std::max(std::abs(value), max); - data[i] = float_to_half(value); + if (dataType_ == FP16) { + float16* data = mutableData(); + for (int i = 0; i < num; ++i) { + float value = 0; + file_stream >> value; + max = std::max(std::abs(value), max); + data[i] = float_to_half(value); + } + } else { + float* data = mutableData(); + for (int i = 0; i < num; ++i) { + float value = 0; + file_stream >> value; + max = std::max(std::abs(value), max); + data[i] = value; + } } flush(); placeHolder_->scale_[0] = max / 127.0f; placeHolder_->scale_[1] = 127.0f / max; } + friend std::ostream& operator<<(std::ostream& os, Tensor& tensor) { + os << "tensor:" + << "\n"; + os << "dims: {"; + for (int i = 0; i < tensor.shape().dimSize(); ++i) { + os << tensor.shape()[i] << " "; + } + os << "}\n"; + for (int i = 0; i < tensor.shape().numel(); i++) { + float value = 0; + if (tensor.dataType() == FP32) { + value = tensor.data()[i]; + } else { + value = half_to_float(tensor.data()[i]); + } + os << value << " "; + } + os << "\n"; + return os; + } + ~Tensor() { if (shape_ != nullptr) { delete shape_; @@ -436,6 +460,7 @@ class Tensor { private: int offset = 0; + float mem_scale_factor_ = 1.0f; std::shared_ptr placeHolder_; Shape* shape_ = nullptr; DataType dataType_ = FP32; diff --git a/lite/backends/fpga/lite_tensor.cc b/lite/backends/fpga/lite_tensor.cc old mode 100644 new mode 100755 index 43218173fd05626fb46495bb254b250c14e5417a..7f1e8d3e17f97315e77532b77bbcfcc8331edd4f --- a/lite/backends/fpga/lite_tensor.cc +++ b/lite/backends/fpga/lite_tensor.cc @@ -95,16 +95,14 @@ void TensorLite::CopyDataFrom(const TensorLite &other) { dims_ = other.dims_; target_ = other.target_; lod_ = other.lod_; - // memory_size_ = other.memory_size_; - // buffer_->CopyDataFrom(*other.buffer_, memory_size_); - zynq_tensor_->mutableData(other.zynq_tensor_->dataType(), - other.zynq_tensor_->shape()); -} + auto dt = zynq_tensor_->dataType(); -// template -// void TensorLite::mutable_data_internal() { + auto shape = other.zynq_tensor_->shape(); -// } + Resize(other.dims()); + zynq_tensor_->mutableData(zynq_tensor_->dataType(), shape); + this->ZynqTensor()->copyFrom(other.ZynqTensor()); +} } // namespace lite } // namespace paddle diff --git a/lite/backends/fpga/lite_tensor.h b/lite/backends/fpga/lite_tensor.h index 2f9df3abb08dd15641323f4a3c59d6175f2e481b..266e0b5ce0ea03108978c3b0a32fbf0e3872c83c 100644 --- a/lite/backends/fpga/lite_tensor.h +++ b/lite/backends/fpga/lite_tensor.h @@ -106,7 +106,7 @@ class TensorLite { // For other devices, T and R may be the same type. template const R *data() const { - return zynq_tensor_->data(); + return zynq_tensor_->data() + offset_; } void Resize(const DDimLite &ddim) { dims_ = ddim; } @@ -125,6 +125,7 @@ class TensorLite { bool persistable() const { return persistable_; } void set_persistable(bool persistable) { persistable_ = persistable; } + // T is the data type and R is the return type // For OpenCL, the return type can be cl::Buffer // and the data type can be float/int8_t. @@ -147,7 +148,13 @@ class TensorLite { size_t memory_size() const { return zynq_tensor_->memorySize(); } + size_t offset() const { return offset_; } + bool IsInitialized() const { return buffer_->data(); } + void clear() { + buffer_->Free(); + offset_ = 0; + } // Other share data to this. void ShareDataWith(const TensorLite &other); @@ -157,6 +164,9 @@ class TensorLite { template TensorLite Slice(int64_t begin, int64_t end) const; + template + void Slice(TensorLite &dst, int64_t begin, int64_t end) const; // NOLINT + TargetType target() const { return target_; } zynqmp::Tensor *ZynqTensor() const { return zynq_tensor_; } @@ -173,16 +183,21 @@ class TensorLite { private: TargetType target_{TargetType::kHost}; + + // precision_ and persistable_ are only used for persistable vars. + // If your tensor wants to be saved and loaded correctly, you must + // set values of precision_ and persistable_ after updating it. + // If your tensor is just a temp tensor, such as activations, + // you can ignore these two attributes. + PrecisionType precision_{PrecisionType::kUnk}; + bool persistable_{false}; + DDimLite dims_; std::shared_ptr buffer_; LoD lod_; size_t memory_size_{}; - size_t offset_{0}; - PrecisionType precision_{PrecisionType::kUnk}; - bool persistable_{false}; - zynqmp::Tensor *zynq_tensor_ = new zynqmp::Tensor(); template @@ -197,6 +212,9 @@ R *TensorLite::mutable_data() { } zynqmp::LayoutType layout_type = zynqmp::NCHW; switch (v.size()) { + case 0: + layout_type = zynqmp::None; + break; case 1: layout_type = zynqmp::N; break; @@ -228,24 +246,60 @@ R *TensorLite::mutable_data(TargetType target) { return mutable_data(); } -template -bool TensorCompareWith(const TensorT &a, const TensorT &b) { - if (a.dims() != b.dims()) return false; - if (memcmp(a.raw_data(), b.raw_data(), a.data_size()) != 0) return false; - return true; -} template TensorLite TensorLite::Slice(int64_t begin, int64_t end) const { - int64_t base = numel() / dims_[0]; + throw - 1; + CHECK_GE(begin, 0); + CHECK_LE(end, dims_[0]); + CHECK_LT(begin, end); + if (dims_[0] == 1) { + return *this; + } else { + int64_t base = numel() / dims_[0]; + + TensorLite dst; + dst.target_ = target_; + auto dst_dims = dims_; + dst_dims[0] = end - begin; + dst.Resize(dst_dims); + void *dst_data = dst.mutable_data(); + + T *src_data = const_cast(data()); + memcpy(dst_data, + src_data + static_cast(begin * base) * sizeof(T), + dst_dims.production() * sizeof(T)); + dst.ZynqTensor()->saveToFile("_slice", true); + + return dst; + } +} + +template +void TensorLite::Slice(TensorLite &dst, int64_t begin, int64_t end) const { + CHECK_GE(begin, 0); + CHECK_LE(end, dims_[0]); + CHECK_LT(begin, end); - TensorLite dst; - dst.buffer_ = buffer_; dst.target_ = target_; auto dst_dims = dims_; dst_dims[0] = end - begin; dst.Resize(dst_dims); - dst.offset_ = offset_ + static_cast(begin * base) * sizeof(T); - return dst; + void *dst_data = dst.mutable_data(); + + int64_t base = numel() / dims_[0]; + + T *src_data = const_cast(data()); + memcpy(dst_data, + src_data + static_cast(begin * dst_dims.production()), + dst_dims.production() * sizeof(T)); } + +template +bool TensorCompareWith(const TensorT &a, const TensorT &b) { + if (a.dims() != b.dims()) return false; + if (memcmp(a.raw_data(), b.raw_data(), a.data_size()) != 0) return false; + return true; +} + } // namespace lite } // namespace paddle diff --git a/lite/backends/npu/CMakeLists.txt b/lite/backends/npu/CMakeLists.txt index 426ff5698146c773c818b2bfd598d6bbbdf7867f..1540741d331097961dcf7cd791c9785a9c53ddd1 100644 --- a/lite/backends/npu/CMakeLists.txt +++ b/lite/backends/npu/CMakeLists.txt @@ -2,5 +2,4 @@ if(NOT LITE_WITH_NPU) return() endif() -lite_cc_library(npu_runtime SRCS runtime.cc DEPS ${npu_runtime_libs}) -lite_cc_library(npu_builder SRCS builder.cc DEPS ${npu_builder_libs} npu_runtime tensor op scope) +lite_cc_library(device_npu SRCS device.cc DEPS ${npu_builder_libs} ${npu_runtime_libs}) diff --git a/lite/backends/npu/builder.h b/lite/backends/npu/builder.h deleted file mode 100644 index 70200354fbab15f043a537300e92e2a26a3d739e..0000000000000000000000000000000000000000 --- a/lite/backends/npu/builder.h +++ /dev/null @@ -1,145 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include -#include -#include "ai_ddk_lib/include/graph/buffer.h" -#include "ai_ddk_lib/include/graph/graph.h" -#include "ai_ddk_lib/include/graph/model.h" -#include "ai_ddk_lib/include/graph/op/all_ops.h" -#include "ai_ddk_lib/include/graph/operator.h" -#include "ai_ddk_lib/include/graph/operator_reg.h" -#include "ai_ddk_lib/include/hiai_ir_build.h" -#include "lite/core/op_lite.h" -#include "lite/core/target_wrapper.h" -#include "lite/core/tensor.h" - -// Extended Ops of HIAI DDK -namespace ge { -/** - * Pads a tensor. - * - * x : the input tensor - * padding : the input tensor must be 2-D - * constant_values : constant values must be a scalar - * - * output : the output tensor - * - * t_paddings : Default DT_INT32 , t_paddings must be the same with - * datatype of the padding - * mode : 0: CONSTANT, 1: REFLECT, 2: SYMMETRIC - * T : datatype of constant_values DT_INT32:3 DT_FLOAT:0 - */ -REG_OP(Pad) - .INPUT(x, TensorType({DT_FLOAT, DT_INT32})) - .INPUT(padding, TensorType({DT_INT32})) - .OPTIONAL_INPUT(constant_values, TensorType({DT_INT32, DT_FLOAT})) - .OUTPUT(output, TensorType({DT_FLOAT, DT_INT32})) - .ATTR(t_paddings, AttrValue::INT{3}) - .ATTR(mode, AttrValue::INT{0}) - .REQUIRED_ATTR(T, AttrValue::INT) - .OP_END(); - -} // namespace ge - -namespace paddle { -namespace lite { -namespace npu { - -class OpList { - public: - static OpList& Global() { - static thread_local OpList x; - return x; - } - void clear() { lists_.clear(); } - void add(std::shared_ptr p) { lists_.push_back(p); } - - private: - std::vector> lists_; -}; - -// Build HIAI IR graph to om model, and store om model data into lite tensor -bool BuildModel(std::vector& inputs, // NOLINT - std::vector& outputs, // NOLINT - lite::Tensor* model_data); - -std::string UniqueName(const std::string& prefix); - -ge::DataType CvtPrecisionType(PrecisionType itype); - -ge::Format CvtDataLayoutType(DataLayoutType itype); - -ge::TensorPtr CvtTensor(Tensor* in_tensor, - std::vector out_shape = {}, - PrecisionType in_ptype = PRECISION(kFloat), - DataLayoutType in_ltype = DATALAYOUT(kNCHW)); - -template -ge::TensorPtr CreateTensorAndFillData(std::vector data, - std::vector shape = {}, - ge::Format format = ge::FORMAT_NCHW) { - const std::type_info& info = typeid(T); - ge::DataType type = ge::DT_FLOAT; - if (info == typeid(float)) { - type = ge::DT_FLOAT; - } else if (info == typeid(int8_t)) { - type = ge::DT_INT8; - } else if (info == typeid(int32_t)) { - type = ge::DT_INT32; - } else { - LOG(FATAL) << "[NPU] Unknow value type " << info.name(); - } - if (shape.empty()) { - shape = {static_cast(data.size())}; - } else { - int size = 1; - for (auto i : shape) { - size *= i; - } - CHECK_EQ(data.size(), size); - } - ge::TensorDesc desc(ge::Shape(shape), format, type); - ge::TensorPtr tensor = std::make_shared(); - tensor->SetTensorDesc(desc); - tensor->SetData(reinterpret_cast(data.data()), - data.size() * sizeof(T)); - return tensor; -} - -template -ge::TensorPtr CreateTensorAndFillData(T value, - std::vector shape = {1}, - ge::Format format = ge::FORMAT_NCHW) { - int64_t size = 1; - for (auto i : shape) { - size *= i; - } - std::vector data(size, value); - return CreateTensorAndFillData(data, shape, format); -} - -int CvtActMode(std::string act_type); - -bool HasInputArg(const OpInfo* op_info, - const Scope* scope, - const std::string& argname); - -} // namespace npu -} // namespace lite -} // namespace paddle diff --git a/lite/backends/npu/device.cc b/lite/backends/npu/device.cc new file mode 100644 index 0000000000000000000000000000000000000000..d62ac9cad3e5ab4e6f63e3b667e3fa93e244fec1 --- /dev/null +++ b/lite/backends/npu/device.cc @@ -0,0 +1,70 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/backends/npu/device.h" +#include "lite/utils/cp_logging.h" + +namespace paddle { +namespace lite { +namespace npu { + +std::unique_ptr Device::Build( + std::string& model_name, // NOLINT + std::vector& input_nodes, // NOLINT + std::vector& output_nodes // NOLINT + ) { + VLOG(3) << "[NPU] Build model"; + // Build the HiAI IR graph to the HiAI om model + ge::Graph ir_graph("graph"); + ir_graph.SetInputs(input_nodes).SetOutputs(output_nodes); + ge::Model om_model("model", "model"); + om_model.SetGraph(ir_graph); + domi::HiaiIrBuild ir_build; + domi::ModelBufferData om_model_buf; + if (!ir_build.CreateModelBuff(om_model, om_model_buf)) { + LOG(WARNING) << "[NPU] CreateModelBuff failed!"; + return nullptr; + } + if (!ir_build.BuildIRModel(om_model, om_model_buf)) { + LOG(WARNING) << "[NPU] BuildIRModel failed!"; + ir_build.ReleaseModelBuff(om_model_buf); + return nullptr; + } + // Create a HiAI model manager client to load the HiAI om model + std::unique_ptr model_client( + new hiai::AiModelMngerClient()); + if (model_client->Init(nullptr) != hiai::AI_SUCCESS) { + LOG(WARNING) << "[NPU] AiModelMngerClient init failed)!"; + ir_build.ReleaseModelBuff(om_model_buf); + return nullptr; + } + model_name = "model_" + std::to_string(model_count_++) + ".om"; + auto model_desc = std::make_shared( + model_name, freq_level(), framework_type(), model_type(), device_type()); + model_desc->SetModelBuffer(om_model_buf.data, om_model_buf.length); + std::vector> model_descs; + model_descs.push_back(model_desc); + if (model_client->Load(model_descs) != hiai::AI_SUCCESS) { + LOG(WARNING) << "[NPU] AiModelMngerClient load model failed!"; + ir_build.ReleaseModelBuff(om_model_buf); + return nullptr; + } + ir_build.ReleaseModelBuff(om_model_buf); + VLOG(3) << "[NPU] Build done"; + return model_client; +} + +} // namespace npu +} // namespace lite +} // namespace paddle diff --git a/lite/backends/npu/runtime.h b/lite/backends/npu/device.h similarity index 63% rename from lite/backends/npu/runtime.h rename to lite/backends/npu/device.h index 8b1ad51518d8626d9a6ecd6203a70b2637bb6004..411600ae0a38e4ee1b4a3ce3d6519b927eeb0a1a 100644 --- a/lite/backends/npu/runtime.h +++ b/lite/backends/npu/device.h @@ -13,38 +13,47 @@ // limitations under the License. #pragma once + #include #include -#include "ai_ddk_lib/include/HiAiModelManagerService.h" -#include "lite/core/tensor.h" +#include +#include +#include "HiAiModelManagerService.h" // NOLINT +#include "hiai_ir_build.h" // NOLINT namespace paddle { namespace lite { namespace npu { -class DeviceInfo { +class Device { public: - static DeviceInfo &Global() { - static DeviceInfo x; + static Device& Global() { + static Device x; return x; } - DeviceInfo() {} + Device() {} int freq_level() { return freq_level_; } int framework_type() { return framework_type_; } int model_type() { return model_type_; } int device_type() { return device_type_; } + // Build the HiAI IR graph to om model, return HiAI model manager client to + // load om model and run inference. + std::unique_ptr Build( + std::string& model_name, // NOLINT + std::vector& input_nodes, // NOLINT + std::vector& output_nodes // NOLINT + ); // NOLINT + private: int freq_level_{3}; int framework_type_{0}; int model_type_{0}; int device_type_{0}; + int model_count_{0}; }; -bool LoadModel(const lite::Tensor &model_data, - std::shared_ptr *model_client, - std::string *model_name); } // namespace npu } // namespace lite } // namespace paddle diff --git a/lite/backends/npu/runtime.cc b/lite/backends/npu/runtime.cc deleted file mode 100644 index 3485f63c7c8bb91081fd1969d0d41733417149d9..0000000000000000000000000000000000000000 --- a/lite/backends/npu/runtime.cc +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "lite/backends/npu/runtime.h" -#include -#include -#include "lite/utils/cp_logging.h" - -namespace paddle { -namespace lite { -namespace npu { - -// Create hiai model manager to load om model from lite tensor, and return the -// manager and an unique model name -bool LoadModel(const lite::Tensor &model_data, - std::shared_ptr *model_client, - std::string *model_name) { - LOG(INFO) << "[NPU] Load model."; - auto model_data_ptr = model_data.data(); - auto model_data_size = model_data.numel() * sizeof(int8_t); - if (model_data_ptr == nullptr || model_data_size == 0) { - return false; - } - *model_client = std::make_shared(); - int ret = (*model_client)->Init(nullptr); - if (ret != hiai::AI_SUCCESS) { - LOG(WARNING) << "[NPU] AiModelMngerClient init failed(" << ret << ")!"; - return false; - } - *model_name = "model.om"; - auto model_desc = std::make_shared( - *model_name, - DeviceInfo::Global().freq_level(), - DeviceInfo::Global().framework_type(), - DeviceInfo::Global().model_type(), - DeviceInfo::Global().device_type()); - model_desc->SetModelBuffer(model_data_ptr, model_data_size); - std::vector> model_descs; - model_descs.push_back(model_desc); - if ((*model_client)->Load(model_descs) != hiai::AI_SUCCESS) { - LOG(WARNING) << "[NPU] AiModelMngerClient load model failed!"; - return false; - } - return true; -} - -} // namespace npu -} // namespace lite -} // namespace paddle diff --git a/lite/backends/opencl/CMakeLists.txt b/lite/backends/opencl/CMakeLists.txt index 1acb98321844191832fd55b640a9b56d3d51b400..dd7f6b417e0d6416eec9bb3e60ef088432776112 100644 --- a/lite/backends/opencl/CMakeLists.txt +++ b/lite/backends/opencl/CMakeLists.txt @@ -11,8 +11,8 @@ lite_cc_library(cl_image SRCS cl_image.cc DEPS tensor cl_image_converter cl_runt lite_cc_library(cl_caller SRCS cl_caller.cc DEPS cl_context cl_image) lite_cc_library(cl_target_wrapper SRCS target_wrapper.cc DEPS cl_runtime) lite_cc_test(test_cl_functions SRCS cl_functions_test.cc DEPS cl_context cl_image cl_caller cl_wrapper cl_target_wrapper - ARGS --cl_path=${CMAKE_SOURCE_DIR}/paddle/fluid/lite/backends/opencl) + ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) lite_cc_test(test_cl_im2col SRCS cl_im2col_test.cc DEPS tensor cl_context cl_wrapper cl_target_wrapper - ARGS --cl_path=${CMAKE_SOURCE_DIR}/paddle/fluid/lite/backends/opencl) + ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) add_dependencies(cl_wrapper opencl_clhpp) diff --git a/lite/backends/opencl/cl_caller.cc b/lite/backends/opencl/cl_caller.cc index 4926a53c43d54b4e2b4d802a7d8ef289c7e87fc5..6b9cab1056beaa6f516a0d3a202a7816c911f1b2 100644 --- a/lite/backends/opencl/cl_caller.cc +++ b/lite/backends/opencl/cl_caller.cc @@ -23,6 +23,7 @@ limitations under the License. */ namespace paddle { namespace lite { + static void CopyImageData(CLContext* context, const CLImage& cl_image, float* out) { @@ -51,119 +52,5 @@ bool InitOpenCLRuntime(std::string cl_path) { return runtime->IsInitSuccess(); } -void elementwise_add(CLContext* context, - const float* in, - const DDim& in_dim, - const float* bias, - const DDim& bias_dim, - float* out, - const DDim& out_dim) { - if (!(bias_dim.size() == 1 || bias_dim.size() == 4)) { - LOG(FATAL) << "Error: bias dims is error"; - return; - } - auto kernel = bias_dim.size() == 1 ? context->GetKernel("channel_add") - : context->GetKernel("elementwise_add"); - CLImage in_image; - in_image.set_tensor_data(in, in_dim); - in_image.InitNormalCLImage(context->GetContext()); - VLOG(3) << " --- Inpu image: " << in_image << " --- "; - CLImage bias_image; - bias_image.set_tensor_data(bias, bias_dim); - bias_image.InitCLImage(context->GetContext()); - VLOG(3) << " --- Bias image: " << bias_image << " --- "; - CLImage out_image; - out_image.InitEmptyImage(context->GetContext(), out_dim); - cl_int status; - status = kernel.setArg(0, *in_image.cl_image()); - CL_CHECK_FATAL(status); - status = kernel.setArg(1, *bias_image.cl_image()); - CL_CHECK_FATAL(status); - status = kernel.setArg(2, *out_image.cl_image()); - CL_CHECK_FATAL(status); - - if (bias_dim.size() == 1) { - int tensor_w = in_dim[3]; - status = kernel.setArg(3, tensor_w); - CL_CHECK_FATAL(status); - } - size_t width = in_image.ImageWidth(); - size_t height = in_image.ImageHeight(); - auto global_work_size = cl::NDRange{width, height}; - status = context->GetCommandQueue().enqueueNDRangeKernel( - kernel, cl::NullRange, global_work_size, cl::NullRange, nullptr, nullptr); - CL_CHECK_FATAL(status); - - status = context->GetCommandQueue().finish(); - CL_CHECK_FATAL(status); - VLOG(3) << " --- Out image: " << out_image << " --- "; - CopyImageData(context, out_image, out); -} - -void pool(CLContext* context, - const std::string pooling_type, - const int pad_h, - const int pad_w, - const int stride_h, - const int stride_w, - const int ksize_h, - const int ksize_w, - const float* in, - const DDim& in_dim, - float* out, - const DDim& out_dim) { - auto kernel = - context->GetKernel(string_format("pool_%s", pooling_type.c_str())); - CLImage in_image; - in_image.set_tensor_data(in, in_dim); - in_image.InitNormalCLImage(context->GetContext()); - VLOG(3) << " --- Inpu image: " << in_image << " --- "; - CLImage out_image; - out_image.InitEmptyImage(context->GetContext(), out_dim); - auto global_work_size = context->DefaultWorkSize(out_image); - auto* in_converter = - dynamic_cast(in_image.image_converter()); - auto* out_converter = - dynamic_cast(out_image.image_converter()); - const int in_height = in_converter->HeightOfOneBlock(); - const int in_width = in_converter->WidthOfOneBlock(); - const int out_height = out_converter->HeightOfOneBlock(); - const int out_width = out_converter->WidthOfOneBlock(); - cl_int status; - status = kernel.setArg(0, in_height); - CL_CHECK_FATAL(status); - status = kernel.setArg(1, in_width); - CL_CHECK_FATAL(status); - status = kernel.setArg(2, out_height); - CL_CHECK_FATAL(status); - status = kernel.setArg(3, out_width); - CL_CHECK_FATAL(status); - status = kernel.setArg(4, pad_h); - CL_CHECK_FATAL(status); - status = kernel.setArg(5, pad_w); - CL_CHECK_FATAL(status); - status = kernel.setArg(6, stride_h); - CL_CHECK_FATAL(status); - status = kernel.setArg(7, stride_w); - CL_CHECK_FATAL(status); - status = kernel.setArg(8, ksize_h); - CL_CHECK_FATAL(status); - status = kernel.setArg(9, ksize_w); - CL_CHECK_FATAL(status); - status = kernel.setArg(10, *in_image.cl_image()); - CL_CHECK_FATAL(status); - status = kernel.setArg(11, *out_image.cl_image()); - CL_CHECK_FATAL(status); - - status = context->GetCommandQueue().enqueueNDRangeKernel( - kernel, cl::NullRange, global_work_size, cl::NullRange, nullptr, nullptr); - CL_CHECK_FATAL(status); - - status = context->GetCommandQueue().finish(); - CL_CHECK_FATAL(status); - VLOG(3) << " --- Out image: " << out_image << " --- "; - CopyImageData(context, out_image, out); -} - } // namespace lite } // namespace paddle diff --git a/lite/backends/opencl/cl_caller.h b/lite/backends/opencl/cl_caller.h index ed5c9153d3cedf140cbf0570b7f71393fb918bf9..1817db9f6bd6d9ecf21978b8293bd9534328de0f 100644 --- a/lite/backends/opencl/cl_caller.h +++ b/lite/backends/opencl/cl_caller.h @@ -23,30 +23,5 @@ namespace lite { bool InitOpenCLRuntime(std::string cl_path); -/// An elementwise_add method to embed OpenCL logic inside, it is used as a -/// black box so that the framework can remain simple. -/// NOTE Currently, these methods are quite expensive, we will optimize them -/// latter. -void elementwise_add(CLContext* context, - const float* in, - const DDim& in_dim, - const float* bias, - const DDim& bias_dim, - float* out, - const DDim& out_dim); - -void pool(CLContext* context, - const std::string pooling_type, - const int pad_h, - const int pad_w, - const int stride_h, - const int stride_w, - const int ksize_h, - const int ksize_w, - const float* in, - const DDim& in_dim, - float* out, - const DDim& out_dim); - } // namespace lite } // namespace paddle diff --git a/lite/backends/opencl/cl_functions_test.cc b/lite/backends/opencl/cl_functions_test.cc index b9f6648c9956e1952b65f66abfa40d912a99ee67..70f47b47946641edf4d023437b48d46cae93ca6e 100644 --- a/lite/backends/opencl/cl_functions_test.cc +++ b/lite/backends/opencl/cl_functions_test.cc @@ -41,9 +41,10 @@ TEST(cl_test, runtime_test) { auto &context = runtime->context(); auto program = runtime->CreateProgram( context, - runtime->cl_path() + "/cl_kernel/" + "image/elementwise_add_kernel.cl"); + runtime->cl_path() + "/cl_kernel/" + "buffer/elementwise_add_kernel.cl"); auto event = runtime->CreateEvent(context); - CHECK(runtime->BuildProgram(program.get())); + const std::string build_option("-DCL_DTYPE_float"); + CHECK(runtime->BuildProgram(program.get(), build_option)); } TEST(cl_test, context_test) { @@ -51,9 +52,11 @@ TEST(cl_test, context_test) { CHECK(runtime->IsInitSuccess()); runtime->set_cl_path(FLAGS_cl_path); CLContext context; - context.AddKernel("pool_max", "image/pool_kernel.cl", ""); - context.AddKernel("elementwise_add", "image/elementwise_add_kernel.cl", ""); - context.AddKernel("elementwise_add", "image/elementwise_add_kernel.cl", ""); + context.AddKernel("pool_max", "image/pool_kernel.cl", "-DCL_DTYPE_float"); + context.AddKernel( + "elementwise_add", "image/elementwise_add_kernel.cl", "-DCL_DTYPE_float"); + context.AddKernel( + "elementwise_add", "image/elementwise_add_kernel.cl", "-DCL_DTYPE_float"); } TEST(cl_test, kernel_test) { @@ -61,9 +64,11 @@ TEST(cl_test, kernel_test) { CHECK(runtime->IsInitSuccess()); runtime->set_cl_path(FLAGS_cl_path); std::unique_ptr context(new CLContext); - context->AddKernel("elementwise_add", "image/elementwise_add_kernel.cl"); - context->AddKernel("pool_max", "image/pool_kernel.cl"); - context->AddKernel("elementwise_add", "image/elementwise_add_kernel.cl"); + context->AddKernel( + "elementwise_add", "image/elementwise_add_kernel.cl", "-DCL_DTYPE_float"); + context->AddKernel("pool_max", "image/pool_kernel.cl", "-DCL_DTYPE_float"); + context->AddKernel( + "elementwise_add", "image/elementwise_add_kernel.cl", "-DCL_DTYPE_float"); auto kernel = context->GetKernel(2); std::unique_ptr in_data(new float[4 * 3 * 256 * 512]); @@ -115,203 +120,12 @@ TEST(cl_test, kernel_test) { LOG(INFO) << out_image; } -TEST(cl_test, channel_add_test) { - std::default_random_engine engine; - std::uniform_real_distribution dist(-5, 5); - - const DDim in_dim = DDim(std::vector{4, 16, 256, 512}); - std::unique_ptr in_data(new float[4 * 16 * 256 * 512]); - for (int i = 0; i < 4 * 16 * 256 * 512; i++) { - in_data[i] = dist(engine); - } - - const DDim bias_dim = DDim(std::vector{16}); - std::unique_ptr bias_data(new float[16]); - for (int i = 0; i < 16; i++) { - bias_data[i] = dist(engine); - } - - std::unique_ptr out_ref(new float[4 * 16 * 256 * 512]); - for (int i = 0; i < 4; i++) { - for (int j = 0; j < 16; j++) { - float b = bias_data[j]; - for (int k = 0; k < 256 * 512; k++) { - int index = (i * 16 + j) * 256 * 512 + k; - out_ref[index] = in_data[index] + b; - } - } - } - - const DDim out_dim = DDim(std::vector{4, 16, 256, 512}); - std::unique_ptr out(new float[4 * 16 * 256 * 512]); - - bool status = InitOpenCLRuntime(FLAGS_cl_path); - CHECK(status) << "Fail to initialize OpenCL runtime."; - std::unique_ptr context(new CLContext); - context->AddKernel("elementwise_add", "image/elementwise_add_kernel.cl"); - context->AddKernel("channel_add", "image/channel_add_kernel.cl"); - elementwise_add(context.get(), - in_data.get(), - in_dim, - bias_data.get(), - bias_dim, - out.get(), - out_dim); - - int stride = 4 * 16 * 256 * 512 / 20; - for (int i = 0; i < 4 * 16 * 256 * 512; i += stride) { - std::cout << out[i] << " "; - } - std::cout << std::endl; - - for (int i = 0; i < 4 * 16 * 256 * 512; i++) { - EXPECT_NEAR(out[i], out_ref[i], 1e-6); - } -} - -TEST(cl_test, elementwise_add_test) { - std::default_random_engine engine; - std::uniform_real_distribution dist(-5, 5); - - const DDim in_dim = DDim(std::vector{4, 16, 256, 512}); - std::unique_ptr in_data(new float[4 * 16 * 256 * 512]); - for (int i = 0; i < 4 * 16 * 256 * 512; i++) { - in_data[i] = dist(engine); - } - - const DDim bias_dim = DDim(std::vector{4, 16, 256, 512}); - std::unique_ptr bias_data(new float[4 * 16 * 256 * 512]); - for (int i = 0; i < 4 * 16 * 256 * 512; i++) { - bias_data[i] = dist(engine); - } - - std::unique_ptr out_ref(new float[4 * 16 * 256 * 512]); - for (int i = 0; i < 4 * 16 * 256 * 512; i++) { - out_ref[i] = in_data[i] + bias_data[i]; - } - - const DDim out_dim = DDim(std::vector{4, 16, 256, 512}); - std::unique_ptr out(new float[4 * 16 * 256 * 512]); - - bool status = InitOpenCLRuntime(FLAGS_cl_path); - CHECK(status) << "Fail to initialize OpenCL runtime."; - std::unique_ptr context(new CLContext); - context->AddKernel("elementwise_add", "image/elementwise_add_kernel.cl"); - context->AddKernel("channel_add", "image/channel_add_kernel.cl"); - elementwise_add(context.get(), - in_data.get(), - in_dim, - bias_data.get(), - bias_dim, - out.get(), - out_dim); - - int stride = 4 * 16 * 256 * 512 / 20; - for (int i = 0; i < 4 * 16 * 256 * 512; i += stride) { - std::cout << out[i] << " "; - } - std::cout << std::endl; - - for (int i = 0; i < 4 * 16 * 256 * 512; i++) { - EXPECT_NEAR(out[i], out_ref[i], 1e-6); - } -} - -void pool_avg(const int padding_height, - const int padding_width, - const int stride_height, - const int stride_width, - const int ksize_height, - const int ksize_width, - const float *input_data, - const DDim &in_dim, - float *output_data, - const DDim &out_dim) { - const int batch_size = in_dim[0]; - const int input_height = in_dim[2]; - const int input_width = in_dim[3]; - const int output_channels = out_dim[1]; - const int output_height = out_dim[2]; - const int output_width = out_dim[3]; - - const size_t input_spatial_size = input_height * input_width; - const size_t output_spatial_size = output_height * output_width; - - for (int i = 0; i < batch_size; i++) { - for (int c = 0; c < output_channels; ++c) { - int channel = i * output_channels + c; - const float *input_ptr = input_data + channel * input_spatial_size; - float *output_ptr = output_data + channel * output_spatial_size; - - for (int ph = 0; ph < output_height; ++ph) { - int hstart = ph * stride_height - padding_height; - int hend = std::min(hstart + ksize_height, input_height); - hstart = std::max(hstart, 0); - for (int pw = 0; pw < output_width; ++pw) { - int wstart = pw * stride_width - padding_width; - int wend = std::min(wstart + ksize_width, input_width); - wstart = std::max(wstart, 0); - - float val = 0.f; - int count = 0; - for (int h = hstart; h < hend; ++h) { - for (int w = wstart; w < wend; ++w) { - val += input_ptr[h * input_width + w]; - ++count; - } - } - output_ptr[ph * output_width + pw] = - (count > 0) ? val * (1.f / count) : 0.f; - } - } - } - } -} - -TEST(cl_test, pool_test) { - std::default_random_engine engine; - std::uniform_real_distribution dist(-5, 5); - - const DDim in_dim = DDim(std::vector{4, 1024, 7, 7}); - std::unique_ptr in_data(new float[4 * 1024 * 7 * 7]); - for (int i = 0; i < 4 * 1024 * 7 * 7; i++) { - in_data[i] = dist(engine); - } - - const DDim out_dim = DDim(std::vector{4, 1024, 1, 1}); - std::unique_ptr out(new float[4 * 1024 * 1 * 1]); - std::unique_ptr out_ref(new float[4 * 1024 * 1 * 1]); - - bool status = InitOpenCLRuntime(FLAGS_cl_path); - CHECK(status) << "Fail to initialize OpenCL runtime."; - std::unique_ptr context(new CLContext); - context->AddKernel("pool_max", "image/pool_kernel.cl"); - context->AddKernel("pool_avg", "image/pool_kernel.cl"); - pool(context.get(), - "avg", - 0, - 0, - 1, - 1, - 7, - 7, - in_data.get(), - in_dim, - out.get(), - out_dim); - pool_avg(0, 0, 1, 1, 7, 7, in_data.get(), in_dim, out_ref.get(), out_dim); - - for (int i = 0; i < 4 * 1024 * 1 * 1; i++) { - EXPECT_NEAR(out[i], out_ref[i], 1e-6); - } -} - TEST(cl_test, target_wrapper_buffer_test) { bool inited = InitOpenCLRuntime(FLAGS_cl_path); CHECK(inited) << "Fail to initialize OpenCL runtime."; std::unique_ptr context(new CLContext); std::string kernel_name = "elementwise_add"; - std::string build_options = "-DCL_DTYPE=float"; + std::string build_options = "-DCL_DTYPE_float"; context->AddKernel( kernel_name, "buffer/elementwise_add_kernel.cl", build_options); std::vector h_a; @@ -396,10 +210,13 @@ TEST(cl_test, target_wrapper_buffer_test) { TEST(cl_test, target_wrapper_image_test) { const size_t cl_image2d_width = 28; const size_t cl_image2d_height = 32; + const size_t cl_image2d_elem_size = + cl_image2d_width * cl_image2d_height * 4; // 4 for RGBA channels const size_t cl_image2d_row_pitch{0}; const size_t cl_image2d_slice_pitch{0}; auto *d_image = static_cast( TargetWrapperCL::MallocImage(cl_image2d_width, cl_image2d_height)); + // Map/Unmap test auto *h_image = static_cast(TargetWrapperCL::MapImage(d_image, @@ -407,15 +224,11 @@ TEST(cl_test, target_wrapper_image_test) { cl_image2d_height, cl_image2d_row_pitch, cl_image2d_slice_pitch)); - CHECK_EQ( - cl_image2d_row_pitch, - cl_image2d_width * 4 * - 4); // row_pitch = 448 = 28 * 4 (RGBA: 4 floats) * 4 (float in bytes) - CHECK_EQ(cl_image2d_slice_pitch, 0); // slice_pitch = 0 + CHECK_EQ(cl_image2d_slice_pitch, 0); LOG(INFO) << "cl_image2d_row_pitch = " << cl_image2d_row_pitch << ", cl_image2d_slice_pitch " << cl_image2d_slice_pitch; - for (int i = 0; i < 10; i++) { + for (int i = 0; i < cl_image2d_elem_size; i++) { h_image[i] = 3.14f * i; } TargetWrapperCL::Unmap(d_image, h_image); @@ -426,15 +239,14 @@ TEST(cl_test, target_wrapper_image_test) { cl_image2d_height, cl_image2d_row_pitch, cl_image2d_slice_pitch)); - for (int i = 0; i < 10; i++) { + for (int i = 0; i < cl_image2d_elem_size; i++) { EXPECT_NEAR(h_ptr[i], 3.14f * i, 1e-6); } TargetWrapperCL::Unmap(d_image, h_ptr); // Imagecpy test - std::vector h_image_cpy(cl_image2d_width * 4 * - cl_image2d_height); // 4 for RGBA channels - for (int i = 0; i < cl_image2d_width * 4 * cl_image2d_height; i++) { + std::vector h_image_cpy(cl_image2d_elem_size); + for (int i = 0; i < cl_image2d_elem_size; i++) { h_image_cpy[i] = 3.14f; } TargetWrapperCL::ImgcpySync(d_image, @@ -446,6 +258,8 @@ TEST(cl_test, target_wrapper_image_test) { IoDirection::HtoD); auto *d_image_cpy = static_cast( TargetWrapperCL::MallocImage(cl_image2d_width, cl_image2d_height)); + + // device to device TargetWrapperCL::ImgcpySync(d_image_cpy, d_image, cl_image2d_width, @@ -454,6 +268,8 @@ TEST(cl_test, target_wrapper_image_test) { cl_image2d_slice_pitch, IoDirection::DtoD); std::fill(h_image_cpy.begin(), h_image_cpy.end(), 0); + + // host to device TargetWrapperCL::ImgcpySync(h_image_cpy.data(), d_image_cpy, cl_image2d_width, @@ -461,7 +277,7 @@ TEST(cl_test, target_wrapper_image_test) { cl_image2d_row_pitch, cl_image2d_slice_pitch, IoDirection::DtoH); - for (int i = 0; i < cl_image2d_width * 4 * cl_image2d_height; i++) { + for (int i = 0; i < cl_image2d_elem_size; i++) { EXPECT_NEAR(h_image_cpy[i], 3.14f, 1e-6); } diff --git a/lite/backends/opencl/cl_image_converter.h b/lite/backends/opencl/cl_image_converter.h index 6faa8045576f06d8c636372de644e6b5c164a5f4..962eb8d3ef35bdb603aa4a56181b1124885d5506 100644 --- a/lite/backends/opencl/cl_image_converter.h +++ b/lite/backends/opencl/cl_image_converter.h @@ -103,6 +103,7 @@ class CLImageConverterNormal : public CLImageConverterBase { }; class CLImageConverterNWBlock : public CLImageConverterBase { + public: DDim InitImageDimInfoWith(const DDim &tensor_dim) override; void NCHWToImage(float *tensor, float *image, @@ -113,6 +114,7 @@ class CLImageConverterNWBlock : public CLImageConverterBase { const DDim &tensor_dim) override; }; class CLImageConverterDWBlock : public CLImageConverterBase { + public: DDim InitImageDimInfoWith(const DDim &tensor_dim) override; void NCHWToImage(float *tensor, float *image, diff --git a/lite/backends/opencl/cl_kernel/buffer/concat_kernel.cl b/lite/backends/opencl/cl_kernel/buffer/concat_kernel.cl new file mode 100644 index 0000000000000000000000000000000000000000..1574cb4a69cd0388698707d8d91c1d9c18b625a2 --- /dev/null +++ b/lite/backends/opencl/cl_kernel/buffer/concat_kernel.cl @@ -0,0 +1,60 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +__kernel void concat2(__global const CL_DTYPE* x_data0, __global const CL_DTYPE* x_data1, __global CL_DTYPE* out_data, + int size, int axis_size, int pre_size, int post_size, int total, int total0, int total1) { + const int index = get_global_id(0); + if (index < size){ + for (int i = 0; i < pre_size; i++){ + int offset_out = index * post_size + i * total; + int offset_in = index * post_size + i * total0; + // memcpy(out_data + offset_out, x_data0 + offset_in, post_size); + CL_DTYPE* dst = out_data + offset_out; + CL_DTYPE* src = x_data0 + offset_in; + for (int k = 0; k < post_size; k++){ + *dst++ = *src++; + } + } + }else if (index < axis_size){ + for (int i = 0; i < pre_size; i++){ + int offset_out = index * post_size + i * total; + int offset_in = index * post_size + i * total1; + // memcpy(out_data + offset_out, x_data1 + offset_in, post_size); + CL_DTYPE* dst = out_data + offset_out; + CL_DTYPE* src = x_data1 + offset_in; + for (int k = 0; k < post_size; k++){ + *dst++ = *src++; + } + } + } +} + +__kernel void concat_mul(__global const CL_DTYPE* x_data, __global CL_DTYPE* out_data, + int axis_size, int pre_size, int post_size, int start, int total, int total0) { + const int index = get_global_id(0); + if (index < axis_size){ + for (int i = 0; i < pre_size; i++){ + int offset_out = (start + index) * post_size + i * total; + int offset_in = index * post_size + i * total0; + // memcpy(out_data + offset_out, x_data + offset_in, post_size); + CL_DTYPE* dst = out_data + offset_out; + CL_DTYPE* src = x_data + offset_in; + for (int k = 0; k < post_size; k++){ + *dst++ = *src++; + } + } + } +} diff --git a/lite/backends/opencl/cl_kernel/buffer/layout_kernel.cl b/lite/backends/opencl/cl_kernel/buffer/layout_kernel.cl index c9c16581d67db0c9143e91e13249edfd5901ddb8..532f947dd342b1ee4db69a084111a97ec014237f 100644 --- a/lite/backends/opencl/cl_kernel/buffer/layout_kernel.cl +++ b/lite/backends/opencl/cl_kernel/buffer/layout_kernel.cl @@ -61,6 +61,57 @@ __kernel void buffer_to_image2d(__global CL_DTYPE *in, write_imagef(output_image, output_pos, output); } +// buffer -> image2d_nw +__kernel void buffer_to_image2d_nw(__global CL_DTYPE* in, + __write_only image2d_t output_image, + __private const int out_H, + __private const int out_W, + __private const int out_N, + __private const int Stride0, + __private const int Stride1, + __private const int Stride2) { + const int out_n = get_global_id(0); + const int out_w = get_global_id(1); + const int out_ch = get_global_id(2); + + const int out_c = out_ch / out_H; + const int out_h = out_ch % out_H; + + const int in_c = out_c; // index of c in h direction + + const int in_n0 = out_n * 4 + 0; + const int in_n1 = out_n * 4 + 1; + const int in_n2 = out_n * 4 + 2; + const int in_n3 = out_n * 4 + 3; + + const int in_h = out_h; + const int in_w = out_w; + + int input_pos0 = in_n0 * Stride2 + in_c * Stride1 + in_h * Stride0 + in_w; + int input_pos1 = in_n1 * Stride2 + in_c * Stride1 + in_h * Stride0 + in_w; + int input_pos2 = in_n2 * Stride2 + in_c * Stride1 + in_h * Stride0 + in_w; + int input_pos3 = in_n3 * Stride2 + in_c * Stride1 + in_h * Stride0 + in_w; + + int2 output_pos; + output_pos.x = out_n * out_W + out_w; + output_pos.y = out_ch; + + CL_DTYPE4 output = (CL_DTYPE4)0.0f; + output.x = convert_float(in[input_pos0]); + if (out_N - 4 * out_n >= 2) { + output.y = convert_float(in[input_pos1]); + } + if (out_N - 4 * out_n >= 3) { + output.z = convert_float(in[input_pos2]); + } + if (out_N - 4 * out_n >= 4) { + output.w = convert_float(in[input_pos3]); + } + write_imagef(output_image, output_pos, output); +} + + + // image2d -> buffer __kernel void image2d_to_buffer(__read_only image2d_t input, __private const int in_width, diff --git a/lite/backends/opencl/cl_kernel/buffer/sigmoid_kernel.cl b/lite/backends/opencl/cl_kernel/buffer/sigmoid_kernel.cl new file mode 100644 index 0000000000000000000000000000000000000000..615bf892b321ba67043d41f6032caa758d78c16f --- /dev/null +++ b/lite/backends/opencl/cl_kernel/buffer/sigmoid_kernel.cl @@ -0,0 +1,22 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +__kernel void sigmoid(__global const CL_DTYPE* x_data, const int count, __global CL_DTYPE* out_data) { + const int index = get_global_id(0); + if (index < count) { + out_data[index] = 1 / (1 + exp(-x_data[index])); + } +} diff --git a/lite/backends/opencl/cl_kernel/cl_common.h b/lite/backends/opencl/cl_kernel/cl_common.h index 7f901fc994ffd82ccfe99f59614a3422260d0dc5..c127c6cec79cb2eb8d82ce6aa6190b23d373ff64 100644 --- a/lite/backends/opencl/cl_kernel/cl_common.h +++ b/lite/backends/opencl/cl_kernel/cl_common.h @@ -14,8 +14,17 @@ limitations under the License. */ #pragma once +///////////////////////////////// +// fp16 enabled, MAX_VALUE, MIN_VALUE +///////////////////////////////// #pragma OPENCL EXTENSION cl_khr_fp16 : enable +#define MAX_VALUE FLT_MAX +#define MIN_VALUE -FLT_MAX + +///////////////////////////////// +// CL_DTYPE_float / CL_DTYPE_half +///////////////////////////////// // Data type: pass one of macros on host: [CL_DTYPE_float, CL_DYPE_half] #ifdef CL_DTYPE_float #define CL_DTYPE float @@ -27,31 +36,43 @@ limitations under the License. */ #define CL_DTYPE_CHAR h #endif +///////////////////////////////// +// GET_VEC_TYPE +///////////////////////////////// // Note: macro name replacement need twice parser #define GET_VEC_TYPE(type__, size__) type__##size__ #define VECTORIZED_TYPE(type__, size__) GET_VEC_TYPE(type__, size__) #define CL_DTYPE4 VECTORIZED_TYPE(CL_DTYPE, 4) +///////////////////////////////// +// CONVERT_TYPE_TO +///////////////////////////////// #define _CONVERT_TYPE_TO(value, type) convert_##type(value) #define CONVERT_TYPE_TO(value, type) _CONVERT_TYPE_TO(value, type) +///////////////////////////////// +// WRITE_IMG_TYPE / READ_IMG_TYPE +///////////////////////////////// #define _WRITE_IMG_TYPE(type_char, img, pos, value) \ write_image##type_char(img, pos, value) #define WRITE_IMG_TYPE(type_char, img, pos, value) \ _WRITE_IMG_TYPE(type_char, img, pos, value) -#define _READ_IMG_TYPE(type_char, img, pos, sampler) \ +#define _READ_IMG_TYPE(type_char, img, sampler, pos) \ read_image##type_char(img, sampler, pos) -#define READ_IMG_TYPE(type_char, img, pos, sampler) \ - _READ_IMG_TYPE(type_char, img, pos, sampler) +#define READ_IMG_TYPE(type_char, img, sampler, pos) \ + _READ_IMG_TYPE(type_char, img, sampler, pos) +///////////////////////////////// +// activation / activation_type4 +///////////////////////////////// inline CL_DTYPE activation(CL_DTYPE in #ifdef PRELU , CL_DTYPE prelu_alpha #endif ) { - CL_DTYPE output; + CL_DTYPE output = in; #ifdef PRELU output = select(prelu_alpha * in, in, in >= (CL_DTYPE)0); #endif @@ -59,5 +80,30 @@ inline CL_DTYPE activation(CL_DTYPE in #ifdef RELU output = fmax(in, (CL_DTYPE)0); #endif + +#ifdef RELU6 + output = clamp(in, (CL_DTYPE)0, (CL_DTYPE)6); +#endif + return output; +} + +inline CL_DTYPE4 activation_type4(CL_DTYPE4 in +#ifdef PRELU + , + CL_DTYPE4 prelu_alpha +#endif + ) { + CL_DTYPE4 output = in; +#ifdef PRELU + output = select(prelu_alpha * in, in, in >= (CL_DTYPE4)0.0); +#endif + +#ifdef RELU + output = fmax(in, (CL_DTYPE4)0); +#endif + +#ifdef RELU6 + output = clamp(in, (CL_DTYPE4)0, (CL_DTYPE4)6); +#endif return output; } diff --git a/lite/backends/opencl/cl_kernel/image/concat_kernel.cl b/lite/backends/opencl/cl_kernel/image/concat_kernel.cl new file mode 100644 index 0000000000000000000000000000000000000000..f0335116f87aac34740dd22ac68f2b6265e62445 --- /dev/null +++ b/lite/backends/opencl/cl_kernel/image/concat_kernel.cl @@ -0,0 +1,64 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +__kernel void concat2(__read_only image2d_t input0, + __read_only image2d_t input1, + __write_only image2d_t output, + int axis_size, int flag, int width) { + const int x = get_global_id(0); // image_width cxw/4 + const int y = get_global_id(1); // image_height nxh + + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | + CLK_ADDRESS_CLAMP | + CLK_FILTER_NEAREST; + int xx = x / width; + if (flag == 0){ + xx = y / width; + } + if (xx < axis_size){ + CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input0, sampler, (int2)(x, y)); + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(x, y), in); + }else{ + int new_val = xx - axis_size; + new_val *= width; + CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input0, sampler, (int2)(new_val, y)); + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(x, y), in); + } + // WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(x, y), in); +} + +__kernel void concat_mul(__read_only image2d_t input0, + __write_only image2d_t output, + int axis_size, int flag, int width, int start) { + const int x = get_global_id(0); // image_width cxw/4 + const int y = get_global_id(1); // image_height nxh + + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | + CLK_ADDRESS_CLAMP | + CLK_FILTER_NEAREST; + int xx = x / width; + if (flag == 0){ + xx = y / width; + } + + if (xx < axis_size && xx >= start){ + xx -= start; + xx *= width; + CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input0, sampler, (int2)(xx, y)); + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(x, y), in); + } + +} diff --git a/lite/backends/opencl/cl_kernel/image/conv2d_1x1_kernel.cl b/lite/backends/opencl/cl_kernel/image/conv2d_1x1_kernel.cl new file mode 100644 index 0000000000000000000000000000000000000000..37e03e802c56d3de9ba08e97c9dfb62f8cd76e9a --- /dev/null +++ b/lite/backends/opencl/cl_kernel/image/conv2d_1x1_kernel.cl @@ -0,0 +1,385 @@ +#include + +__kernel void conv2d_1x1(__private const int global_size_dim0, + __private const int global_size_dim1, + __private const int global_size_dim2, + __read_only image2d_t input_image, + __read_only image2d_t filter, +#if defined(BIASE_CH) || defined(BIASE_ELE) + __read_only image2d_t bias, +#endif +#ifdef BATCH_NORM + __read_only image2d_t new_scale, + __read_only image2d_t new_biase, +#endif + __write_only image2d_t output_image, + __private const int stride, + __private const int offset, + __private const int input_c, + __private const int input_c_origin, + __private const int dilation, + __private const int input_width, /* of one block */ + __private const int input_height, /* of one block */ + __private const int output_width, + __private const int output_height, + __private const int old_w) { + CL_DTYPE zero = 0.0f; + const int out_c = get_global_id(0); + const int out_w = get_global_id(1); + const int out_nh = get_global_id(2); + + int out_w0 = out_w; + int out_w1 = out_w + global_size_dim1; + int out_w2 = out_w + global_size_dim1 * 2; + int out_w3 = out_w + global_size_dim1 * 3; + + int outpos_main = mul24(out_c, old_w); + int2 output_pos0 = (int2)(outpos_main + out_w0, out_nh); + int2 output_pos1 = (int2)(outpos_main + out_w1, out_nh); + int2 output_pos2 = (int2)(outpos_main + out_w2, out_nh); + int2 output_pos3 = (int2)(outpos_main + out_w3, out_nh); + + const sampler_t sampler = + CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + + int2 stride_xy = (int2)(stride, stride); + + int2 ouput_pos_in_one_block0 = (int2)(out_w0, out_nh); + int2 in_pos_in_one_block0 = + ouput_pos_in_one_block0 * stride_xy + (int2)(offset, offset); + + int2 ouput_pos_in_one_block1 = (int2)(out_w1, out_nh); + int2 in_pos_in_one_block1 = + ouput_pos_in_one_block1 * stride_xy + (int2)(offset, offset); + + int2 ouput_pos_in_one_block2 = (int2)(out_w2, out_nh); + int2 in_pos_in_one_block2 = + ouput_pos_in_one_block2 * stride_xy + (int2)(offset, offset); + + int2 ouput_pos_in_one_block3 = (int2)(out_w3, out_nh); + int2 in_pos_in_one_block3 = + ouput_pos_in_one_block3 * stride_xy + (int2)(offset, offset); + +#ifdef BIASE_CH + CL_DTYPE4 output0 = + READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, (int2)(out_c, 0)); + CL_DTYPE4 output1 = output0; + CL_DTYPE4 output2 = output0; + CL_DTYPE4 output3 = output0; +#elif defined(BIASE_ELE) + CL_DTYPE4 output0 = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, output_pos0); + CL_DTYPE4 output1 = output0; + CL_DTYPE4 output2 = output0; + CL_DTYPE4 output3 = output0; + +#else + CL_DTYPE4 output0 = 0.0f; + CL_DTYPE4 output1 = 0.0f; + CL_DTYPE4 output2 = 0.0f; + CL_DTYPE4 output3 = 0.0f; +#endif + + int max_w_bound = input_c * input_width; + int burndary_index = input_c * 4 - input_c_origin; + bool burndary_index_w = + burndary_index == 1 || burndary_index == 2 || burndary_index == 3; + bool burndary_index_z = burndary_index == 2 || burndary_index == 3; + bool burndary_index_y = burndary_index == 3; + + for (int i = 0; i < input_c; ++i) { + // ------------0--------------- + int2 pos_in = (int2)(i * input_width + in_pos_in_one_block0.x, + in_pos_in_one_block0.y); + CL_DTYPE4 input0 = + READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, pos_in); + + CL_DTYPE4 weight0 = + READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, (int2)(out_c, i * 4 + 0)); + CL_DTYPE4 weight1 = + READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, (int2)(out_c, i * 4 + 1)); + CL_DTYPE4 weight2 = + READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, (int2)(out_c, i * 4 + 2)); + CL_DTYPE4 weight3 = + READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, (int2)(out_c, i * 4 + 3)); + int bound_gap = max_w_bound - pos_in.x - 1; + + bool outof_bound = bound_gap < input_width && bound_gap >= 0; + input0.w = select(input0.w, zero, outof_bound && burndary_index_w); + input0.z = select(input0.z, zero, outof_bound && burndary_index_z); + input0.y = select(input0.y, zero, outof_bound && burndary_index_y); + + output0 = mad(input0.x, weight0, output0); + output0 = mad(input0.y, weight1, output0); + output0 = mad(input0.z, weight2, output0); + output0 = mad(input0.w, weight3, output0); + // -------------1-------------- + pos_in = (int2)(i * input_width + in_pos_in_one_block1.x, + in_pos_in_one_block1.y); + CL_DTYPE4 input1 = + READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, pos_in); + + bound_gap = max_w_bound - pos_in.x - 1; + + outof_bound = bound_gap < input_width && bound_gap >= 0; + input1.w = select(input1.w, zero, outof_bound && burndary_index_w); + input1.z = select(input1.z, zero, outof_bound && burndary_index_z); + input1.y = select(input1.y, zero, outof_bound && burndary_index_y); + + output1 = mad(input1.x, weight0, output1); + output1 = mad(input1.y, weight1, output1); + output1 = mad(input1.z, weight2, output1); + output1 = mad(input1.w, weight3, output1); + + // -------------2-------------- + pos_in = (int2)(i * input_width + in_pos_in_one_block2.x, + in_pos_in_one_block2.y); + CL_DTYPE4 input2 = + READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, pos_in); + + bound_gap = max_w_bound - pos_in.x - 1; + + outof_bound = bound_gap < input_width && bound_gap >= 0; + input2.w = select(input2.w, zero, outof_bound && burndary_index_w); + input2.z = select(input2.z, zero, outof_bound && burndary_index_z); + input2.y = select(input2.y, zero, outof_bound && burndary_index_y); + + output2 = mad(input2.x, weight0, output2); + output2 = mad(input2.y, weight1, output2); + output2 = mad(input2.z, weight2, output2); + output2 = mad(input2.w, weight3, output2); + + // -------------3-------------- + pos_in = (int2)(i * input_width + in_pos_in_one_block3.x, + in_pos_in_one_block3.y); + CL_DTYPE4 input3 = + READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, pos_in); + bound_gap = max_w_bound - pos_in.x - 1; + + outof_bound = bound_gap < input_width && bound_gap >= 0; + input3.w = + select(input3.w, + zero, + outof_bound && (burndary_index == 1 || burndary_index == 2 || + burndary_index == 3)); + input3.z = + select(input3.z, + zero, + outof_bound && (burndary_index == 2 || burndary_index == 3)); + input3.y = select(input3.y, zero, outof_bound && burndary_index == 3); + + output3 = mad(input3.x, weight0, output3); + output3 = mad(input3.y, weight1, output3); + output3 = mad(input3.z, weight2, output3); + output3 = mad(input3.w, weight3, output3); + } + +#ifdef BATCH_NORM + output0 = output0 * READ_IMG_TYPE( + CL_DTYPE_CHAR, new_scale, sampler, (int2)(out_c, 0)) + + READ_IMG_TYPE(CL_DTYPE_CHAR, new_biase, sampler, (int2)(out_c, 0)); + + output1 = output1 * READ_IMG_TYPE( + CL_DTYPE_CHAR, new_scale, sampler, (int2)(out_c, 0)) + + READ_IMG_TYPE(CL_DTYPE_CHAR, new_biase, sampler, (int2)(out_c, 0)); + + output2 = output2 * READ_IMG_TYPE( + CL_DTYPE_CHAR, new_scale, sampler, (int2)(out_c, 0)) + + READ_IMG_TYPE(CL_DTYPE_CHAR, new_biase, sampler, (int2)(out_c, 0)); + + output3 = output3 * READ_IMG_TYPE( + CL_DTYPE_CHAR, new_scale, sampler, (int2)(out_c, 0)) + + READ_IMG_TYPE(CL_DTYPE_CHAR, new_biase, sampler, (int2)(out_c, 0)); +#endif + +#ifdef RELU + output0 = activation_type4(output0); + output1 = activation_type4(output1); + output2 = activation_type4(output2); + output3 = activation_type4(output3); +#endif + + if (out_w0 < old_w) { + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, output_pos0, output0); + } + + if (out_w1 < old_w) { + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, output_pos1, output1); + } + + if (out_w2 < old_w) { + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, output_pos2, output2); + } + + if (out_w3 < old_w) { + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, output_pos3, output3); + } +} + +__kernel void conv2d_1x1_simple(__private const int global_size_dim0, + __private const int global_size_dim1, + __private const int global_size_dim2, + __read_only image2d_t input_image, + __read_only image2d_t filter, +#if defined(BIASE_CH) || defined(BIASE_ELE) + __read_only image2d_t bias, +#endif +#ifdef BATCH_NORM +__read_only image2d_t new_scale, + __read_only image2d_t new_biase, +#endif + __write_only image2d_t output_image, + __private const int stride, + __private const int offset, + __private const int input_c, + __private const int input_c_origin, + __private const int dilation, + __private const int input_width, /* of one block */ + __private const int input_height, /* of one block */ + __private const int output_width, + __private const int output_height, + __private const int old_w) { + const int out_c = get_global_id(0); + const int out_w = get_global_id(1); + const int out_nh = get_global_id(2); + + int out_w0 = out_w; + int out_w1 = out_w + global_size_dim1; + int out_w2 = out_w + global_size_dim1 * 2; + int out_w3 = out_w + global_size_dim1 * 3; + + int outpos_main = mul24(out_c, old_w); + int2 output_pos0 = (int2)(outpos_main + out_w0, out_nh); + int2 output_pos1 = (int2)(outpos_main + out_w1, out_nh); + int2 output_pos2 = (int2)(outpos_main + out_w2, out_nh); + int2 output_pos3 = (int2)(outpos_main + out_w3, out_nh); + + const sampler_t sampler = + CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + + int2 stride_xy = (int2)(stride, stride); + + int2 ouput_pos_in_one_block0 = (int2)(out_w0, out_nh); + int2 in_pos_in_one_block0 = + ouput_pos_in_one_block0 * stride_xy + (int2)(offset, offset); + + int2 ouput_pos_in_one_block1 = (int2)(out_w1, out_nh); + int2 in_pos_in_one_block1 = + ouput_pos_in_one_block1 * stride_xy + (int2)(offset, offset); + + int2 ouput_pos_in_one_block2 = (int2)(out_w2, out_nh); + int2 in_pos_in_one_block2 = + ouput_pos_in_one_block2 * stride_xy + (int2)(offset, offset); + + int2 ouput_pos_in_one_block3 = (int2)(out_w3, out_nh); + int2 in_pos_in_one_block3 = + ouput_pos_in_one_block3 * stride_xy + (int2)(offset, offset); + +#ifdef BIASE_CH + CL_DTYPE4 output0 = + READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, (int2)(out_c, 0)); + CL_DTYPE4 output1 = output0; + CL_DTYPE4 output2 = output0; + CL_DTYPE4 output3 = output0; +#elif defined(BIASE_ELE) + CL_DTYPE4 output0 = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, output_pos0); + CL_DTYPE4 output1 = output0; + CL_DTYPE4 output2 = output0; + CL_DTYPE4 output3 = output0; + +#else + CL_DTYPE4 output0 = 0.0f; + CL_DTYPE4 output1 = 0.0f; + CL_DTYPE4 output2 = 0.0f; + CL_DTYPE4 output3 = 0.0f; +#endif + + for (int i = 0; i < input_c; ++i) { + // ------------0--------------- + int2 pos_in = (int2)(i * input_width + in_pos_in_one_block0.x, + in_pos_in_one_block0.y); + CL_DTYPE4 input0 = + READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, pos_in); + + CL_DTYPE4 weight0 = + READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, (int2)(out_c, i * 4 + 0)); + CL_DTYPE4 weight1 = + READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, (int2)(out_c, i * 4 + 1)); + CL_DTYPE4 weight2 = + READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, (int2)(out_c, i * 4 + 2)); + CL_DTYPE4 weight3 = + READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, (int2)(out_c, i * 4 + 3)); + + output0 = mad(input0.x, weight0, output0); + output0 = mad(input0.y, weight1, output0); + output0 = mad(input0.z, weight2, output0); + output0 = mad(input0.w, weight3, output0); + + pos_in = (int2)(i * input_width + in_pos_in_one_block1.x, + in_pos_in_one_block1.y); + CL_DTYPE4 input1 = + READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, pos_in); + output1 = mad(input1.x, weight0, output1); + output1 = mad(input1.y, weight1, output1); + output1 = mad(input1.z, weight2, output1); + output1 = mad(input1.w, weight3, output1); + + pos_in = (int2)(i * input_width + in_pos_in_one_block2.x, + in_pos_in_one_block2.y); + CL_DTYPE4 input2 = + READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, pos_in); + output2 = mad(input2.x, weight0, output2); + output2 = mad(input2.y, weight1, output2); + output2 = mad(input2.z, weight2, output2); + output2 = mad(input2.w, weight3, output2); + + pos_in = (int2)(i * input_width + in_pos_in_one_block3.x, + in_pos_in_one_block3.y); + CL_DTYPE4 input3 = + READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, pos_in); + output3 = mad(input3.x, weight0, output3); + output3 = mad(input3.y, weight1, output3); + output3 = mad(input3.z, weight2, output3); + output3 = mad(input3.w, weight3, output3); + } + +#ifdef BATCH_NORM + output0 = output0 * READ_IMG_TYPE( + CL_DTYPE_CHAR, new_scale, sampler, (int2)(out_c, 0)) + + READ_IMG_TYPE(CL_DTYPE_CHAR, new_biase, sampler, (int2)(out_c, 0)); + + output1 = output1 * READ_IMG_TYPE( + CL_DTYPE_CHAR, new_scale, sampler, (int2)(out_c, 0)) + + READ_IMG_TYPE(CL_DTYPE_CHAR, new_biase, sampler, (int2)(out_c, 0)); + + output2 = output2 * READ_IMG_TYPE( + CL_DTYPE_CHAR, new_scale, sampler, (int2)(out_c, 0)) + + READ_IMG_TYPE(CL_DTYPE_CHAR, new_biase, sampler, (int2)(out_c, 0)); + + output3 = output3 * READ_IMG_TYPE( + CL_DTYPE_CHAR, new_scale, sampler, (int2)(out_c, 0)) + + READ_IMG_TYPE(CL_DTYPE_CHAR, new_biase, sampler, (int2)(out_c, 0)); +#endif + + + output0 = activation_type4(output0); + output1 = activation_type4(output1); + output2 = activation_type4(output2); + output3 = activation_type4(output3); + + + if (out_w0 < old_w) { + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, output_pos0, output0); + } + + if (out_w1 < old_w) { + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, output_pos1, output1); + } + + if (out_w2 < old_w) { + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, output_pos2, output2); + } + + if (out_w3 < old_w) { + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, output_pos3, output3); + } +} diff --git a/lite/backends/opencl/cl_kernel/image/conv2d_3x3_kernel.cl b/lite/backends/opencl/cl_kernel/image/conv2d_3x3_kernel.cl new file mode 100644 index 0000000000000000000000000000000000000000..8d7950d6b897df833ada56e2de5be7c6203de9ea --- /dev/null +++ b/lite/backends/opencl/cl_kernel/image/conv2d_3x3_kernel.cl @@ -0,0 +1,428 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +__kernel void conv2d_3x3(__private const int global_size_dim0, + __private const int global_size_dim1, + __private const int global_size_dim2, + __read_only image2d_t input_image, + __read_only image2d_t filter, +#if defined(BIASE_CH) || defined(BIASE_ELE) + __read_only image2d_t bias, +#endif + __write_only image2d_t output_image, + __private const int stride, + __private const int offset, + __private const int input_c, + __private const int dilation, + __private const int input_width,/* of one block */ + __private const int input_height,/* of one block */ + __private const int output_width, + __private const int output_height, + __private const int output_c, + __private const int filter_channel, + __private const int filter_width, + __private const int filter_height, + __private const int group) { + + const int out_c = get_global_id(0); + const int out_w = get_global_id(1); + const int out_nh = get_global_id(2); + + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | + CLK_ADDRESS_CLAMP | + CLK_FILTER_NEAREST; + + int2 output_pos = (int2)(out_c * global_size_dim1 + out_w, out_nh); + + if (out_c >= global_size_dim0 || + out_w >= global_size_dim1 || + out_nh >= global_size_dim2) { + return; + } + + + int2 stride_xy; + stride_xy.x = stride; + stride_xy.y = stride; + + int2 ouput_pos_in_one_block; + ouput_pos_in_one_block.x = out_w; + ouput_pos_in_one_block.y = out_nh; + + int2 in_pos_in_one_block; + in_pos_in_one_block.x = ouput_pos_in_one_block.x * stride + offset; + in_pos_in_one_block.y = ouput_pos_in_one_block.y * stride + offset; + +#ifdef BIASE_CH + CL_DTYPE4 output = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, (int2)(out_c, 0)); +#elif defined(BIASE_ELE) + CL_DTYPE4 output = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, output_pos); +#else + CL_DTYPE4 output = 0.0f; +#endif + + CL_DTYPE4 input[9]; // 3x3 region of input + if (group == 1) { + for (int i = 0; i < input_c; ++i) { // each run for 3x3 + int2 pos_in = (int2)(i * input_width + in_pos_in_one_block.x, in_pos_in_one_block.y); + + input[0] = select(READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, + (int2)(pos_in.x - dilation, pos_in.y - dilation)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x - dilation < 0 || in_pos_in_one_block.y - dilation < 0 || in_pos_in_one_block.x - dilation >= input_width || in_pos_in_one_block.y - dilation >= input_height) << 15)); + + input[1] = select(READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, + (int2)(pos_in.x, pos_in.y - dilation)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x < 0 || in_pos_in_one_block.y - dilation < 0 || in_pos_in_one_block.x >= input_width || in_pos_in_one_block.y - dilation >= input_height) << 15)); + + input[2] = select(READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, + (int2)(pos_in.x + dilation, pos_in.y - dilation)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x + dilation < 0 || in_pos_in_one_block.y - dilation < 0 || in_pos_in_one_block.x + dilation >= input_width || in_pos_in_one_block.y - dilation >= input_height) << 15)); + + input[3] = select(READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, + (int2)(pos_in.x - dilation, pos_in.y)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x - dilation < 0 || in_pos_in_one_block.y < 0 || in_pos_in_one_block.x - dilation >= input_width || in_pos_in_one_block.y >= input_height) << 15)); + + input[4] = select(READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, + (int2)(pos_in.x, pos_in.y)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x < 0 || in_pos_in_one_block.y < 0 || in_pos_in_one_block.x >= input_width || in_pos_in_one_block.y >= input_height) << 15)); + + input[5] = select(READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, + (int2)(pos_in.x + dilation, pos_in.y)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x + dilation < 0 || in_pos_in_one_block.y < 0 || in_pos_in_one_block.x + dilation >= input_width || in_pos_in_one_block.y >= input_height) << 15)); + + input[6] = select(READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, + (int2)(pos_in.x - dilation, pos_in.y + dilation)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x - dilation < 0 || in_pos_in_one_block.y + dilation < 0 || in_pos_in_one_block.x - dilation >= input_width || in_pos_in_one_block.y + dilation >= input_height) << 15)); + + input[7] = select(READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, + (int2)(pos_in.x, pos_in.y + dilation)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x < 0 || in_pos_in_one_block.y + dilation < 0 || in_pos_in_one_block.x >= input_width || in_pos_in_one_block.y + dilation >= input_height) << 15)); + + input[8] = select(READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, + (int2)(pos_in.x + dilation, pos_in.y + dilation)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x + dilation < 0 || in_pos_in_one_block.y + dilation < 0 || in_pos_in_one_block.x + dilation >= input_width || in_pos_in_one_block.y + dilation >= input_height) << 15)); + + int j = 0; + int2 pos_of_weight; + pos_of_weight.x = i * 3 + j % 3; + pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3; + CL_DTYPE4 weight_x = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.x += dot(input[j], weight_x); + + pos_of_weight.y += 3; + CL_DTYPE4 weight_y = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.y += dot(input[j], weight_y); + + pos_of_weight.y += 3; + CL_DTYPE4 weight_z = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.z += dot(input[j], weight_z); + + pos_of_weight.y += 3; + CL_DTYPE4 weight_w = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.w += dot(input[j], weight_w); + + j = 1; + pos_of_weight.x = i * 3 + j % 3; + pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3; + weight_x = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.x += dot(input[j], weight_x); + + pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3; + weight_y = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.y += dot(input[j], weight_y); + + pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3; + weight_z = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.z += dot(input[j], weight_z); + + pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3; + weight_w = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.w += dot(input[j], weight_w); + + j = 2; + pos_of_weight.x = i * 3 + j % 3; + pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3; + weight_x = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.x += dot(input[j], weight_x); + + pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3; + weight_y = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.y += dot(input[j], weight_y); + + pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3; + weight_z = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.z += dot(input[j], weight_z); + + pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3; + weight_w = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.w += dot(input[j], weight_w); + + j = 3; + pos_of_weight.x = i * 3 + j % 3; + pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3; + weight_x = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.x += dot(input[j], weight_x); + + pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3; + weight_y = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.y += dot(input[j], weight_y); + + pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3; + weight_z = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.z += dot(input[j], weight_z); + + pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3; + weight_w = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.w += dot(input[j], weight_w); + + j = 4; + pos_of_weight.x = i * 3 + j % 3; + pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3; + weight_x = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.x += dot(input[j], weight_x); + + pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3; + weight_y = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.y += dot(input[j], weight_y); + + pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3; + weight_z = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.z += dot(input[j], weight_z); + + pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3; + weight_w = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.w += dot(input[j], weight_w); + + j = 5; + pos_of_weight.x = i * 3 + j % 3; + pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3; + weight_x = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.x += dot(input[j], weight_x); + + pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3; + weight_y = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.y += dot(input[j], weight_y); + + pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3; + weight_z = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.z += dot(input[j], weight_z); + + pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3; + weight_w = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.w += dot(input[j], weight_w); + + j = 6; + pos_of_weight.x = i * 3 + j % 3; + pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3; + weight_x = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.x += dot(input[j], weight_x); + + pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3; + weight_y = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.y += dot(input[j], weight_y); + + pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3; + weight_z = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.z += dot(input[j], weight_z); + + pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3; + weight_w = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.w += dot(input[j], weight_w); + + j = 7; + pos_of_weight.x = i * 3 + j % 3; + pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3; + weight_x = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.x += dot(input[j], weight_x); + + pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3; + weight_y = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.y += dot(input[j], weight_y); + + pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3; + weight_z = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.z += dot(input[j], weight_z); + + pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3; + weight_w = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.w += dot(input[j], weight_w); + + j = 8; + pos_of_weight.x = i * 3 + j % 3; + pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3; + weight_x = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.x += dot(input[j], weight_x); + + pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3; + weight_y = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.y += dot(input[j], weight_y); + + pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3; + weight_z = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.z += dot(input[j], weight_z); + + pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3; + weight_w = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + output.w += dot(input[j], weight_w); + } + } else { // group != 1 + for (int i = 0; i < 4; i++) { + int used_input_channel_num = + (out_c * 4 + i) / (output_c / group) * filter_channel; + for (int f_c = 0; f_c < filter_channel; ++f_c) { + int input_c = used_input_channel_num + f_c; + int input_block = input_c / 4; + int2 pos_in = (int2)(input_block * input_width + in_pos_in_one_block.x, + in_pos_in_one_block.y); + input[0] = select( + READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, + (int2)(pos_in.x - dilation, pos_in.y - dilation)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x - dilation < 0 || + in_pos_in_one_block.y - dilation < 0 || + in_pos_in_one_block.x - dilation >= input_width || + in_pos_in_one_block.y - dilation >= input_height) + << 15)); + input[1] = + select(READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, + (int2)(pos_in.x, pos_in.y - dilation)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x < 0 || + in_pos_in_one_block.y - dilation < 0 || + in_pos_in_one_block.x >= input_width || + in_pos_in_one_block.y - dilation >= input_height) + << 15)); + input[2] = select( + READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, + (int2)(pos_in.x + dilation, pos_in.y - dilation)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x + dilation < 0 || + in_pos_in_one_block.y - dilation < 0 || + in_pos_in_one_block.x + dilation >= input_width || + in_pos_in_one_block.y - dilation >= input_height) + << 15)); + input[3] = select( + READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, + (int2)(pos_in.x - dilation, pos_in.y)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x - dilation < 0 || + in_pos_in_one_block.y < 0 || + in_pos_in_one_block.x - dilation >= input_width || + in_pos_in_one_block.y >= input_height) + << 15)); + input[4] = select( + READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, (int2)(pos_in.x, pos_in.y)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x < 0 || in_pos_in_one_block.y < 0 || + in_pos_in_one_block.x >= input_width || + in_pos_in_one_block.y >= input_height) + << 15)); + input[5] = + select(READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, + (int2)(pos_in.x + dilation, pos_in.y)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x + dilation < 0 || + in_pos_in_one_block.y < 0 || + in_pos_in_one_block.x + dilation >= input_width || + in_pos_in_one_block.y >= input_height) + << 15)); + input[6] = select( + READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, + (int2)(pos_in.x - dilation, pos_in.y + dilation)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x - dilation < 0 || + in_pos_in_one_block.y + dilation < 0 || + in_pos_in_one_block.x - dilation >= input_width || + in_pos_in_one_block.y + dilation >= input_height) + << 15)); + input[7] = + select(READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, + (int2)(pos_in.x, pos_in.y + dilation)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x < 0 || + in_pos_in_one_block.y + dilation < 0 || + in_pos_in_one_block.x >= input_width || + in_pos_in_one_block.y + dilation >= input_height) + << 15)); + input[8] = select( + READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, + (int2)(pos_in.x + dilation, pos_in.y + dilation)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x + dilation < 0 || + in_pos_in_one_block.y + dilation < 0 || + in_pos_in_one_block.x + dilation >= input_width || + in_pos_in_one_block.y + dilation >= input_height) + << 15)); + + CL_DTYPE tmp_out = 0; + for (int j = 0; j < 9; j++) { + int2 pos_of_weight; + pos_of_weight.x = (f_c / 4) * 3 + j % 3; + pos_of_weight.y = out_c * 4 * 3 + i * 3 + j / 3; + CL_DTYPE4 weight = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); + + int f_c_offset = f_c % 4; + CL_DTYPE f_value; + if (f_c_offset == 0) { + f_value = weight.x; + } else if (f_c_offset == 1) { + f_value = weight.y; + } else if (f_c_offset == 2) { + f_value = weight.z; + } else if (f_c_offset == 3) { + f_value = weight.w; + } + + int input_c_offset = input_c % 4; + CL_DTYPE input_value; + if (input_c_offset == 0) { + input_value = input[j].x; + } else if (input_c_offset == 1) { + input_value = input[j].y; + } else if (input_c_offset == 2) { + input_value = input[j].z; + } else if (input_c_offset == 3) { + input_value = input[j].w; + } + tmp_out += f_value * input_value; + } + + if (i == 0) { + output.x += tmp_out; + } else if (i == 1) { + output.y += tmp_out; + } else if (i == 2) { + output.z += tmp_out; + } else if (i == 3) { + output.w += tmp_out; + } + } + } + } + + output = activation_type4(output); + + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, output_pos, output); +} diff --git a/lite/backends/opencl/cl_kernel/image/conv2d_5x5_kernel.cl b/lite/backends/opencl/cl_kernel/image/conv2d_5x5_kernel.cl new file mode 100644 index 0000000000000000000000000000000000000000..d856af6a1d4026b1595bc287901e53f64267dc81 --- /dev/null +++ b/lite/backends/opencl/cl_kernel/image/conv2d_5x5_kernel.cl @@ -0,0 +1,169 @@ +#include + +__kernel void conv2d_5x5(__private const int global_size_dim0, + __private const int global_size_dim1, + __private const int global_size_dim2, + __read_only image2d_t input_image, + __read_only image2d_t filter_image, +#if defined(BIASE_CH) || defined(BIASE_ELE) + __read_only image2d_t bias, +#endif +#ifdef BATCH_NORM + __read_only image2d_t new_scale, + __read_only image2d_t new_biase, +#endif + __write_only image2d_t output_image, + __private const int stride, + __private const int offset, + __private const int input_c, + __private const int dilation, + __private const int input_width, /* of one block */ + __private const int input_height, /* of one block */ + __private const int output_width, + __private const int output_height) { + + const int out_c = get_global_id(0); + const int out_w = get_global_id(1); + const int out_nh = get_global_id(2); + + int2 output_pos = (int2)(out_c * global_size_dim1 + out_w, out_nh); + + if (out_c >= global_size_dim0 || out_w >= global_size_dim1 || + out_nh >= global_size_dim2) { + return; + } + + const int batch_index = out_nh / output_height; + const int out_nh_in_one_batch = out_nh % output_height; + + const int filter_n0 = 4 * out_c + 0; + const int filter_n1 = 4 * out_c + 1; + const int filter_n2 = 4 * out_c + 2; + const int filter_n3 = 4 * out_c + 3; + + int2 stride_xy; + stride_xy.x = stride; + stride_xy.y = stride; + + int2 ouput_pos_in_one_block; + ouput_pos_in_one_block.x = out_w; + ouput_pos_in_one_block.y = out_nh_in_one_batch; + + const sampler_t sampler = + CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + + int2 in_pos_in_one_block; + in_pos_in_one_block.x = ouput_pos_in_one_block.x * stride + offset; + in_pos_in_one_block.y = ouput_pos_in_one_block.y * stride + offset; + +#ifdef BIASE_CH + CL_DTYPE4 output = + READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, (int2)(out_c, 0)); +#elif defined(BIASE_ELE) + CL_DTYPE4 output = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, output_pos); +#else + CL_DTYPE4 output = 0.0f; +#endif + + CL_DTYPE4 input; + CL_DTYPE4 filter[4]; + int2 filter_pos0; + int2 filter_pos1; + int2 filter_pos2; + int2 filter_pos3; + for (int i = 0; i < input_c; ++i) { + int2 pos_in = (int2)(i * input_width + in_pos_in_one_block.x, + in_pos_in_one_block.y + batch_index * input_height); + for (int j = 0; j < 5; j++) { + for (int k = 0; k < 5; k++) { + input = select( + READ_IMG_TYPE(CL_DTYPE_CHAR, + input_image, + sampler, + (int2)(pos_in.x + (j - 2) * dilation, + pos_in.y + (k - 2) * dilation)), + (CL_DTYPE4)(0.0f), + (ushort4)( + (in_pos_in_one_block.x + (j - 2) * dilation < 0 || + in_pos_in_one_block.y + (k - 2) * dilation < 0 || + in_pos_in_one_block.x + (j - 2) * dilation >= input_width || + in_pos_in_one_block.y + (k - 2) * dilation >= input_height) + << 15)); + int filter_h = k; + int filter_w = j; + int filter_c = i; + + filter_pos0.x = filter_c * 5 + filter_w; + filter_pos0.y = filter_n0 * 5 + filter_h; + + filter_pos1.x = filter_c * 5 + filter_w; + filter_pos1.y = filter_n1 * 5 + filter_h; + + filter_pos2.x = filter_c * 5 + filter_w; + filter_pos2.y = filter_n2 * 5 + filter_h; + + filter_pos3.x = filter_c * 5 + filter_w; + filter_pos3.y = filter_n3 * 5 + filter_h; + + filter[0] = + READ_IMG_TYPE(CL_DTYPE_CHAR, filter_image, sampler, filter_pos0); + filter[1] = + READ_IMG_TYPE(CL_DTYPE_CHAR, filter_image, sampler, filter_pos1); + filter[2] = + READ_IMG_TYPE(CL_DTYPE_CHAR, filter_image, sampler, filter_pos2); + filter[3] = + READ_IMG_TYPE(CL_DTYPE_CHAR, filter_image, sampler, filter_pos3); + + output.x += dot(input, filter[0]); + output.y += dot(input, filter[1]); + output.z += dot(input, filter[2]); + output.w += dot(input, filter[3]); + // + // if (output_pos.x == 0 && output_pos.y == 5) { + // printf("i,j,k ={ %d, %d , %d }\n", i,j,k); + // printf("in={ %f , %f , %f , %f } \n", + // convert_float(input.x), + // convert_float(input.y), + // convert_float(input.z), + // convert_float(input.w)); + // printf("filter0={ %f , %f , %f , %f } \n", + // convert_float(filter[0].x), + // convert_float(filter[0].y), + // convert_float(filter[0].z), + // convert_float(filter[0].w)); + // printf("filter1={ %f , %f , %f , %f } \n", + // convert_float(filter[1].x), + // convert_float(filter[1].y), + // convert_float(filter[1].z), + // convert_float(filter[1].w)); + // printf("filter2={ %f , %f , %f , %f } \n", + // convert_float(filter[2].x), + // convert_float(filter[2].y), + // convert_float(filter[2].z), + // convert_float(filter[2].w)); + // printf("filter3={ %f , %f , %f , %f } \n", + // convert_float(filter[3].x), + // convert_float(filter[3].y), + // convert_float(filter[3].z), + // convert_float(filter[3].w)); + // printf("output={ %f , %f , %f , %f } \n", + // convert_float(output.x), + // convert_float(output.y), + // convert_float(output.z), + // convert_float(output.w)); + // } + } + } + } + +#ifdef BATCH_NORM + output = + output * READ_IMG_TYPE( + CL_DTYPE_CHAR, new_scale, sampler, (int2)(out_c, 0)) + + READ_IMG_TYPE(CL_DTYPE_CHAR, new_biase, sampler, (int2)(out_c, 0)); +#endif + + output = activation_type4(output); + + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, output_pos, output); + } diff --git a/lite/backends/opencl/cl_kernel/image/conv2d_7x7_kernel.cl b/lite/backends/opencl/cl_kernel/image/conv2d_7x7_kernel.cl new file mode 100644 index 0000000000000000000000000000000000000000..1f99322812c13287af92b52aee6c346309ee006c --- /dev/null +++ b/lite/backends/opencl/cl_kernel/image/conv2d_7x7_kernel.cl @@ -0,0 +1,134 @@ +#include + +__kernel void conv2d_7x7(__private const int global_size_dim0, + __private const int global_size_dim1, + __private const int global_size_dim2, + __read_only image2d_t input_image, + __read_only image2d_t filter_image, +#if defined(BIASE_CH) || defined(BIASE_ELE) + __read_only image2d_t bias, +#endif +#ifdef BATCH_NORM + __read_only image2d_t new_scale, + __read_only image2d_t new_biase, +#endif + __write_only image2d_t output_image, + __private const int stride, + __private const int offset, + __private const int input_c, + __private const int dilation, + __private const int input_width, /* of one block */ + __private const int input_height, /* of one block */ + __private const int output_width, + __private const int output_height) { + + const int out_c = get_global_id(0); + const int out_w = get_global_id(1); + const int out_nh = get_global_id(2); + + int2 output_pos = (int2)(out_c * global_size_dim1 + out_w, out_nh); + + if (out_c >= global_size_dim0 || out_w >= global_size_dim1 || + out_nh >= global_size_dim2) { + return; + } + + const int batch_index = out_nh / output_height; + const int out_nh_in_one_batch = out_nh % output_height; + + const filter_n0 = 4 * out_c + 0; + const filter_n1 = 4 * out_c + 1; + const filter_n2 = 4 * out_c + 2; + const filter_n3 = 4 * out_c + 3; + + int2 stride_xy; + stride_xy.x = stride; + stride_xy.y = stride; + + int2 ouput_pos_in_one_block; + ouput_pos_in_one_block.x = out_w; + ouput_pos_in_one_block.y = out_nh_in_one_batch; + + const sampler_t sampler = + CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + + int2 in_pos_in_one_block; + in_pos_in_one_block.x = ouput_pos_in_one_block.x * stride + offset; + in_pos_in_one_block.y = ouput_pos_in_one_block.y * stride + offset; + +#ifdef BIASE_CH + CL_DTYPE4 output = + READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, (int2)(out_c, 0)); +#elif defined(BIASE_ELE) + CL_DTYPE4 output = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, output_pos); +#else + CL_DTYPE4 output = 0.0f; +#endif + + CL_DTYPE4 input; + CL_DTYPE4 filter[4]; + int2 filter_pos0; + int2 filter_pos1; + int2 filter_pos2; + int2 filter_pos3; + for (int i = 0; i < input_c; ++i) { + int2 pos_in = (int2)(i * input_width + in_pos_in_one_block.x, + in_pos_in_one_block.y + batch_index * input_height); + for (int j = 0; j < 7; j++) { + for (int k = 0; k < 7; k++) { + input = select( + READ_IMG_TYPE(CL_DTYPE_CHAR, + input_image, + sampler, + (int2)(pos_in.x + (j - 3) * dilation, + pos_in.y + (k - 3) * dilation)), + (CL_DTYPE4)(0.0f), + (ushort4)( + (in_pos_in_one_block.x + (j - 3) * dilation < 0 || + in_pos_in_one_block.y + (k - 3) * dilation < 0 || + in_pos_in_one_block.x + (j - 3) * dilation >= input_width || + in_pos_in_one_block.y + (k - 3) * dilation >= input_height) + << 15)); + int filter_h = k; + int filter_w = j; + int filter_c = i; + + filter_pos0.x = filter_c * 7 + filter_w; + filter_pos0.y = filter_n0 * 7 + filter_h; + + filter_pos1.x = filter_c * 7 + filter_w; + filter_pos1.y = filter_n1 * 7 + filter_h; + + filter_pos2.x = filter_c * 7 + filter_w; + filter_pos2.y = filter_n2 * 7 + filter_h; + + filter_pos3.x = filter_c * 7 + filter_w; + filter_pos3.y = filter_n3 * 7 + filter_h; + + filter[0] = + READ_IMG_TYPE(CL_DTYPE_CHAR, filter_image, sampler, filter_pos0); + filter[1] = + READ_IMG_TYPE(CL_DTYPE_CHAR, filter_image, sampler, filter_pos1); + filter[2] = + READ_IMG_TYPE(CL_DTYPE_CHAR, filter_image, sampler, filter_pos2); + filter[3] = + READ_IMG_TYPE(CL_DTYPE_CHAR, filter_image, sampler, filter_pos3); + + output.x += dot(input, filter[0]); + output.y += dot(input, filter[1]); + output.z += dot(input, filter[2]); + output.w += dot(input, filter[3]); + } + } + } + +#ifdef BATCH_NORM + output = output * READ_IMG_TYPE( + CL_DTYPE_CHAR, new_scale, sampler, (int2)(out_c, 0)) + + READ_IMG_TYPE(CL_DTYPE_CHAR, new_biase, sampler, (int2)(out_c, 0)); +#endif + + output = activation_type4(output); + + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, output_pos, output); +} diff --git a/lite/backends/opencl/cl_kernel/image/depthwise_conv2d_basic_kernel.cl b/lite/backends/opencl/cl_kernel/image/depthwise_conv2d_basic_kernel.cl new file mode 100755 index 0000000000000000000000000000000000000000..27313aea23ed16ecc7a6763dfbbbe63bca18941a --- /dev/null +++ b/lite/backends/opencl/cl_kernel/image/depthwise_conv2d_basic_kernel.cl @@ -0,0 +1,101 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +__kernel void depth_conv2d(__private const int global_size_dim0, + __private const int global_size_dim1, + __private const int global_size_dim2, + __read_only image2d_t input, + __read_only image2d_t filter, +#if defined(BIASE_CH) || defined(BIASE_ELE) + __read_only image2d_t bias, +#endif +#ifdef BATCH_NORM + __read_only image2d_t new_scale, + __read_only image2d_t new_biase, +#endif + __write_only image2d_t output_image, + __private const int stride, + __private const int offset, + __private const int input_c, + __private const int dilation, + __private const int input_width, /* of one block */ + __private const int input_height, /* of one block */ + __private const int output_width, + __private const int output_height, + __private const int filter_width, + __private const int filter_height) { + + const int out_c = get_global_id(0); + const int out_w = get_global_id(1); + const int out_nh = get_global_id(2); + + int2 output_pos = (int2)(out_c * global_size_dim1 + out_w, out_nh); + const sampler_t sampler = + CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + const int batch_index = out_nh / output_height; + const int out_nh_in_one_batch = out_nh % output_height; + int2 stride_xy = (int2)(stride, stride); + int2 ouput_pos_in_one_block = (int2)(out_w, out_nh_in_one_batch); + int2 in_pos_in_one_block = + ouput_pos_in_one_block * stride_xy + (int2)(offset, offset); +#ifdef BIASE_CH + CL_DTYPE4 output = + READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, (int2)(out_c, 0)); +#elif defined(BIASE_ELE) + CL_DTYPE4 output = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, output_pos); +#else + CL_DTYPE4 output = 0.0f; +#endif + + int2 pos_in_input_block = + (int2)(out_c * input_width, batch_index * input_height); + int2 pos_in_filter_block = + (int2)(out_c * filter_width, batch_index * filter_height); + int filter_x = pos_in_filter_block.x; + int filter_y = pos_in_filter_block.y; + int input_x_base = pos_in_input_block.x + in_pos_in_one_block.x; + int input_y_base = pos_in_input_block.y + in_pos_in_one_block.y; + int2 align = {filter_width / 2, filter_height / 2}; + for (int fy = 0; fy < filter_height; ++fy) { + for (int fx = 0; fx < filter_width; ++fx) { + int x_off = fx - align.x; + int y_off = fy - align.y; + CL_DTYPE4 in = select( + READ_IMG_TYPE(CL_DTYPE_CHAR, + input, + sampler, + (int2)(input_x_base + x_off, input_y_base + y_off)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x + x_off < 0 || + in_pos_in_one_block.y + y_off < 0 || + in_pos_in_one_block.x + x_off >= input_width || + in_pos_in_one_block.y + y_off >= input_height) + << 15)); + CL_DTYPE4 f = READ_IMG_TYPE( + CL_DTYPE_CHAR, filter, sampler, (int2)(filter_x + fx, filter_y + fy)); + output += in * f; + } + } +#ifdef BATCH_NORM + output = output * READ_IMG_TYPE( + CL_DTYPE_CHAR, new_scale, sampler, (int2)(out_c, 0)) + + READ_IMG_TYPE(CL_DTYPE_CHAR, new_biase, sampler, (int2)(out_c, 0)); +#endif + + output = activation_type4(output); + + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, output_pos, output); +} diff --git a/lite/backends/opencl/cl_kernel/image/depthwise_conv2d_kernel.cl b/lite/backends/opencl/cl_kernel/image/depthwise_conv2d_kernel.cl new file mode 100755 index 0000000000000000000000000000000000000000..14086dcd16bd1a8770f444bdcd0b6bea78e23b7e --- /dev/null +++ b/lite/backends/opencl/cl_kernel/image/depthwise_conv2d_kernel.cl @@ -0,0 +1,322 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + + +#include + +__kernel void depth_conv2d_3x3(__private const int global_size_dim0, + __private const int global_size_dim1, + __private const int global_size_dim2, + __read_only image2d_t input, + __read_only image2d_t filter, +#if defined(BIASE_CH) || defined(BIASE_ELE) + __read_only image2d_t bias, +#endif +#ifdef BATCH_NORM + __read_only image2d_t new_scale, + __read_only image2d_t new_biase, +#endif + __write_only image2d_t output_image, + __private const int stride, + __private const int offset, + __private const int dilation, + __private const int input_c, + __private const int input_width,/* of one block */ + __private const int input_height, /* of one block */ + __private const int output_width, + __private const int output_height) { + + const int out_c = get_global_id(0); + const int out_w = get_global_id(1); + const int out_nh = get_global_id(2); + + int2 output_pos = (int2)(out_c * global_size_dim1 + out_w, out_nh); + + + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | + CLK_ADDRESS_CLAMP | + CLK_FILTER_NEAREST; + + const int batch_index = out_nh / output_height; + + const int out_nh_in_one_batch = out_nh % output_height; + + + int2 stride_xy = (int2)(stride, stride); + int2 ouput_pos_in_one_block = (int2)(out_w, out_nh_in_one_batch); + + int2 in_pos_in_one_block = ouput_pos_in_one_block * stride_xy + (int2)(offset, offset); + +#ifdef BIASE_CH + CL_DTYPE4 output = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, (int2)(out_c, 0)); +#elif defined(BIASE_ELE) + CL_DTYPE4 output = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, output_pos); +#else + CL_DTYPE4 output = 0.0f; +#endif + + const int filter_width = 3; + const int filter_height = 3; + + int2 pos_in_input_block = (int2)(out_c * input_width, batch_index * input_height); + + int2 pos_in_filter_block = (int2)(out_c * filter_width, batch_index * filter_height); + + int filter_x = pos_in_filter_block.x ; + int filter_y = pos_in_filter_block.y ; + + CL_DTYPE4 inputs[9]; + + inputs[0] = select(READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(pos_in_input_block.x + in_pos_in_one_block.x - 1, pos_in_input_block.y + in_pos_in_one_block.y - 1)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x - 1 < 0 || in_pos_in_one_block.y - 1 < 0 || in_pos_in_one_block.x - 1 >= input_width || in_pos_in_one_block.y - 1 >= input_height) << 15)); + + inputs[1] = select(READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(pos_in_input_block.x + in_pos_in_one_block.x, pos_in_input_block.y + in_pos_in_one_block.y - 1)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x < 0 || in_pos_in_one_block.y - 1 < 0 || in_pos_in_one_block.x >= input_width || in_pos_in_one_block.y - 1 >= input_height) << 15)); + + inputs[2] = select(READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(pos_in_input_block.x + in_pos_in_one_block.x + 1, pos_in_input_block.y + in_pos_in_one_block.y - 1)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x + 1 < 0 || in_pos_in_one_block.y - 1 < 0 || in_pos_in_one_block.x + 1 >= input_width || in_pos_in_one_block.y - 1 >= input_height) << 15)); + + inputs[3] = select(READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(pos_in_input_block.x + in_pos_in_one_block.x - 1, pos_in_input_block.y + in_pos_in_one_block.y)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x - 1 < 0 || in_pos_in_one_block.y < 0 || in_pos_in_one_block.x - 1 >= input_width || in_pos_in_one_block.y >= input_height) << 15)); + /* + if (output_pos.x == 112 && output_pos.y == 0) { + CL_DTYPE4 input1 = inputs[3]; + float4 in = (float4)(input1.x, input1.y, input1.z, input1.w); + printf(" input4 3 - %v4hlf \n", in); + printf(" --- %d ---\n", in_pos_in_one_block.x - 1); + } + */ + + + inputs[4] = select(READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(pos_in_input_block.x + in_pos_in_one_block.x, pos_in_input_block.y + in_pos_in_one_block.y)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x < 0 || in_pos_in_one_block.y < 0 || in_pos_in_one_block.x >= input_width || in_pos_in_one_block.y >= input_height) << 15)); + + inputs[5] = select(READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(pos_in_input_block.x + in_pos_in_one_block.x + 1, pos_in_input_block.y + in_pos_in_one_block.y)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x + 1 < 0 || in_pos_in_one_block.y < 0 || in_pos_in_one_block.x + 1 >= input_width || in_pos_in_one_block.y >= input_height) << 15)); + + inputs[6] = select(READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(pos_in_input_block.x + in_pos_in_one_block.x - 1, pos_in_input_block.y + in_pos_in_one_block.y + 1)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x - 1 < 0 || in_pos_in_one_block.y + 1 < 0 || in_pos_in_one_block.x - 1 >= input_width || in_pos_in_one_block.y + 1 >= input_height) << 15)); + + inputs[7] = select(READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(pos_in_input_block.x + in_pos_in_one_block.x, pos_in_input_block.y + in_pos_in_one_block.y + 1)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x < 0 || in_pos_in_one_block.y + 1 < 0 || in_pos_in_one_block.x >= input_width || in_pos_in_one_block.y + 1 >= input_height) << 15)); + + inputs[8] = select(READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(pos_in_input_block.x + in_pos_in_one_block.x + 1, pos_in_input_block.y + in_pos_in_one_block.y + 1)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x + 1 < 0 || in_pos_in_one_block.y + 1 < 0 || in_pos_in_one_block.x + 1 >= input_width || in_pos_in_one_block.y + 1 >= input_height) << 15)); + + CL_DTYPE4 filters[9]; + filters[0] = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler,(int2)(filter_x,filter_y)); + filters[1] = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler,(int2)(filter_x + 1,filter_y)); + filters[2] = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler,(int2)(filter_x + 2,filter_y)); + filters[3] = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler,(int2)(filter_x,filter_y + 1)); + filters[4] = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler,(int2)(filter_x + 1,filter_y + 1)); + filters[5] = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler,(int2)(filter_x + 2,filter_y + 1)); + filters[6] = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler,(int2)(filter_x,filter_y + 2)); + filters[7] = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler,(int2)(filter_x + 1,filter_y + 2)); + filters[8] = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler,(int2)(filter_x + 2,filter_y + 2)); + + for(int i = 0 ;i < 9 ; i++){ + output += inputs[i] * filters[i]; + } +#ifdef BATCH_NORM + output = output * READ_IMG_TYPE(CL_DTYPE_CHAR, new_scale, sampler, (int2)(out_c, 0)) + READ_IMG_TYPE(CL_DTYPE_CHAR, new_biase, sampler, (int2)(out_c, 0)); +#endif + +#ifdef RELU + output = activation_type4(output); +#endif + + + /* + + if (output_pos.x == 112 && output_pos.y == 0) { + + for (int i = 0; i < 9; ++i) { + CL_DTYPE4 input1 = inputs[i]; + float4 in = (float4)(input1.x, input1.y, input1.z, input1.w); + printf(" input4 %d - %v4hlf \n", i, in); + } + + float4 out = (float4)(output.x, output.y, output.z, output.w); + printf(" depth wise output output4 = %v4hlf \n", out); + printf(" pos_in_input_block -x %d \n ", pos_in_input_block.x); + printf(" pos_in_input_block -y %d \n ", pos_in_input_block.y); + printf(" in_pos_in_one_block - x %d \n", in_pos_in_one_block.x); + printf(" in_pos_in_one_block - y %d \n", in_pos_in_one_block.y); + } + + */ + + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, output_pos, output); + +} + + + +__kernel void depth_conv2d_3x3s1(__private const int ou_ch_blk, + __private const int ou_w_blk, + __private const int ou_nh, + __read_only image2d_t input, + __read_only image2d_t filter, +#if defined(BIASE_CH) || defined(BIASE_ELE) + __read_only image2d_t bias, +#endif +#ifdef BATCH_NORM + __read_only image2d_t new_scale, + __read_only image2d_t new_biase, +#endif + __write_only image2d_t output_image, + __private const int stride, + __private const int pad, + __private const int dilation, + __private const int in_ch, + __private const int in_w,/* of one block */ + __private const int in_h, /* of one block */ + __private const int ou_w, + __private const int ou_h) { + + const int ou_ch_blk_id = get_global_id(0); + const int ou_w_blk_id = get_global_id(1); + const int ou_nh_id = get_global_id(2); + const int w_blk_size = 2; + + const int batch_id = ou_nh_id / ou_h; + int ou_col_id = ou_w_blk_id * w_blk_size; + int ou_row_id = ou_nh_id % ou_h; + int ou_x = mad24(ou_ch_blk_id, ou_w, ou_col_id); + + // input pos in one block and on batch + int col_id = ou_col_id - pad; + int row_id = ou_row_id - pad; + + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | + CLK_ADDRESS_CLAMP | + CLK_FILTER_NEAREST; + +#ifdef BIASE_CH + CL_DTYPE4 output[2]; + output[0] = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, (int2)(ou_ch_blk_id, 0)); + output[1] = output[0]; +#elif defined(BIASE_ELE) + CL_DTYPE4 output[2]; + output[0] = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, (int2)(ou_x, ou_nh_id)); + if (ou_col_id + 1 < ou_w) { + output[1] = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, (int2)(ou_x + 1, ou_nh_id)); + } +#else + CL_DTYPE4 output[2] = {0.0f}; +#endif + + CL_DTYPE4 inputs[12]; + + int filter_x = ou_ch_blk_id * 3; + int filter_y = 0; + CL_DTYPE4 filters[9]; + filters[0] = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler,(int2)(filter_x,filter_y)); + filters[1] = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler,(int2)(filter_x + 1,filter_y)); + filters[2] = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler,(int2)(filter_x + 2,filter_y)); + + int in_x = mad24(ou_ch_blk_id, in_w, col_id); + int in_y = mad24(batch_id, in_h, row_id); + + int y0 = select(in_y, -1, row_id < 0 || row_id >= in_h); + int x0 = select(in_x, -1, col_id < 0 || col_id >= in_w); + inputs[0] = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x0, y0)); + int x1 = select(in_x + 1, -1, col_id + 1 < 0 || col_id + 1 >= in_w); + inputs[1] = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x1, y0)); + int x2 = select(in_x + 2, -1, col_id + 2 < 0 || col_id + 2 >= in_w); + inputs[2] = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x2, y0)); + int x3 = select(in_x + 3, -1, col_id + 3 < 0 || col_id + 3 >= in_w); + inputs[3] = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x3, y0)); + + output[0] = mad(inputs[0], filters[0], output[0]); + output[1] = mad(inputs[1], filters[0], output[1]); + + output[0] = mad(inputs[1], filters[1], output[0]); + output[1] = mad(inputs[2], filters[1], output[1]); + + output[0] = mad(inputs[2], filters[2], output[0]); + output[1] = mad(inputs[3], filters[2], output[1]); + + + filters[3] = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler,(int2)(filter_x,filter_y + 1)); + filters[4] = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler,(int2)(filter_x + 1,filter_y + 1)); + filters[5] = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler,(int2)(filter_x + 2,filter_y + 1)); + + + int y1 = select(in_y + 1, -1, row_id + 1 < 0 || row_id + 1 >= in_h); + inputs[4] = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x0, y1)); + inputs[5] = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x1, y1)); + inputs[6] = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x2, y1)); + inputs[7] = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x3, y1)); + + + output[0] = mad(inputs[4], filters[3], output[0]); + output[1] = mad(inputs[5], filters[3], output[1]); + + output[0] = mad(inputs[5], filters[4], output[0]); + output[1] = mad(inputs[6], filters[4], output[1]); + + output[0] = mad(inputs[6], filters[5], output[0]); + output[1] = mad(inputs[7], filters[5], output[1]); + + + filters[6] = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler,(int2)(filter_x,filter_y + 2)); + filters[7] = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler,(int2)(filter_x + 1,filter_y + 2)); + filters[8] = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler,(int2)(filter_x + 2,filter_y + 2)); + + int y2 = select(in_y + 2, -1, row_id + 2 < 0 || row_id + 2 >= in_h); + inputs[8] = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x0, y2)); + inputs[9] = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x1, y2)); + inputs[10] = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x2, y2)); + inputs[11] = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x3, y2)); + + + output[0] = mad(inputs[8], filters[6], output[0]); + output[1] = mad(inputs[9], filters[6], output[1]); + + output[0] = mad(inputs[9], filters[7], output[0]); + output[1] = mad(inputs[10], filters[7], output[1]); + + output[0] = mad(inputs[10], filters[8], output[0]); + output[1] = mad(inputs[11], filters[8], output[1]); +#ifdef BATCH_NORM + CL_DTYPE4 scale = READ_IMG_TYPE(CL_DTYPE_CHAR, new_scale, sampler, (int2)(ou_ch_blk_id, 0)); + CL_DTYPE4 biase = READ_IMG_TYPE(CL_DTYPE_CHAR, new_biase, sampler, (int2)(ou_ch_blk_id, 0)); + output[0] = mad(scale, output[0], biase); + if (ou_col_id + 1 < ou_w) { + output[1] = mad(scale, output[1], biase); + } +#endif + +#ifdef RELU + output[0] = activation_type4(output[0]); + output[1] = activation_type4(output[1]); +#endif + + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, (int2)(ou_x, ou_nh_id), output[0]); + if (ou_col_id + 1 < ou_w) { + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, (int2)(ou_x + 1, ou_nh_id), output[1]); + } + +} + diff --git a/lite/backends/opencl/cl_kernel/image/elementwise_add_kernel.cl b/lite/backends/opencl/cl_kernel/image/elementwise_add_kernel.cl index ecf719ae9316ed14743e872a1c2cde4b254b35ff..0d8867e6a79b57927c0d23ff549d3b845556dfd8 100644 --- a/lite/backends/opencl/cl_kernel/image/elementwise_add_kernel.cl +++ b/lite/backends/opencl/cl_kernel/image/elementwise_add_kernel.cl @@ -12,15 +12,74 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -__kernel void elementwise_add(__read_only image2d_t input, __read_only image2d_t bias, __write_only image2d_t outputImage) { +#include + +__kernel void elementwise_add(__read_only image2d_t input, + __read_only image2d_t bias, + __write_only image2d_t outputImage) { int x = get_global_id(0); int y = get_global_id(1); + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + int2 coords; coords.x = x; coords.y = y; - float4 in = read_imagef(input, sampler, coords); - float4 biase = read_imagef(bias, sampler, coords); - float4 output = in + biase; - write_imagef(outputImage,coords,output); + + CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, coords); + CL_DTYPE4 biase = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, coords); + CL_DTYPE4 output = activation_type4(in + biase); + + WRITE_IMG_TYPE(CL_DTYPE_CHAR, outputImage,coords,output); } + +__kernel void channel_add(__read_only image2d_t input, + __read_only image2d_t bias, + __write_only image2d_t outputImage, + int w) { + int x = get_global_id(0); + int y = get_global_id(1); + + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + int2 coords; + coords.x = x; + coords.y = y; + + int2 coords_bias; + coords_bias.x = x % w; + coords_bias.y = 0; + + CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, coords); + CL_DTYPE4 biase = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, coords_bias); + CL_DTYPE4 output = in + (CL_DTYPE4)(biase.x); + + WRITE_IMG_TYPE(CL_DTYPE_CHAR, outputImage, coords, output); + } + +__kernel void width_add(__read_only image2d_t input, + __read_only image2d_t bias, + __write_only image2d_t outputImage, + int w) { + int x = get_global_id(0); + int y = get_global_id(1); + + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + int2 coords; + coords.x = x; + coords.y = y; + + int2 coords_bias; + coords_bias.x = x % w; + coords_bias.y = 0; + + CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, coords); + CL_DTYPE4 biase = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, coords_bias); + CL_DTYPE4 output; + + output.x = in.x + biase.x; + output.y = in.y + biase.x; + output.z = in.z + biase.x; + output.w = in.w + biase.x; + + WRITE_IMG_TYPE(CL_DTYPE_CHAR, outputImage, coords, output); +} diff --git a/lite/backends/opencl/cl_kernel/image/elementwise_mul_kernel.cl b/lite/backends/opencl/cl_kernel/image/elementwise_mul_kernel.cl new file mode 100644 index 0000000000000000000000000000000000000000..17b6e8c72a82718a541841ff3c69c175649d7056 --- /dev/null +++ b/lite/backends/opencl/cl_kernel/image/elementwise_mul_kernel.cl @@ -0,0 +1,100 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +__kernel void elementwise_mul(__global image2d_t input, __global image2d_t bias, + __write_only image2d_t outputImage) { + int x = get_global_id(0); + int y = get_global_id(1); + const sampler_t sampler = + CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + int2 coords; + coords.x = x; + coords.y = y; + CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, coords); + CL_DTYPE4 biase = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, coords); + CL_DTYPE4 output = in * biase; + WRITE_IMG_TYPE(CL_DTYPE_CHAR, outputImage, coords, output); +} + +__kernel void channel_mul_d1(__read_only image2d_t input, __read_only image2d_t bias, + __write_only image2d_t outputImage, int w) { + int x = get_global_id(0); + int y = get_global_id(1); + + const sampler_t sampler = + CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + + int2 coords; + coords.x = x; + coords.y = y; + + int2 coords_bias; + coords_bias.x = x % w; + coords_bias.y = 0; + + CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, coords); + CL_DTYPE4 biase = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, coords_bias); + CL_DTYPE4 output = in * (CL_DTYPE4)(biase.x); + + WRITE_IMG_TYPE(CL_DTYPE_CHAR, outputImage, coords, output); +} + +__kernel void channel_mul_d2(__read_only image2d_t input, __read_only image2d_t bias, + __write_only image2d_t outputImage, int w, int h) { + int x = get_global_id(0); + int y = get_global_id(1); + + const sampler_t sampler = + CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + + int2 coords; + coords.x = x; + coords.y = y; + + int2 coords_bias; + coords_bias.x = x % w; + coords_bias.y = y % h; + + CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, coords); + CL_DTYPE4 biase = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, coords_bias); + CL_DTYPE4 output = in * (CL_DTYPE4)(biase.x); + + WRITE_IMG_TYPE(CL_DTYPE_CHAR, outputImage, coords, output); +} + +__kernel void channel_mul_d4(__read_only image2d_t input, __read_only image2d_t bias, + __write_only image2d_t outputImage, int w) { + int x = get_global_id(0); + int y = get_global_id(1); + + const sampler_t sampler = + CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + + int2 coords; + coords.x = x; + coords.y = y; + + int2 coords_bias; + coords_bias.x = x / w; + coords_bias.y = 0; + + CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, coords); + CL_DTYPE4 biase = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, coords_bias); + CL_DTYPE4 output = in * biase; + + WRITE_IMG_TYPE(CL_DTYPE_CHAR, outputImage, coords, output); +} + diff --git a/lite/backends/opencl/cl_kernel/image/nearest_interp_kernel.cl b/lite/backends/opencl/cl_kernel/image/nearest_interp_kernel.cl new file mode 100644 index 0000000000000000000000000000000000000000..b74449d9c8a02551cd74d366849768b4a91a4dce --- /dev/null +++ b/lite/backends/opencl/cl_kernel/image/nearest_interp_kernel.cl @@ -0,0 +1,37 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +__kernel void nearest_interp(__read_only image2d_t input, __write_only image2d_t output, + __private const float scale_h, __private const float scale_w, + __private const int in_dims_h, __private const int out_dims_h, + __private const int in_dims_w, __private const int out_dims_w) { + const int c = get_global_id(0); + const int w = get_global_id(1); + const int nh = get_global_id(2); + int2 output_pos; + output_pos.x = c * out_dims_w + w; + output_pos.y = nh; + int out_n = nh / out_dims_h; + int out_h = nh % out_dims_h; + int2 input_pos; + input_pos.x = c * in_dims_w + w / scale_w; + input_pos.y = out_n * in_dims_h + out_h / scale_h; + + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | + CLK_ADDRESS_CLAMP | + CLK_FILTER_NEAREST; + half4 input_data = read_imageh(input, sampler, (int2)(input_pos.x, input_pos.y)); + write_imageh(output, (int2)(output_pos.x , output_pos.y), input_data); +} diff --git a/lite/backends/opencl/cl_kernel/image/pool_kernel.cl b/lite/backends/opencl/cl_kernel/image/pool_kernel.cl index 0ca3b9141daf671737af8d24cd03e59587e33350..775166261d01dc639cd5af8cee49f7e7fb30cb19 100644 --- a/lite/backends/opencl/cl_kernel/image/pool_kernel.cl +++ b/lite/backends/opencl/cl_kernel/image/pool_kernel.cl @@ -12,15 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#define MIN_VALUE -FLT_MAX - -__kernel void pool_max( - __private const int in_height, __private const int in_width, - __private const int out_height, __private const int out_width, - __private const int pad_top, __private const int pad_left, - __private const int stride_h, __private const int stride_w, - __private const int ksize_h, __private const int ksize_w, - __read_only image2d_t input, __write_only image2d_t output) { +#include + +__kernel void pool_max(__read_only image2d_t input, + __write_only image2d_t output, + __private const int in_height, + __private const int in_width, + __private const int out_height, + __private const int out_width, + __private const int ksize_h, + __private const int ksize_w, + __private const int stride_h, + __private const int stride_w, + __private const int pad_top, + __private const int pad_left) { const int out_c = get_global_id(0); const int out_w = get_global_id(1); const int out_nh = get_global_id(2); @@ -40,25 +45,30 @@ __kernel void pool_max( const int pos_in_x = out_c * in_width; const int pos_in_y = out_n * in_height; - float4 max_value = (float4)(MIN_VALUE); + CL_DTYPE4 max_value = (CL_DTYPE4)(MIN_VALUE); for (int y = start_h; y < end_h; ++y) { for (int x = start_w; x < end_w; ++x) { - float4 tmp = read_imagef(input, sampler, (int2)(pos_in_x + x, pos_in_y + y)); + CL_DTYPE4 tmp = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(pos_in_x + x, pos_in_y + y)); max_value = max(max_value, tmp); } } const int pos_out_x = mad24(out_c, out_width, out_w); - write_imagef(output, (int2)(pos_out_x, out_nh), max_value); + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(pos_out_x, out_nh), max_value); } -__kernel void pool_avg( - __private const int in_height, __private const int in_width, - __private const int out_height, __private const int out_width, - __private const int pad_top, __private const int pad_left, - __private const int stride_h, __private const int stride_w, - __private const int ksize_h, __private const int ksize_w, - __read_only image2d_t input, __write_only image2d_t output) { +__kernel void pool_avg(__read_only image2d_t input, + __write_only image2d_t output, + __private const int in_height, + __private const int in_width, + __private const int out_height, + __private const int out_width, + __private const int ksize_h, + __private const int ksize_w, + __private const int stride_h, + __private const int stride_w, + __private const int pad_top, + __private const int pad_left) { const int out_c = get_global_id(0); const int out_w = get_global_id(1); const int out_nh = get_global_id(2); @@ -76,15 +86,14 @@ __kernel void pool_avg( const int pos_in_x = out_c * in_width; const int pos_in_y = out_n * in_height; - float4 sum = (float4)(0.0f); - int num = 0; + CL_DTYPE4 sum = (CL_DTYPE4)(0.0f); + for (int y = start_h; y < end_h; ++y) { for (int x = start_w; x < end_w; ++x) { - sum += read_imagef(input, sampler, (int2)(pos_in_x + x, pos_in_y + y)); - num++; + sum += READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(pos_in_x + x, pos_in_y + y)); } } - float4 avg = sum / num; + CL_DTYPE4 avg = sum / (ksize_h * ksize_w); const int pos_out_x = mad24(out_c, out_width, out_w); - write_imagef(output, (int2)(pos_out_x, out_nh), avg); + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(pos_out_x, out_nh), avg); } diff --git a/lite/backends/opencl/cl_kernel/image/relu6_kernel.cl b/lite/backends/opencl/cl_kernel/image/relu6_kernel.cl new file mode 100644 index 0000000000000000000000000000000000000000..7750bd98a29151ba2428bdafd462420393fe7433 --- /dev/null +++ b/lite/backends/opencl/cl_kernel/image/relu6_kernel.cl @@ -0,0 +1,32 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +__kernel void relu6(__read_only image2d_t input, + __write_only image2d_t output, + __private const float threshold){ + + const int x = get_global_id(0); + const int y = get_global_id(1); + + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | + CLK_ADDRESS_CLAMP | + CLK_FILTER_NEAREST; + + CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x, y)); + in = max((CL_DTYPE4)(0.0f, 0.0f, 0.0f, 0.0f), in); + in = min((CL_DTYPE4)(threshold, threshold, threshold, threshold), in); + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(x, y), in); +} diff --git a/lite/backends/opencl/cl_kernel/image/relu_kernel.cl b/lite/backends/opencl/cl_kernel/image/relu_kernel.cl index a99ac79d32bcedb48354d2e179ef6c8c1ff7f997..43a27067c2f2c418d314f9bce95bccbbb51a9be0 100644 --- a/lite/backends/opencl/cl_kernel/image/relu_kernel.cl +++ b/lite/backends/opencl/cl_kernel/image/relu_kernel.cl @@ -24,7 +24,7 @@ __kernel void relu(__read_only image2d_t input, CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; - CL_DTYPE4 in = read_imagef(input, sampler, (int2)(x, y)); + CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x, y)); in = max((CL_DTYPE4)(0.0f), in); - write_imagef(output, (int2)(x, y), in); + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(x, y), in); } diff --git a/lite/backends/opencl/cl_kernel/image/reshape_kernel.cl b/lite/backends/opencl/cl_kernel/image/reshape_kernel.cl new file mode 100644 index 0000000000000000000000000000000000000000..314be875d29d2125f9573d33010ee9d33317ea71 --- /dev/null +++ b/lite/backends/opencl/cl_kernel/image/reshape_kernel.cl @@ -0,0 +1,162 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +__kernel void reshape(__read_only image2d_t input_image, + __write_only image2d_t output_image, + __private const int out_C, + __private const int out_H, + __private const int out_W, + __private const int in_W, + __private const int in_H, + __private const int in_Stride0, + __private const int in_Stride1, + __private const int in_Stride2, + __private const int out_Stride0, + __private const int out_Stride1, + __private const int out_Stride2) { + const int out_c = get_global_id(0); + const int out_w = get_global_id(1); + const int out_nh = get_global_id(2); + const int out_n = out_nh / out_H; + const int out_h = out_nh % out_H; + const int out_c0 = out_c * 4; + const int out_c1 = out_c * 4 + 1; + const int out_c2 = out_c * 4 + 2; + const int out_c3 = out_c * 4 + 3; + + int count0 = + out_n * out_Stride2 + out_c0 * out_Stride1 + out_h * out_Stride0 + out_w; + int count1 = + out_n * out_Stride2 + out_c1 * out_Stride1 + out_h * out_Stride0 + out_w; + int count2 = + out_n * out_Stride2 + out_c2 * out_Stride1 + out_h * out_Stride0 + out_w; + int count3 = + out_n * out_Stride2 + out_c3 * out_Stride1 + out_h * out_Stride0 + out_w; + + int in_n0 = count0 / in_Stride2; + int in_n1 = count1 / in_Stride2; + int in_n2 = count1 / in_Stride2; + int in_n3 = count2 / in_Stride2; + + count0 = count0 % in_Stride2; + count1 = count1 % in_Stride2; + count2 = count2 % in_Stride2; + count3 = count3 % in_Stride2; + + int in_c0 = count0 / in_Stride1; + int in_c1 = count1 / in_Stride1; + int in_c2 = count2 / in_Stride1; + int in_c3 = count3 / in_Stride1; + + int in_h0 = (count0 % in_Stride1) / in_Stride0; + int in_h1 = (count1 % in_Stride1) / in_Stride0; + int in_h2 = (count2 % in_Stride1) / in_Stride0; + int in_h3 = (count3 % in_Stride1) / in_Stride0; + + int in_w0 = (count0 % in_Stride1) % in_Stride0; + int in_w1 = (count1 % in_Stride1) % in_Stride0; + int in_w2 = (count2 % in_Stride1) % in_Stride0; + int in_w3 = (count3 % in_Stride1) % in_Stride0; + + int2 input_pos0; + int2 input_pos1; + int2 input_pos2; + int2 input_pos3; + + input_pos0.x = (in_c0 / 4) * in_W + in_w0; + input_pos0.y = in_n0 * in_H + in_h0; + + input_pos1.x = (in_c1 / 4) * in_W + in_w1; + input_pos1.y = in_n1 * in_H + in_h1; + + input_pos2.x = (in_c2 / 4) * in_W + in_w2; + input_pos2.y = in_n2 * in_H + in_h2; + + input_pos3.x = (in_c3 / 4) * in_W + in_w3; + input_pos3.y = in_n3 * in_H + in_h3; + + int2 output_pos; + output_pos.x = out_c * out_W + out_w; + output_pos.y = out_nh; + + const sampler_t sampler = + CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + + CL_DTYPE4 input0; + CL_DTYPE4 input1; + CL_DTYPE4 input2; + CL_DTYPE4 input3; + CL_DTYPE4 output; + + input0 = READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, input_pos0); + if (in_c0 % 4 == 0) { + output.x = input0.x; + } else if (in_c0 % 4 == 1) { + output.x = input0.y; + } else if (in_c0 % 4 == 2) { + output.x = input0.z; + } else { + output.x = input0.w; + } + if (out_C - out_c * 4 >= 2) { + input1 = READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, input_pos1); + if (in_c1 % 4 == 0) { + output.y = input1.x; + } else if (in_c1 % 4 == 1) { + output.y = input1.y; + } else if (in_c1 % 4 == 2) { + output.y = input1.z; + } else { + output.y = input1.w; + } + + } else { + output.y = 0.0f; + } + + if (out_C - out_c * 4 >= 3) { + input2 = READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, input_pos2); + + if (in_c2 % 4 == 0) { + output.z = input2.x; + } else if (in_c2 % 4 == 1) { + output.z = input1.y; + } else if (in_c2 % 4 == 2) { + output.z = input2.z; + } else { + output.z = input2.w; + } + } else { + output.z = 0.0f; + } + + if (out_C - out_c * 4 >= 4) { + input3 = READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, input_pos3); + if (in_c3 % 4 == 0) { + output.w = input3.x; + } else if (in_c3 % 4 == 1) { + output.w = input3.y; + } else if (in_c3 % 4 == 2) { + output.w = input3.z; + } else { + output.w = input3.w; + } + } else { + output.w = 0.0f; + } + + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, output_pos, output); +} diff --git a/lite/backends/opencl/cl_kernel/image/scale_kernel.cl b/lite/backends/opencl/cl_kernel/image/scale_kernel.cl new file mode 100644 index 0000000000000000000000000000000000000000..739ff1338582b65d87dbd9c92f1ea86e0c49f0ff --- /dev/null +++ b/lite/backends/opencl/cl_kernel/image/scale_kernel.cl @@ -0,0 +1,32 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +__kernel void scale(__read_only image2d_t input, + __write_only image2d_t output, + __private float scale, + __private float bias){ + + const int x = get_global_id(0); // image_width + const int y = get_global_id(1); // image_height + + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | + CLK_ADDRESS_CLAMP | + CLK_FILTER_NEAREST; + + CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x, y)); + in = convert_float(scale) * in + convert_float(bias); + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(x, y), in); +} diff --git a/lite/backends/opencl/cl_kernel/image/sigmoid_kernel.cl b/lite/backends/opencl/cl_kernel/image/sigmoid_kernel.cl new file mode 100644 index 0000000000000000000000000000000000000000..d2cb8fa36e21167979172fba634a7862c932b74c --- /dev/null +++ b/lite/backends/opencl/cl_kernel/image/sigmoid_kernel.cl @@ -0,0 +1,30 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +__kernel void sigmoid(__read_only image2d_t input, + __write_only image2d_t output) { + + const int x = get_global_id(0); // image_width + const int y = get_global_id(1); // image_height + + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | + CLK_ADDRESS_CLAMP | + CLK_FILTER_NEAREST; + + CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x, y)); + CL_DTYPE4 out = 1 / (1 + exp(-in)); + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(x, y), out); +} diff --git a/lite/backends/opencl/cl_runtime.cc b/lite/backends/opencl/cl_runtime.cc index c2504ab611e93399c70169f3f123d4a0514c07ad..0c7b2f8575a88082f6d79a5392c4468715a701b9 100644 --- a/lite/backends/opencl/cl_runtime.cc +++ b/lite/backends/opencl/cl_runtime.cc @@ -103,6 +103,7 @@ std::unique_ptr CLRuntime::CreateEvent( bool CLRuntime::BuildProgram(cl::Program* program, const std::string& options) { std::string build_option = options + " -cl-fast-relaxed-math -I " + CLRuntime::Global()->cl_path() + "/cl_kernel"; + VLOG(4) << "OpenCL build_option: " << build_option; status_ = program->build({*device_}, build_option.c_str()); CL_CHECK_ERROR(status_); diff --git a/lite/backends/opencl/target_wrapper.cc b/lite/backends/opencl/target_wrapper.cc index 575f87d0f8d0192345c6ab111db46715a809a976..310567baa539697f6a67b59f6c0e5f29ce46a80e 100644 --- a/lite/backends/opencl/target_wrapper.cc +++ b/lite/backends/opencl/target_wrapper.cc @@ -24,6 +24,8 @@ static cl_channel_type GetCLChannelType(const PrecisionType type) { switch (type) { case PRECISION(kFloat): return CL_FLOAT; + case PRECISION(kFP16): + return CL_HALF_FLOAT; case PRECISION(kInt32): return CL_SIGNED_INT32; case PRECISION(kInt8): @@ -58,17 +60,18 @@ void TargetWrapperCL::Free(void *ptr) { template <> void *TargetWrapperCL::MallocImage(const size_t cl_image2d_width, - const size_t cl_image2d_height) { + const size_t cl_image2d_height, + void *host_ptr) { cl::ImageFormat img_format(CL_RGBA, GetCLChannelType(PRECISION(kFloat))); cl_int status; cl::Image2D *cl_image = new cl::Image2D(CLRuntime::Global()->context(), - CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, + CL_MEM_READ_WRITE | (host_ptr ? CL_MEM_COPY_HOST_PTR : 0), img_format, cl_image2d_width, cl_image2d_height, 0, - nullptr, + host_ptr, &status); if (status != CL_SUCCESS) { delete cl_image; @@ -78,19 +81,20 @@ void *TargetWrapperCL::MallocImage(const size_t cl_image2d_width, return cl_image; } -template <> -void *TargetWrapperCL::MallocImage(const size_t cl_image2d_width, - const size_t cl_image2d_height) { - cl::ImageFormat img_format(CL_RGBA, GetCLChannelType(PRECISION(kInt8))); +template <> // use int16_t represents half float +void *TargetWrapperCL::MallocImage(const size_t cl_image2d_width, + const size_t cl_image2d_height, + void *host_ptr) { + cl::ImageFormat img_format(CL_RGBA, GetCLChannelType(PRECISION(kFP16))); cl_int status; cl::Image2D *cl_image = new cl::Image2D(CLRuntime::Global()->context(), - CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, + CL_MEM_READ_WRITE | (host_ptr ? CL_MEM_COPY_HOST_PTR : 0), img_format, cl_image2d_width, cl_image2d_height, 0, - nullptr, + host_ptr, &status); if (status != CL_SUCCESS) { delete cl_image; @@ -102,17 +106,18 @@ void *TargetWrapperCL::MallocImage(const size_t cl_image2d_width, template <> void *TargetWrapperCL::MallocImage(const size_t cl_image2d_width, - const size_t cl_image2d_height) { + const size_t cl_image2d_height, + void *host_ptr) { cl::ImageFormat img_format(CL_RGBA, GetCLChannelType(PRECISION(kInt32))); cl_int status; cl::Image2D *cl_image = new cl::Image2D(CLRuntime::Global()->context(), - CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, + CL_MEM_READ_WRITE | (host_ptr ? CL_MEM_COPY_HOST_PTR : 0), img_format, cl_image2d_width, cl_image2d_height, 0, - nullptr, + host_ptr, &status); if (status != CL_SUCCESS) { delete cl_image; diff --git a/lite/backends/opencl/target_wrapper.h b/lite/backends/opencl/target_wrapper.h index 7753448052e17ac739f730c9fabcaf9533e0045e..c5ff9e900a70fd96ccb461c74fb61e33815a5e81 100644 --- a/lite/backends/opencl/target_wrapper.h +++ b/lite/backends/opencl/target_wrapper.h @@ -48,7 +48,8 @@ class TargetWrapper { template static void* MallocImage(const size_t cl_image2d_width, - const size_t cl_image2d_height); + const size_t cl_image2d_height, + void* host_ptr = nullptr); static void FreeImage(void* image); static void* Map(void* buffer, size_t offset, size_t size); diff --git a/lite/backends/x86/cpu_info.cc b/lite/backends/x86/cpu_info.cc index c2759d6191aaa7ba277ff2a935ea6fdda8383e1e..aa097f947a0289b4a44417160fbe5d6e6db48020 100644 --- a/lite/backends/x86/cpu_info.cc +++ b/lite/backends/x86/cpu_info.cc @@ -32,26 +32,37 @@ #include #include -DEFINE_double(fraction_of_cpu_memory_to_use, - 1, - "Default use 100% of CPU memory for PaddlePaddle," - "reserve the rest for page tables, etc"); -DEFINE_uint64(initial_cpu_memory_in_mb, - 500ul, - "Initial CPU memory for PaddlePaddle, in MD unit."); - -DEFINE_double( - fraction_of_cuda_pinned_memory_to_use, - 0.5, - "Default use 50% of CPU memory as the pinned_memory for PaddlePaddle," - "reserve the rest for page tables, etc"); +#include "lite/utils/env.h" + +// DEFINE_double(fraction_of_cpu_memory_to_use, +// 1, +// "Default use 100% of CPU memory for PaddlePaddle," +// "reserve the rest for page tables, etc"); +double fraction_of_cpu_memory_to_use = + paddle::lite::GetDoubleFromEnv("fraction_of_cpu_memory_to_use", 1); + +// DEFINE_uint64(initial_cpu_memory_in_mb, +// 500ul, +// "Initial CPU memory for PaddlePaddle, in MD unit."); +uint64_t initial_cpu_memory_in_mb = + paddle::lite::GetUInt64FromEnv("initial_cpu_memory_in_mb", 500ul); + +// DEFINE_double( +// fraction_of_cuda_pinned_memory_to_use, +// 0.5, +// "Default use 50% of CPU memory as the pinned_memory for PaddlePaddle," +// "reserve the rest for page tables, etc"); +double fraction_of_cuda_pinned_memory_to_use = paddle::lite::GetDoubleFromEnv( + "fraction_of_cuda_pinned_memory_to_use", 0.5); // If use_pinned_memory is true, CPUAllocator calls mlock, which // returns pinned and locked memory as staging areas for data exchange // between host and device. Allocates too much would reduce the amount // of memory available to the system for paging. So, by default, we // should set false to use_pinned_memory. -DEFINE_bool(use_pinned_memory, true, "If set, allocate cpu pinned memory."); +// DEFINE_bool(use_pinned_memory, true, "If set, allocate cpu pinned memory."); +bool use_pinned_memory = + paddle::lite::GetBoolFromEnv("use_pinned_memory", true); namespace paddle { namespace lite { @@ -81,7 +92,7 @@ size_t CpuTotalPhysicalMemory() { size_t CpuMaxAllocSize() { // For distributed systems, it requires configuring and limiting // the fraction of memory to use. - return FLAGS_fraction_of_cpu_memory_to_use * CpuTotalPhysicalMemory(); + return fraction_of_cpu_memory_to_use * CpuTotalPhysicalMemory(); } size_t CpuMinChunkSize() { @@ -92,15 +103,14 @@ size_t CpuMinChunkSize() { size_t CpuMaxChunkSize() { // Allow to allocate the maximum chunk size is roughly 3% of CPU memory, // or the initial_cpu_memory_in_mb. - return std::min( - static_cast(CpuMaxAllocSize() / 32), - static_cast(FLAGS_initial_cpu_memory_in_mb * 1 << 20)); + return std::min(static_cast(CpuMaxAllocSize() / 32), + static_cast(initial_cpu_memory_in_mb * 1 << 20)); } size_t CUDAPinnedMaxAllocSize() { // For distributed systems, it requires configuring and limiting // the fraction of memory to use. - return FLAGS_fraction_of_cuda_pinned_memory_to_use * CpuTotalPhysicalMemory(); + return fraction_of_cuda_pinned_memory_to_use * CpuTotalPhysicalMemory(); } size_t CUDAPinnedMinChunkSize() { diff --git a/lite/backends/x86/dynamic_loader.cc b/lite/backends/x86/dynamic_loader.cc index 75bb528f38664fc1061653e1036b73eed74daae9..a05a57e93b23008e49683764b5ed669d5c425e5b 100644 --- a/lite/backends/x86/dynamic_loader.cc +++ b/lite/backends/x86/dynamic_loader.cc @@ -22,36 +22,46 @@ limitations under the License. */ #include "lite/backends/x86/cupti_lib_path.h" #include "lite/backends/x86/port.h" #include "lite/backends/x86/warpctc_lib_path.h" +#include "lite/utils/env.h" #include "lite/utils/paddle_enforce.h" -DEFINE_string(cudnn_dir, - "", - "Specify path for loading libcudnn.so. For instance, " - "/usr/local/cudnn/lib. If empty [default], dlopen " - "will search cudnn from LD_LIBRARY_PATH"); +// DEFINE_string(cudnn_dir, +// "", +// "Specify path for loading libcudnn.so. For instance, " +// "/usr/local/cudnn/lib. If empty [default], dlopen " +// "will search cudnn from LD_LIBRARY_PATH"); +std::string cudnn_dir = paddle::lite::GetStringFromEnv("cudnn_dir"); // NOLINT -DEFINE_string(cuda_dir, - "", - "Specify path for loading cuda library, such as libcublas, " - "libcurand. For instance, /usr/local/cuda/lib64. If default, " - "dlopen will search cuda from LD_LIBRARY_PATH"); +// DEFINE_string(cuda_dir, +// "", +// "Specify path for loading cuda library, such as libcublas, " +// "libcurand. For instance, /usr/local/cuda/lib64. If default, " +// "dlopen will search cuda from LD_LIBRARY_PATH"); +std::string cuda_dir = paddle::lite::GetStringFromEnv("cuda_dir"); // NOLINT -DEFINE_string(warpctc_dir, "", "Specify path for loading libwarpctc.so."); +// DEFINE_string(warpctc_dir, "", "Specify path for loading libwarpctc.so."); +std::string f_warpctc_dir = // NOLINT + paddle::lite::GetStringFromEnv("warpctc_dir"); // NOLINT -DEFINE_string(nccl_dir, - "", - "Specify path for loading nccl library, such as libcublas, " - "libcurand. For instance, /usr/local/cuda/lib64. If default, " - "dlopen will search cuda from LD_LIBRARY_PATH"); +// DEFINE_string(nccl_dir, +// "", +// "Specify path for loading nccl library, such as libcublas, " +// "libcurand. For instance, /usr/local/cuda/lib64. If default, " +// "dlopen will search cuda from LD_LIBRARY_PATH"); +std::string nccl_dir = paddle::lite::GetStringFromEnv("nccl_dir"); // NOLINT -DEFINE_string(cupti_dir, "", "Specify path for loading cupti.so."); +// DEFINE_string(cupti_dir, "", "Specify path for loading cupti.so."); +std::string cupti_dir = paddle::lite::GetStringFromEnv("cupti_dir"); // NOLINT -DEFINE_string( - tensorrt_dir, - "", - "Specify path for loading tensorrt library, such as libnvinfer.so."); +// DEFINE_string( +// tensorrt_dir, +// "", +// "Specify path for loading tensorrt library, such as libnvinfer.so."); +std::string tensorrt_dir = // NOLINT + paddle::lite::GetStringFromEnv("tensorrt_dir"); // NOLINT -DEFINE_string(mklml_dir, "", "Specify path for loading libmklml_intel.so."); +// DEFINE_string(mklml_dir, "", "Specify path for loading libmklml_intel.so."); +std::string mklml_dir = paddle::lite::GetStringFromEnv("mklml_dir"); // NOLINT namespace paddle { namespace lite { @@ -180,28 +190,28 @@ auto error_msg = void* GetCublasDsoHandle() { #if defined(__APPLE__) || defined(__OSX__) - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcublas.dylib"); + return GetDsoHandleFromSearchPath(cuda_dir, "libcublas.dylib"); #elif defined(_WIN32) && defined(PADDLE_WITH_CUDA) - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, win_cublas_lib); + return GetDsoHandleFromSearchPath(cuda_dir, win_cublas_lib); #else - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcublas.so"); + return GetDsoHandleFromSearchPath(cuda_dir, "libcublas.so"); #endif } void* GetCUDNNDsoHandle() { #if defined(__APPLE__) || defined(__OSX__) - return GetDsoHandleFromSearchPath(FLAGS_cudnn_dir, "libcudnn.dylib", false); + return GetDsoHandleFromSearchPath(cudnn_dir, "libcudnn.dylib", false); #elif defined(_WIN32) && defined(PADDLE_WITH_CUDA) - return GetDsoHandleFromSearchPath(FLAGS_cudnn_dir, win_cudnn_lib); + return GetDsoHandleFromSearchPath(cudnn_dir, win_cudnn_lib); #else - return GetDsoHandleFromSearchPath(FLAGS_cudnn_dir, "libcudnn.so", false); + return GetDsoHandleFromSearchPath(cudnn_dir, "libcudnn.so", false); #endif } void* GetCUPTIDsoHandle() { std::string cupti_path = cupti_lib_path; - if (!FLAGS_cupti_dir.empty()) { - cupti_path = FLAGS_cupti_dir; + if (!cupti_dir.empty()) { + cupti_path = cupti_dir; } #if defined(__APPLE__) || defined(__OSX__) return GetDsoHandleFromSearchPath(cupti_path, "libcupti.dylib", false); @@ -212,18 +222,18 @@ void* GetCUPTIDsoHandle() { void* GetCurandDsoHandle() { #if defined(__APPLE__) || defined(__OSX__) - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcurand.dylib"); + return GetDsoHandleFromSearchPath(cuda_dir, "libcurand.dylib"); #elif defined(_WIN32) && defined(PADDLE_WITH_CUDA) - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, win_curand_lib); + return GetDsoHandleFromSearchPath(cuda_dir, win_curand_lib); #else - return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcurand.so"); + return GetDsoHandleFromSearchPath(cuda_dir, "libcurand.so"); #endif } void* GetWarpCTCDsoHandle() { std::string warpctc_dir = warpctc_lib_path; - if (!FLAGS_warpctc_dir.empty()) { - warpctc_dir = FLAGS_warpctc_dir; + if (!f_warpctc_dir.empty()) { + warpctc_dir = f_warpctc_dir; } #if defined(__APPLE__) || defined(__OSX__) return GetDsoHandleFromSearchPath(warpctc_dir, "libwarpctc.dylib"); @@ -236,27 +246,27 @@ void* GetWarpCTCDsoHandle() { void* GetNCCLDsoHandle() { #if defined(__APPLE__) || defined(__OSX__) - return GetDsoHandleFromSearchPath(FLAGS_nccl_dir, "libnccl.dylib"); + return GetDsoHandleFromSearchPath(nccl_dir, "libnccl.dylib"); #else - return GetDsoHandleFromSearchPath(FLAGS_nccl_dir, "libnccl.so"); + return GetDsoHandleFromSearchPath(nccl_dir, "libnccl.so"); #endif } void* GetTensorRtDsoHandle() { #if defined(__APPLE__) || defined(__OSX__) - return GetDsoHandleFromSearchPath(FLAGS_tensorrt_dir, "libnvinfer.dylib"); + return GetDsoHandleFromSearchPath(tensorrt_dir, "libnvinfer.dylib"); #else - return GetDsoHandleFromSearchPath(FLAGS_tensorrt_dir, "libnvinfer.so"); + return GetDsoHandleFromSearchPath(tensorrt_dir, "libnvinfer.so"); #endif } void* GetMKLMLDsoHandle() { #if defined(__APPLE__) || defined(__OSX__) - return GetDsoHandleFromSearchPath(FLAGS_mklml_dir, "libmklml_intel.dylib"); + return GetDsoHandleFromSearchPath(mklml_dir, "libmklml_intel.dylib"); #elif defined(_WIN32) - return GetDsoHandleFromSearchPath(FLAGS_mklml_dir, "mklml.dll"); + return GetDsoHandleFromSearchPath(mklml_dir, "mklml.dll"); #else - return GetDsoHandleFromSearchPath(FLAGS_mklml_dir, "libmklml_intel.so"); + return GetDsoHandleFromSearchPath(mklml_dir, "libmklml_intel.so"); #endif } diff --git a/lite/backends/x86/jit/README.en.md b/lite/backends/x86/jit/README.en.md index cd2aa5c242dba1a9be669a536cd9b614bf890e48..dc9eb4cf239155ba15a855c98e5515adb717d2d5 100644 --- a/lite/backends/x86/jit/README.en.md +++ b/lite/backends/x86/jit/README.en.md @@ -89,7 +89,7 @@ All kernels are inlcuded in `lite/backends/x86/jit/kernels.h`, which is automati 3. Add reference function of `your_key`. Note: - this should be run on CPU and do not depend on any third-party. - - Add `USE_JITKERNEL_REFER(your_key)` in `refer/CmakeLists.txt` to make sure this code can be used. + - Add `USE_JITKERNEL_REFER_LITE(your_key)` in `refer/CmakeLists.txt` to make sure this code can be used. 4. Add unit test in `test.cc`, and verfiy at least `float` and `double`. Test more data type for some special functions if necessary, for example `int8`. 5. Add functions in `benchmark.cc` to test all function of same `KernelType`. Make sure `GetDefaultBestFunc` always get the best one. diff --git a/lite/backends/x86/jit/README.md b/lite/backends/x86/jit/README.md index 6998c5d867b079dfef69a71ca56e6f3fc30363d4..bc0e27234d05c82c9b0dcc431343d7db1a0f4067 100644 --- a/lite/backends/x86/jit/README.md +++ b/lite/backends/x86/jit/README.md @@ -79,7 +79,7 @@ PaddlePaddle/Paddle/paddle/fluid/ # 如何添加新的算子 1. 在`KernelType` 中添加 `your_key` 。 -2. 实现Reference 的逻辑,这个是必须是在CPU上的实现,并且不能依赖任何第三方库。实现后在`refer/CmakeLists.txt`中添加`USE_JITKERNEL_REFER(your_key)`来使用该kernel。 +2. 实现Reference 的逻辑,这个是必须是在CPU上的实现,并且不能依赖任何第三方库。实现后在`refer/CmakeLists.txt`中添加`USE_JITKERNEL_REFER_LITE(your_key)`来使用该kernel。 3. (optional) 实现更多的算法在`more`目录下,可以依赖mkl,intrinsic或者mkldnn等第三方库。 4. (optional) 实现基于Xbyak的生成code,在`gen`目下。 jitcode需要实现自己的`JitCodeCreator`,并注册在与refer相同的`KernelType`上。 5. 添加新的`KernelTuple`,需要与`KernelType`一一对应,是所有类型的一个打包,包括数据类型,属性的类型,以及返回的函数类型。可以参考`SeqPoolTuple`,新加的Attr类型需要特例化`JitCodeKey`方法。 diff --git a/lite/backends/x86/jit/gen/CMakeLists.txt b/lite/backends/x86/jit/gen/CMakeLists.txt index 99244ea9bd919a018732b75d1ab811e8bf338516..62500775282d1c3d960f0fa9b00d3d4a2aef9390 100644 --- a/lite/backends/x86/jit/gen/CMakeLists.txt +++ b/lite/backends/x86/jit/gen/CMakeLists.txt @@ -4,33 +4,33 @@ file(GLOB jitcode_cc_srcs RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.cc") cc_library(jit_kernel_jitcode SRCS ${jitcode_cc_srcs} DEPS jit_kernel_base xbyak) set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} xbyak jit_kernel_jitcode PARENT_SCOPE) -function(USE_JITKERNEL_GEN TARGET) - file(APPEND ${jit_file} "USE_JITKERNEL_GEN(${TARGET});\n") +function(USE_JITKERNEL_GEN_LITE TARGET) + file(APPEND ${jit_file} "USE_JITKERNEL_GEN_LITE(${TARGET});\n") endfunction() # use gen jitcode kernel by name -USE_JITKERNEL_GEN(kMatMul) -USE_JITKERNEL_GEN(kVMul) -USE_JITKERNEL_GEN(kVAdd) -USE_JITKERNEL_GEN(kVSub) -USE_JITKERNEL_GEN(kVAddRelu) -USE_JITKERNEL_GEN(kVScal) -USE_JITKERNEL_GEN(kVAddBias) -USE_JITKERNEL_GEN(kVRelu) -USE_JITKERNEL_GEN(kVSquare) -USE_JITKERNEL_GEN(kVIdentity) -USE_JITKERNEL_GEN(kVExp) -USE_JITKERNEL_GEN(kVSigmoid) -USE_JITKERNEL_GEN(kVTanh) -USE_JITKERNEL_GEN(kLSTMCtHt) -USE_JITKERNEL_GEN(kLSTMC1H1) -USE_JITKERNEL_GEN(kGRUH1) -USE_JITKERNEL_GEN(kGRUHtPart1) -USE_JITKERNEL_GEN(kGRUHtPart2) -USE_JITKERNEL_GEN(kNCHW16CMulNC) -USE_JITKERNEL_GEN(kSeqPool) -USE_JITKERNEL_GEN(kHMax) -USE_JITKERNEL_GEN(kHSum) -USE_JITKERNEL_GEN(kEmbSeqPool) -USE_JITKERNEL_GEN(kSgd) -USE_JITKERNEL_GEN(kVBroadcast) +USE_JITKERNEL_GEN_LITE(kMatMul) +USE_JITKERNEL_GEN_LITE(kVMul) +USE_JITKERNEL_GEN_LITE(kVAdd) +USE_JITKERNEL_GEN_LITE(kVSub) +USE_JITKERNEL_GEN_LITE(kVAddRelu) +USE_JITKERNEL_GEN_LITE(kVScal) +USE_JITKERNEL_GEN_LITE(kVAddBias) +USE_JITKERNEL_GEN_LITE(kVRelu) +USE_JITKERNEL_GEN_LITE(kVSquare) +USE_JITKERNEL_GEN_LITE(kVIdentity) +USE_JITKERNEL_GEN_LITE(kVExp) +USE_JITKERNEL_GEN_LITE(kVSigmoid) +USE_JITKERNEL_GEN_LITE(kVTanh) +USE_JITKERNEL_GEN_LITE(kLSTMCtHt) +USE_JITKERNEL_GEN_LITE(kLSTMC1H1) +USE_JITKERNEL_GEN_LITE(kGRUH1) +USE_JITKERNEL_GEN_LITE(kGRUHtPart1) +USE_JITKERNEL_GEN_LITE(kGRUHtPart2) +USE_JITKERNEL_GEN_LITE(kNCHW16CMulNC) +USE_JITKERNEL_GEN_LITE(kSeqPool) +USE_JITKERNEL_GEN_LITE(kHMax) +USE_JITKERNEL_GEN_LITE(kHSum) +USE_JITKERNEL_GEN_LITE(kEmbSeqPool) +USE_JITKERNEL_GEN_LITE(kSgd) +USE_JITKERNEL_GEN_LITE(kVBroadcast) diff --git a/lite/backends/x86/jit/gen/act.cc b/lite/backends/x86/jit/gen/act.cc index f1f261c199d8d25997b1ce235aa99356834e43a8..45f4f7ddcce8e8864821712698c4496cf40b618c 100644 --- a/lite/backends/x86/jit/gen/act.cc +++ b/lite/backends/x86/jit/gen/act.cc @@ -156,9 +156,9 @@ size_t VTanhCreator::CodeSize(const int& d) const { namespace gen = paddle::lite::jit::gen; -REGISTER_JITKERNEL_GEN(kVRelu, gen::VReluCreator); -REGISTER_JITKERNEL_GEN(kVSquare, gen::VSquareCreator); -REGISTER_JITKERNEL_GEN(kVIdentity, gen::VIdentityCreator); -REGISTER_JITKERNEL_GEN(kVExp, gen::VExpCreator); -REGISTER_JITKERNEL_GEN(kVSigmoid, gen::VSigmoidCreator); -REGISTER_JITKERNEL_GEN(kVTanh, gen::VTanhCreator); +REGISTER_JITKERNEL_GEN_LITE(kVRelu, gen::VReluCreator); +REGISTER_JITKERNEL_GEN_LITE(kVSquare, gen::VSquareCreator); +REGISTER_JITKERNEL_GEN_LITE(kVIdentity, gen::VIdentityCreator); +REGISTER_JITKERNEL_GEN_LITE(kVExp, gen::VExpCreator); +REGISTER_JITKERNEL_GEN_LITE(kVSigmoid, gen::VSigmoidCreator); +REGISTER_JITKERNEL_GEN_LITE(kVTanh, gen::VTanhCreator); diff --git a/lite/backends/x86/jit/gen/blas.cc b/lite/backends/x86/jit/gen/blas.cc index 0bddea6ace7fd338d14da918516223bb17bafdbd..37183e66404dfae139a2bcd25c2855df119f939d 100644 --- a/lite/backends/x86/jit/gen/blas.cc +++ b/lite/backends/x86/jit/gen/blas.cc @@ -181,10 +181,10 @@ DECLARE_BLAS_CREATOR(VAddBias); namespace gen = paddle::lite::jit::gen; -REGISTER_JITKERNEL_GEN(kVMul, gen::VMulCreator); -REGISTER_JITKERNEL_GEN(kVAdd, gen::VAddCreator); -REGISTER_JITKERNEL_GEN(kVSub, gen::VSubCreator); -REGISTER_JITKERNEL_GEN(kVAddRelu, gen::VAddReluCreator); -REGISTER_JITKERNEL_GEN(kVScal, gen::VScalCreator); -REGISTER_JITKERNEL_GEN(kVAddBias, gen::VAddBiasCreator); -REGISTER_JITKERNEL_GEN(kNCHW16CMulNC, gen::NCHW16CMulNCCreator); +REGISTER_JITKERNEL_GEN_LITE(kVMul, gen::VMulCreator); +REGISTER_JITKERNEL_GEN_LITE(kVAdd, gen::VAddCreator); +REGISTER_JITKERNEL_GEN_LITE(kVSub, gen::VSubCreator); +REGISTER_JITKERNEL_GEN_LITE(kVAddRelu, gen::VAddReluCreator); +REGISTER_JITKERNEL_GEN_LITE(kVScal, gen::VScalCreator); +REGISTER_JITKERNEL_GEN_LITE(kVAddBias, gen::VAddBiasCreator); +REGISTER_JITKERNEL_GEN_LITE(kNCHW16CMulNC, gen::NCHW16CMulNCCreator); diff --git a/lite/backends/x86/jit/gen/embseqpool.cc b/lite/backends/x86/jit/gen/embseqpool.cc index 2ff6894383f95699e4209215b0df3a84507a06b4..7e697014ed241a75693b783127633b255964f80b 100644 --- a/lite/backends/x86/jit/gen/embseqpool.cc +++ b/lite/backends/x86/jit/gen/embseqpool.cc @@ -145,4 +145,4 @@ class EmbSeqPoolCreator : public JitCodeCreator { namespace gen = paddle::lite::jit::gen; -REGISTER_JITKERNEL_GEN(kEmbSeqPool, gen::EmbSeqPoolCreator); +REGISTER_JITKERNEL_GEN_LITE(kEmbSeqPool, gen::EmbSeqPoolCreator); diff --git a/lite/backends/x86/jit/gen/gru.cc b/lite/backends/x86/jit/gen/gru.cc index c5737faf134287697ef49b88f10c2590da4cc07d..4c2c57413e30589de96385c34e09733458f66b7b 100644 --- a/lite/backends/x86/jit/gen/gru.cc +++ b/lite/backends/x86/jit/gen/gru.cc @@ -111,6 +111,6 @@ DECLARE_GRU_CREATOR(GRUHtPart2); namespace gen = paddle::lite::jit::gen; -REGISTER_JITKERNEL_GEN(kGRUH1, gen::GRUH1Creator); -REGISTER_JITKERNEL_GEN(kGRUHtPart1, gen::GRUHtPart1Creator); -REGISTER_JITKERNEL_GEN(kGRUHtPart2, gen::GRUHtPart2Creator); +REGISTER_JITKERNEL_GEN_LITE(kGRUH1, gen::GRUH1Creator); +REGISTER_JITKERNEL_GEN_LITE(kGRUHtPart1, gen::GRUHtPart1Creator); +REGISTER_JITKERNEL_GEN_LITE(kGRUHtPart2, gen::GRUHtPart2Creator); diff --git a/lite/backends/x86/jit/gen/hopv.cc b/lite/backends/x86/jit/gen/hopv.cc index 4304dc48c5a084a747227bd4d4aedb1cec1775cd..0fdd63a7405647860416d43a86a7a7abe9fad760 100644 --- a/lite/backends/x86/jit/gen/hopv.cc +++ b/lite/backends/x86/jit/gen/hopv.cc @@ -99,5 +99,5 @@ DECLARE_HOP_CREATOR(HSum); namespace gen = paddle::lite::jit::gen; -REGISTER_JITKERNEL_GEN(kHMax, gen::HMaxCreator); -REGISTER_JITKERNEL_GEN(kHSum, gen::HSumCreator); +REGISTER_JITKERNEL_GEN_LITE(kHMax, gen::HMaxCreator); +REGISTER_JITKERNEL_GEN_LITE(kHSum, gen::HSumCreator); diff --git a/lite/backends/x86/jit/gen/lstm.cc b/lite/backends/x86/jit/gen/lstm.cc index 44e58d0b75612238115d5771082d28c30cad55a2..e4417355202c6370563eadd80e5cb3da6af8cdc6 100644 --- a/lite/backends/x86/jit/gen/lstm.cc +++ b/lite/backends/x86/jit/gen/lstm.cc @@ -138,5 +138,5 @@ DECLARE_LSTM_CREATOR(LSTMC1H1); namespace gen = paddle::lite::jit::gen; -REGISTER_JITKERNEL_GEN(kLSTMCtHt, gen::LSTMCtHtCreator); -REGISTER_JITKERNEL_GEN(kLSTMC1H1, gen::LSTMC1H1Creator); +REGISTER_JITKERNEL_GEN_LITE(kLSTMCtHt, gen::LSTMCtHtCreator); +REGISTER_JITKERNEL_GEN_LITE(kLSTMC1H1, gen::LSTMC1H1Creator); diff --git a/lite/backends/x86/jit/gen/matmul.cc b/lite/backends/x86/jit/gen/matmul.cc index 2c75f6dd5dc4bbf12513d10ef0a4e02e709135fd..010c80fac4842e74c9b8272db472ddf6cf954771 100644 --- a/lite/backends/x86/jit/gen/matmul.cc +++ b/lite/backends/x86/jit/gen/matmul.cc @@ -130,4 +130,4 @@ class MatMulCreator : public JitCodeCreator { namespace gen = paddle::lite::jit::gen; -REGISTER_JITKERNEL_GEN(kMatMul, gen::MatMulCreator); +REGISTER_JITKERNEL_GEN_LITE(kMatMul, gen::MatMulCreator); diff --git a/lite/backends/x86/jit/gen/seqpool.cc b/lite/backends/x86/jit/gen/seqpool.cc index e0cf5e5a5a7646f09666f6ccb35b18610c845317..4c80737aac4bc9cd09f4ff222c8fad8c441887ec 100644 --- a/lite/backends/x86/jit/gen/seqpool.cc +++ b/lite/backends/x86/jit/gen/seqpool.cc @@ -82,4 +82,4 @@ class SeqPoolCreator : public JitCodeCreator { namespace gen = paddle::lite::jit::gen; -REGISTER_JITKERNEL_GEN(kSeqPool, gen::SeqPoolCreator); +REGISTER_JITKERNEL_GEN_LITE(kSeqPool, gen::SeqPoolCreator); diff --git a/lite/backends/x86/jit/gen/sgd.cc b/lite/backends/x86/jit/gen/sgd.cc index 10659f50844d73c14403f9e7a35d800364be1e7b..44e083366132c675b339b2da4bbb3b7c1c6b7569 100644 --- a/lite/backends/x86/jit/gen/sgd.cc +++ b/lite/backends/x86/jit/gen/sgd.cc @@ -127,4 +127,4 @@ class SgdCreator : public JitCodeCreator { namespace gen = paddle::lite::jit::gen; -REGISTER_JITKERNEL_GEN(kSgd, gen::SgdCreator); +REGISTER_JITKERNEL_GEN_LITE(kSgd, gen::SgdCreator); diff --git a/lite/backends/x86/jit/gen/vbroadcast.cc b/lite/backends/x86/jit/gen/vbroadcast.cc index 9e02dca8c40975fb45feed1d818bbe6d3e65db19..fb1e71f7b0b1e6f68a331d264682e80fbab7c219 100644 --- a/lite/backends/x86/jit/gen/vbroadcast.cc +++ b/lite/backends/x86/jit/gen/vbroadcast.cc @@ -88,4 +88,4 @@ class VBroadcastCreator : public JitCodeCreator { namespace gen = paddle::lite::jit::gen; -REGISTER_JITKERNEL_GEN(kVBroadcast, gen::VBroadcastCreator); +REGISTER_JITKERNEL_GEN_LITE(kVBroadcast, gen::VBroadcastCreator); diff --git a/lite/backends/x86/jit/gen_base.cc b/lite/backends/x86/jit/gen_base.cc index 38250d533dd8c94afc87b5f9113ea165d6b7e9ed..7d051aa6f5802844753b71fd43400e20b7f5965b 100644 --- a/lite/backends/x86/jit/gen_base.cc +++ b/lite/backends/x86/jit/gen_base.cc @@ -21,13 +21,15 @@ // posix_memalign #include "lite/backends/x86/cpu_info.h" #include "lite/backends/x86/jit/macro.h" +#include "lite/utils/env.h" #include "lite/utils/paddle_enforce.h" #ifndef _WIN32 #define posix_memalign_free free #endif -DEFINE_bool(dump_jitcode, false, "Whether to dump the jitcode to file"); +// DEFINE_bool(dump_jitcode, false, "Whether to dump the jitcode to file"); +bool dump_jitcode = paddle::lite::GetBoolFromEnv("dump_jitcode"); namespace paddle { namespace lite { diff --git a/lite/backends/x86/jit/gen_base.h b/lite/backends/x86/jit/gen_base.h index b5f942615aa001a119273b52c70116ae66e66126..4af93c2447d64e52676a60e33c01c63ba7221910 100644 --- a/lite/backends/x86/jit/gen_base.h +++ b/lite/backends/x86/jit/gen_base.h @@ -20,7 +20,8 @@ #include #include "lite/backends/x86/jit/kernel_base.h" -DECLARE_bool(dump_jitcode); +// DECLARE_bool(dump_jitcode); +extern bool dump_jitcode; namespace paddle { namespace lite { @@ -36,7 +37,7 @@ class GenBase : public Kernel { template Func getCode() const { const unsigned char* code = this->getCodeInternal(); - if (FLAGS_dump_jitcode) { + if (dump_jitcode) { this->dumpCode(code); } // Note: failed to cast with reinterpret_cast on Mac clang, diff --git a/lite/backends/x86/jit/more/CMakeLists.txt b/lite/backends/x86/jit/more/CMakeLists.txt index 2ddbbcd16a3ffef560581592e3a009c61844d4d5..5641466d8a86e4be7b88d7eaf977e5a58d18f085 100644 --- a/lite/backends/x86/jit/more/CMakeLists.txt +++ b/lite/backends/x86/jit/more/CMakeLists.txt @@ -1,6 +1,6 @@ -function(USE_JITKERNEL_MORE TARGET TYPE) - file(APPEND ${jit_file} "USE_JITKERNEL_MORE(${TARGET} ${TYPE});\n") +function(USE_JITKERNEL_MORE_LITE TARGET TYPE) + file(APPEND ${jit_file} "USE_JITKERNEL_MORE_LITE(${TARGET} ${TYPE});\n") endfunction() # enable it latter diff --git a/lite/backends/x86/jit/more/intrinsic/CMakeLists.txt b/lite/backends/x86/jit/more/intrinsic/CMakeLists.txt index 468937a4f6b27ae525bfd0d8e99cc891eedbc353..80dabc72fbe2db46359cd69760eb5a02cea615af 100644 --- a/lite/backends/x86/jit/more/intrinsic/CMakeLists.txt +++ b/lite/backends/x86/jit/more/intrinsic/CMakeLists.txt @@ -5,5 +5,5 @@ cc_library(jit_kernel_intrinsic SRCS ${jit_kernel_cc_intrinsic} DEPS jit_kernel_ set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} jit_kernel_intrinsic PARENT_SCOPE) # use mkl kernels by name and type -USE_JITKERNEL_MORE(kCRFDecoding, intrinsic) -USE_JITKERNEL_MORE(kLayerNorm, intrinsic) +USE_JITKERNEL_MORE_LITE(kCRFDecoding, intrinsic) +USE_JITKERNEL_MORE_LITE(kLayerNorm, intrinsic) diff --git a/lite/backends/x86/jit/more/mix/CMakeLists.txt b/lite/backends/x86/jit/more/mix/CMakeLists.txt index dd039d29152961210958470a48f086a133ab640c..5e0238f26f1ebbd298dba0957bdc93e16671505f 100644 --- a/lite/backends/x86/jit/more/mix/CMakeLists.txt +++ b/lite/backends/x86/jit/more/mix/CMakeLists.txt @@ -5,11 +5,11 @@ cc_library(jit_kernel_mix SRCS ${jit_kernel_mix_cc} DEPS jit_kernel_base) set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} jit_kernel_mix PARENT_SCOPE) -USE_JITKERNEL_MORE(kVSigmoid, mix) -USE_JITKERNEL_MORE(kVTanh, mix) -USE_JITKERNEL_MORE(kLSTMCtHt, mix) -USE_JITKERNEL_MORE(kLSTMC1H1, mix) -USE_JITKERNEL_MORE(kGRUH1, mix) -USE_JITKERNEL_MORE(kGRUHtPart1, mix) -USE_JITKERNEL_MORE(kGRUHtPart2, mix) -USE_JITKERNEL_MORE(kSoftmax, mix) +USE_JITKERNEL_MORE_LITE(kVSigmoid, mix) +USE_JITKERNEL_MORE_LITE(kVTanh, mix) +USE_JITKERNEL_MORE_LITE(kLSTMCtHt, mix) +USE_JITKERNEL_MORE_LITE(kLSTMC1H1, mix) +USE_JITKERNEL_MORE_LITE(kGRUH1, mix) +USE_JITKERNEL_MORE_LITE(kGRUHtPart1, mix) +USE_JITKERNEL_MORE_LITE(kGRUHtPart2, mix) +USE_JITKERNEL_MORE_LITE(kSoftmax, mix) diff --git a/lite/backends/x86/jit/more/mkl/CMakeLists.txt b/lite/backends/x86/jit/more/mkl/CMakeLists.txt index 56f1a62ad4e06807dace2a81156d92f6b02a14df..3557f531a561caace51225ad23e2d547ad48d08c 100644 --- a/lite/backends/x86/jit/more/mkl/CMakeLists.txt +++ b/lite/backends/x86/jit/more/mkl/CMakeLists.txt @@ -3,18 +3,18 @@ cc_library(jit_kernel_mkl SRCS mkl.cc DEPS jit_kernel_base dynload_mklml) set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} dynload_mklml jit_kernel_mkl PARENT_SCOPE) # use mkl kernels by name and type -USE_JITKERNEL_MORE(kMatMul, mkl) -USE_JITKERNEL_MORE(kVMul, mkl) -USE_JITKERNEL_MORE(kVAdd, mkl) -USE_JITKERNEL_MORE(kVScal, mkl) -USE_JITKERNEL_MORE(kStrideScal, mkl) -USE_JITKERNEL_MORE(kVExp, mkl) -USE_JITKERNEL_MORE(kVSquare, mkl) -USE_JITKERNEL_MORE(kVCopy, mkl) -USE_JITKERNEL_MORE(kVSigmoid, mkl) -USE_JITKERNEL_MORE(kVTanh, mkl) -USE_JITKERNEL_MORE(kSeqPool, mkl) -USE_JITKERNEL_MORE(kSoftmax, mkl) -USE_JITKERNEL_MORE(kEmbSeqPool, mkl) -USE_JITKERNEL_MORE(kSgd, mkl) -USE_JITKERNEL_MORE(kVBroadcast, mkl) +USE_JITKERNEL_MORE_LITE(kMatMul, mkl) +USE_JITKERNEL_MORE_LITE(kVMul, mkl) +USE_JITKERNEL_MORE_LITE(kVAdd, mkl) +USE_JITKERNEL_MORE_LITE(kVScal, mkl) +USE_JITKERNEL_MORE_LITE(kStrideScal, mkl) +USE_JITKERNEL_MORE_LITE(kVExp, mkl) +USE_JITKERNEL_MORE_LITE(kVSquare, mkl) +USE_JITKERNEL_MORE_LITE(kVCopy, mkl) +USE_JITKERNEL_MORE_LITE(kVSigmoid, mkl) +USE_JITKERNEL_MORE_LITE(kVTanh, mkl) +USE_JITKERNEL_MORE_LITE(kSeqPool, mkl) +USE_JITKERNEL_MORE_LITE(kSoftmax, mkl) +USE_JITKERNEL_MORE_LITE(kEmbSeqPool, mkl) +USE_JITKERNEL_MORE_LITE(kSgd, mkl) +USE_JITKERNEL_MORE_LITE(kVBroadcast, mkl) diff --git a/lite/backends/x86/jit/more/mkl/mkl.h b/lite/backends/x86/jit/more/mkl/mkl.h index 8b713e537e74ca2d2a2e79dad7c325cda9c0e7a4..6bc791e64575b8f481f91ea3c28ea4896fe1860d 100644 --- a/lite/backends/x86/jit/more/mkl/mkl.h +++ b/lite/backends/x86/jit/more/mkl/mkl.h @@ -142,14 +142,13 @@ void StrideScal(const T* a, const T* x, T* y, int n, int stride); // remain is the product of dimension shapes after the axis dimension template void Softmax(const T* x, T* y, int n, int bs, int remain = 1) { - std::vector entities(bs); for (int i = 0; i < bs; ++i) { - entities[i] = x[i * n]; + T entity = x[i * n]; for (int c = 1; c < n; ++c) { - entities[i] = x[i * n + c] > entities[i] ? x[i * n + c] : entities[i]; + entity = x[i * n + c] > entity ? x[i * n + c] : entity; } for (int c = 0; c < n; ++c) { - y[i * n + c] = x[i * n + c] - entities[i]; + y[i * n + c] = x[i * n + c] - entity; } } VExp(y, y, n * bs); diff --git a/lite/backends/x86/jit/refer/CMakeLists.txt b/lite/backends/x86/jit/refer/CMakeLists.txt index 7133f596620410d37ffe52a2ee92b7a9974bf1cc..c52b21ad7dca102d18aee25aa60079bf03ae82b9 100644 --- a/lite/backends/x86/jit/refer/CMakeLists.txt +++ b/lite/backends/x86/jit/refer/CMakeLists.txt @@ -2,39 +2,39 @@ cc_library(jit_kernel_refer SRCS refer.cc DEPS jit_kernel_base) set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} jit_kernel_refer PARENT_SCOPE) -function(USE_JITKERNEL_REFER TARGET) - file(APPEND ${jit_file} "USE_JITKERNEL_REFER(${TARGET});\n") +function(USE_JITKERNEL_REFER_LITE TARGET) + file(APPEND ${jit_file} "USE_JITKERNEL_REFER_LITE(${TARGET});\n") endfunction() # use refer kernel by name -USE_JITKERNEL_REFER(kVMul) -USE_JITKERNEL_REFER(kVAdd) -USE_JITKERNEL_REFER(kVAddRelu) -USE_JITKERNEL_REFER(kVSub) -USE_JITKERNEL_REFER(kVScal) -USE_JITKERNEL_REFER(kStrideScal) -USE_JITKERNEL_REFER(kVAddBias) -USE_JITKERNEL_REFER(kVCopy) -USE_JITKERNEL_REFER(kVRelu) -USE_JITKERNEL_REFER(kVIdentity) -USE_JITKERNEL_REFER(kVExp) -USE_JITKERNEL_REFER(kVSigmoid) -USE_JITKERNEL_REFER(kVTanh) -USE_JITKERNEL_REFER(kLSTMCtHt) -USE_JITKERNEL_REFER(kLSTMC1H1) -USE_JITKERNEL_REFER(kGRUH1) -USE_JITKERNEL_REFER(kGRUHtPart1) -USE_JITKERNEL_REFER(kGRUHtPart2) -USE_JITKERNEL_REFER(kCRFDecoding) -USE_JITKERNEL_REFER(kLayerNorm) -USE_JITKERNEL_REFER(kNCHW16CMulNC) -USE_JITKERNEL_REFER(kSeqPool) -USE_JITKERNEL_REFER(kMatMul) -USE_JITKERNEL_REFER(kVSquare) -USE_JITKERNEL_REFER(kHSum) -USE_JITKERNEL_REFER(kHMax) -USE_JITKERNEL_REFER(kStrideASum) -USE_JITKERNEL_REFER(kSoftmax) -USE_JITKERNEL_REFER(kEmbSeqPool) -USE_JITKERNEL_REFER(kSgd) -USE_JITKERNEL_REFER(kVBroadcast) +USE_JITKERNEL_REFER_LITE(kVMul) +USE_JITKERNEL_REFER_LITE(kVAdd) +USE_JITKERNEL_REFER_LITE(kVAddRelu) +USE_JITKERNEL_REFER_LITE(kVSub) +USE_JITKERNEL_REFER_LITE(kVScal) +USE_JITKERNEL_REFER_LITE(kStrideScal) +USE_JITKERNEL_REFER_LITE(kVAddBias) +USE_JITKERNEL_REFER_LITE(kVCopy) +USE_JITKERNEL_REFER_LITE(kVRelu) +USE_JITKERNEL_REFER_LITE(kVIdentity) +USE_JITKERNEL_REFER_LITE(kVExp) +USE_JITKERNEL_REFER_LITE(kVSigmoid) +USE_JITKERNEL_REFER_LITE(kVTanh) +USE_JITKERNEL_REFER_LITE(kLSTMCtHt) +USE_JITKERNEL_REFER_LITE(kLSTMC1H1) +USE_JITKERNEL_REFER_LITE(kGRUH1) +USE_JITKERNEL_REFER_LITE(kGRUHtPart1) +USE_JITKERNEL_REFER_LITE(kGRUHtPart2) +USE_JITKERNEL_REFER_LITE(kCRFDecoding) +USE_JITKERNEL_REFER_LITE(kLayerNorm) +USE_JITKERNEL_REFER_LITE(kNCHW16CMulNC) +USE_JITKERNEL_REFER_LITE(kSeqPool) +USE_JITKERNEL_REFER_LITE(kMatMul) +USE_JITKERNEL_REFER_LITE(kVSquare) +USE_JITKERNEL_REFER_LITE(kHSum) +USE_JITKERNEL_REFER_LITE(kHMax) +USE_JITKERNEL_REFER_LITE(kStrideASum) +USE_JITKERNEL_REFER_LITE(kSoftmax) +USE_JITKERNEL_REFER_LITE(kEmbSeqPool) +USE_JITKERNEL_REFER_LITE(kSgd) +USE_JITKERNEL_REFER_LITE(kVBroadcast) diff --git a/lite/backends/x86/jit/refer/refer.cc b/lite/backends/x86/jit/refer/refer.cc index e1b1240c5d5b0bc382fae8bd1b77f6c412522bdd..c47f8216abd999e66e914b208d96b8f352226f71 100644 --- a/lite/backends/x86/jit/refer/refer.cc +++ b/lite/backends/x86/jit/refer/refer.cc @@ -18,7 +18,7 @@ namespace refer = paddle::lite::jit::refer; #define REGISTER_REFER_KERNEL(func) \ - REGISTER_JITKERNEL_REFER( \ + REGISTER_JITKERNEL_REFER_LITE( \ k##func, refer::func##Kernel, refer::func##Kernel) REGISTER_REFER_KERNEL(VMul); diff --git a/lite/backends/x86/jit/registry.h b/lite/backends/x86/jit/registry.h index 7613a8dd4376045beb3636954668130e7220521e..65e3152d70fdd6262583cddced78e43513f0e0a1 100644 --- a/lite/backends/x86/jit/registry.h +++ b/lite/backends/x86/jit/registry.h @@ -77,16 +77,16 @@ class JitKernelRegistrar { void Touch() {} }; -#define STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE(uniq_name, msg) \ +#define STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE_LITE(uniq_name, msg) \ struct __test_global_namespace_##uniq_name##__ {}; \ static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \ __test_global_namespace_##uniq_name##__>::value, \ msg) // Refer always on CPUPlace -#define REGISTER_JITKERNEL_REFER(kernel_type, ...) \ - STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \ - __reg_jitkernel_##kernel_type##_refer_CPUPlace, \ +#define REGISTER_JITKERNEL_REFER_LITE(kernel_type, ...) \ + STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE_LITE( \ + __reg_litejitkernel_##kernel_type##_refer_CPUPlace, \ "REGISTER_KERNEL_REFER must be called in global namespace"); \ static ::paddle::lite::jit::JitKernelRegistrar< \ ::paddle::lite::jit::ReferKernelPool, \ @@ -94,84 +94,84 @@ class JitKernelRegistrar { __VA_ARGS__> \ __jit_kernel_registrar_##kernel_type##_refer_CPUPlace_( \ ::paddle::lite::jit::KernelType::kernel_type); \ - int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_() { \ + int LiteTouchJitKernelReg_##kernel_type##_refer_CPUPlace_() { \ __jit_kernel_registrar_##kernel_type##_refer_CPUPlace_.Touch(); \ return 0; \ } // kernel_type: should be in paddle::lite::jit::KernelType // place_type: should be one of CPUPlace and GPUPlace in paddle::platform -#define REGISTER_KERNEL_MORE(kernel_type, impl_type, place_type, ...) \ - STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \ - __reg_jitkernel_##kernel_type##_##impl_type##_##place_type, \ - "REGISTER_KERNEL_MORE must be called in global namespace"); \ - extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \ +#define REGISTER_KERNEL_MORE_LITE(kernel_type, impl_type, place_type, ...) \ + STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE_LITE( \ + __reg_litejitkernel_##kernel_type##_##impl_type##_##place_type, \ + "REGISTER_KERNEL_MORE_LITE must be called in global namespace"); \ + extern int LiteTouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \ static int __assert_##kernel_type##_##impl_type##_##place_type##_has_refer_ \ - UNUSED = TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \ + UNUSED = LiteTouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \ static ::paddle::lite::jit::JitKernelRegistrar< \ ::paddle::lite::jit::KernelPool, \ ::paddle::lite::fluid::place_type, \ __VA_ARGS__> \ __jit_kernel_registrar_##kernel_type##_##impl_type##_##place_type##_( \ ::paddle::lite::jit::KernelType::kernel_type); \ - int TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_() { \ + int LiteTouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_() { \ __jit_kernel_registrar_##kernel_type##_##impl_type##_##place_type##_ \ .Touch(); \ return 0; \ } #define REGISTER_JITKERNEL_MORE(kernel_type, impl_type, ...) \ - REGISTER_KERNEL_MORE(kernel_type, impl_type, CPUPlace, __VA_ARGS__) - -#define REGISTER_GPUKERNEL_MORE(kernel_type, impl_type, ...) \ - REGISTER_KERNEL_MORE(kernel_type, impl_type, GPUPlace, __VA_ARGS__) - -#define REGISTER_JITKERNEL_GEN(kernel_type, ...) \ - STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \ - __reg_jitkernel_gen_##kernel_type##_CPUPlace_, \ - "REGISTER_JITKERNEL_GEN must be called in global namespace"); \ - extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \ - static int __assert_gen_##kernel_type##_has_refer_ UNUSED = \ - TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \ - static ::paddle::lite::jit::JitKernelRegistrar< \ - ::paddle::lite::jit::JitCodeCreatorPool, \ - ::paddle::lite::fluid::CPUPlace, \ - __VA_ARGS__> \ - __jit_kernel_registrar_gen_##kernel_type##_CPUPlace_( \ - ::paddle::lite::jit::KernelType::kernel_type); \ - int TouchJitKernelReg_gen_##kernel_type##_CPUPlace_() { \ - __jit_kernel_registrar_gen_##kernel_type##_CPUPlace_.Touch(); \ - return 0; \ + REGISTER_KERNEL_MORE_LITE(kernel_type, impl_type, CPUPlace, __VA_ARGS__) + +#define REGISTER_GPUKERNEL_MORE_LITE(kernel_type, impl_type, ...) \ + REGISTER_KERNEL_MORE_LITE(kernel_type, impl_type, GPUPlace, __VA_ARGS__) + +#define REGISTER_JITKERNEL_GEN_LITE(kernel_type, ...) \ + STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE_LITE( \ + __reg_litejitkernel_gen_##kernel_type##_CPUPlace_, \ + "REGISTER_JITKERNEL_GEN_LITE must be called in global namespace"); \ + extern int LiteTouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \ + static int __assert_gen_##kernel_type##_has_refer_ UNUSED = \ + LiteTouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \ + static ::paddle::lite::jit::JitKernelRegistrar< \ + ::paddle::lite::jit::JitCodeCreatorPool, \ + ::paddle::lite::fluid::CPUPlace, \ + __VA_ARGS__> \ + __jit_kernel_registrar_gen_##kernel_type##_CPUPlace_( \ + ::paddle::lite::jit::KernelType::kernel_type); \ + int LiteTouchJitKernelReg_gen_##kernel_type##_CPUPlace_() { \ + __jit_kernel_registrar_gen_##kernel_type##_CPUPlace_.Touch(); \ + return 0; \ } -#define USE_JITKERNEL_GEN(kernel_type) \ - STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \ - __reg_jitkernel_gen_##kernel_type##_CPUPlace_, \ - "USE_JITKERNEL_GEN must be called in global namespace"); \ - extern int TouchJitKernelReg_gen_##kernel_type##_CPUPlace_(); \ - static int use_jitkernel_gen_##kernel_type##_CPUPlace_ UNUSED = \ - TouchJitKernelReg_gen_##kernel_type##_CPUPlace_() - -#define USE_JITKERNEL_REFER(kernel_type) \ - STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \ - __reg_jitkernel_##kernel_type##_refer_CPUPlace_, \ - "USE_JITKERNEL_REFER must be called in global namespace"); \ - extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \ - static int use_jitkernel_##kernel_type##_refer_CPUPlace_ UNUSED = \ - TouchJitKernelReg_##kernel_type##_refer_CPUPlace_() - -#define USE_KERNEL_MORE(kernel_type, impl_type, place_type) \ - STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \ - __reg_jitkernel_##kernel_type##_##impl_type##_##place_type##_, \ - "USE_JITKERNEL_MORE must be called in global namespace"); \ - extern int \ - TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_(); \ - static int use_jitkernel_##kernel_type##_##impl_type##_##place_type##_ \ - UNUSED = \ - TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_() - -#define USE_JITKERNEL_MORE(kernel_type, impl_type) \ - USE_KERNEL_MORE(kernel_type, impl_type, CPUPlace) +#define USE_JITKERNEL_GEN_LITE(kernel_type) \ + STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE_LITE( \ + __reg_litejitkernel_gen_##kernel_type##_CPUPlace_, \ + "USE_JITKERNEL_GEN_LITE must be called in global namespace"); \ + extern int LiteTouchJitKernelReg_gen_##kernel_type##_CPUPlace_(); \ + static int use_litejitkernel_gen_##kernel_type##_CPUPlace_ UNUSED = \ + LiteTouchJitKernelReg_gen_##kernel_type##_CPUPlace_() + +#define USE_JITKERNEL_REFER_LITE(kernel_type) \ + STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE_LITE( \ + __reg_litejitkernel_##kernel_type##_refer_CPUPlace_, \ + "USE_JITKERNEL_REFER_LITE must be called in global namespace"); \ + extern int LiteTouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \ + static int use_litejitkernel_##kernel_type##_refer_CPUPlace_ UNUSED = \ + LiteTouchJitKernelReg_##kernel_type##_refer_CPUPlace_() + +#define USE_KERNEL_MORE_LITE(kernel_type, impl_type, place_type) \ + STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE_LITE( \ + __reg_litejitkernel_##kernel_type##_##impl_type##_##place_type##_, \ + "USE_JITKERNEL_MORE_LITE must be called in global namespace"); \ + extern int \ + LiteTouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_(); \ + static int use_litejitkernel_##kernel_type##_##impl_type##_##place_type##_ \ + UNUSED = \ + LiteTouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_() + +#define USE_JITKERNEL_MORE_LITE(kernel_type, impl_type) \ + USE_KERNEL_MORE_LITE(kernel_type, impl_type, CPUPlace) } // namespace jit } // namespace lite diff --git a/lite/backends/x86/math/beam_search.cc b/lite/backends/x86/math/beam_search.cc index bbe35b4de5508c70496e5c8566c8d1b982a7155c..8d61fb3bbb97705c697fba934e6cab9424f85bad 100644 --- a/lite/backends/x86/math/beam_search.cc +++ b/lite/backends/x86/math/beam_search.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "lite/backends/x86/math/beam_search.h" #include +#include #include #include "lite/fluid/lod.h" diff --git a/lite/backends/x86/math/detail/avx_mathfun.h b/lite/backends/x86/math/detail/avx_mathfun.h index c95c881512900efb4b39df3ba16b8de686caefcb..2ad0866d6346a24690b30d0da317c6d86e9aebba 100644 --- a/lite/backends/x86/math/detail/avx_mathfun.h +++ b/lite/backends/x86/math/detail/avx_mathfun.h @@ -41,9 +41,11 @@ (this is the zlib license) */ - +#pragma once #include "lite/backends/x86/cpu_info.h" +namespace paddle { +namespace lite { /* __m128 is ugly to write */ typedef __m256 v8sf; // vector of 8 float (avx) typedef __m256i v8si; // vector of 8 int (avx) @@ -134,7 +136,7 @@ typedef union imm_xmm_union { return (ret); \ } -//#warning "Using SSE2 to perform AVX2 bitshift ops" +// #warning "Using SSE2 to perform AVX2 bitshift ops" AVX2_BITOP_USING_SSE2(slli_epi32) AVX2_BITOP_USING_SSE2(srli_epi32) @@ -152,7 +154,7 @@ AVX2_BITOP_USING_SSE2(srli_epi32) return (ret); \ } -//#warning "Using SSE2 to perform AVX2 integer ops" +// #warning "Using SSE2 to perform AVX2 integer ops" AVX2_INTOP_USING_SSE2(and_si128) AVX2_INTOP_USING_SSE2(andnot_si128) AVX2_INTOP_USING_SSE2(cmpeq_epi32) @@ -175,23 +177,23 @@ AVX2_INTOP_USING_SSE2(add_epi32) */ v8sf log256_ps(v8sf x) { v8si imm0; - v8sf one = *(v8sf *)_ps256_1; + v8sf one = *(v8sf *)_ps256_1; // NOLINT // v8sf invalid_mask = _mm256_cmple_ps(x, _mm256_setzero_ps()); v8sf invalid_mask = _mm256_cmp_ps(x, _mm256_setzero_ps(), _CMP_LE_OS); - x = _mm256_max_ps( - x, *(v8sf *)_ps256_min_norm_pos); /* cut off denormalized stuff */ + x = _mm256_max_ps(x, *(v8sf *)_ps256_min_norm_pos); // NOLINT + /* cut off denormalized stuff */ // NOLINT // can be done with AVX2 imm0 = avx2_mm256_srli_epi32(_mm256_castps_si256(x), 23); /* keep only the fractional part */ - x = _mm256_and_ps(x, *(v8sf *)_ps256_inv_mant_mask); - x = _mm256_or_ps(x, *(v8sf *)_ps256_0p5); + x = _mm256_and_ps(x, *(v8sf *)_ps256_inv_mant_mask); // NOLINT + x = _mm256_or_ps(x, *(v8sf *)_ps256_0p5); // NOLINT // this is again another AVX2 instruction - imm0 = avx2_mm256_sub_epi32(imm0, *(v8si *)_pi32_256_0x7f); + imm0 = avx2_mm256_sub_epi32(imm0, *(v8si *)_pi32_256_0x7f); // NOLINT v8sf e = _mm256_cvtepi32_ps(imm0); e = _mm256_add_ps(e, one); @@ -203,7 +205,8 @@ v8sf log256_ps(v8sf x) { } else { x = x - 1.0; } */ // v8sf mask = _mm256_cmplt_ps(x, *(v8sf*)_ps256_cephes_SQRTHF); - v8sf mask = _mm256_cmp_ps(x, *(v8sf *)_ps256_cephes_SQRTHF, _CMP_LT_OS); + v8sf mask = + _mm256_cmp_ps(x, *(v8sf *)_ps256_cephes_SQRTHF, _CMP_LT_OS); // NOLINT v8sf tmp = _mm256_and_ps(x, mask); x = _mm256_sub_ps(x, one); e = _mm256_sub_ps(e, _mm256_and_ps(one, mask)); @@ -211,34 +214,34 @@ v8sf log256_ps(v8sf x) { v8sf z = _mm256_mul_ps(x, x); - v8sf y = *(v8sf *)_ps256_cephes_log_p0; + v8sf y = *(v8sf *)_ps256_cephes_log_p0; // NOLINT y = _mm256_mul_ps(y, x); - y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_log_p1); + y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_log_p1); // NOLINT y = _mm256_mul_ps(y, x); - y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_log_p2); + y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_log_p2); // NOLINT y = _mm256_mul_ps(y, x); - y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_log_p3); + y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_log_p3); // NOLINT y = _mm256_mul_ps(y, x); - y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_log_p4); + y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_log_p4); // NOLINT y = _mm256_mul_ps(y, x); - y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_log_p5); + y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_log_p5); // NOLINT y = _mm256_mul_ps(y, x); - y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_log_p6); + y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_log_p6); // NOLINT y = _mm256_mul_ps(y, x); - y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_log_p7); + y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_log_p7); // NOLINT y = _mm256_mul_ps(y, x); - y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_log_p8); + y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_log_p8); // NOLINT y = _mm256_mul_ps(y, x); y = _mm256_mul_ps(y, z); - tmp = _mm256_mul_ps(e, *(v8sf *)_ps256_cephes_log_q1); + tmp = _mm256_mul_ps(e, *(v8sf *)_ps256_cephes_log_q1); // NOLINT y = _mm256_add_ps(y, tmp); - tmp = _mm256_mul_ps(z, *(v8sf *)_ps256_0p5); + tmp = _mm256_mul_ps(z, *(v8sf *)_ps256_0p5); // NOLINT y = _mm256_sub_ps(y, tmp); - tmp = _mm256_mul_ps(e, *(v8sf *)_ps256_cephes_log_q2); + tmp = _mm256_mul_ps(e, *(v8sf *)_ps256_cephes_log_q2); // NOLINT x = _mm256_add_ps(x, y); x = _mm256_add_ps(x, tmp); x = _mm256_or_ps(x, invalid_mask); // negative arg will be NAN @@ -262,14 +265,14 @@ _PS256_CONST(cephes_exp_p5, 5.0000001201E-1); v8sf exp256_ps(v8sf x) { v8sf tmp = _mm256_setzero_ps(), fx; v8si imm0; - v8sf one = *(v8sf *)_ps256_1; + v8sf one = *(v8sf *)_ps256_1; // NOLINT - x = _mm256_min_ps(x, *(v8sf *)_ps256_exp_hi); - x = _mm256_max_ps(x, *(v8sf *)_ps256_exp_lo); + x = _mm256_min_ps(x, *(v8sf *)_ps256_exp_hi); // NOLINT + x = _mm256_max_ps(x, *(v8sf *)_ps256_exp_lo); // NOLINT /* express exp(x) as exp(g + n*log(2)) */ - fx = _mm256_mul_ps(x, *(v8sf *)_ps256_cephes_LOG2EF); - fx = _mm256_add_ps(fx, *(v8sf *)_ps256_0p5); + fx = _mm256_mul_ps(x, *(v8sf *)_ps256_cephes_LOG2EF); // NOLINT + fx = _mm256_add_ps(fx, *(v8sf *)_ps256_0p5); // NOLINT /* how to perform a floorf with SSE: just below */ // imm0 = _mm256_cvttps_epi32(fx); @@ -283,24 +286,24 @@ v8sf exp256_ps(v8sf x) { mask = _mm256_and_ps(mask, one); fx = _mm256_sub_ps(tmp, mask); - tmp = _mm256_mul_ps(fx, *(v8sf *)_ps256_cephes_exp_C1); - v8sf z = _mm256_mul_ps(fx, *(v8sf *)_ps256_cephes_exp_C2); + tmp = _mm256_mul_ps(fx, *(v8sf *)_ps256_cephes_exp_C1); // NOLINT + v8sf z = _mm256_mul_ps(fx, *(v8sf *)_ps256_cephes_exp_C2); // NOLINT x = _mm256_sub_ps(x, tmp); x = _mm256_sub_ps(x, z); z = _mm256_mul_ps(x, x); - v8sf y = *(v8sf *)_ps256_cephes_exp_p0; + v8sf y = *(v8sf *)_ps256_cephes_exp_p0; // NOLINT y = _mm256_mul_ps(y, x); - y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_exp_p1); + y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_exp_p1); // NOLINT y = _mm256_mul_ps(y, x); - y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_exp_p2); + y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_exp_p2); // NOLINT y = _mm256_mul_ps(y, x); - y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_exp_p3); + y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_exp_p3); // NOLINT y = _mm256_mul_ps(y, x); - y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_exp_p4); + y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_exp_p4); // NOLINT y = _mm256_mul_ps(y, x); - y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_exp_p5); + y = _mm256_add_ps(y, *(v8sf *)_ps256_cephes_exp_p5); // NOLINT y = _mm256_mul_ps(y, z); y = _mm256_add_ps(y, x); y = _mm256_add_ps(y, one); @@ -308,7 +311,7 @@ v8sf exp256_ps(v8sf x) { /* build 2^n */ imm0 = _mm256_cvttps_epi32(fx); // another two AVX2 instructions - imm0 = avx2_mm256_add_epi32(imm0, *(v8si *)_pi32_256_0x7f); + imm0 = avx2_mm256_add_epi32(imm0, *(v8si *)_pi32_256_0x7f); // NOLINT imm0 = avx2_mm256_slli_epi32(imm0, 23); v8sf pow2n = _mm256_castsi256_ps(imm0); y = _mm256_mul_ps(y, pow2n); @@ -349,12 +352,12 @@ v8sf sin256_ps(v8sf x) { // any x sign_bit = x; /* take the absolute value */ - x = _mm256_and_ps(x, *(v8sf *)_ps256_inv_sign_mask); + x = _mm256_and_ps(x, *(v8sf *)_ps256_inv_sign_mask); // NOLINT /* extract the sign bit (upper one) */ - sign_bit = _mm256_and_ps(sign_bit, *(v8sf *)_ps256_sign_mask); + sign_bit = _mm256_and_ps(sign_bit, *(v8sf *)_ps256_sign_mask); // NOLINT /* scale by 4/Pi */ - y = _mm256_mul_ps(x, *(v8sf *)_ps256_cephes_FOPI); + y = _mm256_mul_ps(x, *(v8sf *)_ps256_cephes_FOPI); // NOLINT /* Here we start a series of integer operations, which are in the @@ -367,12 +370,12 @@ v8sf sin256_ps(v8sf x) { // any x imm2 = _mm256_cvttps_epi32(y); /* j=(j+1) & (~1) (see the cephes sources) */ // another two AVX2 instruction - imm2 = avx2_mm256_add_epi32(imm2, *(v8si *)_pi32_256_1); - imm2 = avx2_mm256_and_si256(imm2, *(v8si *)_pi32_256_inv1); + imm2 = avx2_mm256_add_epi32(imm2, *(v8si *)_pi32_256_1); // NOLINT + imm2 = avx2_mm256_and_si256(imm2, *(v8si *)_pi32_256_inv1); // NOLINT y = _mm256_cvtepi32_ps(imm2); /* get the swap sign flag */ - imm0 = avx2_mm256_and_si256(imm2, *(v8si *)_pi32_256_4); + imm0 = avx2_mm256_and_si256(imm2, *(v8si *)_pi32_256_4); // NOLINT imm0 = avx2_mm256_slli_epi32(imm0, 29); /* get the polynom selection mask there is one polynom for 0 <= x <= Pi/4 @@ -380,31 +383,31 @@ v8sf sin256_ps(v8sf x) { // any x Both branches will be computed. */ - imm2 = avx2_mm256_and_si256(imm2, *(v8si *)_pi32_256_2); - imm2 = avx2_mm256_cmpeq_epi32(imm2, *(v8si *)_pi32_256_0); + imm2 = avx2_mm256_and_si256(imm2, *(v8si *)_pi32_256_2); // NOLINT + imm2 = avx2_mm256_cmpeq_epi32(imm2, *(v8si *)_pi32_256_0); // NOLINT #else /* we use SSE2 routines to perform the integer ops */ COPY_IMM_TO_XMM(_mm256_cvttps_epi32(y), imm2_1, imm2_2); - imm2_1 = _mm_add_epi32(imm2_1, *(v4si *)_pi32avx_1); - imm2_2 = _mm_add_epi32(imm2_2, *(v4si *)_pi32avx_1); + imm2_1 = _mm_add_epi32(imm2_1, *(v4si *)_pi32avx_1); // NOLINT + imm2_2 = _mm_add_epi32(imm2_2, *(v4si *)_pi32avx_1); // NOLINT - imm2_1 = _mm_and_si128(imm2_1, *(v4si *)_pi32avx_inv1); - imm2_2 = _mm_and_si128(imm2_2, *(v4si *)_pi32avx_inv1); + imm2_1 = _mm_and_si128(imm2_1, *(v4si *)_pi32avx_inv1); // NOLINT + imm2_2 = _mm_and_si128(imm2_2, *(v4si *)_pi32avx_inv1); // NOLINT COPY_XMM_TO_IMM(imm2_1, imm2_2, imm2); y = _mm256_cvtepi32_ps(imm2); - imm0_1 = _mm_and_si128(imm2_1, *(v4si *)_pi32avx_4); - imm0_2 = _mm_and_si128(imm2_2, *(v4si *)_pi32avx_4); + imm0_1 = _mm_and_si128(imm2_1, *(v4si *)_pi32avx_4); // NOLINT + imm0_2 = _mm_and_si128(imm2_2, *(v4si *)_pi32avx_4); // NOLINT imm0_1 = _mm_slli_epi32(imm0_1, 29); imm0_2 = _mm_slli_epi32(imm0_2, 29); COPY_XMM_TO_IMM(imm0_1, imm0_2, imm0); - imm2_1 = _mm_and_si128(imm2_1, *(v4si *)_pi32avx_2); - imm2_2 = _mm_and_si128(imm2_2, *(v4si *)_pi32avx_2); + imm2_1 = _mm_and_si128(imm2_1, *(v4si *)_pi32avx_2); // NOLINT + imm2_2 = _mm_and_si128(imm2_2, *(v4si *)_pi32avx_2); // NOLINT imm2_1 = _mm_cmpeq_epi32(imm2_1, _mm_setzero_si128()); imm2_2 = _mm_cmpeq_epi32(imm2_2, _mm_setzero_si128()); @@ -418,9 +421,9 @@ v8sf sin256_ps(v8sf x) { // any x /* The magic pass: "Extended precision modular arithmetic" x = ((x - y * DP1) - y * DP2) - y * DP3; */ - xmm1 = *(v8sf *)_ps256_minus_cephes_DP1; - xmm2 = *(v8sf *)_ps256_minus_cephes_DP2; - xmm3 = *(v8sf *)_ps256_minus_cephes_DP3; + xmm1 = *(v8sf *)_ps256_minus_cephes_DP1; // NOLINT + xmm2 = *(v8sf *)_ps256_minus_cephes_DP2; // NOLINT + xmm3 = *(v8sf *)_ps256_minus_cephes_DP3; // NOLINT xmm1 = _mm256_mul_ps(y, xmm1); xmm2 = _mm256_mul_ps(y, xmm2); xmm3 = _mm256_mul_ps(y, xmm3); @@ -429,26 +432,26 @@ v8sf sin256_ps(v8sf x) { // any x x = _mm256_add_ps(x, xmm3); /* Evaluate the first polynom (0 <= x <= Pi/4) */ - y = *(v8sf *)_ps256_coscof_p0; + y = *(v8sf *)_ps256_coscof_p0; // NOLINT v8sf z = _mm256_mul_ps(x, x); y = _mm256_mul_ps(y, z); - y = _mm256_add_ps(y, *(v8sf *)_ps256_coscof_p1); + y = _mm256_add_ps(y, *(v8sf *)_ps256_coscof_p1); // NOLINT y = _mm256_mul_ps(y, z); - y = _mm256_add_ps(y, *(v8sf *)_ps256_coscof_p2); + y = _mm256_add_ps(y, *(v8sf *)_ps256_coscof_p2); // NOLINT y = _mm256_mul_ps(y, z); y = _mm256_mul_ps(y, z); - v8sf tmp = _mm256_mul_ps(z, *(v8sf *)_ps256_0p5); + v8sf tmp = _mm256_mul_ps(z, *(v8sf *)_ps256_0p5); // NOLINT y = _mm256_sub_ps(y, tmp); - y = _mm256_add_ps(y, *(v8sf *)_ps256_1); + y = _mm256_add_ps(y, *(v8sf *)_ps256_1); // NOLINT /* Evaluate the second polynom (Pi/4 <= x <= 0) */ - v8sf y2 = *(v8sf *)_ps256_sincof_p0; + v8sf y2 = *(v8sf *)_ps256_sincof_p0; // NOLINT y2 = _mm256_mul_ps(y2, z); - y2 = _mm256_add_ps(y2, *(v8sf *)_ps256_sincof_p1); + y2 = _mm256_add_ps(y2, *(v8sf *)_ps256_sincof_p1); // NOLINT y2 = _mm256_mul_ps(y2, z); - y2 = _mm256_add_ps(y2, *(v8sf *)_ps256_sincof_p2); + y2 = _mm256_add_ps(y2, *(v8sf *)_ps256_sincof_p2); // NOLINT y2 = _mm256_mul_ps(y2, z); y2 = _mm256_mul_ps(y2, x); y2 = _mm256_add_ps(y2, x); @@ -475,53 +478,53 @@ v8sf cos256_ps(v8sf x) { // any x #endif /* take the absolute value */ - x = _mm256_and_ps(x, *(v8sf *)_ps256_inv_sign_mask); + x = _mm256_and_ps(x, *(v8sf *)_ps256_inv_sign_mask); // NOLINT /* scale by 4/Pi */ - y = _mm256_mul_ps(x, *(v8sf *)_ps256_cephes_FOPI); + y = _mm256_mul_ps(x, *(v8sf *)_ps256_cephes_FOPI); // NOLINT #ifdef __AVX2__ /* store the integer part of y in mm0 */ imm2 = _mm256_cvttps_epi32(y); /* j=(j+1) & (~1) (see the cephes sources) */ - imm2 = avx2_mm256_add_epi32(imm2, *(v8si *)_pi32_256_1); - imm2 = avx2_mm256_and_si256(imm2, *(v8si *)_pi32_256_inv1); + imm2 = avx2_mm256_add_epi32(imm2, *(v8si *)_pi32_256_1); // NOLINT + imm2 = avx2_mm256_and_si256(imm2, *(v8si *)_pi32_256_inv1); // NOLINT y = _mm256_cvtepi32_ps(imm2); - imm2 = avx2_mm256_sub_epi32(imm2, *(v8si *)_pi32_256_2); + imm2 = avx2_mm256_sub_epi32(imm2, *(v8si *)_pi32_256_2); // NOLINT /* get the swap sign flag */ - imm0 = avx2_mm256_andnot_si256(imm2, *(v8si *)_pi32_256_4); + imm0 = avx2_mm256_andnot_si256(imm2, *(v8si *)_pi32_256_4); // NOLINT imm0 = avx2_mm256_slli_epi32(imm0, 29); /* get the polynom selection mask */ - imm2 = avx2_mm256_and_si256(imm2, *(v8si *)_pi32_256_2); - imm2 = avx2_mm256_cmpeq_epi32(imm2, *(v8si *)_pi32_256_0); + imm2 = avx2_mm256_and_si256(imm2, *(v8si *)_pi32_256_2); // NOLINT + imm2 = avx2_mm256_cmpeq_epi32(imm2, *(v8si *)_pi32_256_0); // NOLINT #else /* we use SSE2 routines to perform the integer ops */ COPY_IMM_TO_XMM(_mm256_cvttps_epi32(y), imm2_1, imm2_2); - imm2_1 = _mm_add_epi32(imm2_1, *(v4si *)_pi32avx_1); - imm2_2 = _mm_add_epi32(imm2_2, *(v4si *)_pi32avx_1); + imm2_1 = _mm_add_epi32(imm2_1, *(v4si *)_pi32avx_1); // NOLINT + imm2_2 = _mm_add_epi32(imm2_2, *(v4si *)_pi32avx_1); // NOLINT - imm2_1 = _mm_and_si128(imm2_1, *(v4si *)_pi32avx_inv1); - imm2_2 = _mm_and_si128(imm2_2, *(v4si *)_pi32avx_inv1); + imm2_1 = _mm_and_si128(imm2_1, *(v4si *)_pi32avx_inv1); // NOLINT + imm2_2 = _mm_and_si128(imm2_2, *(v4si *)_pi32avx_inv1); // NOLINT COPY_XMM_TO_IMM(imm2_1, imm2_2, imm2); y = _mm256_cvtepi32_ps(imm2); - imm2_1 = _mm_sub_epi32(imm2_1, *(v4si *)_pi32avx_2); - imm2_2 = _mm_sub_epi32(imm2_2, *(v4si *)_pi32avx_2); + imm2_1 = _mm_sub_epi32(imm2_1, *(v4si *)_pi32avx_2); // NOLINT + imm2_2 = _mm_sub_epi32(imm2_2, *(v4si *)_pi32avx_2); // NOLINT - imm0_1 = _mm_andnot_si128(imm2_1, *(v4si *)_pi32avx_4); - imm0_2 = _mm_andnot_si128(imm2_2, *(v4si *)_pi32avx_4); + imm0_1 = _mm_andnot_si128(imm2_1, *(v4si *)_pi32avx_4); // NOLINT + imm0_2 = _mm_andnot_si128(imm2_2, *(v4si *)_pi32avx_4); // NOLINT imm0_1 = _mm_slli_epi32(imm0_1, 29); imm0_2 = _mm_slli_epi32(imm0_2, 29); COPY_XMM_TO_IMM(imm0_1, imm0_2, imm0); - imm2_1 = _mm_and_si128(imm2_1, *(v4si *)_pi32avx_2); - imm2_2 = _mm_and_si128(imm2_2, *(v4si *)_pi32avx_2); + imm2_1 = _mm_and_si128(imm2_1, *(v4si *)_pi32avx_2); // NOLINT + imm2_2 = _mm_and_si128(imm2_2, *(v4si *)_pi32avx_2); // NOLINT imm2_1 = _mm_cmpeq_epi32(imm2_1, _mm_setzero_si128()); imm2_2 = _mm_cmpeq_epi32(imm2_2, _mm_setzero_si128()); @@ -534,9 +537,9 @@ v8sf cos256_ps(v8sf x) { // any x /* The magic pass: "Extended precision modular arithmetic" x = ((x - y * DP1) - y * DP2) - y * DP3; */ - xmm1 = *(v8sf *)_ps256_minus_cephes_DP1; - xmm2 = *(v8sf *)_ps256_minus_cephes_DP2; - xmm3 = *(v8sf *)_ps256_minus_cephes_DP3; + xmm1 = *(v8sf *)_ps256_minus_cephes_DP1; // NOLINT + xmm2 = *(v8sf *)_ps256_minus_cephes_DP2; // NOLINT + xmm3 = *(v8sf *)_ps256_minus_cephes_DP3; // NOLINT xmm1 = _mm256_mul_ps(y, xmm1); xmm2 = _mm256_mul_ps(y, xmm2); xmm3 = _mm256_mul_ps(y, xmm3); @@ -545,26 +548,26 @@ v8sf cos256_ps(v8sf x) { // any x x = _mm256_add_ps(x, xmm3); /* Evaluate the first polynom (0 <= x <= Pi/4) */ - y = *(v8sf *)_ps256_coscof_p0; + y = *(v8sf *)_ps256_coscof_p0; // NOLINT v8sf z = _mm256_mul_ps(x, x); y = _mm256_mul_ps(y, z); - y = _mm256_add_ps(y, *(v8sf *)_ps256_coscof_p1); + y = _mm256_add_ps(y, *(v8sf *)_ps256_coscof_p1); // NOLINT y = _mm256_mul_ps(y, z); - y = _mm256_add_ps(y, *(v8sf *)_ps256_coscof_p2); + y = _mm256_add_ps(y, *(v8sf *)_ps256_coscof_p2); // NOLINT y = _mm256_mul_ps(y, z); y = _mm256_mul_ps(y, z); - v8sf tmp = _mm256_mul_ps(z, *(v8sf *)_ps256_0p5); + v8sf tmp = _mm256_mul_ps(z, *(v8sf *)_ps256_0p5); // NOLINT y = _mm256_sub_ps(y, tmp); - y = _mm256_add_ps(y, *(v8sf *)_ps256_1); + y = _mm256_add_ps(y, *(v8sf *)_ps256_1); // NOLINT /* Evaluate the second polynom (Pi/4 <= x <= 0) */ - v8sf y2 = *(v8sf *)_ps256_sincof_p0; + v8sf y2 = *(v8sf *)_ps256_sincof_p0; // NOLINT y2 = _mm256_mul_ps(y2, z); - y2 = _mm256_add_ps(y2, *(v8sf *)_ps256_sincof_p1); + y2 = _mm256_add_ps(y2, *(v8sf *)_ps256_sincof_p1); // NOLINT y2 = _mm256_mul_ps(y2, z); - y2 = _mm256_add_ps(y2, *(v8sf *)_ps256_sincof_p2); + y2 = _mm256_add_ps(y2, *(v8sf *)_ps256_sincof_p2); // NOLINT y2 = _mm256_mul_ps(y2, z); y2 = _mm256_mul_ps(y2, x); y2 = _mm256_add_ps(y2, x); @@ -595,42 +598,43 @@ void sincos256_ps(v8sf x, v8sf *s, v8sf *c) { sign_bit_sin = x; /* take the absolute value */ - x = _mm256_and_ps(x, *(v8sf *)_ps256_inv_sign_mask); + x = _mm256_and_ps(x, *(v8sf *)_ps256_inv_sign_mask); // NOLINT /* extract the sign bit (upper one) */ - sign_bit_sin = _mm256_and_ps(sign_bit_sin, *(v8sf *)_ps256_sign_mask); + sign_bit_sin = + _mm256_and_ps(sign_bit_sin, *(v8sf *)_ps256_sign_mask); // NOLINT /* scale by 4/Pi */ - y = _mm256_mul_ps(x, *(v8sf *)_ps256_cephes_FOPI); + y = _mm256_mul_ps(x, *(v8sf *)_ps256_cephes_FOPI); // NOLINT #ifdef __AVX2__ /* store the integer part of y in imm2 */ imm2 = _mm256_cvttps_epi32(y); /* j=(j+1) & (~1) (see the cephes sources) */ - imm2 = avx2_mm256_add_epi32(imm2, *(v8si *)_pi32_256_1); - imm2 = avx2_mm256_and_si256(imm2, *(v8si *)_pi32_256_inv1); + imm2 = avx2_mm256_add_epi32(imm2, *(v8si *)_pi32_256_1); // NOLINT + imm2 = avx2_mm256_and_si256(imm2, *(v8si *)_pi32_256_inv1); // NOLINT y = _mm256_cvtepi32_ps(imm2); imm4 = imm2; /* get the swap sign flag for the sine */ - imm0 = avx2_mm256_and_si256(imm2, *(v8si *)_pi32_256_4); + imm0 = avx2_mm256_and_si256(imm2, *(v8si *)_pi32_256_4); // NOLINT imm0 = avx2_mm256_slli_epi32(imm0, 29); // v8sf swap_sign_bit_sin = _mm256_castsi256_ps(imm0); /* get the polynom selection mask for the sine*/ - imm2 = avx2_mm256_and_si256(imm2, *(v8si *)_pi32_256_2); - imm2 = avx2_mm256_cmpeq_epi32(imm2, *(v8si *)_pi32_256_0); + imm2 = avx2_mm256_and_si256(imm2, *(v8si *)_pi32_256_2); // NOLINT + imm2 = avx2_mm256_cmpeq_epi32(imm2, *(v8si *)_pi32_256_0); // NOLINT // v8sf poly_mask = _mm256_castsi256_ps(imm2); #else /* we use SSE2 routines to perform the integer ops */ COPY_IMM_TO_XMM(_mm256_cvttps_epi32(y), imm2_1, imm2_2); - imm2_1 = _mm_add_epi32(imm2_1, *(v4si *)_pi32avx_1); - imm2_2 = _mm_add_epi32(imm2_2, *(v4si *)_pi32avx_1); + imm2_1 = _mm_add_epi32(imm2_1, *(v4si *)_pi32avx_1); // NOLINT + imm2_2 = _mm_add_epi32(imm2_2, *(v4si *)_pi32avx_1); // NOLINT - imm2_1 = _mm_and_si128(imm2_1, *(v4si *)_pi32avx_inv1); - imm2_2 = _mm_and_si128(imm2_2, *(v4si *)_pi32avx_inv1); + imm2_1 = _mm_and_si128(imm2_1, *(v4si *)_pi32avx_inv1); // NOLINT + imm2_2 = _mm_and_si128(imm2_2, *(v4si *)_pi32avx_inv1); // NOLINT COPY_XMM_TO_IMM(imm2_1, imm2_2, imm2); y = _mm256_cvtepi32_ps(imm2); @@ -638,16 +642,16 @@ void sincos256_ps(v8sf x, v8sf *s, v8sf *c) { imm4_1 = imm2_1; imm4_2 = imm2_2; - imm0_1 = _mm_and_si128(imm2_1, *(v4si *)_pi32avx_4); - imm0_2 = _mm_and_si128(imm2_2, *(v4si *)_pi32avx_4); + imm0_1 = _mm_and_si128(imm2_1, *(v4si *)_pi32avx_4); // NOLINT + imm0_2 = _mm_and_si128(imm2_2, *(v4si *)_pi32avx_4); // NOLINT imm0_1 = _mm_slli_epi32(imm0_1, 29); imm0_2 = _mm_slli_epi32(imm0_2, 29); COPY_XMM_TO_IMM(imm0_1, imm0_2, imm0); - imm2_1 = _mm_and_si128(imm2_1, *(v4si *)_pi32avx_2); - imm2_2 = _mm_and_si128(imm2_2, *(v4si *)_pi32avx_2); + imm2_1 = _mm_and_si128(imm2_1, *(v4si *)_pi32avx_2); // NOLINT + imm2_2 = _mm_and_si128(imm2_2, *(v4si *)_pi32avx_2); // NOLINT imm2_1 = _mm_cmpeq_epi32(imm2_1, _mm_setzero_si128()); imm2_2 = _mm_cmpeq_epi32(imm2_2, _mm_setzero_si128()); @@ -659,9 +663,9 @@ void sincos256_ps(v8sf x, v8sf *s, v8sf *c) { /* The magic pass: "Extended precision modular arithmetic" x = ((x - y * DP1) - y * DP2) - y * DP3; */ - xmm1 = *(v8sf *)_ps256_minus_cephes_DP1; - xmm2 = *(v8sf *)_ps256_minus_cephes_DP2; - xmm3 = *(v8sf *)_ps256_minus_cephes_DP3; + xmm1 = *(v8sf *)_ps256_minus_cephes_DP1; // NOLINT + xmm2 = *(v8sf *)_ps256_minus_cephes_DP2; // NOLINT + xmm3 = *(v8sf *)_ps256_minus_cephes_DP3; // NOLINT xmm1 = _mm256_mul_ps(y, xmm1); xmm2 = _mm256_mul_ps(y, xmm2); xmm3 = _mm256_mul_ps(y, xmm3); @@ -670,15 +674,15 @@ void sincos256_ps(v8sf x, v8sf *s, v8sf *c) { x = _mm256_add_ps(x, xmm3); #ifdef __AVX2__ - imm4 = avx2_mm256_sub_epi32(imm4, *(v8si *)_pi32_256_2); - imm4 = avx2_mm256_andnot_si256(imm4, *(v8si *)_pi32_256_4); + imm4 = avx2_mm256_sub_epi32(imm4, *(v8si *)_pi32_256_2); // NOLINT + imm4 = avx2_mm256_andnot_si256(imm4, *(v8si *)_pi32_256_4); // NOLINT imm4 = avx2_mm256_slli_epi32(imm4, 29); #else - imm4_1 = _mm_sub_epi32(imm4_1, *(v4si *)_pi32avx_2); - imm4_2 = _mm_sub_epi32(imm4_2, *(v4si *)_pi32avx_2); + imm4_1 = _mm_sub_epi32(imm4_1, *(v4si *)_pi32avx_2); // NOLINT + imm4_2 = _mm_sub_epi32(imm4_2, *(v4si *)_pi32avx_2); // NOLINT - imm4_1 = _mm_andnot_si128(imm4_1, *(v4si *)_pi32avx_4); - imm4_2 = _mm_andnot_si128(imm4_2, *(v4si *)_pi32avx_4); + imm4_1 = _mm_andnot_si128(imm4_1, *(v4si *)_pi32avx_4); // NOLINT + imm4_2 = _mm_andnot_si128(imm4_2, *(v4si *)_pi32avx_4); // NOLINT imm4_1 = _mm_slli_epi32(imm4_1, 29); imm4_2 = _mm_slli_epi32(imm4_2, 29); @@ -692,25 +696,25 @@ void sincos256_ps(v8sf x, v8sf *s, v8sf *c) { /* Evaluate the first polynom (0 <= x <= Pi/4) */ v8sf z = _mm256_mul_ps(x, x); - y = *(v8sf *)_ps256_coscof_p0; + y = *(v8sf *)_ps256_coscof_p0; // NOLINT y = _mm256_mul_ps(y, z); - y = _mm256_add_ps(y, *(v8sf *)_ps256_coscof_p1); + y = _mm256_add_ps(y, *(v8sf *)_ps256_coscof_p1); // NOLINT y = _mm256_mul_ps(y, z); - y = _mm256_add_ps(y, *(v8sf *)_ps256_coscof_p2); + y = _mm256_add_ps(y, *(v8sf *)_ps256_coscof_p2); // NOLINT y = _mm256_mul_ps(y, z); y = _mm256_mul_ps(y, z); - v8sf tmp = _mm256_mul_ps(z, *(v8sf *)_ps256_0p5); + v8sf tmp = _mm256_mul_ps(z, *(v8sf *)_ps256_0p5); // NOLINT y = _mm256_sub_ps(y, tmp); - y = _mm256_add_ps(y, *(v8sf *)_ps256_1); + y = _mm256_add_ps(y, *(v8sf *)_ps256_1); // NOLINT /* Evaluate the second polynom (Pi/4 <= x <= 0) */ - v8sf y2 = *(v8sf *)_ps256_sincof_p0; + v8sf y2 = *(v8sf *)_ps256_sincof_p0; // NOLINT y2 = _mm256_mul_ps(y2, z); - y2 = _mm256_add_ps(y2, *(v8sf *)_ps256_sincof_p1); + y2 = _mm256_add_ps(y2, *(v8sf *)_ps256_sincof_p1); // NOLINT y2 = _mm256_mul_ps(y2, z); - y2 = _mm256_add_ps(y2, *(v8sf *)_ps256_sincof_p2); + y2 = _mm256_add_ps(y2, *(v8sf *)_ps256_sincof_p2); // NOLINT y2 = _mm256_mul_ps(y2, z); y2 = _mm256_mul_ps(y2, x); y2 = _mm256_add_ps(y2, x); @@ -729,3 +733,6 @@ void sincos256_ps(v8sf x, v8sf *s, v8sf *c) { *s = _mm256_xor_ps(xmm1, sign_bit_sin); *c = _mm256_xor_ps(xmm2, sign_bit_cos); } + +} // namespace lite +} // namespace paddle diff --git a/lite/backends/x86/math/math_function.cc b/lite/backends/x86/math/math_function.cc index 822b7df936d84c21c226a13a48e8c09a2343f86a..a17807e8a997f0ecf908313a4cb205676e4fa4b8 100644 --- a/lite/backends/x86/math/math_function.cc +++ b/lite/backends/x86/math/math_function.cc @@ -110,11 +110,7 @@ void set_constant(const lite::Context& context, lite::Tensor* tensor, float value) { TensorSetConstantWithTarget func(context, tensor, value); - //#ifdef PADDLE_WITH_CUDA - // tensor->target().apply_visitor(func); - //#else func(); - //#endif } template @@ -123,17 +119,19 @@ struct RowwiseAdd { const lite::Tensor& input, const lite::Tensor& vector, lite::Tensor* output) { - auto in_dims = input.dims(); + const auto& in_dims = input.dims(); auto size = input.numel() / in_dims[0]; PADDLE_ENFORCE_EQ(vector.numel(), size); PADDLE_ENFORCE_EQ(output->dims(), in_dims); - auto in = lite::fluid::EigenMatrix::From(input); - auto vec = lite::fluid::EigenVector::Flatten(vector); - auto out = lite::fluid::EigenMatrix::From(*output); - + const T* input_data = input.data(); + const T* vector_data = vector.data(); + T* output_data = output->mutable_data(); for (int64_t i = 0; i < in_dims[0]; ++i) { - out.chip(i, 0) = in.chip(i, 0) + vec; + for (int64_t j = 0; j < size; ++j) { + output_data[i * in_dims[0] + j] = + input_data[i * in_dims[0] + j] + vector_data[j]; + } } } }; diff --git a/lite/backends/x86/math/pooling.cc b/lite/backends/x86/math/pooling.cc index 9da239f9c63371350403cc0bd0eecc94eab87590..ab6c1edb481f914d5751149aca2595fee550ca51 100644 --- a/lite/backends/x86/math/pooling.cc +++ b/lite/backends/x86/math/pooling.cc @@ -49,7 +49,7 @@ class Pool2dFunctor { const int stride_height = strides[0]; const int stride_width = strides[1]; const int padding_height = paddings[0]; - const int padding_width = paddings[1]; + const int padding_width = paddings[2]; const int input_stride = input_height * input_width; const int output_stride = output_height * output_width; @@ -130,7 +130,7 @@ class Pool2dGradFunctor { const int stride_height = strides[0]; const int stride_width = strides[1]; const int padding_height = paddings[0]; - const int padding_width = paddings[1]; + const int padding_width = paddings[2]; const int input_stride = input_height * input_width; const int output_stride = output_height * output_width; @@ -213,7 +213,7 @@ class MaxPool2dGradFunctor { const int stride_height = strides[0]; const int stride_width = strides[1]; const int padding_height = paddings[0]; - const int padding_width = paddings[1]; + const int padding_width = paddings[2]; const int input_stride = input_height * input_width; const int output_stride = output_height * output_width; @@ -629,7 +629,7 @@ class MaxPool2dWithIndexFunctor { const int stride_height = strides[0]; const int stride_width = strides[1]; const int padding_height = paddings[0]; - const int padding_width = paddings[1]; + const int padding_width = paddings[2]; const int input_stride = input_height * input_width; const int output_stride = output_height * output_width; diff --git a/lite/backends/x86/math/sequence2batch.cc b/lite/backends/x86/math/sequence2batch.cc index ff215781f1efeb20a0e126a6e39a8f3508131abd..c12c05414d717dce706590a491ccae2384f3bfe5 100644 --- a/lite/backends/x86/math/sequence2batch.cc +++ b/lite/backends/x86/math/sequence2batch.cc @@ -24,12 +24,12 @@ class CopyMatrixRowsFunctor { public: void operator()(const lite::Context& context, const lite::Tensor& src, - std::vector index_lod, + const std::vector& index_lod, lite::Tensor* dst, bool is_src_index) { - size_t* index = index_lod.data(); - auto src_dims = src.dims(); - auto dst_dims = dst->dims(); + const size_t* index = index_lod.data(); + const auto& src_dims = src.dims(); + const auto& dst_dims = dst->dims(); PADDLE_ENFORCE_EQ( src_dims.size(), 2UL, "The src must be matrix with rank 2."); PADDLE_ENFORCE_EQ( diff --git a/lite/backends/x86/math/sequence2batch.h b/lite/backends/x86/math/sequence2batch.h index a97bfaf66607e5ea2efbd6f26f311fb4cd9dab67..a70cc5bf73522f97ab312fc48553b5316dbf8376 100644 --- a/lite/backends/x86/math/sequence2batch.h +++ b/lite/backends/x86/math/sequence2batch.h @@ -19,7 +19,6 @@ limitations under the License. */ #include "lite/core/context.h" #include "lite/core/tensor.h" #include "lite/fluid/eigen.h" -// #include "lite/fluid/lod.h" #include "lite/utils/paddle_enforce.h" namespace paddle { @@ -27,11 +26,6 @@ namespace lite { namespace x86 { namespace math { -template -using EigenMatrix = lite::fluid::EigenMatrix; - template class CopyMatrixRowsFunctor { public: @@ -42,7 +36,7 @@ class CopyMatrixRowsFunctor { // The indexed rows are based on the input index. void operator()(const lite::Context& context, const lite::Tensor& src, - std::vector index_lod, + const std::vector& index_lod, lite::Tensor* dst, bool is_src_index); }; @@ -56,6 +50,7 @@ class LoDTensor2BatchFunctor { // seq_info[3] = {(4, 5, 1), (0, 4, 0), (9, 3, 2)} // struct SeqInfo { + SeqInfo() = default; SeqInfo(int start, int length, int seq_idx) : start(start), length(length), seq_idx(seq_idx) {} int start; @@ -89,10 +84,12 @@ class LoDTensor2BatchFunctor { const auto& lod = lods[0]; - std::vector seq_info; + std::vector seq_info(lod.size() - 1); for (size_t seq_id = 0; seq_id < lod.size() - 1; ++seq_id) { int length = lod[seq_id + 1] - lod[seq_id]; - seq_info.emplace_back(lod[seq_id], length, seq_id); + seq_info[seq_id].start = lod[seq_id]; + seq_info[seq_id].length = length; + seq_info[seq_id].seq_idx = seq_id; } std::sort(seq_info.begin(), seq_info.end(), [](SeqInfo a, SeqInfo b) { @@ -122,21 +119,19 @@ class LoDTensor2BatchFunctor { // The max_seqlen represents batch size after rearranging the // input LodTensor. It is also the maximum length of input sequence. - lite::LoD batch_lods; - batch_lods.emplace_back(std::vector{0}); - batch_lods.emplace_back(std::vector{0}); - batch_lods.emplace_back(std::vector{0}); + LoD* batch_lods = batch->mutable_lod(); + batch_lods->resize(3); // batch_lods[0] is the start positions for batch LoDTensor int max_seqlen = seq_info[0].length; - batch_lods[0].resize(static_cast(max_seqlen + 1)); + batch_lods->at(0).resize(static_cast(max_seqlen + 1)); // batch_lods[1] is the raw index in the input LoDTensor - batch_lods[1].resize(static_cast(lod_tensor.dims()[0])); + batch_lods->at(1).resize(static_cast(lod_tensor.dims()[0])); // batch_lods[2] is the sort order for the input LoDTensor. - batch_lods[2].resize(seq_info.size()); + batch_lods->at(2).resize(seq_info.size()); - size_t* batch_starts = batch_lods[0].data(); - size_t* seq2batch_idx = batch_lods[1].data(); + size_t* batch_starts = batch_lods->at(0).data(); + size_t* seq2batch_idx = batch_lods->at(1).data(); batch_starts[0] = 0; for (int n = 0; n < max_seqlen; n++) { auto batch_id = static_cast(batch_starts[n]); @@ -153,14 +148,13 @@ class LoDTensor2BatchFunctor { } batch_starts[n + 1] = static_cast(batch_id); } - size_t* seq_order = batch_lods[2].data(); + size_t* seq_order = batch_lods->at(2).data(); for (size_t i = 0; i < seq_info.size(); ++i) { seq_order[i] = seq_info[i].seq_idx; } - batch->set_lod(batch_lods); CopyMatrixRowsFunctor to_batch; - to_batch(context, lod_tensor, batch_lods[1], batch, true); + to_batch(context, lod_tensor, batch_lods->at(1), batch, true); } }; diff --git a/lite/backends/x86/math/softmax_impl.h b/lite/backends/x86/math/softmax_impl.h index ae997a8680b9012435d80b4aa9f592a775e62e85..ec45377bc55154a4a36ebc5c3684ab7efeeef88e 100644 --- a/lite/backends/x86/math/softmax_impl.h +++ b/lite/backends/x86/math/softmax_impl.h @@ -99,7 +99,7 @@ class SoftmaxFunctor> { const int axis_dim, const lite::Tensor* X, lite::Tensor* Y) { - auto in_dims = X->dims(); + const auto& in_dims = X->dims(); constexpr int kBatchDim = 0; constexpr int kClassDim = 1; @@ -140,7 +140,7 @@ class SoftmaxFunctor> { const int axis_dim, const lite::Tensor* X, lite::Tensor* Y) { - auto in_dims = X->dims(); + const auto& in_dims = X->dims(); const float* in_data = X->data(); float* out_data = Y->mutable_data(); const int kBatchDim = 0; diff --git a/lite/backends/x86/parallel.h b/lite/backends/x86/parallel.h new file mode 100644 index 0000000000000000000000000000000000000000..0689ec4c234509cee6f10f8e0f7dd432edae5c4e --- /dev/null +++ b/lite/backends/x86/parallel.h @@ -0,0 +1,73 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#ifdef PADDLE_WITH_MKLML +#include +#include "lite/backends/x86/mklml.h" +#endif + +namespace paddle { +namespace lite { +namespace x86 { + +static void SetNumThreads(int num_threads) { +#ifdef PADDLE_WITH_MKLML + int real_num_threads = std::max(num_threads, 1); + x86::MKL_Set_Num_Threads(real_num_threads); + omp_set_num_threads(real_num_threads); +#endif +} + +static inline int64_t GetMaxThreads() { + int64_t num_threads = 1; +#ifdef PADDLE_WITH_MKLML + // Do not support nested omp parallem. + num_threads = omp_in_parallel() ? 1 : omp_get_max_threads(); +#endif + return std::max(num_threads, 1L); +} + +using ThreadHandler = + std::function; + +static inline void RunParallelFor(const int64_t begin, + const int64_t end, + const ThreadHandler& f) { + if (begin >= end) { + return; + } + +#ifdef PADDLE_WITH_MKLML + int64_t num_threads = std::min(GetMaxThreads(), end - begin); + if (num_threads > 1) { +#pragma omp parallel num_threads(num_threads) + { + int64_t tid = omp_get_thread_num(); + int64_t chunk_size = (end - begin + num_threads - 1) / num_threads; + int64_t begin_tid = begin + tid * chunk_size; + f(begin_tid, std::min(end, chunk_size + begin_tid)); + } + return; + } +#endif + + f(begin, end); +} + +} // namespace x86 +} // namespace lite +} // namespace paddle diff --git a/lite/backends/xpu/CMakeLists.txt b/lite/backends/xpu/CMakeLists.txt index f911f8e0e7c61481e1d4e309bc0635718be11206..4491fdeaefe9f16265bdee2c07ebb02b86a2b038 100644 --- a/lite/backends/xpu/CMakeLists.txt +++ b/lite/backends/xpu/CMakeLists.txt @@ -2,5 +2,4 @@ if(NOT LITE_WITH_XPU) return() endif() -lite_cc_library(xpu_runtime SRCS runtime.cc DEPS ${xpu_runtime_libs}) -lite_cc_library(xpu_builder SRCS builder.cc DEPS ${xpu_builder_libs} xpu_runtime tensor op scope) +lite_cc_library(device_xpu SRCS device.cc DEPS ${xpu_builder_libs} ${xpu_runtime_libs}) diff --git a/lite/backends/xpu/builder.cc b/lite/backends/xpu/builder.cc deleted file mode 100644 index 796eaf9c46ceb3d29f1ffdc4c86ac45509f07ba1..0000000000000000000000000000000000000000 --- a/lite/backends/xpu/builder.cc +++ /dev/null @@ -1,189 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "lite/backends/xpu/builder.h" -#include // NOLINT -#include -#include "lite/backends/xpu/runtime.h" - -namespace paddle { -namespace lite { -namespace xpu { - -bool HasInputArg(const OpInfo* op_info, - const Scope* scope, - const std::string& argname) { - auto iarg_names = op_info->input_argnames(); - if (std::find(iarg_names.begin(), iarg_names.end(), argname) != - iarg_names.end()) { - auto inputs = op_info->Input(argname); - if (inputs.empty()) { - return false; - } - auto var_name = inputs.front(); - auto var = scope->FindVar(var_name); - return var != nullptr; - } else { - return false; - } -} - -std::string UniqueName(const std::string& prefix) { - static std::mutex counter_mtx; - static std::unordered_map counter_map; - std::unique_lock counter_lck(counter_mtx); - int counter = 1; - auto it = counter_map.find(prefix); - if (it == counter_map.end()) { - counter_map[prefix] = counter; - } else { - counter = ++(it->second); - } - return prefix + "_" + std::to_string(counter); -} - -xtcl::DataType CvtPrecisionType(PrecisionType in_type) { - xtcl::DataType out_type = ::xtcl::Float(32); - switch (in_type) { - case PRECISION(kFloat): - out_type = ::xtcl::Float(32); - break; - case PRECISION(kInt8): - out_type = ::xtcl::Int(8); - break; - case PRECISION(kInt32): - out_type = ::xtcl::Int(32); - break; - default: - LOG(FATAL) << "Can not convert precision type(" << PrecisionToStr(in_type) - << ") from Lite to XPU"; - break; - } - return out_type; -} - -DLDataType CvtDataType(PrecisionType in_type) { - DLDataType out_type = {kDLFloat, 32, 1}; - switch (in_type) { - case PRECISION(kFloat): - out_type = {kDLFloat, 32, 1}; - break; - case PRECISION(kInt8): - out_type = {kDLInt, 8, 1}; - break; - case PRECISION(kInt32): - out_type = {kDLInt, 32, 1}; - break; - default: - LOG(FATAL) << "Can not convert data type(" << PrecisionToStr(in_type) - << ") from Lite to XPU"; - break; - } - return out_type; -} - -xtcl::Array CvtShape(const std::vector& in_shape) { - xtcl::Array out_shape; - for (auto dim : in_shape) { - out_shape.push_back(dim); - } - return out_shape; -} - -xtcl::Array CvtShape(const std::vector& in_shape) { - return CvtShape(std::vector(in_shape.begin(), in_shape.end())); -} - -xtcl::Array CvtShape(const DDim& in_dims) { - return CvtShape(in_dims.Vectorize()); -} - -std::shared_ptr CvtTensor(lite::Tensor* in_tensor, - std::vector out_shape, - PrecisionType in_ptype, - DataLayoutType in_ltype) { - uint8_t* in_data = nullptr; - auto in_size = in_tensor->dims().production(); - auto in_shape = in_tensor->dims().Vectorize(); - if (out_shape.empty()) { - out_shape = in_shape; - } - int in_bytes; - if (in_ptype == PRECISION(kFloat)) { - in_data = reinterpret_cast(in_tensor->mutable_data()); - in_bytes = in_size * sizeof(float); - } else if (in_ptype == PRECISION(kInt32)) { - in_data = reinterpret_cast(in_tensor->mutable_data()); - in_bytes = in_size * sizeof(int32_t); - } else if (in_ptype == PRECISION(kInt8)) { - in_data = reinterpret_cast(in_tensor->mutable_data()); - in_bytes = in_size * sizeof(int8_t); - } else { - LOG(FATAL) << "Unknow precision type " << PrecisionToStr(in_ptype); - } - auto out_tensor = std::make_shared( - xtcl::xNDArray::Empty(out_shape, CvtDataType(in_ptype), {kDLCPU, 0})); - auto out_data = - reinterpret_cast(out_tensor->ToDLPack()->dl_tensor.data); - std::memcpy(out_data, in_data, in_bytes); - return out_tensor; -} - -// Build the XPU subgraph to the XPU model, store the model data into the -// weight tensor of the graph op, and the model data will be loaded again -// by the graph computing kernel when the graph op is executed for inference. -// Due to the lack of XPU APIs for building and outputing the model data, -// the compiled XPU runtime object will be managed by the global variable -// 'DeviceInfo' and the key name for finding the runtime object will be -// stored in the weight tensor of graph op. -// TODO(hong19860320) Compile the XPU subgraph and output the compiled model -// data to the weight tensor of graph op. -bool BuildModel( - std::shared_ptr builder, - std::shared_ptr params, - std::vector>* outputs, - lite::Tensor* model) { - LOG(INFO) << "[XPU] Build Model."; - CHECK(builder != nullptr); - CHECK(outputs != nullptr); - CHECK_GT(outputs->size(), 0); - CHECK(model != nullptr); - - // build graph and fill all of constant params - xtcl::xNetwork network = builder->FinalizeNetwork(*((*outputs)[0])); - auto target = xtcl::Target::Create("llvm"); - auto compiler = xtcl::network::xTensorCompiler(network, target); - compiler.SetParams(*params); // set the data of constant tensors - compiler.Build(); - - // create and register runtime - auto runtime = std::make_shared( - compiler.CreateRuntimeInstance()); - if (runtime == nullptr) { - LOG(WARNING) << "[XPU] Build Model failed!"; - return false; - } - std::string name = UniqueName("xpu"); - LOG(INFO) << "[XPU] Model Name: " << name; - DeviceInfo::Global().Insert(name, runtime); - model->Resize({static_cast(name.length() + 1)}); - memcpy(model->mutable_data(), - reinterpret_cast(name.c_str()), - name.length() + 1); - return true; -} - -} // namespace xpu -} // namespace lite -} // namespace paddle diff --git a/lite/backends/xpu/device.cc b/lite/backends/xpu/device.cc new file mode 100644 index 0000000000000000000000000000000000000000..badde878ad870bfc5fcd1984e39923174a11e9e2 --- /dev/null +++ b/lite/backends/xpu/device.cc @@ -0,0 +1,53 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/backends/xpu/device.h" +#include "lite/utils/cp_logging.h" + +namespace paddle { +namespace lite { +namespace xpu { + +std::unique_ptr Device::Build( + xtcl::network::xNetworkBuilder* builder, + xtcl::network::xTensorCompiler::ParamNDArrayMap* params, + std::vector* outputs) { + VLOG(3) << "[XPU] Build model"; + CHECK(builder != nullptr); + CHECK(outputs != nullptr); + CHECK_GT(outputs->size(), 0); + + // The XPU compiler build the graph and fill all of the constant params, and + // use TupleNode to support multiple outputs + xtcl::Array all_outs; + for (size_t i = 0; i < outputs->size(); i++) { + all_outs.push_back(*outputs->at(i)); + } + xtcl::xNetwork network = + builder->FinalizeNetwork(xtcl::relay::TupleNode::make(all_outs)); + auto target = xtcl::NullValue(); + if (!target_.empty()) { + target = xtcl::Target::Create(target_); + } + xtcl::network::xTensorCompiler compiler(network, target); + compiler.SetParams(*params); // Set the data of constant tensors + compiler.Build(); + VLOG(3) << "[XPU] Build done"; + return std::unique_ptr( + new xtcl::network::xRuntimeInstance(compiler.CreateRuntimeInstance())); +} + +} // namespace xpu +} // namespace lite +} // namespace paddle diff --git a/lite/backends/xpu/device.h b/lite/backends/xpu/device.h new file mode 100644 index 0000000000000000000000000000000000000000..6de18d5466da6e6b791363d2e275ea72376c78b8 --- /dev/null +++ b/lite/backends/xpu/device.h @@ -0,0 +1,64 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace paddle { +namespace lite { +namespace xpu { + +class Device { + public: + static Device& Global() { + static Device x; + return x; + } + Device() { + char* name = std::getenv("XPU_DEVICE_NAME"); + if (name) { + name_ = std::string(name); + } + // XPU_DEVICE_TARGET for XPU model building, which supports 'llvm' and 'xpu + // -libs=xdnn' + char* target = std::getenv("XPU_DEVICE_TARGET"); + if (target) { + target_ = std::string(target); + } + } + + // Build the XPU graph to the XPU runtime, return the XPU runtime which can be + // used to run inference. + std::unique_ptr Build( + xtcl::network::xNetworkBuilder* builder, + xtcl::network::xTensorCompiler::ParamNDArrayMap* params, + std::vector* outputs); + + const std::string name() const { return name_; } + const std::string target() const { return target_; } + + private: + std::string name_{""}; + std::string target_{""}; +}; + +} // namespace xpu +} // namespace lite +} // namespace paddle diff --git a/lite/backends/xpu/runtime.cc b/lite/backends/xpu/runtime.cc deleted file mode 100644 index a2c34b95758e8abf81c8294507d0ca60aad7c021..0000000000000000000000000000000000000000 --- a/lite/backends/xpu/runtime.cc +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "lite/backends/xpu/runtime.h" -#include -#include "lite/utils/cp_logging.h" - -namespace paddle { -namespace lite { -namespace xpu { - -// Extract the model data and recover the XPU model for inference, the function -// is called by the graph computing kernel when the graph op is executed. -// Due to the lack of XPU APIs for loading and recovering the XPU model from -// memory, the key name is obtained from the weight tensor of graph op, to get -// the runtime object for inference from the global variable 'DeviceInfo'. -// TODO(hong19860320) Recover the XPU model from the weight tensor of graph op. -bool LoadModel(const lite::Tensor &model, - std::shared_ptr *runtime) { - LOG(INFO) << "[XPU] Load Model."; - CHECK_GT(model.dims().production(), 0); - std::string name(reinterpret_cast(model.data())); - LOG(INFO) << "[XPU] Model Name: " << name; - CHECK(runtime != nullptr); - *runtime = DeviceInfo::Global().Find(name); - if (*runtime == nullptr) { - LOG(WARNING) << "[XPU] Load Model failed!"; - return false; - } - return true; -} - -} // namespace xpu -} // namespace lite -} // namespace paddle diff --git a/lite/backends/xpu/runtime.h b/lite/backends/xpu/runtime.h deleted file mode 100644 index 4ff8d75bce6156d51a4988d427058da34460443f..0000000000000000000000000000000000000000 --- a/lite/backends/xpu/runtime.h +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include -#include -#include -#include "lite/core/tensor.h" - -namespace paddle { -namespace lite { -namespace xpu { - -class DeviceInfo { - public: - static DeviceInfo& Global() { - static DeviceInfo x; - return x; - } - DeviceInfo() {} - - void Insert(const std::string& name, - std::shared_ptr runtime) { - if (runtimes_.find(name) != runtimes_.end()) { - LOG(WARNING) << "[XPU] Model " << name << " already exists."; - return; - } - runtimes_.emplace(std::make_pair(name, runtime)); - } - - void Clear() { runtimes_.clear(); } - - std::shared_ptr Find( - const std::string& name) const { - if (runtimes_.find(name) != runtimes_.end()) { - return runtimes_.at(name); - } else { - return nullptr; - } - } - - private: - int device_id_{0}; - std::string device_name_{"default"}; - std::unordered_map> - runtimes_; -}; - -bool LoadModel(const lite::Tensor& model, - std::shared_ptr* runtime); - -} // namespace xpu -} // namespace lite -} // namespace paddle diff --git a/lite/core/CMakeLists.txt b/lite/core/CMakeLists.txt index 5eecf1d815d30fe0ef10a55c6b6b351795fe63ae..1d0558451fce67433d966d1f4bff82af26459e33 100644 --- a/lite/core/CMakeLists.txt +++ b/lite/core/CMakeLists.txt @@ -6,7 +6,8 @@ lite_cc_library(target_wrapper SRCS target_wrapper.cc X86_DEPS target_wrapper_x86 CUDA_DEPS target_wrapper_cuda CL_DEPS cl_target_wrapper - FPGA_DEPS fpga_target_wrapper) + FPGA_DEPS fpga_target_wrapper + BM_DEPS target_wrapper_bm) lite_cc_library(memory SRCS memory.cc DEPS target_wrapper CL_DEPS cl_target_wrapper) @@ -33,9 +34,9 @@ lite_cc_library(scope SRCS scope.cc DEPS tensor) lite_cc_library(device_info SRCS device_info.cc DEPS tensor) if (LITE_WITH_ARM) -lite_cc_library(context SRCS context.cc DEPS tensor any device_info CL_DEPS cl_context gflags NPU_DEPS npu_runtime) +lite_cc_library(context SRCS context.cc DEPS tensor any device_info CL_DEPS cl_context gflags) else() -lite_cc_library(context SRCS context.cc DEPS tensor any device_info eigen3 CL_DEPS cl_context gflags XPU_DEPS xpu_runtime) +lite_cc_library(context SRCS context.cc DEPS tensor any device_info eigen3 CL_DEPS cl_context gflags) endif() #-------------------------------------------- GET CODE META INFO ------------------------------------------ @@ -95,11 +96,19 @@ add_custom_command( add_custom_target(op_list_h DEPENDS ops.h) add_custom_target(kernel_list_h DEPENDS kernels.h) add_custom_target(all_kernel_faked_cc DEPENDS all_kernel_faked.cc) - +# create headfile to restore ops info sorted by suppported platforms +add_custom_command( + COMMAND python ${CMAKE_SOURCE_DIR}/lite/tools/cmake_tools/record_supported_kernel_op.py + ${kernels_src_list} + ${ops_src_list} + ${CMAKE_BINARY_DIR}/supported_kernel_op_info.h + OUTPUT supported_kernel_op_info.h # not a real path to the output to force it execute every time. + ) + add_custom_target(supported_kernel_op_info_h DEPENDS supported_kernel_op_info.h) #----------------------------------------------- NOT CHANGE ----------------------------------------------- lite_cc_library(kernel SRCS kernel.cc DEPS context type_system target_wrapper any op_params tensor - PROFILE_DEPS basic_profiler + PROFILE_DEPS lite_profiler ) lite_cc_library(op SRCS op_lite.cc DEPS scope op_registry target_wrapper kernel cpp_op_desc tensor @@ -113,7 +122,7 @@ lite_cc_library(type_system SRCS type_system.cc DEPS tensor target_wrapper) lite_cc_library(program SRCS program.cc DEPS op kernel model_parser ${ops} ${cpp_wrapper} - PROFILE_DEPS basic_profiler) + PROFILE_DEPS lite_profiler) if (NOT LITE_ON_TINY_PUBLISH) lite_cc_library(optimizer SRCS optimizer.cc DEPS mir_pass_manager model_parser program) diff --git a/lite/core/arena/CMakeLists.txt b/lite/core/arena/CMakeLists.txt index bc77afd81e0859b9492b2068ce681098a9393923..0f3f36768bd5a079564002cbb6464d61bd5db3aa 100644 --- a/lite/core/arena/CMakeLists.txt +++ b/lite/core/arena/CMakeLists.txt @@ -5,6 +5,6 @@ endif() lite_cc_library(arena_framework SRCS framework.cc DEPS program gtest) -if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_XPU) AND (LITE_WITH_X86 OR LITE_WITH_ARM)) - lite_cc_test(test_arena_framework SRCS framework_test.cc DEPS arena_framework ${x86_kernels} ${fpga_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) +if((NOT LITE_WITH_OPENCL) AND (LITE_WITH_X86 OR LITE_WITH_ARM)) + lite_cc_test(test_arena_framework SRCS framework_test.cc DEPS arena_framework ${bm_kernels} ${npu_kernels} ${xpu_kernels} ${x86_kernels} ${cuda_kernels} ${fpga_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) endif() diff --git a/lite/core/arena/framework.cc b/lite/core/arena/framework.cc index c59c078787b9a6778227ba6ba51230d1fc2104cb..fe36f1e1ba16ad85c44136b09a0d2e5d3fadf688 100644 --- a/lite/core/arena/framework.cc +++ b/lite/core/arena/framework.cc @@ -14,13 +14,38 @@ #include "lite/core/arena/framework.h" #include "lite/core/context.h" +#include "lite/operators/subgraph_op.h" namespace paddle { namespace lite { namespace arena { void TestCase::CreateInstruction() { - auto op = LiteOpRegistry::Global().Create(op_desc().Type()); + std::shared_ptr op = nullptr; + if (place_.target == TARGET(kNPU) || place_.target == TARGET(kXPU)) { + // Create a new block desc to wrap the original op desc + int sub_block_idx = 0; + auto sub_block_desc = new cpp::BlockDesc(); + sub_block_desc->ClearOps(); + sub_block_desc->ClearVars(); + auto sub_block_op_desc = sub_block_desc->AddOp(); + *sub_block_op_desc = *op_desc_; + // Add the block desc into the subgraph op which used to replace the + // original op + op_desc_.reset(new cpp::OpDesc()); + op_desc_->SetType("subgraph"); + op_desc_->SetAttr("sub_block", sub_block_idx); + auto in_names = sub_block_op_desc->input_vars(); + auto out_names = sub_block_op_desc->output_vars(); + op_desc_->SetInput("Inputs", in_names); + op_desc_->SetOutput("Outputs", out_names); + op_desc_->SetAttr>("input_data_names", in_names); + op_desc_->SetAttr>("output_data_names", out_names); + op = LiteOpRegistry::Global().Create(op_desc().Type()); + static_cast(op.get())->SetSubBlock(sub_block_desc); + } else { + op = LiteOpRegistry::Global().Create(op_desc().Type()); + } CHECK(op) << "no op for " << op_desc().Type(); op->Attach(*op_desc_, inst_scope_); auto kernels = op->CreateKernels({place_}); @@ -37,6 +62,9 @@ void TestCase::CreateInstruction() { // prepare context (*it)->SetContext(std::move(ctx_)); instruction_.reset(new Instruction(op, std::move(*it))); +#ifdef LITE_WITH_PROFILE + instruction_->set_profiler(new profile::Profiler()); +#endif } void TestCase::PrepareInputsForInstruction() { @@ -65,6 +93,19 @@ void TestCase::PrepareInputsForInstruction() { } } +TestCase::~TestCase() { + if (op_desc_->Type() == "subgraph") { + // Release the subblock desc of Subgraph op + auto subgraph_op = const_cast( + static_cast(instruction_->op())); + CHECK(subgraph_op); + auto sub_block_desc = subgraph_op->GetSubBlock(); + if (sub_block_desc) { + delete sub_block_desc; + } + } +} + } // namespace arena } // namespace lite } // namespace paddle diff --git a/lite/core/arena/framework.h b/lite/core/arena/framework.h index 412ac0c167b8abe6d196dc25d1bc5b193d02965d..85edda26e6591bada967165317de00b169a2d0cd 100644 --- a/lite/core/arena/framework.h +++ b/lite/core/arena/framework.h @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include "lite/core/op_registry.h" @@ -42,7 +43,7 @@ class TestCase { : place_(place), scope_(new Scope), alias_(alias) { ctx_ = ContextScheduler::Global().NewContext(place_.target); } - virtual ~TestCase() {} + virtual ~TestCase(); void Prepare() { PrepareScopes(); @@ -77,6 +78,20 @@ class TestCase { // kernel registry. void CheckKernelConsistWithDefinition() {} + // Get the real precision of the output for check precision. When the declare + // precision obtained from the kernel is any, we should set the precision of + // the output in test case. + bool GetPrecisonType(const std::string& var_name, + PrecisionType* precision_type) { + auto res = precision_type_map_.find(var_name); + if (res == precision_type_map_.end()) { + return false; + } else { + *precision_type = precision_type_map_.at(var_name); + return true; + } + } + Scope& scope() { return *scope_; } Scope* baseline_scope() { return base_scope_; } @@ -92,7 +107,8 @@ class TestCase { void SetCommonTensor(const std::string& var_name, const DDim& ddim, const T* data, - const LoD& lod = {}) { + const LoD& lod = {}, + bool is_persistable = false) { auto* tensor = scope_->NewTensor(var_name); tensor->Resize(ddim); auto* d = tensor->mutable_data(); @@ -100,11 +116,26 @@ class TestCase { // set lod if (!lod.empty()) *tensor->mutable_lod() = lod; + // set persistable + tensor->set_persistable(is_persistable); } // Prepare for the operator. virtual void PrepareOpDesc(cpp::OpDesc* op_desc) = 0; + // Set the real precision of the output for check precision. When the declare + // precision obtained from the kernel is any, we should set the precision of + // the output in test case. + void SetPrecisionType(const std::string& var_name, + const PrecisionType& precision_type) { + auto res = precision_type_map_.find(var_name); + if (res == precision_type_map_.end()) { + precision_type_map_.insert({var_name, precision_type}); + } else { + precision_type_map_.at(var_name) = precision_type; + } + } + public: const Instruction& instruction() { return *instruction_; } @@ -148,6 +179,7 @@ class TestCase { Scope* base_scope_{}; std::unique_ptr op_desc_; std::unique_ptr instruction_; + std::unordered_map precision_type_map_; }; class Arena { @@ -159,13 +191,17 @@ class Arena { tester_->Prepare(); } - bool TestPrecision() { + bool TestPrecision(const std::vector& exclude_outs = {}) { tester_->RunBaseline(tester_->baseline_scope()); tester_->RunInstruction(); bool success = true; for (auto& out : tester_->op_desc().OutputArgumentNames()) { for (auto& var : tester_->op_desc().Output(out)) { + if (std::find(exclude_outs.begin(), exclude_outs.end(), var) != + exclude_outs.end()) { + continue; + } success = success && CompareTensor(out, var); } } @@ -180,7 +216,17 @@ class Arena { } auto duration = std::chrono::duration_cast( std::chrono::high_resolution_clock::now() - timer); - LOG(INFO) << "average duration: " << duration.count() << " ms"; + + timer = std::chrono::high_resolution_clock::now(); + for (int i = 0; i < times; i++) { + tester_->RunBaseline(tester_->baseline_scope()); + } + auto duration_basic = std::chrono::duration_cast( + std::chrono::high_resolution_clock::now() - timer); + LOG(INFO) << "average lite duration: " << duration.count() << " ms"; + LOG(INFO) << "average basic duration: " << duration_basic.count() << " ms"; + LOG(INFO) << "speed up ratio: lite_speed / basic_speed: " + << static_cast(duration_basic.count()) / duration.count(); } private: @@ -189,8 +235,11 @@ class Arena { // get tensor type. const Type* type = tester_->instruction().kernel()->GetOutputDeclType(arg_name); - - switch (type->precision()) { + auto precision_type = type->precision(); + if (precision_type == PRECISION(kAny)) { + CHECK(tester_->GetPrecisonType(var_name, &precision_type)); + } + switch (precision_type) { case PRECISION(kFloat): return tester_->CheckPrecision(var_name, abs_error_); case PRECISION(kInt8): @@ -199,7 +248,6 @@ class Arena { return tester_->CheckPrecision(var_name, abs_error_); case PRECISION(kBool): return tester_->CheckPrecision(var_name, abs_error_); - default: LOG(FATAL) << "not support type " << PrecisionToStr(type->precision()); return false; diff --git a/lite/core/context.h b/lite/core/context.h index 545c6d2e8804f72a0bde854f9e5ae82c80b2b53c..653329e4f24b1f391ea41ed39819b60c8a598a3b 100644 --- a/lite/core/context.h +++ b/lite/core/context.h @@ -25,12 +25,6 @@ #include "lite/backends/opencl/cl_context.h" #include "lite/backends/opencl/cl_runtime.h" #endif -#ifdef LITE_WITH_NPU -#include "lite/backends/npu/runtime.h" -#endif -#ifdef LITE_WITH_XPU -#include "lite/backends/xpu/runtime.h" -#endif #include #include @@ -61,6 +55,7 @@ using NPUContext = Context; using XPUContext = Context; using OpenCLContext = Context; using FPGAContext = Context; +using BMContext = Context; template <> class Context { @@ -88,12 +83,29 @@ class Context { }; #endif +#ifdef LITE_WITH_BM +template <> +class Context { + public: + Context() {} + explicit Context(const BMContext& ctx); + // NOTE: InitOnce should only be used by ContextScheduler + void InitOnce() { Init(0); } + + void Init(int dev_id) { TargetWrapperBM::SetDevice(dev_id); } + void CopySharedTo(BMContext* ctx) {} + void* GetHandle() { return TargetWrapperBM::GetHandle(); } + + std::string name() const { return "BMContext"; } +}; +#endif + #ifdef LITE_WITH_XPU template <> class Context { public: Context() {} - explicit Context(const NPUContext& ctx); + explicit Context(const XPUContext& ctx); // NOTE: InitOnce should only be used by ContextScheduler void InitOnce() {} void CopySharedTo(XPUContext* ctx) {} @@ -207,13 +219,6 @@ class Context { ctx->cublas_fp32_ = cublas_fp32_; } - CUDAContext& operator=(const CUDAContext& context) { - this->Init( - context.device_id_, context.exec_stream_id_, context.io_stream_id_); - this->cublas_fp32_ = context.cublas_fp32_; - return *this; - } - const cudaStream_t& exec_stream() const { return exec_stream_; } void SetExecStream(cudaStream_t stream) { exec_stream_ = stream; } @@ -387,6 +392,12 @@ class ContextScheduler { kernel_contexts_[TargetType::kFPGA].As().CopySharedTo( &ctx->As()); break; +#endif +#ifdef LITE_WITH_BM + case TARGET(kBM): + kernel_contexts_[TargetType::kBM].As().CopySharedTo( + &ctx->As()); + break; #endif default: #ifndef LITE_ON_MODEL_OPTIMIZE_TOOL @@ -425,6 +436,9 @@ class ContextScheduler { #endif #ifdef LITE_WITH_XPU InitContext(); +#endif +#ifdef LITE_WITH_BM + InitContext(); #endif } diff --git a/lite/core/device_info.cc b/lite/core/device_info.cc index f5b757ac3ccd6310f6a6fd9fe6483d28ff7adbc6..6e0d743fb9d8d8af5e7168e292c1e85d76844383 100644 --- a/lite/core/device_info.cc +++ b/lite/core/device_info.cc @@ -79,7 +79,7 @@ const int DEFAULT_L3_CACHE_SIZE = 0; int get_cpu_num() { #ifdef LITE_WITH_LINUX // get cpu count from /sys/devices/system/cpu/cpunum/uevent - int max_cpu_num = 20; + int max_cpu_num = 128; int cpu_num = 0; for (int i = 0; i < max_cpu_num; ++i) { char path[256]; @@ -227,19 +227,24 @@ void get_cpu_arch(std::vector* archs, const int cpu_num) { #ifdef LITE_WITH_LINUX std::string get_cpu_name() { - std::string cpu_name; + std::string cpu_name = ""; FILE* fp = fopen("/proc/cpuinfo", "rb"); if (!fp) { return ""; } char line[1024]; + bool first_model_name = true; while (!feof(fp)) { char* s = fgets(line, 1024, fp); if (!s) { break; } if (strstr(line, "Hardware") != NULL) { - cpu_name = std::string(line); + cpu_name += std::string(line); + } + if (strstr(line, "model name") != NULL && first_model_name) { + cpu_name += std::string(line); + first_model_name = false; } } #ifdef LITE_WITH_ANDROID @@ -816,6 +821,21 @@ bool DeviceInfo::SetCPUInfoByName() { SetFP16Info(1, 1); SetDotInfo(1, 1); return true; + } else if (dev_name_.find("FT2000PLUS") != std::string::npos) { + core_num_ = 64; + core_ids_.resize(core_num_); + big_core_ids_.resize(core_num_); + cluster_ids_.resize(core_num_); + for (int i = 0; i < core_num_; ++i) { + core_ids_[i] = i; + big_core_ids_[i] = i; + cluster_ids_[i] = 0; + } + little_core_ids_ = {}; + SetCacheInfo(0, 1, 64 * 1024); + SetCacheInfo(1, 1, 32 * 1024 * 1024); + SetCacheInfo(2, 1, 128 * 1024 * 1024); + return true; } return false; } diff --git a/lite/core/framework.proto b/lite/core/framework.proto index 5adf2a18b98c2a2d3e2f6e8f7dd5688150674dc6..84b5502ff7b369452e7c9988d185450934c78b03 100644 --- a/lite/core/framework.proto +++ b/lite/core/framework.proto @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ syntax = "proto2"; -option optimize_for = LITE_RUNTIME; package paddle.framework.proto; // Any incompatible changes to ProgramDesc and its dependencies should diff --git a/lite/core/kernel.h b/lite/core/kernel.h index 05d7a6b333810a8dc988d84a281f096babe8929f..18a1243c11652afc181f13f0f5a497858a30885f 100644 --- a/lite/core/kernel.h +++ b/lite/core/kernel.h @@ -31,7 +31,7 @@ #include "lite/utils/replace_stl/stream.h" #ifdef LITE_WITH_PROFILE -#include "lite/core/profile/basic_profiler.h" +#include "lite/core/profile/profiler.h" #endif // LITE_WITH_PROFILE namespace paddle { @@ -58,7 +58,10 @@ class KernelBase { virtual void Run() = 0; #ifdef LITE_WITH_PROFILE - void SetProfileID(uint32_t id) { profile_id_ = id; } + void SetProfiler(profile::Profiler* profiler, int id) { + profiler_ = profiler; + profile_id_ = id; + } #endif void Launch() { @@ -80,12 +83,11 @@ class KernelBase { #if defined(LITE_WITH_CUDA) WorkSpace::Global_CUDA().AllocReset(); #endif - #ifdef LITE_WITH_PROFILE - if (profile_id_ >= 0) { - profile::ProfileBlock x(profile_id_, "kernel"); - Run(); - } + profiler_->StopTiming(profile::Type::kCreate, profile_id_, ctx_.get()); + profiler_->StartTiming(profile::Type::kDispatch, profile_id_, ctx_.get()); + Run(); + profiler_->StopTiming(profile::Type::kDispatch, profile_id_, ctx_.get()); #else Run(); #endif @@ -175,6 +177,7 @@ class KernelBase { bool is_first_epoch_{true}; #ifdef LITE_WITH_PROFILE + profile::Profiler* profiler_{nullptr}; int profile_id_{-1}; #endif }; diff --git a/lite/core/lite.map b/lite/core/lite.map index 31adae42196c3d6b82628a2e433b13a4cb467b39..9cfd272eb6d3017a75b40481d25527d7c14478bf 100644 --- a/lite/core/lite.map +++ b/lite/core/lite.map @@ -1,6 +1,8 @@ { global: *paddle*; + *touch_*; + *mir_pass_*; local: *; }; diff --git a/lite/core/memory.cc b/lite/core/memory.cc index b3cb18b33630de6615812471e1acaab59c8e99b0..cfb0b3ae1765864200ecf2d70107a3aa0046899c 100644 --- a/lite/core/memory.cc +++ b/lite/core/memory.cc @@ -40,6 +40,11 @@ void* TargetMalloc(TargetType target, size_t size) { data = TargetWrapper::Malloc(size); break; #endif // LITE_WITH_OPENCL +#ifdef LITE_WITH_BM + case TargetType::kBM: + data = TargetWrapper::Malloc(size); + break; +#endif default: LOG(FATAL) << "Unknown supported target " << TargetToStr(target); } @@ -69,6 +74,11 @@ void TargetFree(TargetType target, void* data) { TargetWrapper::Free(data); break; #endif // LITE_WITH_CUDA +#ifdef LITE_WITH_BM + case TargetType::kBM: + TargetWrapper::Free(data); + break; +#endif default: LOG(FATAL) << "Unknown type"; } @@ -95,6 +105,11 @@ void TargetCopy(TargetType target, void* dst, const void* src, size_t size) { dst, src, size, IoDirection::DtoD); break; #endif +#ifdef LITE_WITH_BM + case TargetType::kBM: + TargetWrapper::MemcpySync(dst, src, size, IoDirection::DtoD); + break; +#endif #ifdef LITE_WITH_OPENCL case TargetType::kOpenCL: TargetWrapperCL::MemcpySync(dst, src, size, IoDirection::DtoD); diff --git a/lite/core/memory.h b/lite/core/memory.h index cb4ac044e7af6994e5e404f379eeb12290e34778..051d47bdde102f5fe058163d0c746fe3c4acf26e 100644 --- a/lite/core/memory.h +++ b/lite/core/memory.h @@ -25,6 +25,10 @@ #include "lite/backends/cuda/target_wrapper.h" #endif // LITE_WITH_CUDA +#ifdef LITE_WITH_BM +#include "lite/backends/bm/target_wrapper.h" +#endif // LITE_WITH_BM + namespace paddle { namespace lite { @@ -71,6 +75,11 @@ void CopySync(void* dst, const void* src, size_t size, IoDirection dir) { case TARGET(kFPGA): TargetWrapper::MemcpySync(dst, src, size, dir); break; +#endif +#ifdef LITE_WITH_BM + case TARGET(kBM): + TargetWrapper::MemcpySync(dst, src, size, dir); + break; #endif } } @@ -100,13 +109,14 @@ class Buffer { template void ResetLazyImage2D(TargetType target, const size_t img_w, - const size_t img_h) { + const size_t img_h, + void* host_ptr = nullptr) { size_t size = sizeof(T) * img_w * img_h * 4; // 4 for RGBA, un-used for opencl Image2D if (target != target_ || cl_image2d_width_ < img_w || cl_image2d_height_ < img_h) { Free(); - data_ = TargetWrapperCL::MallocImage(img_w, img_h); + data_ = TargetWrapperCL::MallocImage(img_w, img_h, host_ptr); target_ = target; space_ = size; // un-used for opencl Image2D cl_image2d_width_ = img_w; @@ -119,6 +129,7 @@ class Buffer { if (space_ > 0) { TargetFree(target_, data_); } + data_ = nullptr; target_ = TargetType::kHost; space_ = 0; } diff --git a/lite/core/mir/CMakeLists.txt b/lite/core/mir/CMakeLists.txt index a44b8348716449519486d37f6784e31ecc39f554..379ef67f2996519d0c8007d8f191efbd2166a9e3 100644 --- a/lite/core/mir/CMakeLists.txt +++ b/lite/core/mir/CMakeLists.txt @@ -16,10 +16,13 @@ lite_cc_library(mir_passes fusion/interpolate_fuse_pass.cc fusion/conv_elementwise_fuse_pass.cc fusion/conv_activation_fuse_pass.cc + fusion/var_conv_2d_activation_fuse_pass.cc fusion/conv_bn_fuse_pass.cc fusion/elementwise_add_activation_fuse_pass.cc fusion/quant_dequant_fuse_pass.cc + fusion/sequence_pool_concat_fuse_pass.cc elimination/identity_scale_eliminate_pass.cc + elimination/elementwise_mul_constant_eliminate_pass.cc static_kernel_pick_pass.cc variable_place_inference_pass.cc type_target_cast_pass.cc @@ -32,7 +35,8 @@ lite_cc_library(mir_passes demo_pass.cc runtime_context_assign_pass.cc memory_optimize_pass.cc - DEPS mir_pass types context ${mir_fusers} ${subgraph_passes}) + weight_quantization_preprocess_pass.cc + DEPS mir_pass types context ${mir_fusers} ${mir_subgraphs}) # lite_cc_test(test_ssa_graph SRCS ssa_graph_test.cc DEPS #mir_ssa_graph scope op diff --git a/lite/core/mir/elimination/elementwise_mul_constant_eliminate_pass.cc b/lite/core/mir/elimination/elementwise_mul_constant_eliminate_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..863c01ef0646794b5cbe54d7a81a8f26dbf164ae --- /dev/null +++ b/lite/core/mir/elimination/elementwise_mul_constant_eliminate_pass.cc @@ -0,0 +1,88 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/pass.h" +#include "lite/core/mir/pass_registry.h" +#include "lite/core/mir/pattern_matcher_high_api.h" + +namespace paddle { +namespace lite { +namespace mir { + +namespace { + +class ElementwiseMulConstantEliminator : public FuseBase { + public: + void BuildPattern() override { + auto* pre_op = OpNode("preop"); // the previous op's output need update + auto* post_op = OpNode("postop"); // the post op's output need update + // TODO(Superjomn) check has only one output + auto* x = + VarNode("x")->assert_is_op_input("elementwise_mul", "X")->AsOutput(); + auto* y = VarNode("Y")->assert_is_op_input("elementwise_mul", "Y"); + + // create op nodes + auto* mul = OpNode("mul", "elementwise_mul") + ->assert_is_op("elementwise_mul") + ->AsIntermediate(); + + auto* fill_constant = OpNode("fill_constant", "fill_constant") + ->assert_is_op("fill_constant") + ->assert_op_attr("value", 1.) + ->AsIntermediate(); + // create output node + auto* mul_out = + VarNode("output")->assert_is_op_output("elementwise_mul", "Out"); + // create topology. + std::vector add_inputs{x, y}; + *pre_op >> *x; + *fill_constant >> *y; + add_inputs >> *mul >> *mul_out; + *mul_out >> *post_op; + + // The pre_op will be eliminated, and a new output-updated op will insert. + mul_out->AsIntermediate(); // mul_out is pre_op's output, need to update + y->AsIntermediate(); // need to update + } + + private: + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override { + auto& post_op = matched.at("postop")->AsStmt(); + auto op_info = *post_op.op_info(); + + op_info.UpdateAllInputs(matched.at("output")->AsArg().name, + matched.at("x")->AsArg().name); + post_op.ResetOp(op_info, graph->valid_places()); + + IR_NODE_LINK_TO(matched.at("x"), matched.at("postop")); + } +}; + +} // namespace + +class ElementwiseMulConstantEliminatePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override { + ElementwiseMulConstantEliminator eliminator; + eliminator(graph.get()); + } +}; + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(elementwise_mul_constant_eliminate_pass, + paddle::lite::mir::ElementwiseMulConstantEliminatePass) + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/elimination/identity_scale_eliminate_pass.cc b/lite/core/mir/elimination/identity_scale_eliminate_pass.cc index acea48c742522d5b6b5f1f3b570fcbfe0c4be08d..345361047bbbad68cdd0b298a163214cbfe114fc 100644 --- a/lite/core/mir/elimination/identity_scale_eliminate_pass.cc +++ b/lite/core/mir/elimination/identity_scale_eliminate_pass.cc @@ -25,7 +25,8 @@ namespace { class Eliminator : public FuseBase { public: void BuildPattern() override { - auto* pre_op = OpNode("preop"); // the previous op's output need update + // the previous op's output need updat + auto* pre_op = OpNode("preop")->assert_is_not_op_type("conditional_block"); // TODO(Superjomn) check has only one output auto* x = VarNode("x")->assert_is_op_input("scale", "X"); auto* scale_op = OpNode("scale", "scale") diff --git a/lite/core/mir/fusion/CMakeLists.txt b/lite/core/mir/fusion/CMakeLists.txt index 5ac52837551f0b78d67dfe1733fe354ee2cf7f01..e65e72cf7b367ee8477f3f783ae4d82372529864 100644 --- a/lite/core/mir/fusion/CMakeLists.txt +++ b/lite/core/mir/fusion/CMakeLists.txt @@ -10,6 +10,9 @@ lite_cc_library(fuse_conv_elementwise lite_cc_library(fuse_conv_activation SRCS conv_activation_fuser.cc DEPS pattern_matcher_high_api) +lite_cc_library(fuse_var_conv_activation + SRCS var_conv_2d_activation_fuser.cc + DEPS pattern_matcher_high_api) lite_cc_library(fuse_conv_bn SRCS conv_bn_fuser.cc DEPS pattern_matcher_high_api) @@ -25,17 +28,22 @@ lite_cc_library(fuse_transpose_softmax_transpose lite_cc_library(fuse_interpolate SRCS interpolate_fuser.cc DEPS pattern_matcher_high_api) +lite_cc_library(fuse_sequence_pool_concat + SRCS sequence_pool_concat_fuser.cc + DEPS pattern_matcher_high_api) set(mir_fusers fuse_fc fuse_shuffle_channel fuse_conv_elementwise fuse_conv_activation + fuse_var_conv_activation fuse_conv_bn fuse_quant_dequant fuse_elementwise_add_activation fuse_transpose_softmax_transpose fuse_interpolate + fuse_sequence_pool_concat CACHE INTERNAL "fusers") if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) diff --git a/lite/core/mir/fusion/conv_activation_fuse_pass.cc b/lite/core/mir/fusion/conv_activation_fuse_pass.cc index ff064fb2ee93fc540e932da36fb07bb78eef989a..b688bbc1083a6ab0f521381c4a988a12badc3141 100644 --- a/lite/core/mir/fusion/conv_activation_fuse_pass.cc +++ b/lite/core/mir/fusion/conv_activation_fuse_pass.cc @@ -29,8 +29,13 @@ void ConvActivationFusePass::Apply(const std::unique_ptr& graph) { act_types.push_back("leaky_relu"); break; } + if (place.target == TARGET(kARM) && place.precision == PRECISION(kFloat)) { + act_types.push_back("relu6"); + act_types.push_back("leaky_relu"); + break; + } } - for (auto conv_type : {"conv2d", "depthwise_conv2d"}) { + for (auto conv_type : {"conv2d", "depthwise_conv2d", "conv2d_transpose"}) { for (auto act_type : act_types) { for (auto has_bias : {true, false}) { fusion::ConvActivationFuser fuser(conv_type, act_type, has_bias); @@ -47,4 +52,5 @@ void ConvActivationFusePass::Apply(const std::unique_ptr& graph) { REGISTER_MIR_PASS(lite_conv_activation_fuse_pass, paddle::lite::mir::ConvActivationFusePass) .BindTargets({TARGET(kAny)}) + .ExcludeTargets({TARGET(kXPU)}) .BindKernel("conv2d"); diff --git a/lite/core/mir/fusion/conv_activation_fuser.cc b/lite/core/mir/fusion/conv_activation_fuser.cc index 6ba11a6a4e82416eb386ec3b34c71183cef5adcb..993fe4e9441824d0c5539e6555e5e12d87e5b98f 100644 --- a/lite/core/mir/fusion/conv_activation_fuser.cc +++ b/lite/core/mir/fusion/conv_activation_fuser.cc @@ -79,6 +79,9 @@ cpp::OpDesc ConvActivationFuser::GenOpDesc(const key2nodes_t& matched) { op_desc.SetAttr("act_type", act_type_); if (act_type_ == "relu") { op_desc.SetAttr("fuse_relu", true); + } else if (act_type_ == "relu6") { + float alpha = act_op_desc.GetAttr("threshold"); + op_desc.SetAttr("fuse_brelu_threshold", alpha); } else if (act_type_ == "leaky_relu") { float alpha = act_op_desc.GetAttr("alpha"); op_desc.SetAttr("leaky_relu_alpha", alpha); diff --git a/lite/core/mir/fusion/conv_bn_fuse_pass.cc b/lite/core/mir/fusion/conv_bn_fuse_pass.cc index d9d9c1bbf55bd33c31aa9a22de934d4eae8657c6..f5a7837b53650e08f9632b499a4c2ab1faeaeedf 100644 --- a/lite/core/mir/fusion/conv_bn_fuse_pass.cc +++ b/lite/core/mir/fusion/conv_bn_fuse_pass.cc @@ -27,7 +27,6 @@ void ConvBNFusePass::Apply(const std::unique_ptr& graph) { // initialze fuser params std::vector conv_has_bias_cases{true, false}; std::vector conv_type_cases{"conv2d", "depthwise_conv2d"}; - // start fuse using params for (auto conv_has_bias : conv_has_bias_cases) { for (auto conv_type : conv_type_cases) { @@ -45,4 +44,4 @@ void ConvBNFusePass::Apply(const std::unique_ptr& graph) { REGISTER_MIR_PASS(lite_conv_bn_fuse_pass, paddle::lite::mir::ConvBNFusePass) .BindTargets({TARGET(kAny)}) - .ExcludeTargets({TARGET(kX86)}); + .ExcludeTargets({TARGET(kX86), TARGET(kXPU), TARGET(kBM)}); diff --git a/lite/core/mir/fusion/conv_bn_fuser.cc b/lite/core/mir/fusion/conv_bn_fuser.cc index ec07278eed1f259c45e225497f94d682b544c57c..0f5bb64e10dd61c3edf4ddd32569a2d365651cdf 100644 --- a/lite/core/mir/fusion/conv_bn_fuser.cc +++ b/lite/core/mir/fusion/conv_bn_fuser.cc @@ -100,14 +100,17 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { auto eps = matched.at("bn")->stmt()->op_info()->GetAttr("epsilon"); // conv - auto conv_weight_t = scope->FindVar(matched.at("conv_weight")->arg()->name) - ->GetMutable(); + std::string conv_weight_name = matched.at("conv_weight")->arg()->name; + auto conv_weight_t = + scope->FindVar(conv_weight_name)->GetMutable(); CHECK_EQ(static_cast(bn_scale_t->data_size()), static_cast(conv_weight_t->dims()[0])) << "The BN bias's size should be equal to the size of the first " << "dim size of the conv weights"; size_t weight_num = conv_weight_t->data_size(); bool enable_int8 = conv_op_desc->HasAttr("enable_int8") ? true : false; + bool is_weight_quantization = + conv_op_desc->HasAttr("quantize_weight_bits") ? true : false; // comupte BN alpha and beta Tensor alpha_tensor, beta_tensor; @@ -160,6 +163,16 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { } } conv_op_desc->SetAttr("weight_scale", weight_scale); + } else if (is_weight_quantization) { + std::string scale_name = conv_weight_name + "_quant_scale"; + if (conv_op_desc->HasAttr(scale_name)) { + auto scale = conv_op_desc->GetAttr>(scale_name); + CHECK_EQ(scale.size(), alpha_tensor.numel()); + for (size_t i = 0; i < scale.size(); i++) { + scale[i] *= alpha_data[i]; + } + conv_op_desc->SetAttr(scale_name, scale); + } } else { // compute new conv_weight auto conv_weight_d = conv_weight_t->mutable_data(); diff --git a/lite/core/mir/fusion/conv_elementwise_fuse_pass.cc b/lite/core/mir/fusion/conv_elementwise_fuse_pass.cc index fd9aadc5d01c2cb3b6c7a3e888503072a0798725..2021bdd3482663b823dd6c1dabdb11be5b5617e2 100644 --- a/lite/core/mir/fusion/conv_elementwise_fuse_pass.cc +++ b/lite/core/mir/fusion/conv_elementwise_fuse_pass.cc @@ -46,4 +46,5 @@ void ConvElementwiseFusePass::Apply(const std::unique_ptr& graph) { REGISTER_MIR_PASS(lite_conv_elementwise_fuse_pass, paddle::lite::mir::ConvElementwiseFusePass) - .BindTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}) + .ExcludeTargets({TARGET(kXPU), TARGET(kBM)}); diff --git a/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.cc b/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.cc index af66f5ab66bd09907cb9d28f00f17d983e54c252..1c2297710b7cf41dc1adb7cde30d9fcfb61c79f0 100644 --- a/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.cc +++ b/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.cc @@ -35,4 +35,7 @@ void ElementwiseAddActivationFusePass::Apply( REGISTER_MIR_PASS(lite_elementwise_add_activation_fuse_pass, paddle::lite::mir::ElementwiseAddActivationFusePass) .BindTargets({TARGET(kAny)}) + .ExcludeTargets({TARGET(kXPU)}) + .ExcludeTargets({TARGET(kBM)}) + .ExcludeTargets({TARGET(kX86)}) .BindKernel("fusion_elementwise_add_activation"); diff --git a/lite/core/mir/fusion/fc_fuse_pass.cc b/lite/core/mir/fusion/fc_fuse_pass.cc index ed10f06f5651f4000485279d682689101d80aa5a..46695be396596c2ce9b74bb771326171fc7b374b 100644 --- a/lite/core/mir/fusion/fc_fuse_pass.cc +++ b/lite/core/mir/fusion/fc_fuse_pass.cc @@ -23,8 +23,13 @@ namespace lite { namespace mir { void FcFusePass::Apply(const std::unique_ptr& graph) { - fusion::FcFuser fuser; +#ifdef LITE_WITH_X86 + fusion::FcFuser fuser(true); fuser(graph.get()); +#endif + + fusion::FcFuser fuser2(false); + fuser2(graph.get()); } } // namespace mir @@ -33,4 +38,7 @@ void FcFusePass::Apply(const std::unique_ptr& graph) { REGISTER_MIR_PASS(lite_fc_fuse_pass, paddle::lite::mir::FcFusePass) .BindTargets({TARGET(kAny)}) + .ExcludeTargets({TARGET(kXPU)}) + .ExcludeTargets({TARGET(kBM)}) + .ExcludeTargets({TARGET(kCUDA)}) .BindKernel("fc"); diff --git a/lite/core/mir/fusion/fc_fuse_pass_test.cc b/lite/core/mir/fusion/fc_fuse_pass_test.cc index f7aa4bb5adcb848531ecc3a8f63bace1c2e3e0ff..54260732c5efe788f0d3740197253fa2321a7d02 100644 --- a/lite/core/mir/fusion/fc_fuse_pass_test.cc +++ b/lite/core/mir/fusion/fc_fuse_pass_test.cc @@ -88,6 +88,7 @@ USE_LITE_OP(mul); USE_LITE_OP(elementwise_add); USE_LITE_OP(elementwise_sub); USE_LITE_OP(fc); +USE_LITE_OP(relu); USE_LITE_OP(feed); USE_LITE_OP(fetch); USE_LITE_OP(io_copy); diff --git a/lite/core/mir/fusion/fc_fuser.cc b/lite/core/mir/fusion/fc_fuser.cc index 460c0fdf7a4309638b9852a315ca0efda02801ab..3c99131083d37ea2c8511ed136bff17c891529af 100644 --- a/lite/core/mir/fusion/fc_fuser.cc +++ b/lite/core/mir/fusion/fc_fuser.cc @@ -35,12 +35,23 @@ void FcFuser::BuildPattern() { std::vector mul_inputs{W, x}; std::vector add_inputs{mul_out, b}; mul_inputs >> *mul >> *mul_out; - add_inputs >> *add >> *Out; // Some op specialities. mul_out->AsIntermediate(); mul->AsIntermediate(); add->AsIntermediate(); + + if (with_relu_) { + auto* add_out = VarNode("add_out"); + auto* relu = OpNode("relu", "relu"); + std::vector relu_inputs{add_out}; + add_inputs >> *add >> *add_out; + relu_inputs >> *relu >> *Out; + add_out->AsIntermediate(); + relu->AsIntermediate(); + } else { + add_inputs >> *add >> *Out; + } } void FcFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { @@ -71,6 +82,9 @@ cpp::OpDesc FcFuser::GenOpDesc(const key2nodes_t& matched) { op_desc.SetAttr( "in_num_col_dims", matched.at("mul")->stmt()->op_info()->GetAttr("x_num_col_dims")); + if (with_relu_) { + op_desc.SetAttr("activation_type", std::string{"relu"}); + } return op_desc; } diff --git a/lite/core/mir/fusion/fc_fuser.h b/lite/core/mir/fusion/fc_fuser.h index 7ba07527898c7e648c5f7f9151642ab0928fa496..6cb08f41574b67df1c78fa296d2d395771a66ee1 100644 --- a/lite/core/mir/fusion/fc_fuser.h +++ b/lite/core/mir/fusion/fc_fuser.h @@ -25,11 +25,13 @@ namespace fusion { class FcFuser : public FuseBase { public: + explicit FcFuser(bool with_relu) : with_relu_(with_relu) {} void BuildPattern() override; void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; private: cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; + bool with_relu_; }; } // namespace fusion diff --git a/lite/core/mir/fusion/quant_dequant_op_fuser.cc b/lite/core/mir/fusion/quant_dequant_op_fuser.cc index f823f45dc66f8ef6cc67cbb9b0d9860c86ec9340..da611e4490f4ba7268d9011b3dbb391a63a88305 100644 --- a/lite/core/mir/fusion/quant_dequant_op_fuser.cc +++ b/lite/core/mir/fusion/quant_dequant_op_fuser.cc @@ -396,6 +396,8 @@ void DeleteQuantDequantOpFuser::InsertNewNode(SSAGraph* graph, op_desc->SetAttr("input_scale", scale_value); op_desc->SetInput("X", {input_act_node->arg()->name}); IR_NODE_LINK_TO(input_act_node, quantized_node) + auto update_op_desc = *quantized_node->stmt()->mutable_op_info(); + quantized_node->stmt()->ResetOp(update_op_desc, graph->valid_places()); // delete nodes and edges std::unordered_set nodes2rm = {input_scale_node, @@ -440,6 +442,8 @@ void DeleteQuantDequantOpFuser::InsertNewNode(SSAGraph* graph, op_desc->SetInput("Y", {input_act_right_node->arg()->name}); IR_NODE_LINK_TO(input_act_left_node, quantized_node) IR_NODE_LINK_TO(input_act_right_node, quantized_node) + auto update_op_desc = *quantized_node->stmt()->mutable_op_info(); + quantized_node->stmt()->ResetOp(update_op_desc, graph->valid_places()); // delete nodes and edges std::unordered_set nodes2rm = {input_scale_left_node, diff --git a/lite/core/mir/fusion/sequence_pool_concat_fuse_pass.cc b/lite/core/mir/fusion/sequence_pool_concat_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..3c3b44ca12a1a9ad76720a9363533b9a20dd0999 --- /dev/null +++ b/lite/core/mir/fusion/sequence_pool_concat_fuse_pass.cc @@ -0,0 +1,36 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/fusion/sequence_pool_concat_fuse_pass.h" +#include +#include +#include "lite/core/mir/fusion/sequence_pool_concat_fuser.h" +#include "lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +void SequencePoolConcatFusePass::Apply(const std::unique_ptr& graph) { + fusion::SequencePoolConcatFuser fuser; + fuser(graph.get()); +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(lite_sequence_pool_concat_fuse_pass, + paddle::lite::mir::SequencePoolConcatFusePass) + .BindTargets({TARGET(kCUDA)}); diff --git a/lite/core/mir/fusion/sequence_pool_concat_fuse_pass.h b/lite/core/mir/fusion/sequence_pool_concat_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..38f931f502430211c1e51de5e9f81af9e43462c8 --- /dev/null +++ b/lite/core/mir/fusion/sequence_pool_concat_fuse_pass.h @@ -0,0 +1,32 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/mir/pass.h" + +namespace paddle { +namespace lite { +namespace mir { + +class SequencePoolConcatFusePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/fusion/sequence_pool_concat_fuser.cc b/lite/core/mir/fusion/sequence_pool_concat_fuser.cc new file mode 100644 index 0000000000000000000000000000000000000000..d1c22aee86505a0d8e3f32b263cbbd9521504e6a --- /dev/null +++ b/lite/core/mir/fusion/sequence_pool_concat_fuser.cc @@ -0,0 +1,153 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/fusion/sequence_pool_concat_fuser.h" +#include +#include + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +// """ +// merge {sequence_pool x 7, concat} => merge_sequence_pool_and_concat +// src1 src2 src7 src1 src2 src7 +// | | | | | +// v v | | ... | +// sequence_pool sequence_pool ...(sequence_pool) | | | +// | | | => ------------------- +// --------------------------------- | +// | | +// v v +// concat sequence_pool_concat +// """ +void SequencePoolConcatFuser::BuildPattern() { + // create nodes. + auto* concat = OpNode("concat", "concat")->AsIntermediate(); + +#define STR1(R) #R +#define STR2(R) STR1(R) + +#define POOL_CONCAT_PATTERN(num) \ + auto* x_##num = VarNode(STR2(sequence_pool_x_##num)) \ + ->assert_is_op_input("sequence_pool", "X") \ + ->AsInput(); \ + auto* sequence_pool_##num = \ + OpNode(STR2(sequence_pool_##num), "sequence_pool")->AsIntermediate(); \ + auto* sequence_pool_##num##_out = \ + VarNode(STR2(sequence_pool_##num##_out)) \ + ->assert_is_op_output("sequence_pool", "Out") \ + ->assert_is_op_nth_input("concat", "X", num - 1) \ + ->AsIntermediate(); \ + auto* sequence_pool_##num##_idx = \ + VarNode(STR2(sequence_pool_##num##_idx)) \ + ->assert_is_op_output("sequence_pool", "MaxIndex") \ + ->AsIntermediate(); \ + *sequence_pool_##num >> *sequence_pool_##num##_idx; \ + *x_##num >> *sequence_pool_##num >> *sequence_pool_##num##_out >> *concat; + + auto* concat_out = + VarNode("concat_out")->assert_is_op_output("concat", "Out"); + *concat >> *concat_out; + + POOL_CONCAT_PATTERN(1); + POOL_CONCAT_PATTERN(2); + POOL_CONCAT_PATTERN(3); + POOL_CONCAT_PATTERN(4); + POOL_CONCAT_PATTERN(5); + POOL_CONCAT_PATTERN(6); + POOL_CONCAT_PATTERN(7); + +#undef POOL_CONCAT_PATTERN +#undef STR1 +#undef STR2 +} + +void SequencePoolConcatFuser::InsertNewNode(SSAGraph* graph, + const key2nodes_t& matched) { + auto op_desc = GenOpDesc(matched); + auto sequence_pool_concat_op = + LiteOpRegistry::Global().Create("sequence_pool_concat"); + + auto concat = matched.at("concat")->stmt()->op(); + auto* scope = concat->scope(); + auto& valid_places = concat->valid_places(); + sequence_pool_concat_op->Attach(op_desc, scope); + + auto* new_op_node = + graph->GraphCreateInstructNode(sequence_pool_concat_op, valid_places); + + IR_NODE_LINK_TO(matched.at("sequence_pool_x_1"), new_op_node); + IR_NODE_LINK_TO(matched.at("sequence_pool_x_2"), new_op_node); + IR_NODE_LINK_TO(matched.at("sequence_pool_x_3"), new_op_node); + IR_NODE_LINK_TO(matched.at("sequence_pool_x_4"), new_op_node); + IR_NODE_LINK_TO(matched.at("sequence_pool_x_5"), new_op_node); + IR_NODE_LINK_TO(matched.at("sequence_pool_x_6"), new_op_node); + IR_NODE_LINK_TO(matched.at("sequence_pool_x_7"), new_op_node); + IR_NODE_LINK_TO(new_op_node, matched.at("concat_out")); +} + +cpp::OpDesc SequencePoolConcatFuser::GenOpDesc(const key2nodes_t& matched) { + cpp::OpDesc op_desc = *matched.at("concat")->stmt()->op_info(); + op_desc.SetType("sequence_pool_concat"); + op_desc.SetInput("X", + {matched.at("sequence_pool_x_1")->arg()->name, + matched.at("sequence_pool_x_2")->arg()->name, + matched.at("sequence_pool_x_3")->arg()->name, + matched.at("sequence_pool_x_4")->arg()->name, + matched.at("sequence_pool_x_5")->arg()->name, + matched.at("sequence_pool_x_6")->arg()->name, + matched.at("sequence_pool_x_7")->arg()->name}); + + std::vector pooltypes; + pooltypes.push_back(matched.at("sequence_pool_1") + ->stmt() + ->op_info() + ->GetAttr("pooltype")); + pooltypes.push_back(matched.at("sequence_pool_2") + ->stmt() + ->op_info() + ->GetAttr("pooltype")); + pooltypes.push_back(matched.at("sequence_pool_3") + ->stmt() + ->op_info() + ->GetAttr("pooltype")); + pooltypes.push_back(matched.at("sequence_pool_4") + ->stmt() + ->op_info() + ->GetAttr("pooltype")); + pooltypes.push_back(matched.at("sequence_pool_5") + ->stmt() + ->op_info() + ->GetAttr("pooltype")); + pooltypes.push_back(matched.at("sequence_pool_6") + ->stmt() + ->op_info() + ->GetAttr("pooltype")); + pooltypes.push_back(matched.at("sequence_pool_7") + ->stmt() + ->op_info() + ->GetAttr("pooltype")); + op_desc.SetAttr("pooltype", pooltypes); + + op_desc.SetOutput("Out", {matched.at("concat_out")->arg()->name}); + + return op_desc; +} + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/fusion/sequence_pool_concat_fuser.h b/lite/core/mir/fusion/sequence_pool_concat_fuser.h new file mode 100644 index 0000000000000000000000000000000000000000..b8f731becd4a19554ddc347db7cca4bb6fd66ee9 --- /dev/null +++ b/lite/core/mir/fusion/sequence_pool_concat_fuser.h @@ -0,0 +1,38 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/mir/pattern_matcher_high_api.h" + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +class SequencePoolConcatFuser : public FuseBase { + public: + void BuildPattern() override; + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; + + private: + cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; +}; + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/fusion/var_conv_2d_activation_fuse_pass.cc b/lite/core/mir/fusion/var_conv_2d_activation_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..0ce2248cbc23d8887a22f94c14b2507fb0cacbed --- /dev/null +++ b/lite/core/mir/fusion/var_conv_2d_activation_fuse_pass.cc @@ -0,0 +1,40 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/fusion/var_conv_2d_activation_fuse_pass.h" +#include +#include +#include "lite/core/mir/fusion/var_conv_2d_activation_fuser.h" +#include "lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +void VarConv2dActivationFusePass::Apply( + const std::unique_ptr& graph) { + std::vector act_types{"relu"}; + for (auto act_type : act_types) { + fusion::VarConvActivationFuser fuser(act_type, "var_conv_2d"); + fuser(graph.get()); + } +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(lite_var_conv_2d_activation_fuse_pass, + paddle::lite::mir::VarConv2dActivationFusePass) + .BindTargets({TARGET(kCUDA)}); diff --git a/lite/core/mir/fusion/var_conv_2d_activation_fuse_pass.h b/lite/core/mir/fusion/var_conv_2d_activation_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..7616aadef340d3e4d6bc11534dd839c91fe9ed1d --- /dev/null +++ b/lite/core/mir/fusion/var_conv_2d_activation_fuse_pass.h @@ -0,0 +1,32 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/mir/pass.h" + +namespace paddle { +namespace lite { +namespace mir { + +class VarConv2dActivationFusePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/fusion/var_conv_2d_activation_fuser.cc b/lite/core/mir/fusion/var_conv_2d_activation_fuser.cc new file mode 100644 index 0000000000000000000000000000000000000000..eabd97ae4513b84c9c002aa1587d45cce6b22e21 --- /dev/null +++ b/lite/core/mir/fusion/var_conv_2d_activation_fuser.cc @@ -0,0 +1,80 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/fusion/var_conv_2d_activation_fuser.h" +#include +#include + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +void VarConvActivationFuser::BuildPattern() { + // create nodes. + auto* input = VarNode("X")->assert_is_op_input(conv_type_, "X")->AsInput(); + auto* filter = VarNode("W")->assert_is_op_input(conv_type_, "W")->AsInput(); + + auto* conv2d = OpNode("var_conv_2d", conv_type_)->AsIntermediate(); + + auto* act = OpNode("act", act_type_)->AsIntermediate(); + + auto* conv2d_out = VarNode("conv2d_out") + ->assert_is_op_output(conv_type_, "Out") + ->assert_is_op_input(act_type_, "X") + ->AsIntermediate(); + auto* conv2d_out_1 = VarNode("conv2d_out_1") + ->assert_is_op_output(conv_type_, "Col") + ->AsIntermediate(); + + auto* out = + VarNode("output")->assert_is_op_output(act_type_, "Out")->AsOutput(); + + // create topology. + std::vector conv2d_inputs{filter, input}; + conv2d_inputs >> *conv2d >> *conv2d_out >> *act >> *out; + *conv2d >> *conv2d_out_1; +} + +void VarConvActivationFuser::InsertNewNode(SSAGraph* graph, + const key2nodes_t& matched) { + auto op_desc = GenOpDesc(matched); + auto conv_op = LiteOpRegistry::Global().Create(conv_type_); + auto conv_old = matched.at("var_conv_2d")->stmt()->op(); + auto* scope = conv_old->scope(); + auto& valid_places = conv_old->valid_places(); + conv_op->Attach(op_desc, scope); + + auto* new_op_node = graph->GraphCreateInstructNode(conv_op, valid_places); + + IR_NODE_LINK_TO(matched.at("X"), new_op_node); + IR_NODE_LINK_TO(matched.at("W"), new_op_node); + IR_NODE_LINK_TO(new_op_node, matched.at("output")); +} + +cpp::OpDesc VarConvActivationFuser::GenOpDesc(const key2nodes_t& matched) { + cpp::OpDesc op_desc = *matched.at("var_conv_2d")->stmt()->op_info(); + op_desc.SetOutput("Out", {matched.at("output")->arg()->name}); + cpp::OpDesc act_op_desc = *matched.at("act")->stmt()->op_info(); + + if (act_type_ == "relu") { + op_desc.SetAttr("fuse_relu", true); + } + return op_desc; +} + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/fusion/var_conv_2d_activation_fuser.h b/lite/core/mir/fusion/var_conv_2d_activation_fuser.h new file mode 100644 index 0000000000000000000000000000000000000000..68bc89f7d13d38dc07814f3296a25bfd7dea0248 --- /dev/null +++ b/lite/core/mir/fusion/var_conv_2d_activation_fuser.h @@ -0,0 +1,44 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/mir/pattern_matcher_high_api.h" + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +class VarConvActivationFuser : public FuseBase { + public: + explicit VarConvActivationFuser(const std::string& act_type, + const std::string& conv_type) + : act_type_(act_type), conv_type_(conv_type) {} + + void BuildPattern() override; + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; + + private: + cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; + std::string act_type_; + std::string conv_type_; +}; + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/graph_visualize_pass.cc b/lite/core/mir/graph_visualize_pass.cc index 76ea9555c29a245aa9f20b158f0706557940bef8..3a27360f94d7d828e1c19214d621f1dfe4e048ca 100644 --- a/lite/core/mir/graph_visualize_pass.cc +++ b/lite/core/mir/graph_visualize_pass.cc @@ -36,15 +36,6 @@ std::string Visualize(mir::SSAGraph* graph) { int id = 0; std::set exists_args; - std::map graph_col; // Different colors of subgraphs - graph_col.insert({{1, "red"}, - {2, "green"}, - {3, "cyan"}, - {4, "bisque3"}, - {5, "coral"}, - {6, "darkseagreen1"}, - {7, "goldenrod1"}, - {8, "darkorchid"}}); for (auto& node : graph->mutable_nodes()) { std::string key; if (node.IsArg()) { @@ -52,24 +43,12 @@ std::string Visualize(mir::SSAGraph* graph) { } else { key = string_format("%s%d", node.AsStmt().op_type().c_str(), id++); } - if (node.IsStmt()) { - auto& stmt = node.AsStmt(); - auto sub_id = stmt.subgraph_id(); - auto it = graph_col.find(sub_id); - if (sub_id > 0 && it != graph_col.end()) { - dot.AddNode(key, - {Dot::Attr("shape", "box"), - Dot::Attr("style", "filled"), - Dot::Attr("color", "black"), - Dot::Attr("fillcolor", it->second)}); - } else { - dot.AddNode(key, - {Dot::Attr("shape", "box"), - Dot::Attr("style", "filled"), - Dot::Attr("color", "black"), - Dot::Attr("fillcolor", "yellow")}); - } + dot.AddNode(key, + {Dot::Attr("shape", "box"), + Dot::Attr("style", "filled"), + Dot::Attr("color", "black"), + Dot::Attr("fillcolor", "yellow")}); for (auto& x : node.inlinks) { auto name = x->AsArg().name; if (!exists_args.count(name)) { diff --git a/lite/core/mir/memory_optimize_pass.cc b/lite/core/mir/memory_optimize_pass.cc index 1f2355e8a3205cce3410bd2cb6ac4a17d8fde602..6256a49a99b9097664c192d40502daf506437a31 100644 --- a/lite/core/mir/memory_optimize_pass.cc +++ b/lite/core/mir/memory_optimize_pass.cc @@ -50,7 +50,7 @@ void MemoryOptimizePass::CollectLifeCycleByDevice( "lod_reset", "concat", "yolo_box", - "graph_op", + "subgraph", "feed", "fetch"}; for (auto* tmp : node->inlinks) { @@ -255,4 +255,5 @@ void MemoryOptimizePass::Apply(const std::unique_ptr& graph) { } // namespace paddle REGISTER_MIR_PASS(memory_optimize_pass, paddle::lite::mir::MemoryOptimizePass) - .BindTargets({TARGET(kARM)}); + .BindTargets({TARGET(kARM)}) + .ExcludeTargets({TARGET(kOpenCL), TARGET(kNPU), TARGET(kXPU), TARGET(kBM)}); diff --git a/lite/core/mir/node.cc b/lite/core/mir/node.cc index 4a90e530a46c4d42d2ba032da1828973dfc1bcef..52fd39182a7132777231929d49c319bb961cf7f9 100644 --- a/lite/core/mir/node.cc +++ b/lite/core/mir/node.cc @@ -53,6 +53,11 @@ void mir::Node::Stmt::ResetOp(const cpp::OpDesc &op_desc, } valid_kernels_ = op_->CreateKernels(valid_places); } +void mir::Node::Stmt::ResetKernels(const std::vector &valid_places) { + CHECK(op_) << "change valid place failed, not created op"; + valid_kernels_.clear(); + valid_kernels_ = op_->CreateKernels(valid_places); +} mir::Node::Arg &mir::Node::AsArg(const std::string &name, int id) { auto &x = AsArg(); diff --git a/lite/core/mir/node.h b/lite/core/mir/node.h index 60fa1fb1ebe49e1be38a7d84cb82545389ea4aac..e7c44d2be689a9d890158c097e198314413d1ba3 100644 --- a/lite/core/mir/node.h +++ b/lite/core/mir/node.h @@ -53,6 +53,7 @@ class Node { const std::vector& valid_places, lite::Scope* scope = nullptr); + void ResetKernels(const std::vector& valid_places); std::string op_type() const { return op_info()->Type(); } const OpInfo* op_info() const; OpInfo* mutable_op_info(); @@ -64,9 +65,6 @@ class Node { return valid_kernels_; } - void ClearSubgraphID() { subgraph_id_ = -1 /* note: not 0 */; } - void SetSubgraphID(int id) { subgraph_id_ = id; } - int subgraph_id() const { return subgraph_id_; } void SetOp(const std::shared_ptr& op) { op_ = op; } const std::shared_ptr op() const { return op_; } @@ -82,11 +80,6 @@ class Node { // Description. std::string desc; - - protected: - // -1 means not in subgraph, 0 means supported but not one id, id started - // from 1 - int subgraph_id_{-1}; }; struct Arg { diff --git a/lite/core/mir/pass.h b/lite/core/mir/pass.h index 4de0fdbf357160348a403d3c8527fe62891237f0..4e8c8be292bbd5e7f46664378634d4f1aeed2965 100644 --- a/lite/core/mir/pass.h +++ b/lite/core/mir/pass.h @@ -52,34 +52,44 @@ class Pass { // Bind targets. At runtime, there must be one device in the bound targets. void BindTargets(const std::set& targets) { - std::set res; for (const auto& target : targets) { const std::set& universe = ExpandValidTargets(target); std::set_union(bound_targets_.begin(), bound_targets_.end(), universe.begin(), universe.end(), - std::inserter(res, res.begin())); + std::inserter(bound_targets_, bound_targets_.begin())); } - bound_targets_ = res; } // Exclude targets. At runtime, there must be one device in the bound targets. + // Disable the pass if one of the valid devices is in the excluded targets. void ExcludeTargets(const std::set& targets) { - std::set res; for (const auto& target : targets) { const std::set& universe = ExpandValidTargets(target); - std::set_difference(bound_targets_.begin(), - bound_targets_.end(), - universe.begin(), - universe.end(), - std::inserter(res, res.begin())); + std::set updated_bound_targets; + std::set_difference( + bound_targets_.begin(), + bound_targets_.end(), + universe.begin(), + universe.end(), + std::inserter(updated_bound_targets, updated_bound_targets.begin())); + bound_targets_ = updated_bound_targets; + std::set_union( + excluded_targets_.begin(), + excluded_targets_.end(), + universe.begin(), + universe.end(), + std::inserter(excluded_targets_, excluded_targets_.begin())); } - bound_targets_ = res; } // Get all bound targets. - const std::set& Targets() const { return bound_targets_; } + const std::set& BoundTargets() const { return bound_targets_; } + // Get all excluded targets. + const std::set& ExcludedTargets() const { + return excluded_targets_; + } // Some passes are only available on qualified kernels and need to be // explicitly declared. @@ -116,6 +126,7 @@ class Pass { std::string name_; std::string doc_; std::set bound_targets_; + std::set excluded_targets_; std::unordered_map> bound_kernels_; }; diff --git a/lite/core/mir/pass_utils.cc b/lite/core/mir/pass_utils.cc index 4f6be2c186d2d940a799201812cce397a9e94eb4..5bddfcbd3c17288546dc6e0a0b4ebf984d26c504 100644 --- a/lite/core/mir/pass_utils.cc +++ b/lite/core/mir/pass_utils.cc @@ -47,10 +47,34 @@ bool KernelRegistered(const std::string name, const Place& place) { return false; } -bool PassMatchesTarget(const mir::Pass& pass, TargetType target) { - const auto& targets = pass.Targets(); - if (targets.find(TARGET(kAny)) != targets.end()) return true; - return (targets.find(target) != targets.end()); +bool PassMatchesTarget(const mir::Pass& pass, + const std::set& targets) { + // Whether the pass is suitable for targets ? The condition is the + // intersection of targets and pass's bound targets is not empty, besides the + // intersection of targets and pass's excluded targets is empty. The formula + // is as follows: matched = !empty(targets ^ pass.bound_targets) && + // empty(targets ^ pass.excluded_targets), where ^ is intersection operation. + const auto& bound_targets = pass.BoundTargets(); + bool matched = bound_targets.find(TARGET(kAny)) != bound_targets.end(); + std::set inter_bound_targets; + std::set_intersection( + bound_targets.begin(), + bound_targets.end(), + targets.begin(), + targets.end(), + std::inserter(inter_bound_targets, inter_bound_targets.begin())); + matched |= !inter_bound_targets.empty(); + const auto& excluded_targets = pass.ExcludedTargets(); + matched &= excluded_targets.find(TARGET(kAny)) == excluded_targets.end(); + std::set inter_excluded_targets; + std::set_intersection( + excluded_targets.begin(), + excluded_targets.end(), + targets.begin(), + targets.end(), + std::inserter(inter_excluded_targets, inter_excluded_targets.begin())); + matched &= inter_excluded_targets.empty(); + return matched; } bool PassMatchesKernels(const mir::Pass& pass) { diff --git a/lite/core/mir/pass_utils.h b/lite/core/mir/pass_utils.h index 942f64bf3190be1f399ac6f014be0881b1450d9b..57e8da5e461f40bd79ece8139c3290e17e762996 100644 --- a/lite/core/mir/pass_utils.h +++ b/lite/core/mir/pass_utils.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include "lite/core/mir/pass.h" @@ -24,7 +25,8 @@ namespace lite { bool KernelRegistered(const std::string name, const Place& place); // Check if the pass hits the hardware target. -bool PassMatchesTarget(const mir::Pass& pass, TargetType target); +bool PassMatchesTarget(const mir::Pass& pass, + const std::set& targets); // Check if the pass hits all necessary operators. bool PassMatchesKernels(const mir::Pass& pass); diff --git a/lite/core/mir/pattern_matcher.cc b/lite/core/mir/pattern_matcher.cc index 8e0fc55be2389244ae065b4c2809bbdd74be370c..b625919cbfb6d26ecbbd1bad36772aff86bee087 100644 --- a/lite/core/mir/pattern_matcher.cc +++ b/lite/core/mir/pattern_matcher.cc @@ -377,6 +377,19 @@ PMNode *PMNode::assert_is_op(const std::string &op_type) { return this; } +PMNode *PMNode::assert_is_not_op_type(const std::string &op_type) { + asserts_.emplace_back([op_type](const Node *x) { + if (x && x->IsStmt()) { + auto *op_info = x->stmt()->op_info(); + if (op_info->Type() == op_type) { + return false; + } + } + return true; + }); + return this; +} + PMNode *PMNode::assert_is_var() { asserts_.emplace_back([](const Node *x) { return x && x->IsArg(); }); return this; diff --git a/lite/core/mir/pattern_matcher.h b/lite/core/mir/pattern_matcher.h index 47a0a30b5667ddc97b3783ab9edbab04281528a4..90c4359c6d3ade98cf60b5c23411e2026cdeccc9 100644 --- a/lite/core/mir/pattern_matcher.h +++ b/lite/core/mir/pattern_matcher.h @@ -123,6 +123,7 @@ struct PMNode { // Assertions, helper functions to simplify the pattern definition. PMNode* assert_is_op(); PMNode* assert_is_op(const std::string& op_type); + PMNode* assert_is_not_op_type(const std::string& op_type); PMNode* assert_is_var(); PMNode* assert_var_not_persistable(); PMNode* assert_is_persistable_var(); diff --git a/lite/core/mir/ssa_graph.cc b/lite/core/mir/ssa_graph.cc index 8f22022789046900c3c09cfb122c914968d8d87f..2b5b65ce5903ede41137311c585c0e87eaaa0e9d 100644 --- a/lite/core/mir/ssa_graph.cc +++ b/lite/core/mir/ssa_graph.cc @@ -123,6 +123,9 @@ void SSAGraph::Build(const Program &program, return true; }; + std::unordered_map var_types = + program.var_data_type(); + std::unordered_map arg_update_node_map_; for (auto &op : program.ops()) { VLOG(3) << op->op_info()->Type(); @@ -137,6 +140,10 @@ void SSAGraph::Build(const Program &program, arg_node->AsArg(name, node_storage_.size() - 1); arg_update_node_map_[name] = arg_node; } + if (var_types.count(name) && !arg_node->arg()->type) { + arg_node->arg()->type = LiteType::GetTensorTy( + TARGET(kUnk), var_types[name], DATALAYOUT(kUnk)); + } if (is_weights(name)) arg_node->AsArg().is_weight = true; CHECK(arg_node->IsRoleSet()); DirectedLink(arg_node, op_node); @@ -146,6 +153,10 @@ void SSAGraph::Build(const Program &program, auto *arg_node = &node_storage_.back(); arg_node->AsArg(name, node_storage_.size() - 1); arg_update_node_map_[name] = arg_node; + if (var_types.count(name) && !arg_node->arg()->type) { + arg_node->arg()->type = LiteType::GetTensorTy( + TARGET(kUnk), var_types[name], DATALAYOUT(kUnk)); + } if (is_weights(name)) arg_node->AsArg().is_weight = true; CHECK(arg_node->IsRoleSet()); diff --git a/lite/core/mir/static_kernel_pick_pass.cc b/lite/core/mir/static_kernel_pick_pass.cc index 90aca56aec426f6b7ca0d300ded979ae7b10f6df..1cc8942d611db389a44cbf6a244775a5b666b587 100644 --- a/lite/core/mir/static_kernel_pick_pass.cc +++ b/lite/core/mir/static_kernel_pick_pass.cc @@ -14,7 +14,10 @@ #include "lite/core/mir/static_kernel_pick_pass.h" #include +#include #include +#include +#include #include #include #include "lite/core/mir/graph_visualize_pass.h" @@ -43,13 +46,33 @@ void StaticKernelPickPass::Apply(const std::unique_ptr& graph) { if (!node.IsStmt()) continue; auto& instruct = node.AsStmt(); + std::unordered_map in_types; + std::unordered_map out_types; + for (std::list::iterator i = node.inlinks.begin(); + i != node.inlinks.end(); + ++i) { + if ((*i)->arg()->type) + in_types[(*i)->arg()->name] = (*i)->arg()->type->precision(); + } + for (std::list::iterator i = node.outlinks.begin(); + i != node.outlinks.end(); + ++i) { + if ((*i)->arg()->type) + out_types[(*i)->arg()->name] = (*i)->arg()->type->precision(); + } // Get candidate kernels std::vector>> scored; CHECK(!instruct.kernels().empty()) << "No kernels found for " << instruct.op_type(); VLOG(4) << "instruct.kernels().size():" << instruct.kernels().size(); for (auto&& kernel : instruct.kernels()) { - float score = KernelGrade(*kernel, graph->valid_places()); + float score = KernelGrade(instruct, + *kernel, + graph->valid_places(), + in_types, + out_types, + instruct.op_info()->input_names(), + instruct.op_info()->output_names()); VLOG(4) << "kernel->summary():" << kernel->summary() << " score:" << score; scored.emplace_back(score, std::move(kernel)); @@ -99,7 +122,13 @@ void StaticKernelPickPass::Apply(const std::unique_ptr& graph) { instruct.ResetOp(update_desc, graph->valid_places()); scored.clear(); for (auto&& kernel : instruct.kernels()) { - float score = KernelGrade(*kernel, graph->valid_places()); + float score = KernelGrade(instruct, + *kernel, + graph->valid_places(), + in_types, + out_types, + instruct.op_info()->input_names(), + instruct.op_info()->output_names()); scored.emplace_back(score, std::move(kernel)); } std::sort(scored.begin(), scored.end(), KernelScoreCmp); diff --git a/lite/core/mir/static_kernel_pick_pass.h b/lite/core/mir/static_kernel_pick_pass.h index 90be0ea54e8761e2e68b12a396dde0df1bba3f26..f655b298bf2d800f4adf142ad14b8ac05ca00482 100644 --- a/lite/core/mir/static_kernel_pick_pass.h +++ b/lite/core/mir/static_kernel_pick_pass.h @@ -16,6 +16,8 @@ #include #include +#include +#include #include #include "lite/core/mir/pass.h" #include "lite/core/types.h" @@ -48,8 +50,14 @@ class StaticKernelPickPass : public mir::StmtPass { private: // Score the kernel. - size_t KernelGrade(const lite::KernelBase& kernel, - const std::vector& places) { + size_t KernelGrade( + const lite::mir::Node::Stmt& instruct, + const lite::KernelBase& kernel, + const std::vector& places, + const std::unordered_map& in_types, + const std::unordered_map& out_types, + const std::vector& in_names, + const std::vector& out_names) { CHECK_GT(places.size(), 0) << "valid_places is empty."; float final_score{-1.}; Place winner_place{places[0]}; @@ -66,7 +74,7 @@ class StaticKernelPickPass : public mir::StmtPass { // valid_places.size() as default. // where i is the place's index in valid_places array. // score: score is the weighted sum of target、percision and layout - for (int i = 0; i < place_size; ++i) { + for (size_t i = 0; i < place_size; ++i) { const auto& place = places[i]; float weight = static_cast(place_size - i) / place_size; size_t score{}; @@ -83,8 +91,12 @@ class StaticKernelPickPass : public mir::StmtPass { (place.precision == kernel.precision() || kernel.precision() == PRECISION(kAny) || place.precision == PRECISION(kAny))) { - score += kMax / static_cast( - core::KernelPickFactor::Factor::PrecisionFirst); + // score skipped, if kernel is int8, but op is not int8 + if (!(kernel.precision() == PRECISION(kInt8) && + !instruct.op_info()->HasAttr("enable_int8"))) { + score += kMax / static_cast( + core::KernelPickFactor::Factor::PrecisionFirst); + } } VLOG(4) << "[score s2]:" << score; if (kernel_pick_factors_.IsDataLayoutConsidered() && @@ -95,6 +107,37 @@ class StaticKernelPickPass : public mir::StmtPass { core::KernelPickFactor::Factor::DataLayoutFirst); } VLOG(4) << "[score s3]:" << score; + + // add new rules for precision: When the input types are consistent with + // kernel's input types and the output types are consistent with kernel's + // output types. Select the kernel of the precision. Note that this + // strategy is not compatible with quantization, so skip quantization op. + if (!instruct.op_info()->HasAttr("enable_int8")) { + bool type_match = true; + for (size_t i = 0; i < in_names.size(); ++i) { + std::string tmp; + CHECK(instruct.op_info()->GetInputArgname(in_names[i], &tmp)); + if (in_types.count(in_names[i]) && + in_types.at(in_names[i]) != + kernel.GetInputDeclType(tmp)->precision()) { + type_match = false; + } + } + for (size_t i = 0; i < out_names.size(); ++i) { + std::string tmp; + CHECK(instruct.op_info()->GetOutputArgname(out_names[i], &tmp)); + if (out_types.count(out_names[i]) && + out_types.at(out_names[i]) != + kernel.GetOutputDeclType(tmp)->precision()) { + type_match = false; + } + } + if (type_match) { + score *= 2; + } + VLOG(4) << "[score s4]:" << score; + } + if (weight * score > final_score) { final_score = weight * score; winner_place = place; diff --git a/lite/core/mir/subgraph/CMakeLists.txt b/lite/core/mir/subgraph/CMakeLists.txt index 95b5fe5ae13e03940bda8d83fcfc252b4ca490ab..f8aa09676c2d1e6d4df6fafbaf6a54bc69491acc 100644 --- a/lite/core/mir/subgraph/CMakeLists.txt +++ b/lite/core/mir/subgraph/CMakeLists.txt @@ -1,50 +1,30 @@ - +lite_cc_library(subgraph_detector + SRCS subgraph_detector.cc + DEPS mir_pass types subgraph_op) lite_cc_library(subgraph_pass - SRCS subgraph_program_pass.cc - DEPS mir_pass types ${mir_fusers}) -lite_cc_test(test_subgraph_pass SRCS subgraph_program_pass_test.cc - DEPS subgraph_pass mir_passes gflags model_parser cxx_api - ARGS --model_dir=${LITE_MODEL_DIR}/mobilenet_v1 SERIAL) -if (WITH_TESTING) - add_dependencies(test_subgraph_pass extern_lite_download_mobilenet_v1_tar_gz) - add_dependencies(test_subgraph_pass extern_lite_download_mobilenet_v2_relu_tar_gz) - set(LINK_FLAGS "-Wl,--version-script ${PADDLE_SOURCE_DIR}/lite/core/lite.map") - set_target_properties(test_subgraph_pass PROPERTIES LINK_FLAGS "${LINK_FLAGS}") -endif() - -set(subgraph_passes subgraph_pass) - -if(LITE_WITH_NPU) - lite_cc_library(npu_pass SRCS generate_npu_program_pass.cc - DEPS mir_pass types context ${mir_fusers} ${npu_bridges} graph_op subgraph_pass) - list(APPEND subgraph_passes npu_pass) - lite_cc_test(test_npu_pass SRCS generate_npu_program_pass_test.cc - DEPS npu_pass mir_passes paddle_api_full paddle_api_light gflags - ARGS --model_dir=${LITE_MODEL_DIR}/mobilenet_v1 - --optimized_model=${LITE_MODEL_DIR}/lite_npu_model_opt SERIAL) - if (WITH_TESTING) - add_dependencies(test_npu_pass extern_lite_download_mobilenet_v1_tar_gz) - add_dependencies(test_subgraph_pass extern_lite_download_mobilenet_v2_relu_tar_gz) + SRCS subgraph_pass.cc + DEPS mir_pass types context ${mir_fusers} subgraph_detector) +if (WITH_TESTING AND NOT LITE_WITH_CUDA) + lite_cc_test(test_subgraph_detector + SRCS subgraph_detector_test.cc + DEPS subgraph_detector mir_passes gflags model_parser cxx_api + ARGS --model_dir=${LITE_MODEL_DIR}/mobilenet_v1 SERIAL) + add_dependencies(test_subgraph_detector + extern_lite_download_mobilenet_v1_tar_gz + extern_lite_download_mobilenet_v2_relu_tar_gz) set(LINK_FLAGS "-Wl,--version-script ${PADDLE_SOURCE_DIR}/lite/core/lite.map") - set_target_properties(test_npu_pass PROPERTIES LINK_FLAGS "${LINK_FLAGS}") - endif() -endif() - -if(LITE_WITH_XPU) - lite_cc_library(xpu_pass SRCS generate_xpu_program_pass.cc - DEPS mir_pass types context ${mir_fusers} ${xpu_bridges} ${xpu_builder_libs} graph_op subgraph_pass) - list(APPEND subgraph_passes xpu_pass) - lite_cc_test(test_xpu_pass SRCS generate_xpu_program_pass_test.cc - DEPS xpu_pass mir_passes paddle_api_full gflags - ARGS --model_dir=${LITE_MODEL_DIR}/mobilenet_v1 - --optimized_model=${LITE_MODEL_DIR}/lite_npu_model_opt SERIAL) - if (WITH_TESTING) - add_dependencies(test_xpu_pass extern_lite_download_mobilenet_v1_tar_gz) - add_dependencies(test_subgraph_pass extern_lite_download_mobilenet_v2_relu_tar_gz) + set_target_properties(test_subgraph_detector PROPERTIES LINK_FLAGS "${LINK_FLAGS}") + lite_cc_test(test_subgraph_pass + SRCS subgraph_pass_test.cc + DEPS mir_passes paddle_api_full paddle_api_light gflags + ARGS --model_dir=${LITE_MODEL_DIR}/mobilenet_v1 + --optimized_model_dir=${LITE_MODEL_DIR}/lite_model_opt SERIAL) + add_dependencies(test_subgraph_pass + extern_lite_download_mobilenet_v1_tar_gz + extern_lite_download_mobilenet_v2_relu_tar_gz) set(LINK_FLAGS "-Wl,--version-script ${PADDLE_SOURCE_DIR}/lite/core/lite.map") - set_target_properties(test_xpu_pass PROPERTIES LINK_FLAGS "${LINK_FLAGS}") - endif() + set_target_properties(test_subgraph_pass PROPERTIES LINK_FLAGS "${LINK_FLAGS}") endif() -set(subgraph_passes ${subgraph_passes} CACHE INTERNAL "subgraph_passes") -message(STATUS "----> subgraph_passes: ${subgraph_passes}") +set(mir_subgraphs subgraph_pass CACHE INTERNAL "mir_subgraphs") +message(STATUS "----> mir_subgraphs: ${mir_subgraphs}") diff --git a/lite/core/mir/subgraph/generate_npu_program_pass.cc b/lite/core/mir/subgraph/generate_npu_program_pass.cc deleted file mode 100644 index c83cd70d8225a0b33a50ebdad331283f377e0059..0000000000000000000000000000000000000000 --- a/lite/core/mir/subgraph/generate_npu_program_pass.cc +++ /dev/null @@ -1,219 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "lite/core/mir/subgraph/generate_npu_program_pass.h" -#include -#include -#include -#include -#include -#include "lite/core/mir/graph_visualize_pass.h" -#include "lite/core/mir/pass_registry.h" -#include "lite/core/mir/pattern_matcher.h" - -#include "lite/backends/npu/builder.h" -#include "lite/kernels/npu/bridges/paddle_use_npu_bridges.h" -#include "lite/kernels/npu/bridges/registry.h" - -namespace paddle { -namespace lite { -namespace mir { -namespace subgraph { - -std::shared_ptr GenerateNPUProgramPass::CvtVarNode( - lite::mir::Node* var_node, const Scope* scope) { - CHECK(var_node->IsArg()); - const auto& arg = var_node->AsArg(); - VLOG(4) << "[NPU] Convert var node " << arg.name; - - auto* var = scope->FindVar(arg.name); - CHECK(var); - auto* tensor = var->GetMutable(); - CHECK(tensor); - auto dims = tensor->dims(); - if (arg.is_weight) { - auto wgt = std::make_shared(arg.name); - LOG(INFO) << "[NPU] Convert const var node " << arg.name; - VLOG(4) << dims; - wgt->set_attr_value(lite::npu::CvtTensor(tensor)); - return wgt; - } else { - CHECK_EQ(dims.size(), 4); - LOG(INFO) << "[NPU] Convert data var node " << arg.name; - LOG(INFO) << dims; - // TODO(xxx): support more types and dims size - ge::TensorDesc desc(ge::Shape(dims.Vectorize()), - ge::Format::FORMAT_NCHW, - ge::DataType::DT_FLOAT); - - // auto size = desc.GetShape().GetShapeSize(); - // ge::TensorUtils::SetSize(desc, size*sizeof(float)); - // ge::TensorUtils::SetRealDimCnt(desc, 4); - auto data = std::make_shared(arg.name); - data->update_input_desc_x(desc); - return data; - } - return nullptr; -} - -void GenerateNPUProgramPass::CvtAllOpNodes( - const std::vector& nodes2cvt, - lite::kernels::npu::bridges::node_map_type* converted_vars) { - const auto& bridges = lite::kernels::npu::bridges::Factory::Instance(); - const auto& cvtfunc_map = bridges.AllFunctions(); - // return record all converted vars - // op node's inputs must be found in converted_vars - for (auto& node : nodes2cvt) { - lite::kernels::npu::bridges::node_map_type node_inputs; - auto& stmt = node->AsStmt(); - for (auto& var_node : node->inlinks) { - auto& arg = var_node->AsArg(); - // weight should be handled in the converter, so skip here - if (arg.is_weight) { - continue; - } - auto var_name = arg.name; - if (!converted_vars->count(var_name)) { - converted_vars->insert( - std::make_pair(var_name, CvtVarNode(var_node, stmt.op()->scope()))); - } - node_inputs.insert(*converted_vars->find(var_name)); - } - auto node_outputs = cvtfunc_map.at(stmt.op_type())(stmt.op(), node_inputs); - converted_vars->insert(node_outputs.begin(), node_outputs.end()); - } -} - -std::string GenerateNPUProgramPass::BuildNPUGraph( - const std::unordered_set& op_nodes, - const std::unordered_set& in_data_vars, - const std::unordered_set& out_data_vars, - int sub_id) { - auto ordered_nodes = GetTopologicalOrder(op_nodes); - lite::kernels::npu::bridges::node_map_type converted_vars; - CvtAllOpNodes(ordered_nodes, &converted_vars); - - std::vector in_var_names; - std::vector out_var_names; - std::vector inputs; - std::vector outputs; - for (auto i : in_data_vars) { - auto argname = i->AsArg().name; - in_var_names.push_back(argname); - inputs.push_back(*converted_vars.at(argname)); - } - for (auto i : out_data_vars) { - auto argname = i->AsArg().name; - out_var_names.push_back(argname); - outputs.push_back(*converted_vars.at(argname)); - } - - std::string weight_var_name = "graph" + std::to_string(sub_id) + "_weights"; - auto any_op = (*op_nodes.begin())->AsStmt().op(); - auto weight = any_op->scope()->Var(weight_var_name)->GetMutable(); - weight->set_persistable(true); - weight->set_precision(PRECISION(kInt8)); - // Compiling IR graph to NPU model and store mode data into weight tensor with - // persistable=true, Sothat the model parser can recognize it and save it to - // param files - if (!lite::npu::BuildModel(inputs, outputs, weight)) { - LOG(WARNING) << "[NPU] Build NPU graph failed (subgraph=" << sub_id << ")"; - throw std::runtime_error("Build NPU graph failed."); - } - LOG(INFO) << "[NPU] Build NPU graph success (subgraph=" << sub_id << ")"; - return weight_var_name; -} - -void GenerateNPUProgramPass::GenNPUSubgraph( - const std::unique_ptr& graph, - const std::unordered_set& op_nodes, - int sub_id) { - std::unordered_set in_data_vars; - std::unordered_set in_wgt_vars; - std::unordered_set out_data_vars; - std::unordered_set out_unused_vars; - FindInputOutputVars( - op_nodes, &in_data_vars, &in_wgt_vars, &out_data_vars, &out_unused_vars); - - auto weight_var_name = - BuildNPUGraph(op_nodes, in_data_vars, out_data_vars, sub_id); - - auto any_op = (*op_nodes.begin())->AsStmt().op(); - InsertNewNode(graph, - weight_var_name, - any_op->scope(), - any_op->valid_places(), - in_data_vars, - in_wgt_vars, - out_data_vars, - out_unused_vars); - - auto nodes2rm = GetNode2rm( - op_nodes, {in_data_vars, in_wgt_vars, out_data_vars, out_unused_vars}); - - GraphSafeRemoveNodes(graph.get(), nodes2rm); -} - -void GenerateNPUProgramPass::Apply(const std::unique_ptr& graph) { - LOG(INFO) << "[NPU] Before NPU Pass \n" << Visualize(graph.get()); - const auto& bridges = lite::kernels::npu::bridges::Factory::Instance(); - const auto& op_map = bridges.AllFunctions(); - std::vector supported_op_types; - for (auto& i : op_map) { - LOG(INFO) << "[NPU] Supported type: " << i.first; - supported_op_types.push_back(i.first); - } - - try { - int num_subgraph = FuseSubgraph(graph, supported_op_types); - InferOnce(graph); - auto op_nodes_all = ClassifySubgraph(graph); - CHECK_EQ(op_nodes_all.size(), num_subgraph); - int id = 1; - for (auto& op_nodes : op_nodes_all) { - LOG(INFO) << "[NPU] Converting Subgraph " << id; - GenNPUSubgraph(graph, op_nodes.second, id); - LOG(INFO) << "[NPU] After NPU Pass Subgraph " << id << "\n" - << Visualize(graph.get()); - id++; - } - } catch (...) { - LOG(WARNING) << "[NPU] Build NPU graph failed."; - throw std::runtime_error("[NPU] Build NPU graph failed."); - } - - for (auto& item : graph->StmtTopologicalOrder()) { - if (item->IsStmt()) { - auto& stmt = item->AsStmt(); - LOG(INFO) << stmt; - insts_.emplace_back(stmt.op(), std::move(stmt.kernels().front())); - } - } -} - -std::unique_ptr GenerateNPUProgramPass::GenProgram() { - LOG(INFO) << "[NPU] program insts.size " << insts_.size(); - std::unique_ptr program( - new RuntimeProgram(std::move(insts_))); - return program; -} - -} // namespace subgraph -} // namespace mir -} // namespace lite -} // namespace paddle - -REGISTER_MIR_PASS(generate_npu_program_pass, - paddle::lite::mir::subgraph::GenerateNPUProgramPass) - .BindTargets({TARGET(kNPU)}); diff --git a/lite/core/mir/subgraph/generate_npu_program_pass.h b/lite/core/mir/subgraph/generate_npu_program_pass.h deleted file mode 100644 index 823ca5f1f624a9e920a5f395a9d5098c5ea52929..0000000000000000000000000000000000000000 --- a/lite/core/mir/subgraph/generate_npu_program_pass.h +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include -#include -#include -#include -#include "lite/backends/npu/builder.h" -#include "lite/core/mir/pass.h" -#include "lite/core/mir/subgraph/subgraph_program_pass.h" -#include "lite/kernels/npu/bridges/registry.h" - -namespace paddle { -namespace lite { -namespace mir { -namespace subgraph { - -class GenerateNPUProgramPass : public SubgraphProgramPass { - public: - using key2nodes_t = std::map; - - void Apply(const std::unique_ptr& graph) override; - std::unique_ptr GenProgram(); - - protected: - // nodes2cvt: op nodes to convert - // return cvted_vars: converted var nodes - void CvtAllOpNodes(const std::vector& nodes2cvt, - lite::kernels::npu::bridges::node_map_type* cvted_vars); - - std::shared_ptr CvtVarNode(lite::mir::Node* var_node, - const Scope* scope); - - std::string BuildNPUGraph(const std::unordered_set& op_nodes, - const std::unordered_set& in_data_vars, - const std::unordered_set& out_data_vars, - int sub_id); - - void GenNPUSubgraph(const std::unique_ptr& graph, - const std::unordered_set& op_nodes, - int sub_id); - - private: - std::vector insts_; -}; - -} // namespace subgraph -} // namespace mir -} // namespace lite -} // namespace paddle diff --git a/lite/core/mir/subgraph/generate_npu_program_pass_test.cc b/lite/core/mir/subgraph/generate_npu_program_pass_test.cc deleted file mode 100644 index 95339d6175c98f22d542db24f02d6d714ccbe2a8..0000000000000000000000000000000000000000 --- a/lite/core/mir/subgraph/generate_npu_program_pass_test.cc +++ /dev/null @@ -1,172 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include "lite/api/paddle_api.h" -#include "lite/api/paddle_use_kernels.h" -#include "lite/api/paddle_use_ops.h" -#include "lite/api/paddle_use_passes.h" -#include "lite/api/test_helper.h" -#include "lite/utils/cp_logging.h" - -DEFINE_string(model_file, "", "model file path of combined protobuf model"); -DEFINE_string(params_file, "", "params file path of combined protobuf model"); -DEFINE_string(optimized_model_dir, "", "path of optimized naive buffer model"); -DEFINE_string(input_tensor_shape, "1,3,224,224", "shapes of input tensors"); -DEFINE_int32(output_tensor_num, 1, "number of output tensors"); - -namespace paddle { -namespace lite { - -std::vector> ParseShape(std::string txt) { - std::vector> shape; - while (!txt.empty()) { - size_t idx = txt.find_first_of(":"); - std::string dims = txt.substr(0, idx); - std::vector s; - while (!dims.empty()) { - size_t idx = dims.find_first_of(","); - int d = atoi(dims.substr(0, idx).c_str()); - VLOG(3) << d; - s.push_back(d); - if (idx == std::string::npos) { - break; - } else { - dims = dims.substr(idx + 1); - } - } - shape.push_back(s); - if (idx == std::string::npos) { - break; - } else { - txt = txt.substr(idx + 1); - } - } - return shape; -} - -int64_t ShapeProduction(std::vector shape) { - int64_t s = 1; - for (int64_t dim : shape) { - s *= dim; - } - return s; -} - -void FillInputTensor( - const std::shared_ptr& predictor, - const std::vector>& input_tensor_shape, - const float value) { - for (int i = 0; i < input_tensor_shape.size(); i++) { - auto input_tensor = predictor->GetInput(i); - input_tensor->Resize(input_tensor_shape[i]); - auto input_tensor_data = input_tensor->mutable_data(); - auto input_tensor_size = ShapeProduction(input_tensor->shape()); - for (int j = 0; j < input_tensor_size; j++) { - input_tensor_data[i] = value; - } - } -} - -void CompareOutputTensor( - const std::shared_ptr& tar_predictor, - const std::shared_ptr& ref_predictor, - const int output_tensor_num) { - for (int i = 0; i < output_tensor_num; i++) { - auto tar_output_tensor = tar_predictor->GetOutput(i); - auto ref_output_tensor = ref_predictor->GetOutput(i); - auto tar_output_tensor_data = tar_output_tensor->data(); - auto ref_output_tensor_data = ref_output_tensor->data(); - auto tar_output_tensor_size = ShapeProduction(tar_output_tensor->shape()); - auto ref_output_tensor_size = ShapeProduction(ref_output_tensor->shape()); - EXPECT_EQ(tar_output_tensor_size, ref_output_tensor_size); - for (size_t j = 0; j < ref_output_tensor_size; j++) { - auto abs_diff = - std::fabs(tar_output_tensor_data[j] - ref_output_tensor_data[j]); - auto rel_diff = abs_diff / (std::fabs(ref_output_tensor_data[j]) + 1e-6); - VLOG(3) << "val: " << tar_output_tensor_data[j] - << " ref: " << ref_output_tensor_data[j] - << " abs_diff: " << abs_diff << " rel_diff: " << rel_diff; - EXPECT_LT(rel_diff, 0.1); - } - } -} - -std::shared_ptr TestModel( - const std::string& model_dir, - const std::string& model_file, - const std::string& params_file, - const std::vector& valid_places, - const std::vector>& input_tensor_shape, - const std::string& optimized_model_dir) { - // generate optimized model - lite_api::CxxConfig cxx_config; - cxx_config.set_model_dir(model_dir); - cxx_config.set_model_file(model_file); - cxx_config.set_param_file(params_file); - cxx_config.set_valid_places(valid_places); - auto predictor = lite_api::CreatePaddlePredictor(cxx_config); - FillInputTensor(predictor, input_tensor_shape, 1); - predictor->SaveOptimizedModel(optimized_model_dir, - lite_api::LiteModelType::kNaiveBuffer); - // load optimized model - lite_api::MobileConfig mobile_config; - mobile_config.set_model_dir(optimized_model_dir); - mobile_config.set_power_mode(lite_api::PowerMode::LITE_POWER_HIGH); - mobile_config.set_threads(1); - predictor = lite_api::CreatePaddlePredictor(mobile_config); - FillInputTensor(predictor, input_tensor_shape, 1); - // run optimized model - for (int i = 0; i < FLAGS_warmup; i++) { - predictor->Run(); - } - for (int i = 0; i < FLAGS_repeats; i++) { - auto start = GetCurrentUS(); - predictor->Run(); - LOG(INFO) << i << ", " << GetCurrentUS() - start << "us"; - } - return predictor; -} - -TEST(NPUSubgraph, compare) { - // parsing input tensor shape, supported formats: "1,3,224,224" - // "1,3,224,224:1,80" - std::vector> input_tensor_shape = - ParseShape(FLAGS_input_tensor_shape); - // generate and run optimized CPU model - LOG(INFO) << " ================ CPU ================== "; - auto cpu_predictor = - TestModel(FLAGS_model_dir, - FLAGS_model_file, - FLAGS_params_file, - {lite_api::Place{TARGET(kARM), PRECISION(kFloat)}}, - input_tensor_shape, - FLAGS_optimized_model_dir + "/CPU"); - // generate and run optimized NPU model - LOG(INFO) << " ================ NPU ================== "; - auto npu_predictor = - TestModel(FLAGS_model_dir, - FLAGS_model_file, - FLAGS_params_file, - {lite_api::Place{TARGET(kARM), PRECISION(kFloat)}, - lite_api::Place{TARGET(kNPU), PRECISION(kFloat)}}, - input_tensor_shape, - FLAGS_optimized_model_dir + "/NPU"); - // verify results - CompareOutputTensor(npu_predictor, cpu_predictor, FLAGS_output_tensor_num); -} - -} // namespace lite -} // namespace paddle diff --git a/lite/core/mir/subgraph/generate_xpu_program_pass.cc b/lite/core/mir/subgraph/generate_xpu_program_pass.cc deleted file mode 100644 index 319e1e51feb917b803753807ddbb1f72c2cb7084..0000000000000000000000000000000000000000 --- a/lite/core/mir/subgraph/generate_xpu_program_pass.cc +++ /dev/null @@ -1,206 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "lite/core/mir/subgraph/generate_xpu_program_pass.h" -#include -#include -#include -#include -#include -#include "lite/core/mir/graph_visualize_pass.h" -#include "lite/core/mir/pass_registry.h" -#include "lite/core/mir/pattern_matcher.h" - -#include "lite/backends/xpu/builder.h" -#include "lite/kernels/xpu/bridges/paddle_use_xpu_bridges.h" -#include "lite/kernels/xpu/bridges/registry.h" - -namespace paddle { -namespace lite { -namespace mir { -namespace subgraph { - -std::shared_ptr GenerateXPUProgramPass::CvtVarNode( - lite::kernels::xpu::bridges::graph_ctx_type* graph_ctx, - lite::mir::Node* var_node, - const Scope* scope) { - CHECK(var_node->IsArg()); - const auto& arg = var_node->AsArg(); - auto var_name = arg.name; - VLOG(4) << "[XPU] Convert var node " << var_name; - - auto* var = scope->FindVar(var_name); - CHECK(var); - auto* tensor = var->GetMutable(); - CHECK(tensor); - auto dims = tensor->dims(); - auto cvted_var_node = - std::make_shared(graph_ctx->builder->CreateTensor( - var_name, lite::xpu::CvtShape(dims), ::xtcl::Float(32))); - if (arg.is_weight) { - auto cvted_var_tensor = lite::xpu::CvtTensor(tensor); - graph_ctx->params->emplace(std::make_pair(var_name, *cvted_var_tensor)); - } - return cvted_var_node; -} - -void GenerateXPUProgramPass::CvtAllOpNodes( - const std::vector& op_nodes, - lite::kernels::xpu::bridges::graph_ctx_type* graph_ctx, - lite::kernels::xpu::bridges::node_map_type* cvted_var_nodes) { - const auto& bridges = lite::kernels::xpu::bridges::Factory::Instance(); - const auto& supported_lists = bridges.AllFunctions(); - // return record all converted vars - // op node's inputs must be found in converted_vars - for (auto& node : op_nodes) { - lite::kernels::xpu::bridges::node_map_type input_nodes; - auto& stmt = node->AsStmt(); - for (auto& var_node : node->inlinks) { - auto& arg = var_node->AsArg(); - // weight should be handled in the converter, so skip here - if (arg.is_weight) { - continue; - } - auto var_name = arg.name; - if (!cvted_var_nodes->count(var_name)) { - cvted_var_nodes->insert(std::make_pair( - var_name, CvtVarNode(graph_ctx, var_node, stmt.op()->scope()))); - } - input_nodes.insert(*cvted_var_nodes->find(var_name)); - } - auto output_nodes = - supported_lists.at(stmt.op_type())(stmt.op(), graph_ctx, input_nodes); - cvted_var_nodes->insert(output_nodes.begin(), output_nodes.end()); - } -} - -std::string GenerateXPUProgramPass::BuildXPUGraph( - const std::unordered_set& op_nodes, - const std::unordered_set& in_data_vars, - const std::unordered_set& out_data_vars, - int sub_id) { - auto ordered_op_nodes = GetTopologicalOrder(op_nodes); - lite::kernels::xpu::bridges::graph_ctx_type graph_ctx; - graph_ctx.builder = std::make_shared(); - graph_ctx.params = - std::make_shared(); - lite::kernels::xpu::bridges::node_map_type cvted_var_nodes; - CvtAllOpNodes(ordered_op_nodes, &graph_ctx, &cvted_var_nodes); - - std::string weight_var_name = "graph" + std::to_string(sub_id) + "_weights"; - auto any_op = (*op_nodes.begin())->AsStmt().op(); - auto weight = any_op->scope()->Var(weight_var_name)->GetMutable(); - weight->set_persistable(true); - weight->set_precision(PRECISION(kInt8)); - // Compiling graph to XPU model and store mode data into weight tensor with - // persistable=true, Sothat the model parser can recognize it and save it to - // param files - std::vector> ordered_cvted_var_nodes; - for (auto out_data_var : out_data_vars) { - auto var_name = out_data_var->AsArg().name; - ordered_cvted_var_nodes.push_back(cvted_var_nodes[var_name]); - } - if (!lite::xpu::BuildModel(graph_ctx.builder, - graph_ctx.params, - &ordered_cvted_var_nodes, - weight)) { - LOG(WARNING) << "[XPU] Build XPU graph failed (subgraph=" << sub_id << ")"; - throw std::runtime_error("[XPU] Build XPU graph failed."); - } - LOG(INFO) << "[XPU] Build XPU graph success (subgraph=" << sub_id << ")"; - return weight_var_name; -} - -void GenerateXPUProgramPass::GenXPUSubgraph( - const std::unique_ptr& graph, - const std::unordered_set& op_nodes, - int sub_id) { - std::unordered_set in_data_vars; - std::unordered_set in_wgt_vars; - std::unordered_set out_data_vars; - std::unordered_set out_unused_vars; - FindInputOutputVars( - op_nodes, &in_data_vars, &in_wgt_vars, &out_data_vars, &out_unused_vars); - - auto weight_var_name = - BuildXPUGraph(op_nodes, in_data_vars, out_data_vars, sub_id); - - auto any_op = (*op_nodes.begin())->AsStmt().op(); - InsertNewNode(graph, - weight_var_name, - any_op->scope(), - any_op->valid_places(), - in_data_vars, - in_wgt_vars, - out_data_vars, - out_unused_vars); - - auto nodes2rm = GetNode2rm( - op_nodes, {in_data_vars, in_wgt_vars, out_data_vars, out_unused_vars}); - - GraphSafeRemoveNodes(graph.get(), nodes2rm); -} - -void GenerateXPUProgramPass::Apply(const std::unique_ptr& graph) { - LOG(INFO) << "[XPU] Before XPU Pass \n" << Visualize(graph.get()); - const auto& bridges = lite::kernels::xpu::bridges::Factory::Instance(); - const auto& op_map = bridges.AllFunctions(); - std::vector supported_op_types; - for (auto& i : op_map) { - LOG(INFO) << "[XPU] Supported type: " << i.first; - supported_op_types.push_back(i.first); - } - - try { - int num_subgraph = FuseSubgraph(graph, supported_op_types); - InferOnce(graph); - auto op_nodes_all = ClassifySubgraph(graph); - CHECK_EQ(op_nodes_all.size(), num_subgraph); - int id = 1; - for (auto& op_nodes : op_nodes_all) { - LOG(INFO) << "[XPU] Converting Subgraph " << id; - GenXPUSubgraph(graph, op_nodes.second, id); - LOG(INFO) << "[XPU] After XPU Pass Subgraph " << id << "\n" - << Visualize(graph.get()); - id++; - } - } catch (...) { - LOG(WARNING) << "[XPU] Build XPU graph failed."; - throw std::runtime_error("[XPU] Build XPU graph failed."); - } - - for (auto& item : graph->StmtTopologicalOrder()) { - if (item->IsStmt()) { - auto& stmt = item->AsStmt(); - LOG(INFO) << stmt; - insts_.emplace_back(stmt.op(), std::move(stmt.kernels().front())); - } - } -} - -std::unique_ptr GenerateXPUProgramPass::GenProgram() { - LOG(INFO) << "[XPU] program insts.size=" << insts_.size(); - std::unique_ptr program( - new RuntimeProgram(std::move(insts_))); - return program; -} - -} // namespace subgraph -} // namespace mir -} // namespace lite -} // namespace paddle - -REGISTER_MIR_PASS(generate_xpu_program_pass, - paddle::lite::mir::subgraph::GenerateXPUProgramPass) - .BindTargets({TARGET(kXPU)}); diff --git a/lite/core/mir/subgraph/generate_xpu_program_pass.h b/lite/core/mir/subgraph/generate_xpu_program_pass.h deleted file mode 100644 index cf121ae9503201e8cf6be40fe9054ccaf6e4b172..0000000000000000000000000000000000000000 --- a/lite/core/mir/subgraph/generate_xpu_program_pass.h +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include -#include -#include -#include -#include "lite/backends/xpu/builder.h" -#include "lite/core/mir/pass.h" -#include "lite/core/mir/subgraph/subgraph_program_pass.h" -#include "lite/kernels/xpu/bridges/registry.h" - -namespace paddle { -namespace lite { -namespace mir { -namespace subgraph { - -class GenerateXPUProgramPass : public SubgraphProgramPass { - public: - using key2nodes_t = std::map; - - void Apply(const std::unique_ptr& graph) override; - std::unique_ptr GenProgram(); - - protected: - // nodes2cvt: op nodes to convert - // return cvted_vars: converted var nodes - void CvtAllOpNodes( - const std::vector& op_nodes, - lite::kernels::xpu::bridges::graph_ctx_type* graph_ctx, - lite::kernels::xpu::bridges::node_map_type* cvted_var_nodes); - - std::shared_ptr CvtVarNode( - lite::kernels::xpu::bridges::graph_ctx_type* graph_ctx, - lite::mir::Node* var_node, - const Scope* scope); - - std::string BuildXPUGraph(const std::unordered_set& op_nodes, - const std::unordered_set& in_data_vars, - const std::unordered_set& out_data_vars, - int sub_id); - - void GenXPUSubgraph(const std::unique_ptr& graph, - const std::unordered_set& op_nodes, - int sub_id); - - private: - std::vector insts_; -}; - -} // namespace subgraph -} // namespace mir -} // namespace lite -} // namespace paddle diff --git a/lite/core/mir/subgraph/generate_xpu_program_pass_test.cc b/lite/core/mir/subgraph/generate_xpu_program_pass_test.cc deleted file mode 100644 index 728ecbc6b77666accd432b1ad82a03860588ab40..0000000000000000000000000000000000000000 --- a/lite/core/mir/subgraph/generate_xpu_program_pass_test.cc +++ /dev/null @@ -1,172 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include "lite/api/paddle_api.h" -#include "lite/api/paddle_use_kernels.h" -#include "lite/api/paddle_use_ops.h" -#include "lite/api/paddle_use_passes.h" -#include "lite/api/test_helper.h" -#include "lite/utils/cp_logging.h" - -DEFINE_string(model_file, "", "model file path of combined protobuf model"); -DEFINE_string(params_file, "", "params file path of combined protobuf model"); -DEFINE_string(optimized_model_dir, "", "path of optimized naive buffer model"); -DEFINE_string(input_tensor_shape, "1,3,224,224", "shapes of input tensors"); -DEFINE_int32(output_tensor_num, 1, "number of output tensors"); - -namespace paddle { -namespace lite { - -std::vector> ParseShape(std::string txt) { - std::vector> shape; - while (!txt.empty()) { - size_t idx = txt.find_first_of(":"); - std::string dims = txt.substr(0, idx); - std::vector s; - while (!dims.empty()) { - size_t idx = dims.find_first_of(","); - int d = atoi(dims.substr(0, idx).c_str()); - VLOG(3) << d; - s.push_back(d); - if (idx == std::string::npos) { - break; - } else { - dims = dims.substr(idx + 1); - } - } - shape.push_back(s); - if (idx == std::string::npos) { - break; - } else { - txt = txt.substr(idx + 1); - } - } - return shape; -} - -int64_t ShapeProduction(std::vector shape) { - int64_t s = 1; - for (int64_t dim : shape) { - s *= dim; - } - return s; -} - -void FillInputTensor( - const std::shared_ptr& predictor, - const std::vector>& input_tensor_shape, - const float value) { - for (int i = 0; i < input_tensor_shape.size(); i++) { - auto input_tensor = predictor->GetInput(i); - input_tensor->Resize(input_tensor_shape[i]); - auto input_tensor_data = input_tensor->mutable_data(); - auto input_tensor_size = ShapeProduction(input_tensor->shape()); - for (int j = 0; j < input_tensor_size; j++) { - input_tensor_data[j] = value; - } - } -} - -void CompareOutputTensor( - const std::shared_ptr& tar_predictor, - const std::shared_ptr& ref_predictor, - const int output_tensor_num) { - for (int i = 0; i < output_tensor_num; i++) { - auto tar_output_tensor = tar_predictor->GetOutput(i); - auto ref_output_tensor = ref_predictor->GetOutput(i); - auto tar_output_tensor_data = tar_output_tensor->data(); - auto ref_output_tensor_data = ref_output_tensor->data(); - auto tar_output_tensor_size = ShapeProduction(tar_output_tensor->shape()); - auto ref_output_tensor_size = ShapeProduction(ref_output_tensor->shape()); - EXPECT_EQ(tar_output_tensor_size, ref_output_tensor_size); - for (size_t j = 0; j < ref_output_tensor_size; j++) { - auto diff = - std::fabs(tar_output_tensor_data[j] - ref_output_tensor_data[j]) / - (std::fabs(ref_output_tensor_data[j]) + 1e-6); - VLOG(3) << diff; - EXPECT_LT(diff, 0.1); - } - } -} - -std::shared_ptr TestModel( - const std::string& model_dir, - const std::string& model_file, - const std::string& params_file, - const std::vector& valid_places, - const std::vector>& input_tensor_shape, - const std::string& optimized_model_dir) { - // generate optimized model - lite_api::CxxConfig cxx_config; - cxx_config.set_model_dir(model_dir); - cxx_config.set_model_file(model_file); - cxx_config.set_param_file(params_file); - cxx_config.set_valid_places(valid_places); - auto predictor = lite_api::CreatePaddlePredictor(cxx_config); - FillInputTensor(predictor, input_tensor_shape, -1); - predictor->SaveOptimizedModel(optimized_model_dir, - lite_api::LiteModelType::kNaiveBuffer); -#if 0 // TODO(hong19860320) supports light api for XPU - // load optimized model - lite_api::MobileConfig mobile_config; - mobile_config.set_model_dir(optimized_model_dir); - mobile_config.set_power_mode(lite_api::PowerMode::LITE_POWER_HIGH); - mobile_config.set_threads(1); - predictor = lite_api::CreatePaddlePredictor(mobile_config); - FillInputTensor(predictor, input_tensor_shape, 1); -#endif - // run optimized model - for (int i = 0; i < FLAGS_warmup; i++) { - predictor->Run(); - } - for (int i = 0; i < FLAGS_repeats; i++) { - auto start = GetCurrentUS(); - predictor->Run(); - LOG(INFO) << i << ", " << GetCurrentUS() - start << "us"; - } - return predictor; -} - -TEST(XPUSubgraph, compare) { - // parsing input tensor shape, supported formats: "1,3,224,224" - // "1,3,224,224:1,80" - std::vector> input_tensor_shape = - ParseShape(FLAGS_input_tensor_shape); - // generate and run optimized CPU model - LOG(INFO) << " ================ CPU ================== "; - auto cpu_predictor = - TestModel(FLAGS_model_dir, - FLAGS_model_file, - FLAGS_params_file, - {lite_api::Place{TARGET(kX86), PRECISION(kFloat)}}, - input_tensor_shape, - FLAGS_optimized_model_dir + "/CPU"); - // generate and run optimized XPU model - LOG(INFO) << " ================ XPU ================== "; - auto xpu_predictor = - TestModel(FLAGS_model_dir, - FLAGS_model_file, - FLAGS_params_file, - {lite_api::Place{TARGET(kXPU), PRECISION(kFloat)}, - lite_api::Place{TARGET(kX86), PRECISION(kFloat)}}, - input_tensor_shape, - FLAGS_optimized_model_dir + "/XPU"); - // verify results - CompareOutputTensor(xpu_predictor, cpu_predictor, FLAGS_output_tensor_num); -} - -} // namespace lite -} // namespace paddle diff --git a/lite/core/mir/subgraph/subgraph_detector.cc b/lite/core/mir/subgraph/subgraph_detector.cc new file mode 100644 index 0000000000000000000000000000000000000000..6d48b053a1a4140252d35e85d2351644d3c216e9 --- /dev/null +++ b/lite/core/mir/subgraph/subgraph_detector.cc @@ -0,0 +1,551 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/subgraph/subgraph_detector.h" +#include +#include +#include +#include +#include +#include "lite/core/mir/dot.h" +#include "lite/core/mir/pass_registry.h" +#include "lite/core/mir/pattern_matcher.h" +#include "lite/operators/subgraph_op.h" + +namespace paddle { +namespace lite { +namespace mir { + +using inference::analysis::Dot; + +std::string SubgraphVisualizer::operator()() { + inference::analysis::Dot dot; + const std::vector subgraph_colors{ + "red", "green", "cyan", "bisque3", + "coral", "darkseagreen1", "goldenrod1", "darkorchid", + "antiquewhite", "aquamarine", "azure", "bisque4", + "blue2", "brown1", "burlywood1", "cadetblue1", + "chartreuse1", "chocolate1", "coral1", "cornsilk", + "crimson", "cyan4", "darkgoldenrod4", "darkolivegreen2", + "darkorange2", "darkorchid2", "darkseagreen3", "darkslategray", + "deeppink2", "deepskyblue2", "dodgerblue", "firebrick", + "floralwhite", "gold1", "skyblue3", "indianred", + "indigo", "lavenderblush2", "lightblue1", "lightsalmon3", + "khaki1", "ivory4", "sandybrown", "olivedrab2", + "turquoise4", "snow3", "sienna4", "salmon2", + }; + std::unordered_map subgraph_indices; + for (int i = 0; i < subgraphs_.size(); i++) { + for (int j = 0; j < subgraphs_[i].size(); j++) { + subgraph_indices[subgraphs_[i][j]] = i; + } + } + std::unordered_map exists_ops; + std::set exists_args; + for (auto &node : graph_->StmtTopologicalOrder()) { + if (!node->IsStmt()) { + continue; + } + auto op_type = node->AsStmt().op_type(); + if (!exists_ops.count(op_type)) { + exists_ops[op_type] = 0; + } else { + exists_ops[op_type]++; + } + auto op_name = op_type + std::to_string(exists_ops[op_type]); + std::string op_color = "white"; + if (subgraph_indices.count(node)) { + auto subgraph_idx = subgraph_indices[node]; + op_name += "_subgraph_" + std::to_string(subgraph_idx); + op_color = subgraph_colors[subgraph_idx % subgraph_colors.size()]; + } + dot.AddNode(op_name, + {Dot::Attr("shape", "box"), + Dot::Attr("style", "filled"), + Dot::Attr("color", "black"), + Dot::Attr("fillcolor", op_color)}); + for (auto &in_node : node->inlinks) { + auto arg_name = in_node->AsArg().name; + if (!exists_args.count(arg_name)) { + dot.AddNode(arg_name, {}); + exists_args.insert(arg_name); + } + dot.AddEdge(arg_name, op_name, {}); + } + for (auto &out_node : node->outlinks) { + auto arg_name = out_node->AsArg().name; + if (!exists_args.count(arg_name)) { + dot.AddNode(arg_name, {}); + exists_args.insert(arg_name); + } + dot.AddEdge(op_name, arg_name, {}); + } + } + + auto res = dot.Build(); + std::cout << "subgraphs: " << subgraphs_.size() << "\n" << res << std::endl; + return res; +} + +// Find the ancestor node +SubgraphDetector::node_dat_t * +SubgraphDetector::node_dat_t::UnionFindAncestor() { + node_dat_t *ancestor = this; + while (ancestor->union_find_parent != ancestor) { + ancestor = ancestor->union_find_parent; + } + return ancestor; +} + +// Merge the two adjacent nodes into one node. +// Suppose we have two adjacent nodes src and dst. +// We will perform the following operations: +// 1. add all inputs(except src) of dst to src inlinks. +// 2. add all outputs of dst to src outlinks. +// 3. change all the dst's inputs and outputs +// corresponding inlinks and outlinks to src node. +// 4. delete all dst's inlinks and outlinks. +void SubgraphDetector::node_dat_t::UnionFindCombine(node_dat_t *candidate) { + // Make this two node share the same ancestor. + union_find_parent = UnionFindAncestor(); + node_dat_t *candidate_ancestor = candidate->UnionFindAncestor(); + candidate_ancestor->union_find_parent = union_find_parent; + candidate->union_find_parent = union_find_parent; + + // Obtain the input and output nodes for the combined one + std::unordered_set inputs(inlinks.begin(), inlinks.end()); + std::unordered_set outputs(candidate->outlinks.begin(), + candidate->outlinks.end()); + for (auto *out_node : outlinks) { + if (out_node != candidate) { + outputs.insert(out_node); + } + } + for (auto *in_node : candidate->inlinks) { + if (in_node != this) { + inputs.insert(in_node); + } + } + +// Update the dst and src node's inlinks and outlinks. +#ifdef __clang__ + inlinks = node_set_t(inputs.begin(), inputs.end()); + outlinks = node_set_t(outputs.begin(), outputs.end()); + candidate->inlinks.clear(); + candidate->outlinks.clear(); +#else + inlinks = std::move(node_set_t(inputs.begin(), inputs.end())); + outlinks = std::move(node_set_t(outputs.begin(), outputs.end())); + candidate->inlinks.clear(); + candidate->outlinks.clear(); +#endif + + // Change all the dst inputs and outputs corresponding inlink and + // outlink to the src node. + for (auto *in_node : inlinks) { + for (auto *&out_node : in_node->outlinks) { + if (out_node == candidate) { + out_node = this; + } + } + } + for (auto *out_node : outlinks) { + for (auto *&in_node : out_node->inlinks) { + if (in_node == candidate) { + in_node = this; + } + } + } +} + +// FlexibleDFS +// If reverse is true, do reverse dfs. +// If enter func is not nullptr, calls enter(node) before visiting any children +// of node. +// If leave func not nullptr, calls leave(node) after visiting all parents of +// node. +void SubgraphDetector::FlexibleDFS( + const node_set_t &source, + bool reverse, + const std::function &enter, + const std::function &leave) { + std::vector> stack; // node, leave + for (auto &node : source) { + stack.push_back(std::pair(node, false)); + } + std::unordered_set visited; + while (!stack.empty()) { + auto top = stack.back(); + stack.pop_back(); + + if (top.second) { + if (leave && !leave(top.first)) return; + } + if (visited.count(top.first)) continue; + visited.insert(top.first); + + if (enter && !enter(top.first)) return; + + if (leave) + stack.push_back(std::pair(top.first, true)); + const node_set_t iter_nodes = + reverse == true ? top.first->inlinks : top.first->outlinks; + for (auto *node : iter_nodes) { + if (!visited.count(node)) { + stack.push_back(std::pair(node, false)); + } + } + } +} + +void SubgraphDetector::InitNodes(node_map_t *nodes) { + // Initialize and mark the subgraph detector nodes based on teller. + for (auto &it : *nodes) { + for (auto &in_node : it.first->inlinks) { + it.second->inlinks.push_back((*nodes)[in_node]); + } + for (auto &out_node : it.first->outlinks) { + it.second->outlinks.push_back((*nodes)[out_node]); + } + if (teller_(it.first)) { + it.second->marked = true; + if (it.first->IsStmt()) { + // If a function is inside the subgraph, mark all the output variables + // to be inside too, so that two marked functions will be inside a same + // subgraph, lets take a example: A_function->var->B_function, if + // A_function is marked, var should also be marked, so that B_function + // will be in the same subgraph with A_function if B_function is + // marked. + for (auto &out_node : it.first->outlinks) { + (*nodes)[out_node]->marked = true; + } + } + } + } +} // namespace mir + +std::vector> SubgraphDetector::ExtractSubgraphs( + node_map_t *nodes) { + for (auto &it : *nodes) { + node_dat_t *node = it.second; + if (!node->marked) { + continue; + } + // Our algorithm must guarantee that: + // 1. The graph is always directed acyclic graph(DAG). + // 2. If there is a path in the subgraph from X to Y (X and Y are both + // nodes in the subgraph), then all paths from X to Y are in the + // subgraph. + // + // In order to achieve the above guarantee. + // For adjacent nodes src -> dst. + // 1. Get all dst input nodes except src. + // 2. Reverse DFS from those input nodes + // 3. If there is a path from input nodes to src, + // then the src and dst nodes can not be fused into one node, + // otherwise it can be done. + while (true) { + std::unordered_set contract_nodes; + for (auto *out_node : node->outlinks) { + // must be an candidate + if (!out_node->marked) continue; + // get all dst input nodes except src node. + node_set_t source_nodes; + for (auto *in_node : out_node->inlinks) { + if (in_node != node) { + source_nodes.push_back(in_node); + } + } + + // Reverse DFS from the source_nodes. + bool have_excess_path = false; + FlexibleDFS(source_nodes, + true, + nullptr, + [&have_excess_path, node](const node_dat_t *n) { + if (n == node) { + have_excess_path = true; + return false; + } + return true; + }); + if (have_excess_path) continue; + contract_nodes.insert(out_node); + } + if (contract_nodes.empty()) break; + + for (auto &contract_node : contract_nodes) { + node->UnionFindCombine(contract_node); + } + } + } + + std::unordered_map> clusters; + for (auto &node : graph_->StmtTopologicalOrder()) { + if (!node->IsStmt()) continue; + if ((*nodes)[node]->marked) { + clusters[(*nodes)[node]->UnionFindAncestor()].push_back(node); + } + } + std::vector> subgraphs; + std::for_each(clusters.begin(), + clusters.end(), + [&](const decltype(clusters)::value_type &it) { + subgraphs.push_back(it.second); + }); + return subgraphs; +} + +std::vector> SubgraphDetector::operator()() { + node_map_t nodes; + for (auto &node : graph_->mutable_nodes()) { + nodes[&node] = new node_dat_t(&node); + CHECK(nodes[&node]); + } + // Initialize and mark the subgraph detector nodes based on teller. + InitNodes(&nodes); + // Run the Extract algorithm to find all subgraphs. + std::vector> subgraphs = ExtractSubgraphs(&nodes); + for (auto &it : nodes) { + CHECK(it.second); + delete it.second; + } + return subgraphs; +} + +void SubgraphFuser::InsertNewNode(SSAGraph *graph, + int subgraph_idx, + const std::vector &subgraph_nodes) { + // Create and attach a new subgraph op + cpp::OpDesc subgraph_op_desc; + subgraph_op_desc.SetType("subgraph"); + + // Create a new sub block desc for storing all of Ops an Vars of the target + // subgraph and sub_block_idx is set as a attribute of subgraph op, + // sub_block_idx < 0 means it's a new subgraph op + int sub_block_idx = -(subgraph_idx + 1); + auto sub_block_desc = new cpp::BlockDesc(); + sub_block_desc->ClearOps(); + sub_block_desc->ClearVars(); + for (auto &op_node : subgraph_nodes) { + auto sub_block_op_desc = sub_block_desc->AddOp(); + *sub_block_op_desc = *op_node->AsStmt().op_info(); + sub_block_op_desc->SetAttr( + kKernelTypeAttr, + op_node->AsStmt().picked_kernel().SerializedKernelType()); + } + subgraph_op_desc.SetAttr("sub_block", sub_block_idx); + + // Extract input and output nodes from the target subgraph + std::unordered_set input_var_nodes; + std::unordered_set weight_var_nodes; + std::unordered_set output_var_nodes; + std::unordered_set local_var_nodes; + std::unordered_set unused_var_nodes; + ExtractInputsOutputs(subgraph_nodes, + &input_var_nodes, + &weight_var_nodes, + &output_var_nodes, + &local_var_nodes, + &unused_var_nodes); + + // Set input and output name mapping which stores the real inputs and + // outputs + std::vector input_var_names; + std::vector output_var_names; + for (auto &var_node : input_var_nodes) { + input_var_names.push_back(var_node->AsArg().name); + } + for (auto &var_node : output_var_nodes) { + output_var_names.push_back(var_node->AsArg().name); + } + subgraph_op_desc.SetAttr>("input_data_names", + input_var_names); + subgraph_op_desc.SetAttr>("output_data_names", + output_var_names); + + // Set all of the inputs and outputs to the target subgraph op + // To prevent vars are removed in RuntimeProgram::UpdateVarsOfProgram() + for (auto &var_node : weight_var_nodes) { + input_var_names.push_back(var_node->AsArg().name); + } + for (auto &var_node : local_var_nodes) { + output_var_names.push_back(var_node->AsArg().name); + } + for (auto &var_node : unused_var_nodes) { + output_var_names.push_back(var_node->AsArg().name); + } + subgraph_op_desc.SetInput("Inputs", input_var_names); + subgraph_op_desc.SetOutput("Outputs", output_var_names); + auto subgraph_op = LiteOpRegistry::Global().Create("subgraph"); + static_cast(subgraph_op.get()) + ->SetSubBlock(sub_block_desc); + auto any_op = (*subgraph_nodes.begin())->AsStmt().op(); + subgraph_op->Attach(subgraph_op_desc, any_op->scope()); + + // Create and add a new subgraph node into the graph + auto subgraph_op_node = + graph->GraphCreateInstructNode(subgraph_op, any_op->valid_places()); + for (auto &var_node : input_var_nodes) { + IR_NODE_LINK_TO(var_node, subgraph_op_node); + } + for (auto &var_node : weight_var_nodes) { + IR_NODE_LINK_TO(var_node, subgraph_op_node); + } + for (auto &var_node : output_var_nodes) { + IR_OP_VAR_LINK(subgraph_op_node, var_node); + } + for (auto &var_node : local_var_nodes) { + IR_OP_VAR_LINK(subgraph_op_node, var_node); + } + for (auto &var_node : unused_var_nodes) { + IR_OP_VAR_LINK(subgraph_op_node, var_node); + } + + // Create and assign the context to the picked kernel of the new subgraph + // node + auto &inst = subgraph_op_node->AsStmt(); + inst.picked_kernel().SetContext( + ContextScheduler::Global().NewContext(inst.picked_kernel().target())); + + // Remove subgraph nodes and unused var nodes + auto nodes2rm = GetNodes2RM(subgraph_nodes, + {input_var_nodes, + weight_var_nodes, + output_var_nodes, + local_var_nodes, + unused_var_nodes}); + GraphSafeRemoveNodes(graph, nodes2rm); +} + +void SubgraphFuser::ReplaceNodesWithSubgraphs(SSAGraph *graph, + const SubgraphTeller &teller, + int min_subgraph_size) { + std::vector> subgraphs = + SubgraphDetector(graph, teller)(); + SubgraphVisualizer(graph, subgraphs)(); + for (int subgraph_idx = 0; subgraph_idx < subgraphs.size(); subgraph_idx++) { + if (subgraphs[subgraph_idx].size() >= min_subgraph_size) { + InsertNewNode(graph, subgraph_idx, subgraphs[subgraph_idx]); + } + } +} + +void SubgraphFuser::operator()() { + ReplaceNodesWithSubgraphs(graph_, teller_, min_subgraph_size_); +} + +void ExtractInputsOutputs(const std::vector &op_nodes, + std::unordered_set *input_var_nodes, + std::unordered_set *weight_var_nodes, + std::unordered_set *output_var_nodes, + std::unordered_set *local_var_nodes, + std::unordered_set *unused_var_nodes) { + for (auto &op_node : op_nodes) { + for (auto &var_node : op_node->inlinks) { + if (var_node->AsArg().is_weight) { + weight_var_nodes->insert(var_node); + continue; + } + if (!var_node->inlinks.empty()) { + // Var can only come from one op node, so use front + auto *prev_op_node = var_node->inlinks.front(); + if (std::find(op_nodes.begin(), op_nodes.end(), prev_op_node) != + op_nodes.end()) { + continue; + } + } + input_var_nodes->insert(var_node); + } + for (auto &var_node : op_node->outlinks) { + if (var_node->outlinks.empty()) { + // The next op is empty so this var is actually unused + unused_var_nodes->insert(var_node); + continue; + } + // Var can have more than one next op node, So, if any one in the + // op_nodes then continue + bool next_op_in_nodes = false; + for (auto &next_op_node : var_node->outlinks) { + if (std::find(op_nodes.begin(), op_nodes.end(), next_op_node) != + op_nodes.end()) { + next_op_in_nodes = true; + } + } + if (next_op_in_nodes) { + local_var_nodes->insert(var_node); + continue; + } + output_var_nodes->insert(var_node); + } + } +} + +std::unordered_set GetNodes2RM( + const std::vector &op_nodes, + const std::vector> &excluded_var_nodes) { + std::unordered_set nodes2rm(op_nodes.begin(), op_nodes.end()); + for (auto &op_node : op_nodes) { + for (auto &var_node : op_node->inlinks) { + if (!nodes2rm.count(var_node)) { + nodes2rm.insert(var_node); + } + } + for (auto &var_node : op_node->outlinks) { + if (!nodes2rm.count(var_node)) { + nodes2rm.insert(var_node); + } + } + } + // Excluded nodes should not be removed + for (auto &excluded_var_node : excluded_var_nodes) { + for (auto &var_node : excluded_var_node) { + if (nodes2rm.count(var_node)) { + nodes2rm.erase(var_node); + } + } + } + return nodes2rm; +} + +static void SortHelper(Node *node, + const std::unordered_set &unordered_nodes, + std::unordered_set *visited_nodes, + std::vector *ordered_nodes) { + for (auto &var_node : node->inlinks) { + if (var_node->inlinks.empty()) continue; + auto *op_node = var_node->inlinks.front(); + if (unordered_nodes.count(op_node) && !visited_nodes->count(op_node)) { + SortHelper(op_node, unordered_nodes, visited_nodes, ordered_nodes); + } + } + ordered_nodes->push_back(node); + visited_nodes->insert(node); +} + +std::vector GetTopologicalOrder( + const std::unordered_set &unordered_nodes) { + std::unordered_set visited_nodes; + std::vector ordered_nodes; + for (auto &node : unordered_nodes) { + if (!node->IsStmt()) continue; + if (visited_nodes.count(node)) continue; + SortHelper(node, unordered_nodes, &visited_nodes, &ordered_nodes); + } + return ordered_nodes; +} + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/subgraph/subgraph_detector.h b/lite/core/mir/subgraph/subgraph_detector.h new file mode 100644 index 0000000000000000000000000000000000000000..b6873655e976a785383269972221f001196431f8 --- /dev/null +++ b/lite/core/mir/subgraph/subgraph_detector.h @@ -0,0 +1,127 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include "lite/core/mir/pass.h" + +namespace paddle { +namespace lite { +namespace mir { + +using SubgraphTeller = std::function; + +class SubgraphVisualizer { + public: + SubgraphVisualizer(SSAGraph* graph, + const std::vector>& subgraphs) + : graph_(graph), subgraphs_(subgraphs) {} + std::string operator()(); + + protected: + SSAGraph* graph_{nullptr}; + std::vector> subgraphs_; +}; + +/* + * Divide the graph into subgraphs according to the specified conditions. + * Return the divided clusters, a cluster is consisted of the op nodes in the + * subgraph. + */ +class SubgraphDetector { + public: + // This is a simple representation of a graph. The SDNode hold the + // pointer of the Node. This is to avoid changing the original graph in the + // process of graph analysis. + struct node_dat_t; + using node_map_t = std::unordered_map; + using node_set_t = std::vector; + struct node_dat_t { + explicit node_dat_t(Node* _node) : node(_node) {} + Node* node; + bool marked{false}; + node_dat_t* union_find_parent{this}; + node_set_t inlinks{}; + node_set_t outlinks{}; + node_dat_t* UnionFindAncestor(); + void UnionFindCombine(node_dat_t* candidate); + }; + SubgraphDetector(SSAGraph* graph, const SubgraphTeller& teller) + : graph_(graph), teller_(teller) {} + std::vector> operator()(); + + void FlexibleDFS(const node_set_t& source, + bool reverse, + const std::function& enter, + const std::function& leave); + void InitNodes(node_map_t* nodes); + std::vector> ExtractSubgraphs(node_map_t* nodes); + + protected: + SSAGraph* graph_{nullptr}; + SubgraphTeller teller_; +}; + +/* + * Replace all of subgraphs with the subgraph ops, a block desc is added into + * the subgraph op to wrap the original op nodes, keep all of var nodes of the + * original ops nodes as the inputs and outputs of the subgraph op + */ +class SubgraphFuser { + public: + SubgraphFuser(SSAGraph* graph, + const SubgraphTeller& teller, + int min_subgraph_size) + : graph_(graph), teller_(teller), min_subgraph_size_{min_subgraph_size} {} + void operator()(); + + // Remove the op nodes of the subgraphs and replace with the subgraph ops. + void ReplaceNodesWithSubgraphs(SSAGraph* graph, + const SubgraphTeller& teller, + int min_subgraph_size); + // Create a subgraph node with a block desc to wrap the original op nodes of + // the subgraph + void InsertNewNode(SSAGraph* graph, + int subgraph_idx, + const std::vector& subgraph_nodes); + + protected: + SSAGraph* graph_{nullptr}; + SubgraphTeller teller_; + int min_subgraph_size_; +}; + +void ExtractInputsOutputs(const std::vector& op_nodes, + std::unordered_set* input_var_nodes, + std::unordered_set* weight_var_nodes, + std::unordered_set* output_var_nodes, + std::unordered_set* local_var_nodes, + std::unordered_set* unused_var_nodes); + +std::unordered_set GetNodes2RM( + const std::vector& op_nodes, + const std::vector>& excluded_var_nodes); + +std::vector GetTopologicalOrder( + const std::unordered_set& unordered_nodes); + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/subgraph/subgraph_program_pass_test.cc b/lite/core/mir/subgraph/subgraph_detector_test.cc similarity index 65% rename from lite/core/mir/subgraph/subgraph_program_pass_test.cc rename to lite/core/mir/subgraph/subgraph_detector_test.cc index 22e20b81d831ff25df090a7565e671b9139122f7..3b0d7c5cd5c8a0d0901750148359f430b6d49894 100644 --- a/lite/core/mir/subgraph/subgraph_program_pass_test.cc +++ b/lite/core/mir/subgraph/subgraph_detector_test.cc @@ -12,68 +12,25 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "lite/core/mir/subgraph/subgraph_program_pass.h" +#include "lite/core/mir/subgraph/subgraph_detector.h" #include #include #include #include "lite/api/paddle_use_ops.h" #include "lite/api/paddle_use_passes.h" -#include "lite/core/mir/graph_visualize_pass.h" #include "lite/core/mir/ssa_graph.h" #include "lite/core/program.h" #include "lite/model_parser/cpp/program_desc.h" #include "lite/model_parser/model_parser.h" DEFINE_string(model_dir, "", "model_dir"); +DEFINE_string(model_file, "", "model file path of combined protobuf model"); +DEFINE_string(params_file, "", "params file path of combined protobuf model"); namespace paddle { namespace lite { -TEST(SubgraphTest, models) { - cpp::ProgramDesc program_desc; - auto scope = std::make_shared(); - // LoadModelPb(FLAGS_model_dir, - // FLAGS_model_dir + "/model", - // FLAGS_model_dir + "/params", - // scope.get(), - // &program_desc, - // true); - LoadModelPb(FLAGS_model_dir, "", "", scope.get(), &program_desc); - std::vector valid_places({ - Place{TARGET(kHost), PRECISION(kFloat)}, -#ifdef LITE_WITH_ARM - Place{TARGET(kARM), PRECISION(kFloat)}, -#endif -#ifdef LITE_WITH_NPU - Place{TARGET(kNPU), PRECISION(kFloat)}, -#endif -#ifdef LITE_WITH_XPU - Place{TARGET(kXPU), PRECISION(kFloat)}, -#endif - }); - lite::Program program(program_desc, scope, valid_places); - auto graph = std::unique_ptr(new mir::SSAGraph()); - graph->Build(program, valid_places); - - std::vector supported_op_types{"concat", - "conv2d", - "depthwise_conv2d", - "batch_norm", - "scale", - "pool2d", - "mul", - "elementwise_add", - "softmax", - "split", - "relu", - "reshape2", - "transpose2"}; - auto* pass = new mir::subgraph::SubgraphProgramPass; - ASSERT_EQ(pass->FuseSubgraph(graph, supported_op_types), 1); - LOG(INFO) << "After NPU Pass \n" << Visualize(graph.get()); -} - -// return output_var_names +// The helper functions for building model manually std::vector AddFCDesc( cpp::BlockDesc* block_desc, const std::shared_ptr& scope, @@ -84,24 +41,23 @@ std::vector AddFCDesc( static int id = 0; std::string prefix = "fc_" + std::to_string(id); auto* op_desc = block_desc->AddOp(); - auto* wgt = block_desc->AddVar(); - auto* bias = block_desc->AddVar(); - auto* out = block_desc->AddVar(); + auto* wgt = block_desc->AddVar(); wgt->SetName(prefix + "_W"); - bias->SetName(prefix + "_Bias"); - out->SetName(prefix + "_Out"); - std::vector out_var_names{prefix + "_Out"}; - - auto* wtensor = scope->Var(prefix + "_W")->GetMutable(); + auto* wtensor = scope->Var(prefix + "_W")->GetMutable(); wtensor->Resize(wshape); wtensor->mutable_data(); - auto* btensor = scope->Var(prefix + "_Bias")->GetMutable(); + auto* bias = block_desc->AddVar(); + bias->SetName(prefix + "_Bias"); + auto* btensor = scope->Var(prefix + "_Bias")->GetMutable(); btensor->Resize({wshape[1]}); btensor->mutable_data(); - scope->Var(prefix + "_Out")->GetMutable(); + auto* out = block_desc->AddVar(); + out->SetName(prefix + "_Out"); + std::vector out_var_names{prefix + "_Out"}; + scope->Var(prefix + "_Out")->GetMutable(); op_desc->SetType("fc"); op_desc->SetInput("Input", input_var_names); @@ -127,7 +83,7 @@ std::vector AddElementwiseAddDesc( out->SetName(prefix + "_Out"); std::vector out_var_names{prefix + "_Out"}; - scope->Var(prefix + "_Out")->GetMutable(); + scope->Var(prefix + "_Out")->GetMutable(); op_desc->SetType("elementwise_add"); op_desc->SetInput("X", input_X_names); @@ -151,7 +107,7 @@ std::vector AddFeedDesc( out->SetName(prefix + "_Out"); std::vector out_var_names{prefix + "_Out"}; - scope->Var(prefix + "_Out")->GetMutable(); + scope->Var(prefix + "_Out")->GetMutable(); op_desc->SetType("feed"); op_desc->SetInput("X", input_X_names); @@ -174,7 +130,7 @@ std::vector AddFetchDesc( out->SetName(prefix + "_Out"); std::vector out_var_names{prefix + "_Out"}; - scope->Var(prefix + "_Out")->GetMutable(); + scope->Var(prefix + "_Out")->GetMutable(); op_desc->SetType("fetch"); op_desc->SetInput("X", input_X_names); @@ -184,41 +140,88 @@ std::vector AddFetchDesc( return out_var_names; } -std::unique_ptr BuildSimpleNet( - cpp::ProgramDesc* program_desc, - const std::shared_ptr& scope, - const std::vector& valid_places) { - program_desc->ClearBlocks(); - auto* block_desc = program_desc->AddBlock(); +TEST(Subgraph, detect_simple_model) { + cpp::ProgramDesc program_desc; + std::vector valid_places{{TARGET(kHost), PRECISION(kFloat)}}; + auto scope = std::make_shared(); + // Build a simple network + program_desc.ClearBlocks(); + auto* block_desc = program_desc.AddBlock(); block_desc->ClearOps(); block_desc->ClearVars(); - auto* var_desc = block_desc->AddVar(); var_desc->SetName("feed_var"); - auto* feed_var = scope->Var("feed_var")->GetMutable(); + auto* feed_var = scope->Var("feed_var")->GetMutable(); feed_var->Resize({1, 4}); auto fc1_out = AddFCDesc(block_desc, scope, {"feed_var"}, {4, 5}); auto fc2_out = AddFCDesc(block_desc, scope, fc1_out, {5, 2}); - - lite::Program program(*program_desc, scope, valid_places); + Program program(program_desc, scope, valid_places); auto graph = std::unique_ptr(new mir::SSAGraph()); graph->Build(program, valid_places); - - return graph; + // Apply subgraph detector and check results + auto teller = [](mir::Node* node) { + if (!node->IsStmt()) return false; + auto& stmt = node->AsStmt(); + auto op_type = stmt.op_type(); + const std::vector supported_types = {"fc"}; + return std::find(supported_types.begin(), supported_types.end(), op_type) != + supported_types.end(); + }; + std::vector> subgraphs = + mir::SubgraphDetector(graph.get(), teller)(); + ASSERT_EQ(subgraphs.size(), 1); + ASSERT_EQ(graph->nodes().size(), 9); + mir::SubgraphVisualizer(graph.get(), subgraphs)(); } -TEST(SubGraphTest, SimpleNet) { +TEST(Subgraph, detect_custom_model) { + if (FLAGS_model_dir.empty() && FLAGS_model_file.empty() && + FLAGS_params_file.empty()) { + LOG(INFO) << "Using --model_dir, or --model_file and --params_file to set " + "the path of model files."; + return; + } cpp::ProgramDesc program_desc; - std::vector places{{TARGET(kHost), PRECISION(kFloat)}}; auto scope = std::make_shared(); - auto graph = BuildSimpleNet(&program_desc, scope, places); - - std::vector supported_op_types{"fc"}; - auto* pass = new mir::subgraph::SubgraphProgramPass; - ASSERT_EQ(pass->FuseSubgraph(graph, supported_op_types), 1); - - ASSERT_EQ(graph->nodes().size(), 9); - // LOG(INFO) << "After NPU Pass \n" << Visualize(graph.get()); + LoadModelPb(FLAGS_model_dir, + FLAGS_model_file, + FLAGS_params_file, + scope.get(), + &program_desc, + !FLAGS_model_file.empty() && !FLAGS_params_file.empty(), + false); + std::vector valid_places({ +#ifdef LITE_WITH_ARM + Place{TARGET(kARM), PRECISION(kFloat)}, +#endif +#ifdef LITE_WITH_X86 + Place{TARGET(kX86), PRECISION(kFloat)}, +#endif +#ifdef LITE_WITH_NPU + Place{TARGET(kNPU), PRECISION(kFloat)}, +#endif +#ifdef LITE_WITH_XPU + Place{TARGET(kXPU), PRECISION(kFloat)}, +#endif + }); + Program program(program_desc, scope, valid_places); + auto graph = std::unique_ptr(new mir::SSAGraph()); + graph->Build(program, valid_places); + // Apply subgraph detector and check results + auto teller = [](mir::Node* node) { + if (!node->IsStmt()) return false; + auto& stmt = node->AsStmt(); + auto op_type = stmt.op_type(); + const std::vector unsupported_types = { + "feed", "fetch", "subgraph"}; + return std::find(unsupported_types.begin(), + unsupported_types.end(), + op_type) == unsupported_types.end(); + }; + std::vector> subgraphs = + mir::SubgraphDetector(graph.get(), teller)(); + ASSERT_EQ(subgraphs.size(), 1); + mir::SubgraphVisualizer(graph.get(), subgraphs)(); } } // namespace lite diff --git a/lite/core/mir/subgraph/subgraph_pass.cc b/lite/core/mir/subgraph/subgraph_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..5e2cecd277820ab39b5a25db6159591157982d01 --- /dev/null +++ b/lite/core/mir/subgraph/subgraph_pass.cc @@ -0,0 +1,79 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/subgraph/subgraph_pass.h" +#include +#include +#include +#include +#include +#include "lite/core/mir/pass_registry.h" +#include "lite/core/mir/subgraph/subgraph_detector.h" + +namespace paddle { +namespace lite { +namespace mir { + +void NPUSubgraphPass::Apply(const std::unique_ptr& graph) { + std::unordered_set supported_lists; +#define USE_SUBGRAPH_BRIDGE(op_type, target) supported_lists.insert(#op_type); +#include "lite/kernels/npu/bridges/paddle_use_bridges.h" +#undef USE_SUBGRAPH_BRIDGE + auto teller = [&](Node* node) { + if (!node->IsStmt()) return false; + auto& stmt = node->AsStmt(); + return supported_lists.count(stmt.op_type()) != 0; + }; + SubgraphFuser fuser(graph.get(), teller, 1 /* min_subgraph_size */); + fuser(); +} + +void XPUSubgraphPass::Apply(const std::unique_ptr& graph) { + std::unordered_set supported_lists; +#define USE_SUBGRAPH_BRIDGE(op_type, target) supported_lists.insert(#op_type); +#include "lite/kernels/xpu/bridges/paddle_use_bridges.h" +#undef USE_SUBGRAPH_BRIDGE + auto teller = [&](Node* node) { + if (!node->IsStmt()) return false; + auto& stmt = node->AsStmt(); + return supported_lists.count(stmt.op_type()) != 0; + }; + SubgraphFuser fuser(graph.get(), teller, 1 /* min_subgraph_size */); + fuser(); +} + +void BMSubgraphPass::Apply(const std::unique_ptr& graph) { + std::unordered_set supported_lists; +#define USE_SUBGRAPH_BRIDGE(op_type, target) supported_lists.insert(#op_type); +#include "lite/kernels/bm/bridges/paddle_use_bridges.h" +#undef USE_SUBGRAPH_BRIDGE + auto teller = [&](Node* node) { + if (!node->IsStmt()) return false; + auto& stmt = node->AsStmt(); + return supported_lists.count(stmt.op_type()) != 0; + }; + SubgraphFuser fuser(graph.get(), teller, 1 /* min_subgraph_size */); + fuser(); +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(npu_subgraph_pass, paddle::lite::mir::NPUSubgraphPass) + .BindTargets({TARGET(kNPU)}); +REGISTER_MIR_PASS(xpu_subgraph_pass, paddle::lite::mir::XPUSubgraphPass) + .BindTargets({TARGET(kXPU)}); +REGISTER_MIR_PASS(bm_subgraph_pass, paddle::lite::mir::BMSubgraphPass) + .BindTargets({TARGET(kBM)}); diff --git a/lite/core/mir/subgraph/subgraph_pass.h b/lite/core/mir/subgraph/subgraph_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..1ba0f2ab4aa52c384f4175de0eb34475b34fb94c --- /dev/null +++ b/lite/core/mir/subgraph/subgraph_pass.h @@ -0,0 +1,42 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/mir/pass.h" + +namespace paddle { +namespace lite { +namespace mir { + +class NPUSubgraphPass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override; +}; + +class XPUSubgraphPass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override; +}; + +class BMSubgraphPass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/subgraph/subgraph_pass_test.cc b/lite/core/mir/subgraph/subgraph_pass_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..247795a86ce2cbe962b161311f7845622ee3983e --- /dev/null +++ b/lite/core/mir/subgraph/subgraph_pass_test.cc @@ -0,0 +1,227 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "lite/api/paddle_api.h" +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/api/paddle_use_passes.h" +#include "lite/api/test_helper.h" +#include "lite/utils/cp_logging.h" + +DEFINE_string(model_file, "", "model file path of combined protobuf model"); +DEFINE_string(params_file, "", "params file path of combined protobuf model"); +DEFINE_string(optimized_model_dir, "", "path of optimized naive buffer model"); +DEFINE_string(input_tensor_shape, "1,3,224,224", "shape of input tensors"); +DEFINE_string(input_tensor_type, "float32", "data type of input tensors"); +DEFINE_string(output_tensor_type, "float32", "data type of output tensors"); + +namespace paddle { +namespace lite { + +// The helper functions for loading and running model from command line and +// verifying output data +std::vector TypeParsing(std::string text) { + std::vector types; + while (!text.empty()) { + size_t index = text.find_first_of(":"); + std::string type = text.substr(0, index); + VLOG(3) << type; + types.push_back(type); + if (index == std::string::npos) { + break; + } else { + text = text.substr(index + 1); + } + } + return types; +} + +std::vector> ShapeParsing(std::string text) { + std::vector> shapes; + while (!text.empty()) { + size_t index = text.find_first_of(":"); + std::string slice = text.substr(0, index); + std::vector shape; + while (!slice.empty()) { + size_t index = slice.find_first_of(","); + int d = atoi(slice.substr(0, index).c_str()); + VLOG(3) << d; + shape.push_back(d); + if (index == std::string::npos) { + break; + } else { + slice = slice.substr(index + 1); + } + } + shapes.push_back(shape); + if (index == std::string::npos) { + break; + } else { + text = text.substr(index + 1); + } + } + return shapes; +} + +int64_t ShapeProduction(std::vector shape) { + int64_t s = 1; + for (int64_t dim : shape) { + s *= dim; + } + return s; +} + +void FillInputTensors( + const std::shared_ptr& predictor, + const std::vector>& input_tensor_shape, + const std::vector& input_tensor_type, + const float value) { +#define FILL_TENSOR_WITH_TYPE(type) \ + auto input_tensor_data = input_tensor->mutable_data(); \ + for (int j = 0; j < input_tensor_size; j++) { \ + input_tensor_data[j] = static_cast(value); \ + } + for (int i = 0; i < input_tensor_shape.size(); i++) { + auto input_tensor = predictor->GetInput(i); + input_tensor->Resize(input_tensor_shape[i]); + auto input_tensor_size = ShapeProduction(input_tensor->shape()); + if (input_tensor_type[i] == "float32") { + FILL_TENSOR_WITH_TYPE(float) + } else if (input_tensor_type[i] == "int64") { + FILL_TENSOR_WITH_TYPE(int64_t) + } + } +#undef FILL_TENSOR_WITH_TYPE +} + +void CheckOutputTensors( + const std::shared_ptr& tar_predictor, + const std::shared_ptr& ref_predictor, + const std::vector& output_tensor_type) { +#define CHECK_TENSOR_WITH_TYPE(type) \ + auto tar_output_tensor_data = tar_output_tensor->data(); \ + auto ref_output_tensor_data = ref_output_tensor->data(); \ + for (size_t j = 0; j < ref_output_tensor_size; j++) { \ + auto abs_diff = \ + std::fabs(tar_output_tensor_data[j] - ref_output_tensor_data[j]); \ + auto rel_diff = abs_diff / (std::fabs(ref_output_tensor_data[j]) + 1e-6); \ + VLOG(5) << "val: " << tar_output_tensor_data[j] \ + << " ref: " << ref_output_tensor_data[j] \ + << " abs_diff: " << abs_diff << " rel_diff: " << rel_diff; \ + EXPECT_LT(rel_diff, 0.1); \ + } + for (int i = 0; i < output_tensor_type.size(); i++) { + auto tar_output_tensor = tar_predictor->GetOutput(i); + auto ref_output_tensor = ref_predictor->GetOutput(i); + auto tar_output_tensor_size = ShapeProduction(tar_output_tensor->shape()); + auto ref_output_tensor_size = ShapeProduction(ref_output_tensor->shape()); + EXPECT_EQ(tar_output_tensor_size, ref_output_tensor_size); + if (output_tensor_type[i] == "float32") { + CHECK_TENSOR_WITH_TYPE(float) + } else if (output_tensor_type[i] == "int64") { + CHECK_TENSOR_WITH_TYPE(int64_t) + } + } +#undef CHECK_TENSOR_WITH_TYPE +} + +std::shared_ptr TestModel( + const std::string& model_dir, + const std::string& model_file, + const std::string& params_file, + const std::vector& valid_places, + const std::vector>& input_tensor_shape, + const std::vector& input_tensor_type, + const std::string& optimized_model_dir) { + // Generate optimized model + lite_api::CxxConfig cxx_config; + cxx_config.set_model_dir(model_dir); + cxx_config.set_model_file(model_file); + cxx_config.set_param_file(params_file); + cxx_config.set_valid_places(valid_places); + auto predictor = lite_api::CreatePaddlePredictor(cxx_config); + predictor->SaveOptimizedModel(optimized_model_dir, + lite_api::LiteModelType::kNaiveBuffer); + // Load optimized model + lite_api::MobileConfig mobile_config; + mobile_config.set_model_from_file(optimized_model_dir + ".nb"); + mobile_config.set_power_mode(lite_api::PowerMode::LITE_POWER_HIGH); + mobile_config.set_threads(1); + predictor = lite_api::CreatePaddlePredictor(mobile_config); + FillInputTensors(predictor, input_tensor_shape, input_tensor_type, 1); + // Run optimized model + for (int i = 0; i < FLAGS_warmup; i++) { + predictor->Run(); + } + for (int i = 0; i < FLAGS_repeats; i++) { + auto start = GetCurrentUS(); + predictor->Run(); + LOG(INFO) << i << ", " << GetCurrentUS() - start << "us"; + } + return predictor; +} + +TEST(Subgraph, generate_model_and_check_precision) { + if (FLAGS_model_dir.empty() && FLAGS_model_file.empty() && + FLAGS_params_file.empty()) { + LOG(INFO) << "Using --model_dir, or --model_file and --params_file to set " + "the path of model files."; + return; + } + // Parsing the shape of input tensors from strings, supported formats: + // "1,3,224,224" and "1,3,224,224:1,80" + auto input_tensor_shape = ShapeParsing(FLAGS_input_tensor_shape); + // Parsing the data type of input and output tensors from strings, supported + // formats: "float32" and "float32:int64:int8" + auto input_tensor_type = TypeParsing(FLAGS_input_tensor_type); + auto output_tensor_type = TypeParsing(FLAGS_output_tensor_type); + std::vector valid_places({ +#ifdef LITE_WITH_ARM + lite_api::Place{TARGET(kARM), PRECISION(kFloat)}, +#endif +#ifdef LITE_WITH_X86 + lite_api::Place{TARGET(kX86), PRECISION(kFloat)}, +#endif + }); + // Generate and run optimized model on CPU as the reference predictor + auto ref_predictor = TestModel(FLAGS_model_dir, + FLAGS_model_file, + FLAGS_params_file, + valid_places, + input_tensor_shape, + input_tensor_type, + FLAGS_optimized_model_dir + "_ref_opt_model"); +// Generate and run optimized model on NPU/XPU as the target predictor +#ifdef LITE_WITH_NPU + valid_places.push_back(lite_api::Place{TARGET(kNPU), PRECISION(kFloat)}); +#endif +#ifdef LITE_WITH_XPU + valid_places.push_back(lite_api::Place{TARGET(kXPU), PRECISION(kFloat)}); +#endif + auto tar_predictor = TestModel(FLAGS_model_dir, + FLAGS_model_file, + FLAGS_params_file, + valid_places, + input_tensor_shape, + input_tensor_type, + FLAGS_optimized_model_dir + "_tar_opt_model"); + // Check the difference of the output tensors between reference predictor and + // target predictor + CheckOutputTensors(tar_predictor, ref_predictor, output_tensor_type); +} + +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/subgraph/subgraph_program_pass.cc b/lite/core/mir/subgraph/subgraph_program_pass.cc deleted file mode 100644 index 719a01dfd892f83da5e1d9b1efa6df758612acc7..0000000000000000000000000000000000000000 --- a/lite/core/mir/subgraph/subgraph_program_pass.cc +++ /dev/null @@ -1,345 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "lite/core/mir/subgraph/subgraph_program_pass.h" -#include -#include -#include -#include -#include "lite/core/mir/graph_visualize_pass.h" -#include "lite/core/mir/pass_registry.h" -#include "lite/core/mir/pattern_matcher.h" - -namespace paddle { -namespace lite { -namespace mir { -namespace subgraph { - -std::unordered_map> -SubgraphProgramPass::ClassifySubgraph(const std::unique_ptr& graph) { - std::unordered_map> op_nodes; - for (auto& item : graph->StmtTopologicalOrder()) { - if (!item->IsStmt()) continue; - auto& stmt = item->AsStmt(); - int sub_id = stmt.subgraph_id(); - if (sub_id < 1) continue; - if (!op_nodes.count(sub_id)) { - op_nodes[sub_id] = std::unordered_set(); - } - op_nodes.at(sub_id).insert(item); - } - return op_nodes; -} - -cpp::OpDesc SubgraphProgramPass::GenGraphOpDesc( - const std::string& weight_var_name, - const std::vector& in_var_names, - const std::vector& out_var_names) { - cpp::OpDesc op_desc; - op_desc.SetType("graph_op"); - op_desc.SetInput("Inputs", in_var_names); - op_desc.SetInput("Weight", {weight_var_name}); - op_desc.SetOutput("Outputs", out_var_names); - return op_desc; -} - -void SubgraphProgramPass::InsertNewNode( - const std::unique_ptr& graph, - const std::string& weight_var_name, - Scope* scope, - const std::vector& valid_places, - std::unordered_set in_data_vars, - std::unordered_set in_wgt_vars, - std::unordered_set out_data_vars, - std::unordered_set out_unused_vars) { - std::vector in_var_names; - std::vector out_var_names; - for (auto i : in_data_vars) { - in_var_names.push_back(i->AsArg().name); - } - for (auto i : out_data_vars) { - out_var_names.push_back(i->AsArg().name); - } - - auto op_desc = GenGraphOpDesc(weight_var_name, in_var_names, out_var_names); - - auto graph_op = LiteOpRegistry::Global().Create("graph_op"); - graph_op->Attach(op_desc, scope); - auto* new_op_node = graph->GraphCreateInstructNode(graph_op, valid_places); - - for (auto& in_var : in_data_vars) { - IR_NODE_LINK_TO(in_var, new_op_node); - } - for (auto& in_var : in_wgt_vars) { - IR_NODE_LINK_TO(in_var, new_op_node); - } - for (auto& out_var : out_data_vars) { - IR_OP_VAR_LINK(new_op_node, out_var); - } - for (auto& out_var : out_unused_vars) { - IR_OP_VAR_LINK(new_op_node, out_var); - } - - // add weight node to store pre-compilied NPU model - auto new_weight_node = graph->NewArgumentNode(weight_var_name); - new_weight_node->AsArg().is_weight = true; - new_weight_node->AsArg().is_persist = true; - DirectedLink(new_weight_node, new_op_node); - - // assign context - auto& inst = new_op_node->AsStmt(); - inst.picked_kernel().SetContext( - ContextScheduler::Global().NewContext(inst.picked_kernel().target())); -} - -void SubgraphProgramPass::SortHelper( - Node* node, - const std::unordered_set& nodes_all, - std::unordered_set* visited_nodes, - std::vector* ret) { - for (auto& var_node : node->inlinks) { - if (var_node->inlinks.empty()) continue; - auto* op_node = var_node->inlinks.front(); - if (nodes_all.count(op_node) && !visited_nodes->count(op_node)) { - SortHelper(op_node, nodes_all, visited_nodes, ret); - } - } - ret->push_back(node); - visited_nodes->insert(node); -} - -std::vector SubgraphProgramPass::GetTopologicalOrder( - const std::unordered_set& nodes) { - std::unordered_set visited; - std::vector ret; - for (auto& node : nodes) { - if (!node->IsStmt()) continue; - if (visited.count(node)) continue; - SortHelper(node, nodes, &visited, &ret); - } - return ret; -} - -void SubgraphProgramPass::FindInputOutputVars( - const std::unordered_set& op_nodes, - std::unordered_set* in_data_vars, - std::unordered_set* in_wgt_vars, - std::unordered_set* out_data_vars, - std::unordered_set* out_unused_vars) { - for (auto& op_node : op_nodes) { - for (auto& in_var : op_node->inlinks) { - if (in_var->AsArg().is_weight) { - in_wgt_vars->insert(in_var); - continue; - } - if (!in_var->inlinks.empty()) { - // var can only come from one op node, so use front - auto* pre_op_node = in_var->inlinks.front(); - if (op_nodes.count(pre_op_node)) { - continue; - } - } - in_data_vars->insert(in_var); - } - for (auto& out_var : op_node->outlinks) { - if (out_var->outlinks.empty()) { - // the next op is empty so this var is actually unused - out_unused_vars->insert(out_var); - continue; - } - // var can have more than one next op node - // so, if any one in the op_nodes then continue - bool next_op_in_nodes = false; - for (auto& next_op_node : out_var->outlinks) { - if (op_nodes.count(next_op_node)) { - next_op_in_nodes = true; - } - } - if (next_op_in_nodes) { - continue; - } - - out_data_vars->insert(out_var); - } - } -} - -std::unordered_set SubgraphProgramPass::GetNode2rm( - const std::unordered_set& op_nodes, - const std::vector>& excluded_nodes) { - std::unordered_set nodes2rm(op_nodes.begin(), op_nodes.end()); - for (auto& op_node : op_nodes) { - for (auto& in_var : op_node->inlinks) { - if (!nodes2rm.count(in_var)) { - nodes2rm.insert(in_var); - } - } - for (auto& out_var : op_node->outlinks) { - if (!nodes2rm.count(out_var)) { - nodes2rm.insert(out_var); - } - } - } - // some nodes should not be removed - for (auto& e : excluded_nodes) { - for (auto& i : e) { - if (nodes2rm.count(i)) { - nodes2rm.erase(i); - } - } - } - return nodes2rm; -} - -void SubgraphProgramPass::InferOnce(const std::unique_ptr& graph) { - for (auto& item : graph->StmtTopologicalOrder()) { - if (!item->IsStmt()) continue; - auto& stmt = item->AsStmt(); - auto& op = stmt.op(); - auto scope = op->scope(); - std::string op_type = op->op_info()->Type(); - // check the dimension of input variables in the scope, must not be empty ! - if (op_type == "feed") { - auto input_var_names = op->op_info()->output_names(); - CHECK_GE(input_var_names.size(), 1); - for (auto input_var_name : input_var_names) { - auto input_var = scope->FindVar(input_var_name); - CHECK(input_var) << "No input variable '" << input_var_name - << "' found in scope " << scope; - auto input = input_var->GetMutable(); - CHECK(!input->dims().empty()) << "The dimension of input variable '" - << input_var_name - << "' can not be empty."; - } - continue; - } - if (op_type == "fetch") { - continue; - } - op->CheckShape(); - op->InferShape(); - -#ifndef LITH_WITH_XPU - // TOOD(xxx): remove Launch() at last - auto& kkks = stmt.kernels(); - if (!kkks.empty()) { - auto& kk = stmt.kernels().front(); - if (kk) { - kk->Launch(); - } - } -#endif - } -} - -void SubgraphProgramPass::InitSubgraphID( - const std::unique_ptr& graph, - const std::vector& supported_op_types) { - for (auto& item : graph->StmtTopologicalOrder()) { - if (!item->IsStmt()) continue; - auto& stmt = item->AsStmt(); - stmt.ClearSubgraphID(); - if (std::find(supported_op_types.begin(), - supported_op_types.end(), - stmt.op_type()) != supported_op_types.end()) { - stmt.SetSubgraphID(0); - LOG(INFO) << "supported " << stmt.op_type(); - } else { - LOG(INFO) << "======= not supported " << stmt.op_type(); - } - } -} - -// mark current and all output supported nodes -void SubgraphProgramPass::ChangeAllOutConnectedID(Node* node, - int to_id, - int from_id) { - if (!node) return; - if (node->IsStmt()) { - auto& stmt = node->AsStmt(); - if (stmt.subgraph_id() == from_id) { - stmt.SetSubgraphID(to_id); - for (auto& i : node->outlinks) { - ChangeAllOutConnectedID(i, to_id, from_id); - } - } else { - LOG(INFO) << "failed op type:" << stmt.op_type(); - return; - } - } else { - // this it arg node - bool all_out_op_supported = true; - for (auto& i : node->outlinks) { - if (!i->IsStmt()) return; - auto& stmt = i->AsStmt(); - if (stmt.subgraph_id() < from_id) { - all_out_op_supported = false; - } - } - if (!all_out_op_supported) { - return; - } - for (auto& i : node->outlinks) { - CHECK(i->IsStmt()); - auto& stmt = i->AsStmt(); - if (stmt.subgraph_id() == from_id) { - stmt.SetSubgraphID(to_id); - for (auto& o : i->outlinks) { - ChangeAllOutConnectedID(o, to_id, from_id); - } - } - } - } -} - -int SubgraphProgramPass::FuseSubgraphID( - const std::unique_ptr& graph) { - int sub_id = 1; // id start from 1 not 0 - for (auto& item : graph->StmtTopologicalOrder()) { - // bool inputvar = false; - if (!item->IsStmt()) continue; - auto& stmt = item->AsStmt(); - /* - if (stmt.subgraph_id() == -1) { - for (auto& i : item->outlinks) { - for (auto& j : i->outlinks) { - if (j->IsStmt()) { - auto& jstmt = j->AsStmt(); - if (jstmt.subgraph_id() == 0) inputvar = true; - } - } - } - } - */ - if (stmt.subgraph_id() != 0) continue; - ChangeAllOutConnectedID(item, sub_id); - sub_id++; - } - return sub_id - 1; -} - -int SubgraphProgramPass::FuseSubgraph( - const std::unique_ptr& graph, - const std::vector& supported_op_types) { - InitSubgraphID(graph, supported_op_types); - return FuseSubgraphID(graph); -} -} // namespace subgraph -} // namespace mir -} // namespace lite -} // namespace paddle - -REGISTER_MIR_PASS(subgraph_program_pass, - paddle::lite::mir::subgraph::SubgraphProgramPass) - .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/subgraph/subgraph_program_pass.h b/lite/core/mir/subgraph/subgraph_program_pass.h deleted file mode 100644 index 24c0233bbb428a71fa5645b23573494b5067d8b1..0000000000000000000000000000000000000000 --- a/lite/core/mir/subgraph/subgraph_program_pass.h +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include -#include -#include -#include -#include "lite/core/mir/pass.h" - -namespace paddle { -namespace lite { -namespace mir { -namespace subgraph { - -class SubgraphProgramPass : public ProgramPass { - public: - using key2nodes_t = std::map; - - // make all the linked ops in subgraph with same subgraph_id - // return the fused subgraph numbers - int FuseSubgraph(const std::unique_ptr& graph, - const std::vector& supported_op_types); - - void Apply(const std::unique_ptr& graph) override{}; - - protected: - void InferOnce(const std::unique_ptr& graph); - - // clear all subgraph id and mark all ops, which could be fuse, as id zero - void InitSubgraphID(const std::unique_ptr& graph, - const std::vector& supported_op_types); - - // make all the linked ops in subgraph with same subgraph_id - // return the fused subgraph numbers - int FuseSubgraphID(const std::unique_ptr& graph); - - // // GenerateFusedGraph: - // std::unique_ptr GenerateFusedGraph(const - // std::unique_ptr& graph, int sub_num); - void ChangeAllOutConnectedID(Node* node, int to_id, int from_id = 0); - - // Below function cloud be useful in child classes // - // classify node by subgraph id - std::unordered_map> ClassifySubgraph( - const std::unique_ptr& graph); - - // generate the graph op desc - cpp::OpDesc GenGraphOpDesc(const std::string& weight_var_name, - const std::vector& in_var_names, - const std::vector& out_var_names); - - // insert a new graph op node - void InsertNewNode(const std::unique_ptr& graph, - const std::string& weight_var_name, - Scope* scope, - const std::vector& valid_places, - std::unordered_set in_data_vars, - std::unordered_set in_wgt_vars, - std::unordered_set out_data_vars, - std::unordered_set out_unused_vars); - - // Sort and return the topology order of nodes set - std::vector GetTopologicalOrder( - const std::unordered_set& nodes); - - // find all input data vars, input weight vars, - // output data vars and output vars from the nodes - void FindInputOutputVars(const std::unordered_set& op_nodes, - std::unordered_set* in_data_vars, - std::unordered_set* in_wgt_vars, - std::unordered_set* out_data_vars, - std::unordered_set* out_unused_vars); - - // return the node to remove in the subgraph - std::unordered_set GetNode2rm( - const std::unordered_set& op_nodes, - const std::vector>& excluded_nodes); - - private: - // sort nodes to operational sequence - void SortHelper(Node* node, - const std::unordered_set& nodes_all, - std::unordered_set* visited_nodes, - std::vector* ret); -}; - -} // namespace subgraph -} // namespace mir -} // namespace lite -} // namespace paddle diff --git a/lite/core/mir/type_target_cast_pass.cc b/lite/core/mir/type_target_cast_pass.cc index b008faa687474a88988adb9da81c594306298b26..ae74bd8d4d5647139a13509dfda0bb2b41ecc5c7 100644 --- a/lite/core/mir/type_target_cast_pass.cc +++ b/lite/core/mir/type_target_cast_pass.cc @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include "lite/core/mir/graph_visualize_pass.h" @@ -35,18 +36,23 @@ void TypeTargetTransformPass::Apply(const std::unique_ptr& graph) { CHECK(!valid_places_.empty()); + // record the copied node. + std::unordered_map copied_nodes; + for (auto& node : nodes) { if (!node->IsStmt() || node->AsStmt().op_type() == "while") continue; auto inlinks = node->inlinks; for (auto* in : inlinks) { - ComplementInputs(graph.get(), node, in); + ComplementInputs(graph.get(), node, in, &copied_nodes); } } } -void TypeTargetTransformPass::ComplementInputs(SSAGraph* graph, - Node* inst_node, - Node* in) { +void TypeTargetTransformPass::ComplementInputs( + SSAGraph* graph, + Node* inst_node, + Node* in, + std::unordered_map* copied_nodes) { // If this input is out of date. if (inst_node->inlinks.end() == std::find(inst_node->inlinks.begin(), inst_node->inlinks.end(), in)) @@ -67,8 +73,13 @@ void TypeTargetTransformPass::ComplementInputs(SSAGraph* graph, << " for kernel " << inst.op()->DebugString() << " " << *in->AsArg().type << " -> " << *decl_arg_type; // Add an IoCopy instruction to make the input compatible with other dist. - AddIoCopyInst( - *in->AsArg().type, *decl_arg_type, in, graph, inst_node, valid_places_); + AddIoCopyInst(*in->AsArg().type, + *decl_arg_type, + in, + graph, + inst_node, + copied_nodes, + valid_places_); } } @@ -78,128 +89,132 @@ void TypeTargetTransformPass::AddIoCopyInst( Node* in, SSAGraph* graph, Node* inst_node, + std::unordered_map* copied_nodes, const std::vector& valid_places) { CHECK(!valid_places.empty()) << "valid_place should be set"; // var -> new_transform_op -> new_var -> inst // So there will be a new Argument node and a new IoCopy Statement Node. CHECK(in->IsArg()); + // auto node_id = [&] { return graph->nodes().size(); }; auto io_copy_output_name = string_format("%s/target_trans", in->AsArg().name.c_str()); // string_format("%s/target_trans/%d", in->AsArg().name.c_str(), node_id()); - // TODO(MyPandaShaoxiang) should set same place with input? - auto* io_copy_output_arg = graph->NewArgumentNode(io_copy_output_name); - // Set the place for io_copy_output_arg node, the target should be equal to - // to.target() - // The precision and layout should be equal to from.precision(), from.layout() - io_copy_output_arg->AsArg().type = - LiteType::GetTensorTy(to.target(), from.precision(), from.layout()); - auto* io_copy_inst = graph->NewInstructNode(); - - bool in_persist = in->AsArg().is_weight || in->AsArg().is_persist; - std::string io_copy_type = in_persist ? "io_copy_once" : "io_copy"; - io_copy_output_arg->AsArg().is_persist = in_persist; - // create Op and kernels. - auto io_copy_op = LiteOpRegistry::Global().Create(io_copy_type); - CHECK(io_copy_op) << "create op [" << io_copy_op << "] failed"; - // CHECK(io_copy_op); - // Create the new var manually. - inst_node->AsStmt().op()->scope()->Var(io_copy_output_name); - - // Create IoCopy Instruction. - cpp::OpDesc op_desc; - op_desc.SetType(io_copy_type); - op_desc.SetInput("Input", {in->AsArg().name}); - op_desc.SetOutput("Out", {io_copy_output_name}); - - io_copy_op->Attach(op_desc, inst_node->AsStmt().op()->scope()); - auto kernels = io_copy_op->CreateKernels(valid_places); - // fix(MyPandaShaoxiang): select kernel that input_dcl_type same as in.type - bool is_found = false; - std::vector> selected_kernels; - for (auto& kernel : kernels) { - const Type* in_arg_ty = kernel->GetInputDeclType("Input"); - const Type* out_arg_ty = kernel->GetOutputDeclType("Out"); - - VLOG(4) << "------ kernel info -------"; - VLOG(4) << "*in_arg_ty(io_copy kernel input):" << *in_arg_ty; - VLOG(4) << "from(last kernel output):" << from; - VLOG(4) << "out_arg_ty(io_copy kernel output):" << *out_arg_ty; - VLOG(4) << "to:" << to << "\n"; - - // kernel choose branch for opencl backend - // judge inst's target whether is kOpenCL - // Note: to == *decl_arg_type == in of inst, not output of last inst - // ignore [layout check] for layout between [to] and [from] - // Because all of origin opencl insts in model, are not default layout - // NCHW, - // so skip layout check. - // detailed node info see below: - // [*in->AsArg().type] -> [from]: out of inst's previous kernel - // [*decl_arg_type] -> [to]: input of inst, not output of last - // [in_arg_ty]: in of io_copy - // [out_arg_ty]: out of io_copy - // - // noto: replace LITE_WITH_OPENCL macro with judge input and output target - // of io_copy - if ((in_arg_ty->target() == TARGET(kOpenCL) || - out_arg_ty->target() == TARGET(kOpenCL)) && // judge OpenCL first - (TargetCompatibleTo(*in_arg_ty, from) && - PrecisionCompatibleTo(*in_arg_ty, from) && - DeviceCompatibleTo(*in_arg_ty, from) && - TargetCompatibleTo(*out_arg_ty, to))) { - VLOG(4) << "picked, opencl found"; - is_found = true; - } else if (TypeCompatible(*in_arg_ty, from) && - out_arg_ty->target() == to.target()) { - VLOG(4) << "picked"; - is_found = true; - } - if (is_found) { - selected_kernels.emplace_back(std::move(kernel)); - // we pick the kernel - io_copy_inst->AsStmt( - io_copy_type, std::move(selected_kernels), io_copy_op); - break; + if (copied_nodes->count(in->AsArg().name)) { + // Remove the old link + RemoveDirectedLink(in, inst_node); + + // Update the original instruction OpDesc. + // Update its input to the io_copy_output_name + // Add new link, newarg->inst + DirectedLink(copied_nodes->at(in->AsArg().name), + inst_node); // [io_copy kernel]'s output -> [current kernel] + + UpdateInstNode(in, graph, inst_node, io_copy_output_name); + } else { + // TODO(MyPandaShaoxiang) should set same place with input? + auto* io_copy_output_arg = graph->NewArgumentNode(io_copy_output_name); + // Set the place for io_copy_output_arg node, the target should be equal to + // to.target() + // The precision and layout should be equal to from.precision(), + // from.layout() + io_copy_output_arg->AsArg().type = + LiteType::GetTensorTy(to.target(), from.precision(), from.layout()); + auto* io_copy_inst = graph->NewInstructNode(); + + bool in_persist = in->AsArg().is_weight || in->AsArg().is_persist; + std::string io_copy_type = in_persist ? "io_copy_once" : "io_copy"; + io_copy_output_arg->AsArg().is_persist = in_persist; + // create Op and kernels. + auto io_copy_op = LiteOpRegistry::Global().Create(io_copy_type); + CHECK(io_copy_op) << "create op [" << io_copy_op << "] failed"; + // CHECK(io_copy_op); + // Create the new var manually. + inst_node->AsStmt().op()->scope()->Var(io_copy_output_name); + + // Create IoCopy Instruction. + cpp::OpDesc op_desc; + op_desc.SetType(io_copy_type); + op_desc.SetInput("Input", {in->AsArg().name}); + op_desc.SetOutput("Out", {io_copy_output_name}); + + io_copy_op->Attach(op_desc, inst_node->AsStmt().op()->scope()); + auto kernels = io_copy_op->CreateKernels(valid_places); + // fix(MyPandaShaoxiang): select kernel that input_dcl_type same as in.type + bool is_found = false; + std::vector> selected_kernels; + for (auto& kernel : kernels) { + const Type* in_arg_ty = kernel->GetInputDeclType("Input"); + const Type* out_arg_ty = kernel->GetOutputDeclType("Out"); + + VLOG(4) << "------ kernel info -------"; + VLOG(4) << "*in_arg_ty(io_copy kernel input):" << *in_arg_ty; + VLOG(4) << "from(last kernel output):" << from; + VLOG(4) << "out_arg_ty(io_copy kernel output):" << *out_arg_ty; + VLOG(4) << "to:" << to << "\n"; + + // kernel choose branch for opencl backend + // judge inst's target whether is kOpenCL + // Note: to == *decl_arg_type == in of inst, not output of last inst + // ignore [layout check] for layout between [to] and [from] + // Because all of origin opencl insts in model, are not default layout + // NCHW, + // so skip layout check. + // detailed node info see below: + // [*in->AsArg().type] -> [from]: out of inst's previous kernel + // [*decl_arg_type] -> [to]: input of inst, not output of last + // [in_arg_ty]: in of io_copy + // [out_arg_ty]: out of io_copy + // + // noto: replace LITE_WITH_OPENCL macro with judge input and output target + // of io_copy + if ((in_arg_ty->target() == TARGET(kOpenCL) || + out_arg_ty->target() == TARGET(kOpenCL)) && // judge OpenCL first + (TargetCompatibleTo(*in_arg_ty, from) && + PrecisionCompatibleTo(*in_arg_ty, from) && + DeviceCompatibleTo(*in_arg_ty, from) && + TargetCompatibleTo(*out_arg_ty, to))) { + VLOG(4) << "picked, opencl found"; + is_found = true; + } else if (TypeCompatible(*in_arg_ty, from) && + out_arg_ty->target() == to.target()) { + VLOG(4) << "picked"; + is_found = true; + } + + if (is_found) { + selected_kernels.emplace_back(std::move(kernel)); + // we pick the kernel + io_copy_inst->AsStmt( + io_copy_type, std::move(selected_kernels), io_copy_op); + (*copied_nodes)[in->AsArg().name] = io_copy_output_arg; + break; + } + + VLOG(4) << "not picked"; } - VLOG(4) << "not picked"; - } + CHECK(is_found) << "Can't find a io_copy kernel for io_copy op: " << from + << ":" << in->AsArg().name << " -> " << to << ":" + << inst_node->AsStmt().op_info()->Type(); + // Remove the old link + RemoveDirectedLink(in, inst_node); - CHECK(is_found) << "Can't find a io_copy kernel for io_copy op: " << from - << ":" << in->AsArg().name << " -> " << to << ":" - << inst_node->AsStmt().op_info()->Type(); - // Remove the old link - RemoveDirectedLink(in, inst_node); - - // Update the original instruction OpDesc. - // Update its input to the io_copy_output_name - // Add new link, var -> new_inst, new_inst->newarg, newarg->inst - DirectedLink(in, io_copy_inst); // [last kernel]'s output -> [io_copy kernel] - DirectedLink( - io_copy_inst, - io_copy_output_arg); // [io_copy kernel] -> [io_copy kernel]'s output - DirectedLink(io_copy_output_arg, - inst_node); // [io_copy kernel]'s output -> [current kernel] + // Update the original instruction OpDesc. + // Update its input to the io_copy_output_name + // Add new link, var -> new_inst, new_inst->newarg, newarg->inst + DirectedLink(in, + io_copy_inst); // [last kernel]'s output -> [io_copy kernel] + DirectedLink( + io_copy_inst, + io_copy_output_arg); // [io_copy kernel] -> [io_copy kernel]'s output + DirectedLink(io_copy_output_arg, + inst_node); // [io_copy kernel]'s output -> [current kernel] - // reset opdesc and update kernel information - UpdateInputTo(inst_node->AsStmt().op()->mutable_op_info(), - in->AsArg().name, - io_copy_output_name); - auto original_selected_kernel = - std::move(inst_node->AsStmt().kernels().front()); - auto update_op_info = *inst_node->AsStmt().op_info(); - // ResetOp() will change the Stmt op_info_ value, - // after that the old op_info_ value will be nullified. - // So, we can't pass `*inst_node->AsStmt().op_info()` into ResetOp. - // `update_op_info` is the copy of `*inst_node->AsStmt().op_info(). - // Whenever update the op_info of a stmt, we should call its ResetOp(). - inst_node->AsStmt().ResetOp(update_op_info, graph->valid_places()); - inst_node->AsStmt().kernels().clear(); - inst_node->AsStmt().kernels().emplace_back( - std::move(original_selected_kernel)); + UpdateInstNode(in, graph, inst_node, io_copy_output_name); + } std::string tmp; if (inst_node->AsStmt().op_info()->GetInputArgname("a", &tmp)) { @@ -220,6 +235,28 @@ void TypeTargetTransformPass::SetValidPlaces( valid_places_ = valid_places; } +void TypeTargetTransformPass::UpdateInstNode(Node* in, + SSAGraph* graph, + Node* inst_node, + std::string io_copy_output_name) { + // reset opdesc and update kernel information + UpdateInputTo(inst_node->AsStmt().op()->mutable_op_info(), + in->AsArg().name, + io_copy_output_name); + auto original_selected_kernel = + std::move(inst_node->AsStmt().kernels().front()); + auto update_op_info = *inst_node->AsStmt().op_info(); + // ResetOp() will change the Stmt op_info_ value, + // after that the old op_info_ value will be nullified. + // So, we can't pass `*inst_node->AsStmt().op_info()` into ResetOp. + // `update_op_info` is the copy of `*inst_node->AsStmt().op_info(). + // Whenever update the op_info of a stmt, we should call its ResetOp(). + inst_node->AsStmt().ResetOp(update_op_info, graph->valid_places()); + inst_node->AsStmt().kernels().clear(); + inst_node->AsStmt().kernels().emplace_back( + std::move(original_selected_kernel)); +} + } // namespace mir } // namespace lite } // namespace paddle diff --git a/lite/core/mir/type_target_cast_pass.h b/lite/core/mir/type_target_cast_pass.h index 8a8cfaf9f9282cb477f7b9dd404d6f869333221b..e9a275882f7c2cb813c1c0b8add5cc4ca89b0c8b 100644 --- a/lite/core/mir/type_target_cast_pass.h +++ b/lite/core/mir/type_target_cast_pass.h @@ -16,6 +16,7 @@ #include #include +#include #include #include "lite/core/mir/pass.h" #include "lite/core/op_registry.h" @@ -44,13 +45,17 @@ class TypeTargetTransformPass : public ProgramPass { public: void Apply(const std::unique_ptr& graph) override; - void ComplementInputs(SSAGraph* graph, Node* inst_node, Node* in); + void ComplementInputs(SSAGraph* graph, + Node* inst_node, + Node* in, + std::unordered_map* copied_nodes); void AddIoCopyInst(const Type& from, const Type& to, Node* in, SSAGraph* graph, Node* inst_node, + std::unordered_map* copied_nodes, const std::vector& valid_places); void SetValidPlaces(const std::vector& valid_places); @@ -58,6 +63,11 @@ class TypeTargetTransformPass : public ProgramPass { const std::vector& valid_places() const { return valid_places_; } private: + void UpdateInstNode(Node* in, + SSAGraph* graph, + Node* inst_node, + std::string io_copy_output_name); + std::vector valid_places_; }; diff --git a/lite/core/mir/variable_place_inference_pass.h b/lite/core/mir/variable_place_inference_pass.h index 3f5d161a56aafa7fd9d058fd404e65cb04572116..875bf23082a24cb6fcae878b46cc9dcdbb2b76f7 100644 --- a/lite/core/mir/variable_place_inference_pass.h +++ b/lite/core/mir/variable_place_inference_pass.h @@ -48,6 +48,10 @@ class VariablePlaceInferencePass : public DebugPass { void CheckAllArgumentTypeDetermined(SSAGraph* graph) { for (auto& node : graph->mutable_nodes()) { if (node.IsArg()) { + if (node.inlinks.size() == 0 && node.outlinks.size() == 0) { + // empty node + continue; + } CHECK(node.AsArg().type) << "node " << node.AsArg().name << " type not determined, " << &node; } @@ -129,6 +133,17 @@ class VariablePlaceInferencePass : public DebugPass { } else { x_in->AsArg().type = type; } + } else if (x_in->AsArg().type->target() == TARGET(kUnk) && + x_in->AsArg().type->precision() != PRECISION(kUnk) && + x_in->AsArg().type->layout() == DATALAYOUT(kUnk)) { + // If is quantization, infer the Int8 type. + if (type->precision() == PRECISION(kInt8)) { + x_in->AsArg().type = type; + } else { + PrecisionType tmp_ptype = x_in->AsArg().type->precision(); + x_in->AsArg().type = LiteType::GetTensorTy( + type->target(), tmp_ptype, type->layout()); + } } } @@ -149,6 +164,17 @@ class VariablePlaceInferencePass : public DebugPass { } else { x_out->AsArg().type = type; } + } else if (x_out->AsArg().type->target() == TARGET(kUnk) && + x_out->AsArg().type->precision() != PRECISION(kUnk) && + x_out->AsArg().type->layout() == DATALAYOUT(kUnk)) { + // If is quantization, infer the Int8 type. + if (type->precision() == PRECISION(kInt8)) { + x_out->AsArg().type = type; + } else { + PrecisionType tmp_ptype = x_out->AsArg().type->precision(); + x_out->AsArg().type = LiteType::GetTensorTy( + type->target(), tmp_ptype, type->layout()); + } } } } diff --git a/lite/core/mir/weight_quantization_preprocess_pass.cc b/lite/core/mir/weight_quantization_preprocess_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..c7889a54903f2a1d194fb3eade0bd92670b36699 --- /dev/null +++ b/lite/core/mir/weight_quantization_preprocess_pass.cc @@ -0,0 +1,60 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/weight_quantization_preprocess_pass.h" +#include +#include +#include +#include "lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +void WeightQuantizationPreprocessPass::Apply( + const std::unique_ptr& graph) { + std::vector weight_quantized_op = {"conv2d", "depthwise_conv2d"}; + for (auto& node : graph->StmtTopologicalOrder()) { + if (node->IsStmt() && + std::find(weight_quantized_op.begin(), + weight_quantized_op.end(), + node->AsStmt().op_type()) != weight_quantized_op.end()) { + auto* scope = node->stmt()->op()->scope(); + auto* op_desc = node->stmt()->mutable_op_info(); + if (op_desc->HasAttr("quantize_weight_bits")) { + for (auto& input_name : op_desc->input_vars()) { + std::string scale_name = input_name + "_quant_scale"; + if (op_desc->HasAttr(scale_name)) { + VLOG(5) << "op:" << op_desc->Type() << " input_name:" << input_name; + auto input_tensor = + scope->FindVar(input_name)->GetMutable(); + int weight_out_channel = static_cast(input_tensor->dims()[0]); + auto input_scale = op_desc->GetAttr>(scale_name); + // scale length is equal to weight out channel + std::vector scale_list(weight_out_channel, input_scale[0]); + op_desc->SetAttr(scale_name, scale_list); + } + } + } + } + } +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(weight_quantization_preprocess_pass, + paddle::lite::mir::WeightQuantizationPreprocessPass) + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/weight_quantization_preprocess_pass.h b/lite/core/mir/weight_quantization_preprocess_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..76a35c6b443c692ec08688abd4c10680be62b8af --- /dev/null +++ b/lite/core/mir/weight_quantization_preprocess_pass.h @@ -0,0 +1,38 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/mir/pass.h" +#include "lite/core/op_registry.h" +#include "lite/core/target_wrapper.h" + +namespace paddle { +namespace lite { +namespace mir { +/* + * If the model is quantized by WeightQuantization in PostTrainingQuantization, + * the data type of the weight in quantized ops (conv2d, depthwise_conv2d) is + * int, and the scale is save in the quantized ops. + * WeightQuantizationPreprocessPass obtains the scale value, expands the + * scale value to a list, and save the list in the quantized ops. + */ +class WeightQuantizationPreprocessPass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/op_registry.cc b/lite/core/op_registry.cc index 3b8b350ad82f2cc1ce296b1ad74a6e322abec8ff..b49670eefb8b2c6aae30cb041de4d055a2b9964c 100644 --- a/lite/core/op_registry.cc +++ b/lite/core/op_registry.cc @@ -40,6 +40,18 @@ std::list> KernelRegistry::Create( return Create(op_type); \ + case DATALAYOUT(kImageDefault): \ + return Create(op_type); \ + case DATALAYOUT(kImageFolder): \ + return Create(op_type); \ + case DATALAYOUT(kImageNW): \ + return Create(op_type); \ default: \ LOG(FATAL) << "unsupported kernel layout " << DataLayoutToStr(layout); \ } @@ -54,6 +66,8 @@ std::list> KernelRegistry::Create( CREATE_KERNEL1(target__, kFP16); \ case PRECISION(kAny): \ CREATE_KERNEL1(target__, kAny); \ + case PRECISION(kInt32): \ + CREATE_KERNEL1(target__, kInt32); \ case PRECISION(kInt64): \ CREATE_KERNEL1(target__, kInt64); \ default: \ @@ -86,6 +100,9 @@ std::list> KernelRegistry::Create( case TARGET(kFPGA): { CREATE_KERNEL(kFPGA); } break; + case TARGET(kBM): { + CREATE_KERNEL(kBM); + } break; default: CHECK(false) << "not supported kernel target " << TargetToStr(target); } @@ -115,6 +132,8 @@ KernelRegistry::KernelRegistry() INIT_FOR(kCUDA, kAny, kNCHW); INIT_FOR(kCUDA, kAny, kAny); INIT_FOR(kCUDA, kInt8, kNHWC); + INIT_FOR(kCUDA, kInt64, kNCHW); + INIT_FOR(kCUDA, kInt64, kNHWC); INIT_FOR(kHost, kFloat, kNCHW); INIT_FOR(kHost, kAny, kNCHW); @@ -134,6 +153,7 @@ KernelRegistry::KernelRegistry() INIT_FOR(kARM, kInt8, kNCHW); INIT_FOR(kARM, kAny, kNCHW); INIT_FOR(kARM, kAny, kAny); + INIT_FOR(kARM, kInt32, kNCHW); INIT_FOR(kOpenCL, kFloat, kNCHW); INIT_FOR(kOpenCL, kFloat, kNHWC); @@ -142,6 +162,17 @@ KernelRegistry::KernelRegistry() INIT_FOR(kOpenCL, kFloat, kAny); INIT_FOR(kOpenCL, kInt8, kNCHW); INIT_FOR(kOpenCL, kAny, kAny); + INIT_FOR(kOpenCL, kFP16, kNCHW); + INIT_FOR(kOpenCL, kFP16, kNHWC); + INIT_FOR(kOpenCL, kFP16, kImageDefault); + INIT_FOR(kOpenCL, kFP16, kImageFolder); + INIT_FOR(kOpenCL, kFP16, kImageNW); + INIT_FOR(kOpenCL, kFloat, kImageDefault); + INIT_FOR(kOpenCL, kFloat, kImageFolder); + INIT_FOR(kOpenCL, kFloat, kImageNW); + INIT_FOR(kOpenCL, kAny, kImageDefault); + INIT_FOR(kOpenCL, kAny, kImageFolder); + INIT_FOR(kOpenCL, kAny, kImageNW); INIT_FOR(kNPU, kFloat, kNCHW); INIT_FOR(kNPU, kInt8, kNCHW); @@ -158,6 +189,11 @@ KernelRegistry::KernelRegistry() INIT_FOR(kFPGA, kFloat, kNHWC); INIT_FOR(kFPGA, kAny, kNHWC); INIT_FOR(kFPGA, kAny, kAny); + + INIT_FOR(kBM, kFloat, kNCHW); + INIT_FOR(kBM, kInt8, kNCHW); + INIT_FOR(kBM, kAny, kNCHW); + INIT_FOR(kBM, kAny, kAny); #undef INIT_FOR } diff --git a/lite/core/op_registry.h b/lite/core/op_registry.h index 1c67ee8f3dcafe30d9bda587d62233d0e715071e..a49682eea68240bfa178eb3d3351b8c7fb41048d 100644 --- a/lite/core/op_registry.h +++ b/lite/core/op_registry.h @@ -145,6 +145,15 @@ class KernelRegistry final { KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // diff --git a/lite/core/optimizer.h b/lite/core/optimizer.h index a50ff3e6110a09851791190239358445141c8657..ddd94484ac4bb8d96d5c55300c985d21b44f1843 100644 --- a/lite/core/optimizer.h +++ b/lite/core/optimizer.h @@ -15,6 +15,7 @@ #pragma once #include #include +#include #include #include #include "lite/core/mir/generate_program_pass.h" @@ -26,12 +27,6 @@ #include "lite/core/program.h" #include "lite/core/types.h" #include "lite/model_parser/model_parser.h" -#ifdef LITE_WITH_NPU -#include "lite/core/mir/subgraph/generate_npu_program_pass.h" -#endif -#ifdef LITE_WITH_XPU -#include "lite/core/mir/subgraph/generate_xpu_program_pass.h" -#endif namespace paddle { namespace lite { @@ -50,21 +45,6 @@ class Optimizer { valid_places_ = valid_places; CHECK(!valid_places.empty()) << "At least one valid_place should be set"; CHECK(!graph_) << "duplicate optimize found"; - auto valid_places_has_target = [&](TargetType t) -> bool { - for (auto& p : valid_places) { - if (p.target == t) { - return true; - } - } - return false; - }; - std::map lite_with_targets{ - {"kOpenCL", valid_places_has_target(TARGET(kOpenCL))}, - {"kNPU", valid_places_has_target(TARGET(kNPU))}, - {"kXPU", valid_places_has_target(TARGET(kXPU))}}; - VLOG(4) << "lite_with_targets['kOpenCL']:" << lite_with_targets["kOpenCL"]; - VLOG(4) << "lite_with_targets['kNPU']:" << lite_with_targets["kNPU"]; - VLOG(4) << "lite_with_targets['kXPU']:" << lite_with_targets["kXPU"]; graph_.reset(new mir::SSAGraph); graph_->Build(program, valid_places); @@ -75,19 +55,24 @@ class Optimizer { if (passes.empty()) { std::vector passes_local{ - {"lite_quant_dequant_fuse_pass", // - "lite_conv_elementwise_fuse_pass", // conv-elemwise-bn - "lite_conv_bn_fuse_pass", // - "lite_conv_elementwise_fuse_pass", // conv-bn-elemwise + {"lite_quant_dequant_fuse_pass", // + "weight_quantization_preprocess_pass", // + "lite_conv_elementwise_fuse_pass", // conv-elemwise-bn + "lite_conv_bn_fuse_pass", // + "lite_conv_elementwise_fuse_pass", // conv-bn-elemwise // TODO(Superjomn) Refine the fusion related design to select fusion // kernels for devices automatically. "lite_conv_activation_fuse_pass", // + "lite_var_conv_2d_activation_fuse_pass", // "lite_fc_fuse_pass", // "lite_shuffle_channel_fuse_pass", // "lite_transpose_softmax_transpose_fuse_pass", // "lite_interpolate_fuse_pass", // "identity_scale_eliminate_pass", // -#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK + "elementwise_mul_constant_eliminate_pass", // + "lite_sequence_pool_concat_fuse_pass", // +#if (defined LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) || (defined LITE_WITH_CUDA) || \ + (defined LITE_WITH_ARM) "lite_elementwise_add_activation_fuse_pass", // #endif "static_kernel_pick_pass", // pick original kernel from graph @@ -122,13 +107,10 @@ class Optimizer { "argument_type_display_pass", "runtime_context_assign_pass", - "argument_type_display_pass"}}; - if ((!lite_with_targets["kOpenCL"]) && (!lite_with_targets["kNPU"]) && - (!lite_with_targets["kXPU"])) { - // TODO(ysh329): cause CL_INVALID_MEM_OBJECT when setArg in OpenCL - // kernel - passes_local.emplace_back("memory_optimize_pass"); - } + "argument_type_display_pass", + "memory_optimize_pass", + "npu_subgraph_pass", + "xpu_subgraph_pass"}}; RunPasses(passes_local); } else { RunPasses(passes); @@ -140,40 +122,6 @@ class Optimizer { // Generate a new program based on the mir graph. std::unique_ptr GenRuntimeProgram() { -#if defined(LITE_WITH_NPU) || defined(LITE_WITH_XPU) - auto target_place = Place{ -#ifdef LITE_WITH_NPU - TARGET(kNPU), -#endif -#ifdef LITE_WITH_XPU - TARGET(kXPU), -#endif - PRECISION(kFloat)}; - if (std::find(valid_places_.begin(), valid_places_.end(), target_place) != - valid_places_.end()) { -#ifdef LITE_WITH_NPU - auto pass = mir::PassManager::Global() - .LookUp( - "generate_npu_program_pass"); -#endif - -#ifdef LITE_WITH_XPU - auto pass = mir::PassManager::Global() - .LookUp( - "generate_xpu_program_pass"); -#endif - try { - pass->Apply(graph_); - auto program = pass->GenProgram(); - CHECK(exec_scope_); - program->set_exec_scope(exec_scope_); - return program; - } catch (...) { - LOG(WARNING) << "Build " << TargetToStr(target_place.target) - << " program failed!"; - } - } -#endif auto pass = mir::PassManager::Global().LookUp( "generate_program_pass"); pass->Apply(graph_); @@ -215,14 +163,16 @@ class Optimizer { for (auto& x : passes) { LOG(INFO) << "== Running pass: " << x; mir::Pass* pass = mir::PassManager::Global().LookUp(x); - CHECK(pass) << "Can not find pass: " << x; - bool matched = false; + if (!pass) { + LOG(INFO) << " - Skip " << x << " because the pass isn't found."; + continue; + } + std::set targets; for (const auto& place : valid_places_) { - if (PassMatchesTarget(*pass, place.target)) { - matched = true; - } + targets.insert(place.target); } - matched = matched && PassMatchesKernels(*pass); + bool matched = + PassMatchesTarget(*pass, targets) && PassMatchesKernels(*pass); if (!matched) { LOG(INFO) << " - Skip " << x << " because the target or kernel does not match."; diff --git a/lite/core/profile/CMakeLists.txt b/lite/core/profile/CMakeLists.txt index 54a239024413834cb30c6e135c378d10480863e7..b7ddd810af46a25e2c331c2f0364a72f466dc636 100644 --- a/lite/core/profile/CMakeLists.txt +++ b/lite/core/profile/CMakeLists.txt @@ -5,4 +5,5 @@ endif() lite_cc_library(basic_profiler SRCS basic_profiler.cc DEPS gflags) lite_cc_test(test_basic_profiler SRCS basic_profiler_test.cc DEPS basic_profiler) - +lite_cc_library(lite_profiler SRCS profiler.cc DEPS context) +lite_cc_test(test_lite_timer SRCS test_timer.cc DEPS lite_profiler) diff --git a/lite/core/profile/profiler.cc b/lite/core/profile/profiler.cc new file mode 100644 index 0000000000000000000000000000000000000000..f4d0e3c0afbe1f9df4e381a502e1800a3d58ba68 --- /dev/null +++ b/lite/core/profile/profiler.cc @@ -0,0 +1,151 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/profile/profiler.h" +#include +#include +#include + +namespace paddle { +namespace lite { +namespace profile { + +namespace { +auto op_comp = [](const OpCharacter& c1, const OpCharacter& c2) { + return (c1.target < c2.target) || (c1.op_type < c2.op_type) || + (c1.kernel_name < c2.kernel_name) || (c1.remark < c2.remark); +}; +} + +std::map TypeStr{ + {Type::kUnk, "Unknown"}, + {Type::kCreate, "Create"}, + {Type::kDispatch, "Dispatch"}, +}; + +StatisUnit::StatisUnit(const OpCharacter& ch) : character(ch) { + create_t.reset(new DeviceTimer()); + if (ch.target == TargetType::kCUDA) { +#ifdef LITE_WITH_CUDA + dispatch_t.reset(new DeviceTimer()); +#else + LOG(ERROR) << "The timer type specified as cuda is uninitialized, so the " + "default x86 timer is used instead."; +#endif + } else { + dispatch_t.reset(new DeviceTimer()); + } +} + +lite::profile::Timer* StatisUnit::Timer(Type type) { + if (type == Type::kCreate) { + return create_t.get(); + } else if (type == Type::kDispatch) { + return dispatch_t.get(); + } + LOG(FATAL) << "Timer cannot be returned for unknown platforms."; + return nullptr; +} + +int Profiler::NewTimer(const OpCharacter& ch) { + StatisUnit unit(ch); + units_.push_back(std::move(unit)); + return units_.size() - 1; +} + +void Profiler::StartTiming(Type type, const int index, KernelContext* ctx) { + CHECK_LT(index, units_.size()) + << "The timer index in the profiler is out of range."; + units_[index].Timer(type)->Start(ctx); +} + +float Profiler::StopTiming(Type type, const int index, KernelContext* ctx) { + CHECK_LT(index, units_.size()) + << "The timer index in the profiler is out of range."; + return units_[index].Timer(type)->Stop(ctx); +} + +std::string Profiler::Summary(Type type, bool concise, size_t w) { + using std::setw; + using std::left; + using std::fixed; + STL::stringstream ss; + std::string title; + // Title. + if (concise) { + ss << "Timing cycle = " << units_.front().Timer(type)->LapTimes().Size() + << std::endl; + ss << "===== Concise " << TypeStr.find(type)->second + << " Profiler Summary: " << name_ << ", Exclude " << w + << " warm-ups =====" << std::endl; + } else { + ss << "===== Detailed " << TypeStr.find(type)->second + << " Profiler Summary: " << name_ << ", Exclude " << w + << " warm-ups =====" << std::endl; + } + ss << setw(25) << left << "Operator Type" + << " " << setw(40) << left << "Kernel Name" + << " " << setw(12) << left << "Remark" + << " " << setw(12) << left << "Avg (ms)" + << " " << setw(12) << left << "Min (ms)" + << " " << setw(12) << left << "Max (ms)" + << " " << setw(12) << left << "Last (ms)" << std::endl; + // Profile information. + if (concise) { + std::map summary(op_comp); + for (auto& unit : units_) { + auto ch = summary.find(unit.Character()); + if (ch != summary.end()) { + ch->second.avg += unit.Timer(type)->LapTimes().Avg(w); + ch->second.min += unit.Timer(type)->LapTimes().Min(w); + ch->second.max += unit.Timer(type)->LapTimes().Max(w); + } else { + TimeInfo info({unit.Timer(type)->LapTimes().Avg(w), + unit.Timer(type)->LapTimes().Min(w), + unit.Timer(type)->LapTimes().Max(w)}); + summary.insert({unit.Character(), info}); + } + } + for (const auto& item : summary) { + // clang-format off + ss << setw(25) << left << fixed << item.first.op_type \ + << " " << setw(40) << left << fixed << item.first.kernel_name \ + << " " << setw(12) << left << fixed << item.first.remark \ + << " " << setw(12) << left << fixed << item.second.avg \ + << " " << setw(12) << left << fixed << item.second.min \ + << " " << setw(12) << left << fixed << item.second.max \ + << " " << std::endl; + // clang-format on + } + } else { + for (auto& unit : units_) { + const auto& times = unit.Timer(type)->LapTimes(); + // clang-format off + ss << setw(25) << left << fixed << unit.Character().op_type \ + << " " << setw(40) << left << fixed << unit.Character().kernel_name \ + << " " << setw(12) << left << fixed << unit.Character().remark \ + << " " << setw(12) << left << fixed << times.Avg(w) \ + << " " << setw(12) << left << fixed << times.Min(w) \ + << " " << setw(12) << left << fixed << times.Max(w) \ + << " " << setw(12) << left << fixed << times.Last(w) \ + << std::endl; + // clang-format on + } + } + return ss.str(); +} + +} // namespace profile +} // namespace lite +} // namespace paddle diff --git a/lite/core/profile/profiler.h b/lite/core/profile/profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..3933e5ba01ebcb20420494a955cbc0e202879f76 --- /dev/null +++ b/lite/core/profile/profiler.h @@ -0,0 +1,75 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include "lite/core/profile/timer.h" + +namespace paddle { +namespace lite { +namespace profile { + +enum class Type { + kUnk = 0, + kCreate, + kDispatch, +}; + +extern std::map TypeStr; + +struct TimeInfo { + float avg; + float min; + float max; +}; + +struct OpCharacter { + TargetType target; + std::string op_type{std::string("N/A")}; + std::string kernel_name{std::string("N/A")}; + std::string remark{std::string("N/A")}; +}; + +class StatisUnit final { + public: + explicit StatisUnit(const OpCharacter& ch); + lite::profile::Timer* Timer(Type type); + const OpCharacter& Character() const { return character; } + + protected: + std::unique_ptr create_t; + std::unique_ptr dispatch_t; + OpCharacter character; +}; + +class Profiler final { + public: + Profiler() = default; + explicit Profiler(const std::string& name) : name_(name) {} + int NewTimer(const OpCharacter& ch); + void StartTiming(Type type, const int index, KernelContext* ctx); + float StopTiming(Type type, const int index, KernelContext* ctx); + std::string Summary(Type type, bool concise = true, size_t warm_up = 10); + + private: + std::string name_{std::string("N/A")}; + std::vector units_; +}; + +} // namespace profile +} // namespace lite +} // namespace paddle diff --git a/lite/core/profile/test_timer.cc b/lite/core/profile/test_timer.cc new file mode 100644 index 0000000000000000000000000000000000000000..3841f0151890d377a87f4f5d4b6d069ee75b560e --- /dev/null +++ b/lite/core/profile/test_timer.cc @@ -0,0 +1,81 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include // NOLINT +#include // NOLINT +#include "lite/core/context.h" +#include "lite/core/profile/profiler.h" +#include "lite/core/profile/timer.h" +#include "lite/utils/cp_logging.h" + +namespace paddle { +namespace lite { +namespace profile { + +TEST(timer, real_latency) { + Timer timer; + + timer.Start(); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + timer.Stop(); + + timer.Start(); + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + timer.Stop(); + + LOG(INFO) << "LapTimes().Avg() = " << timer.LapTimes().Avg(); +} + +#ifdef LITE_WITH_CUDA +TEST(gpu_timer, real_latency) { + DeviceTimer timer; + KernelContext ctx; + cudaStream_t exec_stream; + cudaStreamCreate(&exec_stream); + (&ctx.As())->SetExecStream(exec_stream); + + timer.Start(&ctx); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + timer.Stop(&ctx); + + (&timer)->Start(&ctx); + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + timer.Stop(&ctx); + + LOG(INFO) << "LapTimes().Avg() = " << timer.LapTimes().Avg(); +} + +TEST(profiler, real_latency) { + KernelContext ctx; + cudaStream_t exec_stream; + cudaStreamCreate(&exec_stream); + (&ctx.As())->SetExecStream(exec_stream); + + Profiler profiler("name"); + profile::OpCharacter ch; + ch.target = TargetType::kCUDA; + ch.op_type = "operator/1"; + ch.kernel_name = "kernel/1"; + int idx = profiler.NewTimer(ch); + profiler.StartTiming(Type::kDispatch, idx, &ctx); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + profiler.StopTiming(Type::kDispatch, idx, &ctx); + std::cout << profiler.Summary(Type::kDispatch); +} +#endif + +} // namespace profile +} // namespace lite +} // namespace paddle diff --git a/lite/core/profile/timer.h b/lite/core/profile/timer.h new file mode 100644 index 0000000000000000000000000000000000000000..e9bb16bd27d5ec6fd21814c35db52b2467a12b51 --- /dev/null +++ b/lite/core/profile/timer.h @@ -0,0 +1,140 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include // NOLINT +#include +#ifdef LITE_WITH_CUDA +#include "lite/backends/cuda/cuda_utils.h" +#endif +#include "lite/core/context.h" + +namespace paddle { +namespace lite { +namespace profile { + +template +class TimeList { + public: + void Clear() { laps_t_.clear(); } + void Add(T t) { laps_t_.push_back(t); } + T Last(size_t offset = 0) const { + if (!Size(offset)) { + return 0; + } + return laps_t_.back(); + } + T Max(size_t offset = 0) const { + if (!Size(offset)) { + return 0; + } + return *std::max_element((laps_t_.begin() + offset), laps_t_.end()); + } + T Min(size_t offset = 0) const { + if (!Size(offset)) { + return 0; + } + return *std::min_element((laps_t_.begin() + offset), laps_t_.end()); + } + T Sum(size_t offset = 0) const { + if (!Size(offset)) { + return 0; + } + return std::accumulate((laps_t_.begin() + offset), laps_t_.end(), 0.0); + } + size_t Size(size_t offset = 0) const { + size_t size = (laps_t_.size() <= offset) ? 0 : (laps_t_.size() - offset); + return size; + } + T Avg(size_t offset = 0) const { + if (!Size(offset)) { + return 0; + } + return Sum(offset) / Size(offset); + } + const std::vector& Raw() const { return laps_t_; } + + private: + std::vector laps_t_; +}; + +class Timer { + public: + Timer() = default; + virtual ~Timer() = default; + + void Reset() { laps_t_.Clear(); } + void Start() { t_start_ = std::chrono::system_clock::now(); } + float Stop() { + t_stop_ = std::chrono::system_clock::now(); + auto ts = std::chrono::duration_cast(t_stop_ - + t_start_); + float elapse_ms = 1000.f * static_cast(ts.count()) * + std::chrono::microseconds::period::num / + std::chrono::microseconds::period::den; + this->laps_t_.Add(elapse_ms); + return elapse_ms; + } + virtual void Start(KernelContext* ctx) { return Start(); } + virtual float Stop(KernelContext* ctx) { return Stop(); } + float AvgLapTimeMs() const { return laps_t_.Avg(); } + const TimeList& LapTimes() const { return laps_t_; } + + protected: + TimeList laps_t_; + + private: + std::chrono::time_point t_start_, t_stop_; +}; + +template +class DeviceTimer final : public Timer {}; + +#ifdef LITE_WITH_CUDA +template <> +class DeviceTimer final : public Timer { + public: + DeviceTimer() { + CUDA_CALL(cudaEventCreate(&e_start_)); + CUDA_CALL(cudaEventCreate(&e_stop_)); + } + ~DeviceTimer() { + CUDA_CALL(cudaEventDestroy(e_start_)); + CUDA_CALL(cudaEventDestroy(e_stop_)); + } + void Start(KernelContext* ctx) { + cudaStream_t stream; + stream = ctx->As().exec_stream(); + CUDA_CALL(cudaEventRecord(e_start_, stream)); + } + float Stop(KernelContext* ctx) { + cudaStream_t stream; + stream = ctx->As().exec_stream(); + CUDA_CALL(cudaEventRecord(e_stop_, stream)); + CUDA_CALL(cudaEventSynchronize(e_stop_)); + float elapse_ms = 1.f; + CUDA_CALL(cudaEventElapsedTime(&elapse_ms, e_start_, e_stop_)); + this->laps_t_.Add(elapse_ms); + return elapse_ms; + } + + private: + cudaEvent_t e_start_, e_stop_; +}; +#endif + +} // namespace profile +} // namespace lite +} // namespace paddle diff --git a/lite/core/program.cc b/lite/core/program.cc index b60f279c0fc74904477a080579a799f601e359b0..0895643a6adde0095f9d2892c41f263eedd4284f 100644 --- a/lite/core/program.cc +++ b/lite/core/program.cc @@ -17,6 +17,8 @@ #include "lite/model_parser/cpp/block_desc.h" #include "lite/model_parser/cpp/op_desc.h" #include "lite/model_parser/cpp/var_desc.h" +#include "lite/operators/conditional_block_op.h" +#include "lite/operators/subgraph_op.h" #include "lite/operators/while_op.h" #ifdef LITE_WITH_PROFILE #include "lite/core/profile/precision_profiler.h" @@ -30,10 +32,32 @@ void RuntimeProgram::SaveOpInfosToProgram(cpp::ProgramDesc* desc) { // NOTE: RuntimeProgram do not has all meta info, so save model just update // upon origin model CHECK(desc->BlocksSize()); - auto& main_block = *desc->GetBlock(0); - main_block.ClearOps(); + auto main_block = desc->GetBlock(0); + main_block->ClearOps(); for (auto& node : instructions_) { - auto* op = main_block.AddOp(); + auto op_type = node.op()->op_info()->Type(); + if (op_type == "subgraph") { + auto subgraph_op = const_cast( + static_cast(node.op())); + int sub_block_idx = subgraph_op->op_info()->GetAttr("sub_block"); + if (sub_block_idx < 0) { + // It's a new subgraph op when its sub_block_idx < 0, Now we add its + // subblock desc to the program desc, Then update its sub_block_idx to + // the index of block desc of the program desc. + sub_block_idx = desc->BlocksSize(); + auto sub_block_desc = subgraph_op->GetSubBlock(); + CHECK(sub_block_desc); + auto new_block_desc = desc->AddBlock(); + *new_block_desc = *sub_block_desc; + delete sub_block_desc; + subgraph_op->mutable_op_info()->SetAttr("sub_block", + sub_block_idx); + subgraph_op->SetSubBlock(new_block_desc); + // Update main block desc after a new subblock desc is added + main_block = desc->GetBlock(0); + } + } + auto op = main_block->AddOp(); *op = *node.op()->op_info(); op->SetAttr(kKernelTypeAttr, node.kernel()->SerializedKernelType()); } @@ -113,15 +137,21 @@ void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) { void RuntimeProgram::Run() { for (auto& inst : instructions_) { - std::string op_type = inst.op()->op_info()->Type(); - if (op_type == "feed" || op_type == "fetch") continue; +#ifndef LITE_WITH_FPGA + if (inst.is_feed_fetch_op()) continue; +#endif inst.Run(); #ifdef LITE_WITH_PROFILE #ifdef LITE_WITH_PRECISION_PROFILE +#ifndef LITE_WITH_FPGA LITE_PRECISION_PROFILE(inst) +#endif #endif // LITE_WITH_PRECISION_PROFILE #endif // LITE_WITH_PROFILE } +#ifdef LITE_WITH_PROFILE + LOG(INFO) << "\n" << profiler_.Summary(profile::Type::kDispatch, false, 0); +#endif // LITE_WITH_PROFILE } void Program::Build(const cpp::ProgramDesc& prog) { @@ -138,12 +168,26 @@ void Program::Build(const cpp::ProgramDesc& prog) { VLOG(4) << "create Op [" << op_type << "]"; auto op = LiteOpRegistry::Global().Create(op_type); CHECK(op) << "no Op found for " << op_type; - if (op_type == "while") { + if (op_type == "while" || op_type == "conditional_block" || + op_type == "subgraph") { auto sub_block_idx = op_desc.GetAttr("sub_block"); - auto sub_block = + CHECK(sub_block_idx >= 0 && sub_block_idx < program.BlocksSize()) + << "Invalid attribute sub_block(" << sub_block_idx << ") for " + << op_type; + auto sub_block_desc = const_cast(prog).GetBlock( sub_block_idx); - static_cast(op.get())->SetSubBlock(sub_block); + CHECK(sub_block_desc); + if (op_type == "while") { + static_cast(op.get())->SetSubBlock( + sub_block_desc); + } else if (op_type == "conditional_block") { + static_cast(op.get())->SetSubBlock( + sub_block_desc); + } else if (op_type == "subgraph") { + static_cast(op.get())->SetSubBlock( + sub_block_desc); + } } ops_.emplace_back(std::move(op)); ops_.back()->Attach(op_desc, exec_scope_); @@ -159,6 +203,27 @@ void Program::PrepareWorkspace(const cpp::ProgramDesc& prog) { tmp_vars_.push_back("feed"); tmp_vars_.push_back("fetch"); + auto VarPrecision2KernlPrecision = + [](const lite::VarDescAPI::Type& type) -> PrecisionType { + switch (type) { + case lite::VarDescAPI::Type::FP32: + return PRECISION(kFloat); + case lite::VarDescAPI::Type::FP16: + return PRECISION(kFP16); + case lite::VarDescAPI::Type::INT8: + return PRECISION(kInt8); + case lite::VarDescAPI::Type::INT16: + return PRECISION(kInt16); + case lite::VarDescAPI::Type::INT32: + return PRECISION(kInt32); + case lite::VarDescAPI::Type::INT64: + return PRECISION(kInt64); + default: + // LOG(FATAL) << "not supported type: " << static_cast(type); + return PRECISION(kUnk); + } + }; + auto program = prog; CHECK(program.BlocksSize()); for (size_t b = 0; b < program.BlocksSize(); ++b) { @@ -166,7 +231,16 @@ void Program::PrepareWorkspace(const cpp::ProgramDesc& prog) { for (size_t i = 0; i < main_block.VarsSize(); ++i) { auto& var_desc = *main_block.GetVar(i); if (!var_desc.Persistable()) { + if (var_desc.GetType() == lite::VarDescAPI::Type::LOD_TENSOR && + VarPrecision2KernlPrecision(var_desc.GetDataType()) != + PRECISION(kUnk)) { + var_data_type_[var_desc.Name()] = + VarPrecision2KernlPrecision(var_desc.GetDataType()); + } tmp_vars_.push_back(var_desc.Name()); + VLOG(4) << "var name: " << var_desc.Name() << " type is " + << static_cast(var_desc.GetType()) << " data type is " + << static_cast(var_desc.GetDataType()); exec_scope_->Var(var_desc.Name()); if (b > 0) { VLOG(4) << "var: " << var_desc.Name(); @@ -181,13 +255,16 @@ void Program::PrepareWorkspace(const cpp::ProgramDesc& prog) { } void Instruction::Run() { +#ifdef LITE_WITH_PROFILE + CHECK(profiler_) << "Profiler pointer of kernel can not be nullptr. " + "When LITE_WITH_PROFILE is defined, please set a " + "Profiler for Instruction."; + profiler_->StartTiming( + profile::Type::kCreate, profile_id_, kernel_->mutable_context()); +#endif CHECK(op_) << "op null"; CHECK(kernel_) << "kernel null"; -#ifdef LITE_WITH_PROFILE - if (profile_id_ >= 0) { - profile::ProfileBlock x(profile_id_, "instruction"); - } -#endif // LITE_WITH_PROFILE + if (first_epoch_) { first_epoch_ = false; CHECK(op_->CheckShape()); @@ -196,14 +273,8 @@ void Instruction::Run() { if (op_->run_once() && has_run_) { return; } -#ifndef LITE_SHUTDOWN_LOG - VLOG(4) << "kernel launch"; -#endif + op_->InferShape(); -#ifndef LITE_SHUTDOWN_LOG - VLOG(4) << ">> Running kernel: " << op_->op_info()->Repr() << " on Target " - << TargetToStr(kernel_->target()); -#endif kernel_->Launch(); has_run_ = true; } diff --git a/lite/core/program.h b/lite/core/program.h index 7a6700da61f7ba9f35491613d7733b4b637b8ff0..c845a17c52c0c565e339a13e093f3e8f59e8d4a7 100644 --- a/lite/core/program.h +++ b/lite/core/program.h @@ -16,15 +16,13 @@ #include #include #include +#include #include #include #include "lite/core/kernel.h" #include "lite/core/op_lite.h" #include "lite/core/op_registry.h" #include "lite/model_parser/cpp/program_desc.h" -#ifdef LITE_WITH_PROFILE -#include "lite/core/profile/basic_profiler.h" -#endif // LITE_WITH_PROFILE namespace paddle { namespace lite { @@ -66,6 +64,10 @@ struct Program { lite::Scope* exec_scope() { return exec_scope_; } lite::Scope* scope() { return scope_.get(); } + const std::unordered_map& var_data_type() const { + return var_data_type_; + } + private: // Build from a program and scope. void Build(const cpp::ProgramDesc& program); @@ -73,6 +75,7 @@ struct Program { void PrepareWorkspace(const cpp::ProgramDesc& program); private: + std::unordered_map var_data_type_; std::list tmp_vars_; std::list weights_; std::list> ops_; @@ -88,20 +91,10 @@ struct Instruction { Instruction(const std::shared_ptr& op, std::unique_ptr&& kernel) : op_(op), kernel_(std::move(kernel)) { -#ifdef LITE_WITH_PROFILE - if (op_->Type() != "feed" && op_->Type() != "fetch") { - profile_id_ = profile::BasicProfiler::Global() - .NewRcd(kernel_->SerializedKernelType()) - .id(); - kernel_->SetProfileID(profile_id_); - // Set profile custom info - auto& profiler = - *profile::BasicProfiler::Global().mutable_record( - profile_id_); - profiler.SetCustomInfo("op_type", op_->Type()); - profiler.SetCustomInfo("op_info", op_->SerializedOpInfo()); + std::string op_type = op->Type(); + if (op_type == "feed" || op_type == "fetch") { + is_feed_fetch_op_ = true; } -#endif // LITE_WITH_PROFILE } // Run the instruction. @@ -113,14 +106,31 @@ struct Instruction { const KernelBase* kernel() const { return kernel_.get(); } KernelBase* mutable_kernel() { return kernel_.get(); } + bool is_feed_fetch_op() const { return is_feed_fetch_op_; } + +#ifdef LITE_WITH_PROFILE + void set_profiler(profile::Profiler* profiler) { + profiler_ = profiler; + if (op_->Type() != "feed" && op_->Type() != "fetch") { + profile::OpCharacter ch; + ch.target = kernel()->target(); + ch.op_type = op_->Type(); + ch.kernel_name = kernel()->name(); + profile_id_ = profiler->NewTimer(ch); + kernel_->SetProfiler(profiler_, profile_id_); + } + } +#endif + private: std::shared_ptr op_; std::unique_ptr kernel_; + bool is_feed_fetch_op_{false}; bool first_epoch_{true}; bool has_run_{false}; #ifdef LITE_WITH_PROFILE - // for profiler + profile::Profiler* profiler_; int profile_id_{-1}; #endif // LITE_WITH_PROFILE }; @@ -135,6 +145,15 @@ class LITE_API RuntimeProgram { if (instructions_.empty()) { LOG(FATAL) << "no instructions"; } +#ifdef LITE_WITH_PROFILE + set_profiler(); +#endif + } + ~RuntimeProgram() { +#ifdef LITE_WITH_PROFILE + LOG(INFO) << "\n" << profiler_.Summary(profile::Type::kCreate); + LOG(INFO) << "\n" << profiler_.Summary(profile::Type::kDispatch); +#endif // LITE_WITH_PROFILE } void Run(); @@ -159,6 +178,15 @@ class LITE_API RuntimeProgram { RuntimeProgram(const RuntimeProgram&) = delete; std::vector instructions_; lite::Scope* exec_scope_{}; + +#ifdef LITE_WITH_PROFILE + profile::Profiler profiler_; + void set_profiler() { + for (auto i = instructions_.begin(); i != instructions_.end(); ++i) { + i->set_profiler(&profiler_); + } + } +#endif }; } // namespace lite diff --git a/lite/core/tensor.cc b/lite/core/tensor.cc index 1c7db871c7b525d6e4944fd0d669e81bcaff7f2a..38a6be6767eae62f9d91c9c11811bc49639331bf 100644 --- a/lite/core/tensor.cc +++ b/lite/core/tensor.cc @@ -25,21 +25,17 @@ using value_type = int64_t; value_type DDimLite::production() const { value_type res = 1; - for (size_t i = 0; i < this->size(); i++) { - res *= (*this)[i]; + for (size_t i = 0; i < data_.size(); i++) { + res *= data_[i]; } return res; } value_type DDimLite::count(int start, int end) const { - if (start < 0) { - start = 0; - } - if (end > size()) { - end = size(); - } + start = std::max(start, 0); + end = std::min(end, static_cast(data_.size())); if (end < start) { - end = start; + return 0; } value_type sum = 1; for (auto i = start; i < end; ++i) { @@ -49,11 +45,13 @@ value_type DDimLite::count(int start, int end) const { } DDimLite DDimLite::Slice(int start, int end) const { - std::vector vec; + start = std::max(start, 0); + end = std::min(end, static_cast(data_.size())); + std::vector new_dim(end - start); for (int i = start; i < end; i++) { - vec.push_back((*this)[i]); + new_dim[i - start] = data_[i]; } - return DDimLite(vec); + return DDim(new_dim); } std::string DDimLite::repr() const { @@ -104,6 +102,12 @@ const cl::Image2D *TensorLite::data() const { if (nullptr == buffer_->data()) return nullptr; return static_cast(buffer_->data()); } + +template <> // use int16_t represent half float +const cl::Image2D *TensorLite::data() const { + if (nullptr == buffer_->data()) return nullptr; + return static_cast(buffer_->data()); +} #endif } // namespace lite diff --git a/lite/core/tensor.h b/lite/core/tensor.h index 8c4fe1604a517332e52b243404828e81af26f419..04e540002b553a0e0f7db0144fd970bdb6a4d9ed 100644 --- a/lite/core/tensor.h +++ b/lite/core/tensor.h @@ -85,7 +85,11 @@ class DDimLite { } friend bool operator!=(const DDimLite &a, const DDimLite &b) { - return !(a == b); + if (a.size() != b.size()) return true; + for (size_t i = 0; i < a.size(); i++) { + if (a[i] != b[i]) return true; + } + return false; } private: @@ -118,7 +122,7 @@ class TensorLite { } void Resize(const DDimLite &ddim) { dims_ = ddim; } - void Resize(const std::vector &x) { dims_ = DDimLite(x); } + void Resize(const std::vector &x) { dims_.ConstructFrom(x); } const DDimLite &dims() const { return dims_; } int64_t numel() const { return dims_.production(); } @@ -139,6 +143,7 @@ class TensorLite { // For other devices, T and R may be the same type. template R *mutable_data() { + precision_ = lite_api::PrecisionTypeTrait::Type(); memory_size_ = dims_.production() * sizeof(T); buffer_->ResetLazy(target_, memory_size_); return reinterpret_cast(static_cast(buffer_->data()) + @@ -147,9 +152,11 @@ class TensorLite { #ifdef LITE_WITH_OPENCL template - R *mutable_data(const size_t img_w, const size_t img_h) { + R *mutable_data(const size_t img_w, + const size_t img_h, + void *host_ptr = nullptr) { target_ = TARGET(kOpenCL); - buffer_->ResetLazyImage2D(target_, img_w, img_h); + buffer_->ResetLazyImage2D(target_, img_w, img_h, host_ptr); return static_cast(buffer_->data()); } #endif @@ -161,10 +168,7 @@ class TensorLite { template R *mutable_data(TargetType target) { target_ = target; - memory_size_ = dims_.production() * sizeof(T); - buffer_->ResetLazy(target, memory_size()); - return reinterpret_cast(static_cast(buffer_->data()) + - offset_); + return mutable_data(); } void *mutable_data(size_t memory_size); void *mutable_data(TargetType target, size_t memory_size); @@ -174,6 +178,10 @@ class TensorLite { (static_cast(buffer_->data()) + offset_)); } + void clear() { + buffer_->Free(); + offset_ = 0; + } size_t data_size() const { return this->dims().production(); } size_t memory_size() const { return memory_size_; } @@ -251,6 +259,9 @@ bool TensorCompareWith(const TensorT &a, const TensorT &b) { #ifdef LITE_WITH_OPENCL template <> const cl::Image2D *TensorLite::data() const; + +template <> // use int16_t represent half float +const cl::Image2D *TensorLite::data() const; #endif } // namespace lite diff --git a/lite/core/version.h.in b/lite/core/version.h.in index 3082adc5abecb20f5ce19032177fc7cdb75299ff..d34c32073b852a50b5d26984ed4812ac4f38a870 100644 --- a/lite/core/version.h.in +++ b/lite/core/version.h.in @@ -42,7 +42,7 @@ static std::string version() { std::string tag = paddlelite_tag(); if (tag.empty()) { - ss << paddlelite_branch() << "(" << paddlelite_commit() << ")"; + ss << paddlelite_commit(); } else { ss << tag; } diff --git a/lite/demo/cxx/Makefile.def b/lite/demo/cxx/Makefile.def index 1b5da970e8fa9b2793f7a4982d5ed22ed21e79fd..800331035323735c01b04940e70fd034ede51c84 100644 --- a/lite/demo/cxx/Makefile.def +++ b/lite/demo/cxx/Makefile.def @@ -1,35 +1,43 @@ -CXX_DEFINES = -DARM_WITH_OMP -DHPPL_STUB_FUNC -DLITE_WITH_ARM -DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK \ - -DLITE_WITH_LINUX -DPADDLE_DISABLE_PROFILER -DPADDLE_NO_PYTHON -DPADDLE_WITH_TESTING -LDFLAGS = -latomic -pthread -ldl +# get the name of current operation system: Linux or Darwin +SYSTEM=$(shell "uname") -SYSROOT_COMPLILE = --sysroot=/opt/android-ndk-r17c/sysroot +CXX_DEFINES = -DARM_WITH_OMP -DHPPL_STUB_FUNC -DLITE_WITH_ARM -DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK \ + -DLITE_WITH_LINUX -DPADDLE_DISABLE_PROFILER -DPADDLE_NO_PYTHON -DPADDLE_WITH_TESTING +LDFLAGS = -latomic -pthread -ldl -llog -lz -THIRD_PARTY_LIBS = ../../../third_party/gflags/lib/libgflags.a +SYSROOT_COMPLILE = --sysroot=$(NDK_ROOT)/sysroot -SYSTEM_INCLUDES = -I/opt/android-ndk-r17c/sources/cxx-stl/llvm-libc++/include \ - -I/opt/android-ndk-r17c/sources/cxx-stl/llvm-libc++abi/include \ - -I/opt/android-ndk-r17c/sources/android/support/include \ - -I/opt/android-ndk-r17c/sysroot/usr/include \ +SYSTEM_INCLUDES = -I$(NDK_ROOT)/sources/cxx-stl/llvm-libc++/include \ + -I$(NDK_ROOT)/sources/cxx-stl/llvm-libc++abi/include \ + -I$(NDK_ROOT)/sources/android/support/include \ + -I$(NDK_ROOT)/sysroot/usr/include \ -THIRD_PARTY_INCLUDES = -I../../../third_party/gflags/include ifeq ($(ARM_ABI), arm8) - CC = /opt/android-ndk-r17c/toolchains/aarch64-linux-android-4.9/prebuilt/linux-x86_64/bin/aarch64-linux-android-g++ - CXX_FLAGS = -funwind-tables -no-canonical-prefixes -D__ANDROID_API__=23 -fexceptions -frtti -std=c++11 -fopenmp -O3 -DNDEBUG -fPIE - CXXFLAGS_LINK = $(CXX_FLAGS) -pie -Wl,--gc-sections - SYSROOT_LINK = --sysroot=/opt/android-ndk-r17c/platforms/android-24/arch-arm64 - SYSTEM_LIBS = /opt/android-ndk-r17c/sources/cxx-stl/llvm-libc++/libs/arm64-v8a/libc++_static.a \ - /opt/android-ndk-r17c/sources/cxx-stl/llvm-libc++/libs/arm64-v8a/libc++abi.a - INCLUDES = $(SYSTEM_INCLUDES) -I/opt/android-ndk-r17c/sysroot/usr/include/aarch64-linux-android $(THIRD_PARTY_INCLUDES) + ifeq ($(SYSTEM), Linux) + CC = $(NDK_ROOT)/toolchains/aarch64-linux-android-4.9/prebuilt/linux-x86_64/bin/aarch64-linux-android-g++ + else + CC = $(NDK_ROOT)/toolchains/aarch64-linux-android-4.9/prebuilt/darwin-x86_64/bin/aarch64-linux-android-g++ + endif + CXX_FLAGS = -funwind-tables -no-canonical-prefixes -D__ANDROID_API__=23 -fexceptions -frtti -std=c++11 -fopenmp -O3 -DNDEBUG -fPIE + CXXFLAGS_LINK = $(CXX_FLAGS) -pie -Wl,--gc-sections + SYSROOT_LINK = --sysroot=$(NDK_ROOT)/platforms/android-24/arch-arm64 + SYSTEM_LIBS = $(NDK_ROOT)/sources/cxx-stl/llvm-libc++/libs/arm64-v8a/libc++_static.a \ + $(NDK_ROOT)/sources/cxx-stl/llvm-libc++/libs/arm64-v8a/libc++abi.a + INCLUDES = $(SYSTEM_INCLUDES) -I$(NDK_ROOT)/sysroot/usr/include/aarch64-linux-android else - CC = /opt/android-ndk-r17c/toolchains/arm-linux-androideabi-4.9/prebuilt/linux-x86_64/bin/arm-linux-androideabi-g++ + ifeq ($(SYSTEM), Linux) + CC = $(NDK_ROOT)/toolchains/arm-linux-androideabi-4.9/prebuilt/linux-x86_64/bin/arm-linux-androideabi-g++ + else + CC = $(NDK_ROOT)/toolchains/arm-linux-androideabi-4.9/prebuilt/darwin-x86_64/bin/arm-linux-androideabi-g++ + endif CXX_FLAGS = -march=armv7-a -mthumb -mfpu=neon -mfloat-abi=softfp -funwind-tables -no-canonical-prefixes \ - -D__ANDROID_API__=23 -fexceptions -frtti -std=c++11 -fopenmp -O3 -DNDEBUG -fPIE + -D__ANDROID_API__=23 -fexceptions -frtti -std=c++11 -fopenmp -O3 -DNDEBUG -fPIE CXXFLAGS_LINK = $(CXX_FLAGS) -pie -Wl,--fix-cortex-a8 -Wl,--gc-sections -Wl,-z,nocopyreloc - SYSROOT_LINK = --sysroot=/opt/android-ndk-r17c/platforms/android-23/arch-arm - SYSTEM_LIBS = /opt/android-ndk-r17c/sources/cxx-stl/llvm-libc++/libs/armeabi-v7a/libc++_static.a \ - /opt/android-ndk-r17c/sources/cxx-stl/llvm-libc++/libs/armeabi-v7a/libc++abi.a \ - /opt/android-ndk-r17c/sources/cxx-stl/llvm-libc++/libs/armeabi-v7a/libandroid_support.a \ - /opt/android-ndk-r17c/sources/cxx-stl/llvm-libc++/libs/armeabi-v7a/libunwind.a - INCLUDES = $(SYSTEM_INCLUDES) -I/opt/android-ndk-r17c/sysroot/usr/include/arm-linux-androideabi $(THIRD_PARTY_INCLUDES) + SYSROOT_LINK = --sysroot=$(NDK_ROOT)/platforms/android-23/arch-arm + SYSTEM_LIBS = $(NDK_ROOT)/sources/cxx-stl/llvm-libc++/libs/armeabi-v7a/libc++_static.a \ + $(NDK_ROOT)/sources/cxx-stl/llvm-libc++/libs/armeabi-v7a/libc++abi.a \ + $(NDK_ROOT)/sources/cxx-stl/llvm-libc++/libs/armeabi-v7a/libandroid_support.a \ + $(NDK_ROOT)/sources/cxx-stl/llvm-libc++/libs/armeabi-v7a/libunwind.a + INCLUDES = $(SYSTEM_INCLUDES) -I$(NDK_ROOT)/sysroot/usr/include/arm-linux-androideabi endif diff --git a/lite/demo/cxx/README.md b/lite/demo/cxx/README.md index ec72c044e3fd08bd775b23c373945c5bb5743d1d..447bcbaff018d15a1bc3075c1153f724672f40a8 100644 --- a/lite/demo/cxx/README.md +++ b/lite/demo/cxx/README.md @@ -1,42 +1,181 @@ # C++ Demo -1. 使用`lite/tools/Dockerfile.mobile`生成docker镜像 -2. 运行并进入docker镜像环境,执行`wget http://paddle-inference-dist.bj.bcebos.com/lite_release/r0.1/inference_lite_lib.android.armv8.tar.gz `下载所需demo环境。(armv7 demo可使用命令`wget http://paddle-inference-dist.bj.bcebos.com/lite_release/r0.1/inference_lite_lib.android.armv7.tar.gz` 进行下载)。 -3. 解压下载文件`tar zxvf inference_lite_lib.android.armv8.tar.gz ` -4. 执行以下命令准备模拟器环境 + +> 欢迎加入PaddleLite百度官方QQ群(696965088),会有专业同学解答您的疑问与困惑。 + +1. 环境准备 + - 一台可以编译PaddleLite的电脑 + - 一台armv7或armv8架构的安卓手机 + +2. 人脸识别和佩戴口罩判断的Demo + +参考[源码编译](https://paddlepaddle.github.io/Paddle-Lite/v2.2.0/source_compile/)准备编译环境。 + +执行下面命令,下载PaddleLite代码。 +```shell +git clone https://github.com/PaddlePaddle/Paddle-Lite.git +cd Paddle-Lite +``` + +进入PaddleLite根目录,编译预测库。 ```shell -# armv8 -adb kill-server -adb devices | grep emulator | cut -f1 | while read line; do adb -s $line emu kill; done -echo n | avdmanager create avd -f -n paddle-armv8 -k "system-images;android-24;google_apis;arm64-v8a" -echo -ne '\n' | ${ANDROID_HOME}/emulator/emulator -avd paddle-armv8 -noaudio -no-window -gpu off -port 5554 & -sleep 1m +./lite/tools/build.sh \ + --arm_os=android \ + --arm_abi=armv8 \ + --arm_lang=gcc \ + --android_stl=c++_static \ + --build_extra=ON \ + --shutdown_log=OFF \ + tiny_publish ``` + +进入编译目录,下载模型和图片的压缩包,编译可执行文件。 ```shell -# armv7 -adb kill-server -adb devices | grep emulator | cut -f1 | while read line; do adb -s $line emu kill; done -echo n | avdmanager create avd -f -n paddle-armv7 -k "system-images;android-24;google_apis;armeabi-v7a" -echo -ne '\n' | ${ANDROID_HOME}/emulator/emulator -avd paddle-armv7 -noaudio -no-window -gpu off -port 5554 & -sleep 1m +cd build.lite.android.armv8.gcc/inference_lite_lib.android.armv8/demo/cxx/mask_detection +wget https://paddle-inference-dist.bj.bcebos.com/mask_detection.tar.gz +tar zxvf mask_detection.tar.gz +make +``` + +当然,大家也可以通过PaddleHub下载人脸检测模型和口罩佩戴判断模型。 ``` -5. 准备模型、编译并运行完整api的demo +# 下载paddlehub以后,通过python执行以下代码 +import paddlehub as hub +pyramidbox_lite_mobile_mask = hub.Module(name="pyramidbox_lite_mobile_mask") +# 将模型保存在test_program文件夹之中 +pyramidbox_lite_mobile_mask.processor.save_inference_model(dirname="test_program") +# 通过以上命令,可以获得人脸检测和口罩佩戴判断模型,分别存储在pyramidbox_lite和mask_detector之中。文件夹中的__model__是模型结构文件,__param__文件是权重文件。 +# 从PaddleHub下载的是预测模型,需要使用PaddleLite提供的model_optimize_tools对预测模型进行转换,请参考[模型转换文档](https://paddlepaddle.github.io/Paddle-Lite/v2.2.0/model_optimize_tool/)。 +``` + +电脑连接安卓手机,将可执行文件、测试图片、模型文件、预测库push到安卓手机上。 +``` +adb push mask_detection /data/local/tmp/ +adb push test.jpg /data/local/tmp/ +adb push face_detection /data/local/tmp +adb push mask_classification /data/local/tmp +adb push ../../../cxx/lib/libpaddle_light_api_shared.so /data/local/tmp/ +adb shell chmod +x /data/local/tmp/mask_detection +``` + +进入安卓手机,执行demo。 +``` +adb shell +cd /data/local/tmp +export LD_LIBRARY_PATH=/data/local/tmp/:$LD_LIBRARY_PATH +./mask_detection face_detection mask_classification test.jpg +``` + +回到电脑端,将结果取出,查看如下效果图。 +``` +adb pull /data/local/tmp/test_mask_detection_result.jpg ./ +``` + +![test_mask_detection_result](https://user-images.githubusercontent.com/7383104/74279176-6200cd00-4d55-11ea-9fc0-83cfc2b3b37d.jpg) + +3. 编译并运行全量api的demo(注:当编译模式为tiny_pubish时将不存在该demo) ```shell cd inference_lite_lib.android.armv8/demo/cxx/mobile_full wget http://paddle-inference-dist.bj.bcebos.com/mobilenet_v1.tar.gz tar zxvf mobilenet_v1.tar.gz make -adb -s emulator-5554 push mobilenet_v1 /data/local/tmp/ -adb -s emulator-5554 push mobilenetv1_full_api /data/local/tmp/ -adb -s emulator-5554 shell chmod +x /data/local/tmp/mobilenetv1_full_api -adb -s emulator-5554 shell "/data/local/tmp/mobilenetv1_full_api --model_dir=/data/local/tmp/mobilenet_v1 --optimized_model_dir=/data/local/tmp/mobilenet_v1.opt" +adb push mobilenet_v1 /data/local/tmp/ +adb push mobilenetv1_full_api /data/local/tmp/ +adb push ../../../cxx/lib/libpaddle_full_api_shared.so /data/local/tmp/ +adb shell chmod +x /data/local/tmp/mobilenetv1_full_api +adb shell "export LD_LIBRARY_PATH=/data/local/tmp/:$LD_LIBRARY_PATH && +/data/local/tmp/mobilenetv1_full_api --model_dir=/data/local/tmp/mobilenet_v1 --optimized_model_dir=/data/local/tmp/mobilenet_v1.opt" ``` 运行成功将在控制台输出预测结果的前10个类别的预测概率 -6. 编译并运行轻量级api的demo +4. 编译并运行轻量级api的demo ```shell cd ../mobile_light make -adb -s emulator-5554 push mobilenetv1_light_api /data/local/tmp/ -adb -s emulator-5554 shell chmod +x /data/local/tmp/mobilenetv1_light_api -adb -s emulator-5554 shell "/data/local/tmp/mobilenetv1_light_api --model_dir=/data/local/tmp/mobilenet_v1.opt" +adb push mobilenetv1_light_api /data/local/tmp/ +adb push ../../../cxx/lib/libpaddle_light_api_shared.so /data/local/tmp/ +adb shell chmod +x /data/local/tmp/mobilenetv1_light_api +adb shell "export LD_LIBRARY_PATH=/data/local/tmp/:$LD_LIBRARY_PATH && +/data/local/tmp/mobilenetv1_light_api /data/local/tmp/mobilenet_v1.opt" +``` +运行成功将在控制台输出预测结果的前10个类别的预测概率 + +5. 编译并运行ssd目标检测的demo +```shell +cd ../ssd_detection +wget https://paddle-inference-dist.bj.bcebos.com/mobilenetv1-ssd.tar.gz +tar zxvf mobilenetv1-ssd.tar.gz +make +adb push ssd_detection /data/local/tmp/ +adb push test.jpg /data/local/tmp/ +adb push mobilenetv1-ssd /data/local/tmp +adb push ../../../cxx/lib/libpaddle_light_api_shared.so /data/local/tmp/ +adb shell chmod +x /data/local/tmp/ssd_detection +adb shell "export LD_LIBRARY_PATH=/data/local/tmp/:$LD_LIBRARY_PATH && +/data/local/tmp/ssd_detection /data/local/tmp/mobilenetv1-ssd /data/local/tmp/test.jpg" +adb pull /data/local/tmp/test_ssd_detection_result.jpg ./ ``` +运行成功将在ssd_detection目录下看到生成的目标检测结果图像: test_ssd_detection_result.jpg + +6. 编译并运行yolov3目标检测的demo +```shell +cd ../yolov3_detection +wget https://paddle-inference-dist.bj.bcebos.com/mobilenetv1-yolov3.tar.gz +tar zxvf mobilenetv1-yolov3.tar.gz +make +adb push yolov3_detection /data/local/tmp/ +adb push test.jpg /data/local/tmp/ +adb push mobilenetv1-yolov3 /data/local/tmp +adb push ../../../cxx/lib/libpaddle_light_api_shared.so /data/local/tmp/ +adb shell chmod +x /data/local/tmp/yolov3_detection +adb shell "export LD_LIBRARY_PATH=/data/local/tmp/:$LD_LIBRARY_PATH && +/data/local/tmp/yolov3_detection /data/local/tmp/mobilenetv1-yolov3 /data/local/tmp/test.jpg" +adb pull /data/local/tmp/test_yolov3_detection_result.jpg ./ +``` +运行成功将在yolov3_detection目录下看到生成的目标检测结果图像: test_yolov3_detection_result.jpg + +7. 编译并运行物体分类的demo +```shell +cd ../mobile_classify +wget http://paddle-inference-dist.bj.bcebos.com/mobilenet_v1.tar.gz +tar zxvf mobilenet_v1.tar.gz +./model_optimize_tool optimize model +make + +adb push mobile_classify /data/local/tmp/ +adb push test.jpg /data/local/tmp/ +adb push labels.txt /data/local/tmp/ +adb push ../../../cxx/lib/libpaddle_light_api_shared.so /data/local/tmp/ +adb shell chmod +x /data/local/tmp/mobile_classify +adb shell "export LD_LIBRARY_PATH=/data/local/tmp/:$LD_LIBRARY_PATH && +/data/local/tmp/mobile_classify /data/local/tmp/mobilenetv1opt2 /data/local/tmp/test.jpg /data/local/tmp/labels.txt" +``` +运行成功将在控制台输出预测结果的前5个类别的预测概率 +- 如若想看前10个类别的预测概率,在运行命令输入topk的值即可 + eg: + ```shell + adb shell "export LD_LIBRARY_PATH=/data/local/tmp/:$LD_LIBRARY_PATH && + /data/local/tmp/mobile_classify /data/local/tmp/mobilenetv1opt2/ /data/local/tmp/test.jpg /data/local/tmp/labels.txt 10" + ``` +- 如若想看其他模型的分类结果, 在运行命令输入model_dir 及其model的输入大小即可 + eg: + ```shell + adb shell "export LD_LIBRARY_PATH=/data/local/tmp/:$LD_LIBRARY_PATH && + /data/local/tmp/mobile_classify /data/local/tmp/mobilenetv2opt2/ /data/local/tmp/test.jpg /data/local/tmp/labels.txt 10 224 224" + ``` + +8. 编译含CV预处理库模型单测demo +```shell +cd ../test_cv +wget http://paddle-inference-dist.bj.bcebos.com/mobilenet_v1.tar.gz +tar zxvf mobilenet_v1.tar.gz +./model_optimize_tool optimize model +make +adb push test_model_cv /data/local/tmp/ +adb push test.jpg /data/local/tmp/ +adb push labels.txt /data/local/tmp/ +adb push ../../../cxx/lib/libpaddle_full_api_shared.so /data/local/tmp/ +adb shell chmod +x /data/local/tmp/test_model_cv +adb shell "export LD_LIBRARY_PATH=/data/local/tmp/:$LD_LIBRARY_PATH && +/data/local/tmp/test_model_cv /data/local/tmp/mobilenetv1opt2 /data/local/tmp/test.jpg /data/local/tmp/labels.txt" +``` +运行成功将在控制台输出预测结果的前10个类别的预测概率 diff --git a/lite/demo/cxx/makefiles/mask_detection/Makefile.android.armv7 b/lite/demo/cxx/makefiles/mask_detection/Makefile.android.armv7 new file mode 100644 index 0000000000000000000000000000000000000000..dd6d4b0960160e140e2f051b78814d2fee08d5e0 --- /dev/null +++ b/lite/demo/cxx/makefiles/mask_detection/Makefile.android.armv7 @@ -0,0 +1,61 @@ +ARM_ABI = arm7 +export ARM_ABI + +include ../Makefile.def + +LITE_ROOT=../../../ + +THIRD_PARTY_DIR=${LITE_ROOT}/third_party + +OPENCV_VERSION=opencv4.1.0 + +OPENCV_LIBS = ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/libs/libopencv_imgcodecs.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/libs/libopencv_imgproc.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/libs/libopencv_core.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/libtegra_hal.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/liblibjpeg-turbo.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/liblibwebp.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/liblibpng.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/liblibjasper.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/liblibtiff.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/libIlmImf.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/libtbb.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/libcpufeatures.a + +OPENCV_INCLUDE = -I../../../third_party/${OPENCV_VERSION}/armeabi-v7a/include + +CXX_INCLUDES = $(INCLUDES) ${OPENCV_INCLUDE} -I$(LITE_ROOT)/cxx/include + +CXX_LIBS = ${OPENCV_LIBS} -L$(LITE_ROOT)/cxx/lib/ -lpaddle_light_api_shared $(SYSTEM_LIBS) + +############################################################### +# How to use one of static libaray: # +# `libpaddle_api_full_bundled.a` # +# `libpaddle_api_light_bundled.a` # +############################################################### +# Note: default use lite's shared library. # +############################################################### +# 1. Comment above line using `libpaddle_light_api_shared.so` +# 2. Undo comment below line using `libpaddle_api_light_bundled.a` + +#CXX_LIBS = $(LITE_ROOT)/cxx/lib/libpaddle_api_light_bundled.a $(SYSTEM_LIBS) + +mask_detection: fetch_opencv mask_detection.o + $(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) mask_detection.o -o mask_detection $(CXX_LIBS) $(LDFLAGS) + +mask_detection.o: mask_detection.cc + $(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o mask_detection.o -c mask_detection.cc + +fetch_opencv: + @ test -d ${THIRD_PARTY_DIR} || mkdir ${THIRD_PARTY_DIR} + @ test -e ${THIRD_PARTY_DIR}/${OPENCV_VERSION}.tar.gz || \ + (echo "fetch opencv libs" && \ + wget -P ${THIRD_PARTY_DIR} https://paddle-inference-dist.bj.bcebos.com/${OPENCV_VERSION}.tar.gz) + @ test -d ${THIRD_PARTY_DIR}/${OPENCV_VERSION} || \ + tar -zxvf ${THIRD_PARTY_DIR}/${OPENCV_VERSION}.tar.gz -C ${THIRD_PARTY_DIR} + + +.PHONY: clean +clean: + rm -f mask_detection.o + rm -f mask_detection diff --git a/lite/demo/cxx/makefiles/mask_detection/Makefile.android.armv8 b/lite/demo/cxx/makefiles/mask_detection/Makefile.android.armv8 new file mode 100644 index 0000000000000000000000000000000000000000..c2f601ed2f68c342b47c5add451f84c537f978de --- /dev/null +++ b/lite/demo/cxx/makefiles/mask_detection/Makefile.android.armv8 @@ -0,0 +1,61 @@ +ARM_ABI = arm8 +export ARM_ABI + +include ../Makefile.def + +LITE_ROOT=../../../ + +THIRD_PARTY_DIR=${LITE_ROOT}/third_party + +OPENCV_VERSION=opencv4.1.0 + +OPENCV_LIBS = ../../../third_party/${OPENCV_VERSION}/arm64-v8a/libs/libopencv_imgcodecs.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/libs/libopencv_imgproc.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/libs/libopencv_core.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/libtegra_hal.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibjpeg-turbo.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibwebp.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibpng.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibjasper.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibtiff.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/libIlmImf.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/libtbb.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/libcpufeatures.a + +OPENCV_INCLUDE = -I../../../third_party/${OPENCV_VERSION}/arm64-v8a/include + +CXX_INCLUDES = $(INCLUDES) ${OPENCV_INCLUDE} -I$(LITE_ROOT)/cxx/include + +CXX_LIBS = ${OPENCV_LIBS} -L$(LITE_ROOT)/cxx/lib/ -lpaddle_light_api_shared $(SYSTEM_LIBS) + +############################################################### +# How to use one of static libaray: # +# `libpaddle_api_full_bundled.a` # +# `libpaddle_api_light_bundled.a` # +############################################################### +# Note: default use lite's shared library. # +############################################################### +# 1. Comment above line using `libpaddle_light_api_shared.so` +# 2. Undo comment below line using `libpaddle_api_light_bundled.a` + +#CXX_LIBS = $(LITE_ROOT)/cxx/lib/libpaddle_api_light_bundled.a $(SYSTEM_LIBS) + +mask_detection: fetch_opencv mask_detection.o + $(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) mask_detection.o -o mask_detection $(CXX_LIBS) $(LDFLAGS) + +mask_detection.o: mask_detection.cc + $(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o mask_detection.o -c mask_detection.cc + +fetch_opencv: + @ test -d ${THIRD_PARTY_DIR} || mkdir ${THIRD_PARTY_DIR} + @ test -e ${THIRD_PARTY_DIR}/${OPENCV_VERSION}.tar.gz || \ + (echo "fetch opencv libs" && \ + wget -P ${THIRD_PARTY_DIR} https://paddle-inference-dist.bj.bcebos.com/${OPENCV_VERSION}.tar.gz) + @ test -d ${THIRD_PARTY_DIR}/${OPENCV_VERSION} || \ + tar -zxvf ${THIRD_PARTY_DIR}/${OPENCV_VERSION}.tar.gz -C ${THIRD_PARTY_DIR} + + +.PHONY: clean +clean: + rm -f mask_detection.o + rm -f mask_detection diff --git a/lite/demo/cxx/makefiles/mobile_classify/Makefile.android.armv7 b/lite/demo/cxx/makefiles/mobile_classify/Makefile.android.armv7 new file mode 100644 index 0000000000000000000000000000000000000000..8d446af9b174d8876fdd9aafd64bc2057dd7e17e --- /dev/null +++ b/lite/demo/cxx/makefiles/mobile_classify/Makefile.android.armv7 @@ -0,0 +1,61 @@ +ARM_ABI = arm7 +export ARM_ABI + +include ../Makefile.def + +LITE_ROOT=../../../ + +THIRD_PARTY_DIR=${LITE_ROOT}/third_party + +OPENCV_VERSION=opencv4.1.0 + +OPENCV_LIBS = ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/libs/libopencv_imgcodecs.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/libs/libopencv_imgproc.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/libs/libopencv_core.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/libtegra_hal.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/liblibjpeg-turbo.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/liblibwebp.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/liblibpng.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/liblibjasper.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/liblibtiff.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/libIlmImf.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/libtbb.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/libcpufeatures.a + +OPENCV_INCLUDE = -I../../../third_party/${OPENCV_VERSION}/armeabi-v7a/include + +CXX_INCLUDES = $(INCLUDES) ${OPENCV_INCLUDE} -I$(LITE_ROOT)/cxx/include + +CXX_LIBS = ${OPENCV_LIBS} -L$(LITE_ROOT)/cxx/lib/ -lpaddle_light_api_shared $(SYSTEM_LIBS) + +############################################################### +# How to use one of static libaray: # +# `libpaddle_api_full_bundled.a` # +# `libpaddle_api_light_bundled.a` # +############################################################### +# Note: default use lite's shared library. # +############################################################### +# 1. Comment above line using `libpaddle_light_api_shared.so` +# 2. Undo comment below line using `libpaddle_api_light_bundled.a` + +#CXX_LIBS = $(LITE_ROOT)/cxx/lib/libpaddle_api_light_bundled.a $(SYSTEM_LIBS) + +mobile_classify: fetch_opencv mobile_classify.o + $(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) mobile_classify.o -o mobile_classify $(CXX_LIBS) $(LDFLAGS) + +mobile_classify.o: mobile_classify.cc + $(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o mobile_classify.o -c mobile_classify.cc + +fetch_opencv: + @ test -d ${THIRD_PARTY_DIR} || mkdir ${THIRD_PARTY_DIR} + @ test -e ${THIRD_PARTY_DIR}/${OPENCV_VERSION}.tar.gz || \ + (echo "fetch opencv libs" && \ + wget -P ${THIRD_PARTY_DIR} https://paddle-inference-dist.bj.bcebos.com/${OPENCV_VERSION}.tar.gz) + @ test -d ${THIRD_PARTY_DIR}/${OPENCV_VERSION} || \ + tar -zxvf ${THIRD_PARTY_DIR}/${OPENCV_VERSION}.tar.gz -C ${THIRD_PARTY_DIR} + + +.PHONY: clean +clean: + rm -f mobile_classify.o + rm -f mobile_classify diff --git a/lite/demo/cxx/makefiles/mobile_classify/Makefile.android.armv8 b/lite/demo/cxx/makefiles/mobile_classify/Makefile.android.armv8 new file mode 100644 index 0000000000000000000000000000000000000000..255c42f2dca5364d9a639c993737608657568b17 --- /dev/null +++ b/lite/demo/cxx/makefiles/mobile_classify/Makefile.android.armv8 @@ -0,0 +1,61 @@ +ARM_ABI = arm8 +export ARM_ABI + +include ../Makefile.def + +LITE_ROOT=../../../ + +THIRD_PARTY_DIR=${LITE_ROOT}/third_party + +OPENCV_VERSION=opencv4.1.0 + +OPENCV_LIBS = ../../../third_party/${OPENCV_VERSION}/arm64-v8a/libs/libopencv_imgcodecs.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/libs/libopencv_imgproc.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/libs/libopencv_core.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/libtegra_hal.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibjpeg-turbo.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibwebp.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibpng.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibjasper.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibtiff.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/libIlmImf.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/libtbb.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/libcpufeatures.a + +OPENCV_INCLUDE = -I../../../third_party/${OPENCV_VERSION}/arm64-v8a/include + +CXX_INCLUDES = $(INCLUDES) ${OPENCV_INCLUDE} -I$(LITE_ROOT)/cxx/include + +CXX_LIBS = ${OPENCV_LIBS} -L$(LITE_ROOT)/cxx/lib/ -lpaddle_light_api_shared $(SYSTEM_LIBS) + +############################################################### +# How to use one of static libaray: # +# `libpaddle_api_full_bundled.a` # +# `libpaddle_api_light_bundled.a` # +############################################################### +# Note: default use lite's shared library. # +############################################################### +# 1. Comment above line using `libpaddle_light_api_shared.so` +# 2. Undo comment below line using `libpaddle_api_light_bundled.a` + +#CXX_LIBS = $(LITE_ROOT)/cxx/lib/libpaddle_api_light_bundled.a $(SYSTEM_LIBS) + +mobile_classify: fetch_opencv mobile_classify.o + $(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) mobile_classify.o -o mobile_classify $(CXX_LIBS) $(LDFLAGS) + +mobile_classify.o: mobile_classify.cc + $(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o mobile_classify.o -c mobile_classify.cc + +fetch_opencv: + @ test -d ${THIRD_PARTY_DIR} || mkdir ${THIRD_PARTY_DIR} + @ test -e ${THIRD_PARTY_DIR}/${OPENCV_VERSION}.tar.gz || \ + (echo "fetch opencv libs" && \ + wget -P ${THIRD_PARTY_DIR} https://paddle-inference-dist.bj.bcebos.com/${OPENCV_VERSION}.tar.gz) + @ test -d ${THIRD_PARTY_DIR}/${OPENCV_VERSION} || \ + tar -zxvf ${THIRD_PARTY_DIR}/${OPENCV_VERSION}.tar.gz -C ${THIRD_PARTY_DIR} + + +.PHONY: clean +clean: + rm -f mobile_classify.o + rm -f mobile_classify diff --git a/lite/demo/cxx/makefiles/mobile_full/Makefile.android.armv7 b/lite/demo/cxx/makefiles/mobile_full/Makefile.android.armv7 index f795b41d46acc3be67ff6c1a0bba0de1c1d8c82d..8ab8a3b7436c836f681510e28461628ed1038709 100644 --- a/lite/demo/cxx/makefiles/mobile_full/Makefile.android.armv7 +++ b/lite/demo/cxx/makefiles/mobile_full/Makefile.android.armv7 @@ -5,9 +5,25 @@ include ../Makefile.def LITE_ROOT=../../../ -CXX_INCLUDES = $(INCLUDES) -I$(LITE_ROOT)/cxx/include +THIRD_PARTY_INCLUDES = -I../../../third_party/gflags/include -CXX_LIBS = $(THIRD_PARTY_LIBS) $(LITE_ROOT)/cxx/lib/libpaddle_api_full_bundled.a $(SYSTEM_LIBS) +THIRD_PARTY_LIBS = ../../../third_party/gflags/lib/libgflags.a + +CXX_INCLUDES = $(INCLUDES) ${THIRD_PARTY_INCLUDES} -I$(LITE_ROOT)/cxx/include + +CXX_LIBS = $(THIRD_PARTY_LIBS) -L$(LITE_ROOT)/cxx/lib/ -lpaddle_full_api_shared $(SYSTEM_LIBS) + +############################################################### +# How to use one of static libaray: # +# `libpaddle_api_full_bundled.a` # +# `libpaddle_api_light_bundled.a` # +############################################################### +# Note: default use lite's shared library. # +############################################################### +# 1. Comment above line using `libpaddle_full_api_shared.so` +# 2. Undo comment below line using `libpaddle_api_full_bundled.a` + +#CXX_LIBS = $(THIRD_PARTY_LIBS) $(LITE_ROOT)/cxx/lib/libpaddle_api_full_bundled.a $(SYSTEM_LIBS) mobilenetv1_full_api: mobilenetv1_full_api.o $(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) mobilenetv1_full_api.o -o mobilenetv1_full_api $(CXX_LIBS) $(LDFLAGS) diff --git a/lite/demo/cxx/makefiles/mobile_full/Makefile.android.armv8 b/lite/demo/cxx/makefiles/mobile_full/Makefile.android.armv8 index d0767145b00bd40a3fbeff2aef4f7a0fc6f542d6..c13320603bcce91ebe1fca9014e36b07540abca1 100644 --- a/lite/demo/cxx/makefiles/mobile_full/Makefile.android.armv8 +++ b/lite/demo/cxx/makefiles/mobile_full/Makefile.android.armv8 @@ -5,9 +5,25 @@ include ../Makefile.def LITE_ROOT=../../../ -CXX_INCLUDES = $(INCLUDES) -I$(LITE_ROOT)/cxx/include +THIRD_PARTY_INCLUDES = -I../../../third_party/gflags/include -CXX_LIBS = $(THIRD_PARTY_LIBS) $(LITE_ROOT)/cxx/lib/libpaddle_api_full_bundled.a $(SYSTEM_LIBS) +THIRD_PARTY_LIBS = ../../../third_party/gflags/lib/libgflags.a + +CXX_INCLUDES = $(INCLUDES) ${THIRD_PARTY_INCLUDES} -I$(LITE_ROOT)/cxx/include + +CXX_LIBS = $(THIRD_PARTY_LIBS) -L$(LITE_ROOT)/cxx/lib/ -lpaddle_full_api_shared $(SYSTEM_LIBS) + +############################################################### +# How to use one of static libaray: # +# `libpaddle_api_full_bundled.a` # +# `libpaddle_api_light_bundled.a` # +############################################################### +# Note: default use lite's shared library. # +############################################################### +# 1. Comment above line using `libpaddle_full_api_shared.so` +# 2. Undo comment below line using `libpaddle_api_full_bundled.a` + +#CXX_LIBS = $(THIRD_PARTY_LIBS) $(LITE_ROOT)/cxx/lib/libpaddle_api_full_bundled.a $(SYSTEM_LIBS) mobilenetv1_full_api: mobilenetv1_full_api.o $(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) mobilenetv1_full_api.o -o mobilenetv1_full_api $(CXX_LIBS) $(LDFLAGS) diff --git a/lite/demo/cxx/makefiles/mobile_light/Makefile.android.armv7 b/lite/demo/cxx/makefiles/mobile_light/Makefile.android.armv7 index d235d6e25fa9abe47ba50d8336cafcdd6580e30d..9150ae6e44e2314a482f7fcb3d139a20cf9f0304 100644 --- a/lite/demo/cxx/makefiles/mobile_light/Makefile.android.armv7 +++ b/lite/demo/cxx/makefiles/mobile_light/Makefile.android.armv7 @@ -7,7 +7,19 @@ LITE_ROOT=../../../ CXX_INCLUDES = $(INCLUDES) -I$(LITE_ROOT)/cxx/include -CXX_LIBS = $(THIRD_PARTY_LIBS) $(LITE_ROOT)/cxx/lib/libpaddle_api_light_bundled.a $(SYSTEM_LIBS) +CXX_LIBS = -L$(LITE_ROOT)/cxx/lib/ -lpaddle_light_api_shared $(SYSTEM_LIBS) + +############################################################### +# How to use one of static libaray: # +# `libpaddle_api_full_bundled.a` # +# `libpaddle_api_light_bundled.a` # +############################################################### +# Note: default use lite's shared library. # +############################################################### +# 1. Comment above line using `libpaddle_light_api_shared.so` +# 2. Undo comment below line using `libpaddle_api_light_bundled.a` + +#CXX_LIBS = $(LITE_ROOT)/cxx/lib/libpaddle_api_light_bundled.a $(SYSTEM_LIBS) mobilenetv1_light_api: mobilenetv1_light_api.o $(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) mobilenetv1_light_api.o -o mobilenetv1_light_api $(CXX_LIBS) $(LDFLAGS) diff --git a/lite/demo/cxx/makefiles/mobile_light/Makefile.android.armv8 b/lite/demo/cxx/makefiles/mobile_light/Makefile.android.armv8 index b91aadcef813de2a6f3371fe2cc4989bd87cf1ab..7a2dbdd0fcc9611fe79fb2660ad215ac4ba0d769 100644 --- a/lite/demo/cxx/makefiles/mobile_light/Makefile.android.armv8 +++ b/lite/demo/cxx/makefiles/mobile_light/Makefile.android.armv8 @@ -7,7 +7,19 @@ LITE_ROOT=../../../ CXX_INCLUDES = $(INCLUDES) -I$(LITE_ROOT)/cxx/include -CXX_LIBS = $(THIRD_PARTY_LIBS) $(LITE_ROOT)/cxx/lib/libpaddle_api_light_bundled.a $(SYSTEM_LIBS) +CXX_LIBS = -L$(LITE_ROOT)/cxx/lib/ -lpaddle_light_api_shared $(SYSTEM_LIBS) + +############################################################### +# How to use one of static libaray: # +# `libpaddle_api_full_bundled.a` # +# `libpaddle_api_light_bundled.a` # +############################################################### +# Note: default use lite's shared library. # +############################################################### +# 1. Comment above line using `libpaddle_light_api_shared.so` +# 2. Undo comment below line using `libpaddle_api_light_bundled.a` + +#CXX_LIBS = $(LITE_ROOT)/cxx/lib/libpaddle_api_light_bundled.a $(SYSTEM_LIBS) mobilenetv1_light_api: mobilenetv1_light_api.o $(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) mobilenetv1_light_api.o -o mobilenetv1_light_api $(CXX_LIBS) $(LDFLAGS) diff --git a/lite/demo/cxx/makefiles/ssd_detection/Makefile.android.armv7 b/lite/demo/cxx/makefiles/ssd_detection/Makefile.android.armv7 new file mode 100644 index 0000000000000000000000000000000000000000..05f1c2e276b9cc41cfd4e3f9b4c82790d844ba52 --- /dev/null +++ b/lite/demo/cxx/makefiles/ssd_detection/Makefile.android.armv7 @@ -0,0 +1,61 @@ +ARM_ABI = arm7 +export ARM_ABI + +include ../Makefile.def + +LITE_ROOT=../../../ + +THIRD_PARTY_DIR=${LITE_ROOT}/third_party + +OPENCV_VERSION=opencv4.1.0 + +OPENCV_LIBS = ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/libs/libopencv_imgcodecs.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/libs/libopencv_imgproc.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/libs/libopencv_core.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/libtegra_hal.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/liblibjpeg-turbo.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/liblibwebp.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/liblibpng.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/liblibjasper.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/liblibtiff.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/libIlmImf.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/libtbb.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/libcpufeatures.a + +OPENCV_INCLUDE = -I../../../third_party/${OPENCV_VERSION}/armeabi-v7a/include + +CXX_INCLUDES = $(INCLUDES) ${OPENCV_INCLUDE} -I$(LITE_ROOT)/cxx/include + +CXX_LIBS = ${OPENCV_LIBS} -L$(LITE_ROOT)/cxx/lib/ -lpaddle_light_api_shared $(SYSTEM_LIBS) + +############################################################### +# How to use one of static libaray: # +# `libpaddle_api_full_bundled.a` # +# `libpaddle_api_light_bundled.a` # +############################################################### +# Note: default use lite's shared library. # +############################################################### +# 1. Comment above line using `libpaddle_light_api_shared.so` +# 2. Undo comment below line using `libpaddle_api_light_bundled.a` + +#CXX_LIBS = $(LITE_ROOT)/cxx/lib/libpaddle_api_light_bundled.a $(SYSTEM_LIBS) + +ssd_detection: fetch_opencv ssd_detection.o + $(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) ssd_detection.o -o ssd_detection $(CXX_LIBS) $(LDFLAGS) + +ssd_detection.o: ssd_detection.cc + $(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o ssd_detection.o -c ssd_detection.cc + +fetch_opencv: + @ test -d ${THIRD_PARTY_DIR} || mkdir ${THIRD_PARTY_DIR} + @ test -e ${THIRD_PARTY_DIR}/${OPENCV_VERSION}.tar.gz || \ + (echo "fetch opencv libs" && \ + wget -P ${THIRD_PARTY_DIR} https://paddle-inference-dist.bj.bcebos.com/${OPENCV_VERSION}.tar.gz) + @ test -d ${THIRD_PARTY_DIR}/${OPENCV_VERSION} || \ + tar -zxvf ${THIRD_PARTY_DIR}/${OPENCV_VERSION}.tar.gz -C ${THIRD_PARTY_DIR} + + +.PHONY: clean +clean: + rm -f ssd_detection.o + rm -f ssd_detection diff --git a/lite/demo/cxx/makefiles/ssd_detection/Makefile.android.armv8 b/lite/demo/cxx/makefiles/ssd_detection/Makefile.android.armv8 new file mode 100644 index 0000000000000000000000000000000000000000..77ff07df9541c554ac5fabf3cf56ee4a8904ea9c --- /dev/null +++ b/lite/demo/cxx/makefiles/ssd_detection/Makefile.android.armv8 @@ -0,0 +1,61 @@ +ARM_ABI = arm8 +export ARM_ABI + +include ../Makefile.def + +LITE_ROOT=../../../ + +THIRD_PARTY_DIR=${LITE_ROOT}/third_party + +OPENCV_VERSION=opencv4.1.0 + +OPENCV_LIBS = ../../../third_party/${OPENCV_VERSION}/arm64-v8a/libs/libopencv_imgcodecs.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/libs/libopencv_imgproc.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/libs/libopencv_core.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/libtegra_hal.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibjpeg-turbo.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibwebp.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibpng.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibjasper.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibtiff.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/libIlmImf.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/libtbb.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/libcpufeatures.a + +OPENCV_INCLUDE = -I../../../third_party/${OPENCV_VERSION}/arm64-v8a/include + +CXX_INCLUDES = $(INCLUDES) ${OPENCV_INCLUDE} -I$(LITE_ROOT)/cxx/include + +CXX_LIBS = ${OPENCV_LIBS} -L$(LITE_ROOT)/cxx/lib/ -lpaddle_light_api_shared $(SYSTEM_LIBS) + +############################################################### +# How to use one of static libaray: # +# `libpaddle_api_full_bundled.a` # +# `libpaddle_api_light_bundled.a` # +############################################################### +# Note: default use lite's shared library. # +############################################################### +# 1. Comment above line using `libpaddle_light_api_shared.so` +# 2. Undo comment below line using `libpaddle_api_light_bundled.a` + +#CXX_LIBS = $(LITE_ROOT)/cxx/lib/libpaddle_api_light_bundled.a $(SYSTEM_LIBS) + +ssd_detection: fetch_opencv ssd_detection.o + $(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) ssd_detection.o -o ssd_detection $(CXX_LIBS) $(LDFLAGS) + +ssd_detection.o: ssd_detection.cc + $(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o ssd_detection.o -c ssd_detection.cc + +fetch_opencv: + @ test -d ${THIRD_PARTY_DIR} || mkdir ${THIRD_PARTY_DIR} + @ test -e ${THIRD_PARTY_DIR}/${OPENCV_VERSION}.tar.gz || \ + (echo "fetch opencv libs" && \ + wget -P ${THIRD_PARTY_DIR} https://paddle-inference-dist.bj.bcebos.com/${OPENCV_VERSION}.tar.gz) + @ test -d ${THIRD_PARTY_DIR}/${OPENCV_VERSION} || \ + tar -zxvf ${THIRD_PARTY_DIR}/${OPENCV_VERSION}.tar.gz -C ${THIRD_PARTY_DIR} + + +.PHONY: clean +clean: + rm -f ssd_detection.o + rm -f ssd_detection diff --git a/lite/demo/cxx/makefiles/test_cv/Makefile.android.armv7 b/lite/demo/cxx/makefiles/test_cv/Makefile.android.armv7 new file mode 100644 index 0000000000000000000000000000000000000000..d659a316cd856fd550e83b125573409f239b8cf2 --- /dev/null +++ b/lite/demo/cxx/makefiles/test_cv/Makefile.android.armv7 @@ -0,0 +1,71 @@ +ARM_ABI = arm7 +LITE_WITH_CV = ON +export ARM_ABI +export LITE_WITH_CV + +include ../Makefile.def + +LITE_ROOT=../../../ + +THIRD_PARTY_DIR=${LITE_ROOT}/third_party + +OPENCV_VERSION=opencv4.1.0 + +OPENCV_LIBS = ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/libs/libopencv_imgcodecs.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/libs/libopencv_imgproc.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/libs/libopencv_core.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/libtegra_hal.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/liblibjpeg-turbo.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/liblibwebp.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/liblibpng.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/liblibjasper.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/liblibtiff.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/libIlmImf.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/libtbb.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/libcpufeatures.a + +OPENCV_INCLUDE = -I../../../third_party/${OPENCV_VERSION}/armeabi-v7a/include + +CXX_INCLUDES = $(INCLUDES) ${OPENCV_INCLUDE} -I$(LITE_ROOT)/cxx/include + +CXX_LIBS = ${OPENCV_LIBS} -L$(LITE_ROOT)/cxx/lib/ -lpaddle_full_api_shared $(SYSTEM_LIBS) + +############################################################### +# How to use one of static libaray: # +# `libpaddle_api_full_bundled.a` # +# `libpaddle_api_light_bundled.a` # +############################################################### +# Note: default use lite's shared library. # +############################################################### +# 1. Comment above line using `libpaddle_light_api_shared.so` +# 2. Undo comment below line using `libpaddle_api_light_bundled.a` + +#CXX_LIBS = $(LITE_ROOT)/cxx/lib/libpaddle_api_light_bundled.a $(SYSTEM_LIBS) + +test_model_cv: fetch_opencv test_model_cv.o + $(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) test_model_cv.o -o test_model_cv $(CXX_LIBS) $(LDFLAGS) + +test_model_cv.o: test_model_cv.cc + $(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o test_model_cv.o -c test_model_cv.cc + +test_img_prepross: fetch_opencv test_img_prepross.o + $(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) test_img_prepross.o -o test_img_prepross $(CXX_LIBS) $(LDFLAGS) + +test_img_prepross.o: test_img_prepross.cc + $(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o test_img_prepross.o -c test_img_prepross.cc + +fetch_opencv: + @ test -d ${THIRD_PARTY_DIR} || mkdir ${THIRD_PARTY_DIR} + @ test -e ${THIRD_PARTY_DIR}/${OPENCV_VERSION}.tar.gz || \ + (echo "fetch opencv libs" && \ + wget -P ${THIRD_PARTY_DIR} https://paddle-inference-dist.bj.bcebos.com/${OPENCV_VERSION}.tar.gz) + @ test -d ${THIRD_PARTY_DIR}/${OPENCV_VERSION} || \ + tar -zxvf ${THIRD_PARTY_DIR}/${OPENCV_VERSION}.tar.gz -C ${THIRD_PARTY_DIR} + + +.PHONY: clean +clean: + rm -f test_model_cv.o + rm -f test_model_cv + rm -f test_img_prepross.o + rm -f test_img_prepross diff --git a/lite/demo/cxx/makefiles/test_cv/Makefile.android.armv8 b/lite/demo/cxx/makefiles/test_cv/Makefile.android.armv8 new file mode 100644 index 0000000000000000000000000000000000000000..c80b07d5c029a3624a514e07375fd08e8770da25 --- /dev/null +++ b/lite/demo/cxx/makefiles/test_cv/Makefile.android.armv8 @@ -0,0 +1,70 @@ +ARM_ABI = arm8 +LITE_WITH_CV = ON +export ARM_ABI +export LITE_WITH_CV + +include ../Makefile.def + +LITE_ROOT=../../../ + +THIRD_PARTY_DIR=${LITE_ROOT}/third_party + +OPENCV_VERSION=opencv4.1.0 + +OPENCV_LIBS = ../../../third_party/${OPENCV_VERSION}/arm64-v8a/libs/libopencv_imgcodecs.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/libs/libopencv_imgproc.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/libs/libopencv_core.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/libtegra_hal.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibjpeg-turbo.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibwebp.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibpng.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibjasper.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibtiff.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/libIlmImf.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/libtbb.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/libcpufeatures.a + +OPENCV_INCLUDE = -I../../../third_party/${OPENCV_VERSION}/arm64-v8a/include + +CXX_INCLUDES = $(INCLUDES) ${OPENCV_INCLUDE} -I$(LITE_ROOT)/cxx/include + +CXX_LIBS = ${OPENCV_LIBS} -L$(LITE_ROOT)/cxx/lib/ -lpaddle_full_api_shared $(SYSTEM_LIBS) +############################################################### +# How to use one of static libaray: # +# `libpaddle_api_full_bundled.a` # +# `libpaddle_api_light_bundled.a` # +############################################################### +# Note: default use lite's shared library. # +############################################################### +# 1. Comment above line using `libpaddle_light_api_shared.so` +# 2. Undo comment below line using `libpaddle_api_light_bundled.a` + +#CXX_LIBS = ${OPENCV_LIBS} $(LITE_ROOT)/cxx/lib/libpaddle_api_light_bundled.a $(SYSTEM_LIBS) + +test_model_cv: fetch_opencv test_model_cv.o + $(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) test_model_cv.o -o test_model_cv $(CXX_LIBS) $(LDFLAGS) + +test_model_cv.o: test_model_cv.cc + $(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o test_model_cv.o -c test_model_cv.cc + +test_img_prepross: fetch_opencv test_img_prepross.o + $(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) test_img_prepross.o -o test_img_prepross $(CXX_LIBS) $(LDFLAGS) + +test_img_prepross.o: test_img_prepross.cc + $(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o test_img_prepross.o -c test_img_prepross.cc + +fetch_opencv: + @ test -d ${THIRD_PARTY_DIR} || mkdir ${THIRD_PARTY_DIR} + @ test -e ${THIRD_PARTY_DIR}/${OPENCV_VERSION}.tar.gz || \ + (echo "fetch opencv libs" && \ + wget -P ${THIRD_PARTY_DIR} https://paddle-inference-dist.bj.bcebos.com/${OPENCV_VERSION}.tar.gz) + @ test -d ${THIRD_PARTY_DIR}/${OPENCV_VERSION} || \ + tar -zxvf ${THIRD_PARTY_DIR}/${OPENCV_VERSION}.tar.gz -C ${THIRD_PARTY_DIR} + + +.PHONY: clean +clean: + rm -f test_model_cv.o + rm -f test_model_cv + rm -f test_img_prepross.o + rm -f test_img_prepross diff --git a/lite/demo/cxx/makefiles/yolov3_detection/Makefile.android.armv7 b/lite/demo/cxx/makefiles/yolov3_detection/Makefile.android.armv7 new file mode 100644 index 0000000000000000000000000000000000000000..b584f5623594fd64f10a86766828c62cdfe08aef --- /dev/null +++ b/lite/demo/cxx/makefiles/yolov3_detection/Makefile.android.armv7 @@ -0,0 +1,61 @@ +ARM_ABI = arm7 +export ARM_ABI + +include ../Makefile.def + +LITE_ROOT=../../../ + +THIRD_PARTY_DIR=${LITE_ROOT}/third_party + +OPENCV_VERSION=opencv4.1.0 + +OPENCV_LIBS = ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/libs/libopencv_imgcodecs.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/libs/libopencv_imgproc.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/libs/libopencv_core.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/libtegra_hal.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/liblibjpeg-turbo.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/liblibwebp.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/liblibpng.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/liblibjasper.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/liblibtiff.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/libIlmImf.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/libtbb.a \ + ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/libcpufeatures.a + +OPENCV_INCLUDE = -I../../../third_party/${OPENCV_VERSION}/armeabi-v7a/include + +CXX_INCLUDES = $(INCLUDES) ${OPENCV_INCLUDE} -I$(LITE_ROOT)/cxx/include + +CXX_LIBS = ${OPENCV_LIBS} -L$(LITE_ROOT)/cxx/lib/ -lpaddle_light_api_shared $(SYSTEM_LIBS) + +############################################################### +# How to use one of static libaray: # +# `libpaddle_api_full_bundled.a` # +# `libpaddle_api_light_bundled.a` # +############################################################### +# Note: default use lite's shared library. # +############################################################### +# 1. Comment above line using `libpaddle_light_api_shared.so` +# 2. Undo comment below line using `libpaddle_api_light_bundled.a` + +#CXX_LIBS = $(LITE_ROOT)/cxx/lib/libpaddle_api_light_bundled.a $(SYSTEM_LIBS) + +yolov3_detection: fetch_opencv yolov3_detection.o + $(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) yolov3_detection.o -o yolov3_detection $(CXX_LIBS) $(LDFLAGS) + +yolov3_detection.o: yolov3_detection.cc + $(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o yolov3_detection.o -c yolov3_detection.cc + +fetch_opencv: + @ test -d ${THIRD_PARTY_DIR} || mkdir ${THIRD_PARTY_DIR} + @ test -e ${THIRD_PARTY_DIR}/${OPENCV_VERSION}.tar.gz || \ + (echo "fetch opencv libs" && \ + wget -P ${THIRD_PARTY_DIR} https://paddle-inference-dist.bj.bcebos.com/${OPENCV_VERSION}.tar.gz) + @ test -d ${THIRD_PARTY_DIR}/${OPENCV_VERSION} || \ + tar -zxvf ${THIRD_PARTY_DIR}/${OPENCV_VERSION}.tar.gz -C ${THIRD_PARTY_DIR} + + +.PHONY: clean +clean: + rm -f yolov3_detection.o + rm -f yolov3_detection diff --git a/lite/demo/cxx/makefiles/yolov3_detection/Makefile.android.armv8 b/lite/demo/cxx/makefiles/yolov3_detection/Makefile.android.armv8 new file mode 100644 index 0000000000000000000000000000000000000000..27779817012bce527d4506a0dcd377bf4ced3c1a --- /dev/null +++ b/lite/demo/cxx/makefiles/yolov3_detection/Makefile.android.armv8 @@ -0,0 +1,61 @@ +ARM_ABI = arm8 +export ARM_ABI + +include ../Makefile.def + +LITE_ROOT=../../../ + +THIRD_PARTY_DIR=${LITE_ROOT}/third_party + +OPENCV_VERSION=opencv4.1.0 + +OPENCV_LIBS = ../../../third_party/${OPENCV_VERSION}/arm64-v8a/libs/libopencv_imgcodecs.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/libs/libopencv_imgproc.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/libs/libopencv_core.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/libtegra_hal.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibjpeg-turbo.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibwebp.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibpng.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibjasper.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibtiff.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/libIlmImf.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/libtbb.a \ + ../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/libcpufeatures.a + +OPENCV_INCLUDE = -I../../../third_party/${OPENCV_VERSION}/arm64-v8a/include + +CXX_INCLUDES = $(INCLUDES) ${OPENCV_INCLUDE} -I$(LITE_ROOT)/cxx/include + +CXX_LIBS = ${OPENCV_LIBS} -L$(LITE_ROOT)/cxx/lib/ -lpaddle_light_api_shared $(SYSTEM_LIBS) + +############################################################### +# How to use one of static libaray: # +# `libpaddle_api_full_bundled.a` # +# `libpaddle_api_light_bundled.a` # +############################################################### +# Note: default use lite's shared library. # +############################################################### +# 1. Comment above line using `libpaddle_light_api_shared.so` +# 2. Undo comment below line using `libpaddle_api_light_bundled.a` + +#CXX_LIBS = $(LITE_ROOT)/cxx/lib/libpaddle_api_light_bundled.a $(SYSTEM_LIBS) + +yolov3_detection: fetch_opencv yolov3_detection.o + $(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) yolov3_detection.o -o yolov3_detection $(CXX_LIBS) $(LDFLAGS) + +yolov3_detection.o: yolov3_detection.cc + $(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o yolov3_detection.o -c yolov3_detection.cc + +fetch_opencv: + @ test -d ${THIRD_PARTY_DIR} || mkdir ${THIRD_PARTY_DIR} + @ test -e ${THIRD_PARTY_DIR}/${OPENCV_VERSION}.tar.gz || \ + (echo "fetch opencv libs" && \ + wget -P ${THIRD_PARTY_DIR} https://paddle-inference-dist.bj.bcebos.com/${OPENCV_VERSION}.tar.gz) + @ test -d ${THIRD_PARTY_DIR}/${OPENCV_VERSION} || \ + tar -zxvf ${THIRD_PARTY_DIR}/${OPENCV_VERSION}.tar.gz -C ${THIRD_PARTY_DIR} + + +.PHONY: clean +clean: + rm -f yolov3_detection.o + rm -f yolov3_detection diff --git a/lite/demo/cxx/mask_detection/mask_detection.cc b/lite/demo/cxx/mask_detection/mask_detection.cc new file mode 100644 index 0000000000000000000000000000000000000000..748b84365fc70aa59171a6bf8847f554308fdc8c --- /dev/null +++ b/lite/demo/cxx/mask_detection/mask_detection.cc @@ -0,0 +1,246 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "opencv2/core.hpp" +#include "opencv2/imgcodecs.hpp" +#include "opencv2/imgproc.hpp" +#include "paddle_api.h" // NOLINT + +using namespace paddle::lite_api; // NOLINT + +struct Object { + int batch_id; + cv::Rect rec; + int class_id; + float prob; +}; + +int64_t ShapeProduction(const shape_t& shape) { + int64_t res = 1; + for (auto i : shape) res *= i; + return res; +} + +// fill tensor with mean and scale and trans layout: nhwc -> nchw, neon speed up +void neon_mean_scale(const float* din, + float* dout, + int size, + const std::vector mean, + const std::vector scale) { + if (mean.size() != 3 || scale.size() != 3) { + std::cerr << "[ERROR] mean or scale size must equal to 3\n"; + exit(1); + } + float32x4_t vmean0 = vdupq_n_f32(mean[0]); + float32x4_t vmean1 = vdupq_n_f32(mean[1]); + float32x4_t vmean2 = vdupq_n_f32(mean[2]); + float32x4_t vscale0 = vdupq_n_f32(scale[0]); + float32x4_t vscale1 = vdupq_n_f32(scale[1]); + float32x4_t vscale2 = vdupq_n_f32(scale[2]); + + float* dout_c0 = dout; + float* dout_c1 = dout + size; + float* dout_c2 = dout + size * 2; + + int i = 0; + for (; i < size - 3; i += 4) { + float32x4x3_t vin3 = vld3q_f32(din); + float32x4_t vsub0 = vsubq_f32(vin3.val[0], vmean0); + float32x4_t vsub1 = vsubq_f32(vin3.val[1], vmean1); + float32x4_t vsub2 = vsubq_f32(vin3.val[2], vmean2); + float32x4_t vs0 = vmulq_f32(vsub0, vscale0); + float32x4_t vs1 = vmulq_f32(vsub1, vscale1); + float32x4_t vs2 = vmulq_f32(vsub2, vscale2); + vst1q_f32(dout_c0, vs0); + vst1q_f32(dout_c1, vs1); + vst1q_f32(dout_c2, vs2); + + din += 12; + dout_c0 += 4; + dout_c1 += 4; + dout_c2 += 4; + } + for (; i < size; i++) { + *(dout_c0++) = (*(din++) - mean[0]) * scale[0]; + *(dout_c1++) = (*(din++) - mean[1]) * scale[1]; + *(dout_c2++) = (*(din++) - mean[2]) * scale[2]; + } +} + +void pre_process(const cv::Mat& img, + int width, + int height, + const std::vector& mean, + const std::vector& scale, + float* data, + bool is_scale = false) { + cv::Mat resized_img; + cv::resize( + img, resized_img, cv::Size(width, height), 0.f, 0.f, cv::INTER_CUBIC); + cv::Mat imgf; + float scale_factor = is_scale ? 1.f / 256 : 1.f; + resized_img.convertTo(imgf, CV_32FC3, scale_factor); + const float* dimg = reinterpret_cast(imgf.data); + neon_mean_scale(dimg, data, width * height, mean, scale); +} + +void RunModel(std::string det_model_dir, + std::string class_model_dir, + std::string img_path) { + // Prepare + cv::Mat img = imread(img_path, cv::IMREAD_COLOR); + float shrink = 0.2; + int width = img.cols; + int height = img.rows; + int s_width = static_cast(width * shrink); + int s_height = static_cast(height * shrink); + + // Detection + MobileConfig config; + config.set_model_dir(det_model_dir); + + // Create Predictor For Detction Model + std::shared_ptr predictor = + CreatePaddlePredictor(config); + + // Get Input Tensor + std::unique_ptr input_tensor0(std::move(predictor->GetInput(0))); + input_tensor0->Resize({1, 3, s_height, s_width}); + auto* data = input_tensor0->mutable_data(); + + // Do PreProcess + std::vector detect_mean = {104.f, 117.f, 123.f}; + std::vector detect_scale = {0.007843, 0.007843, 0.007843}; + pre_process(img, s_width, s_height, detect_mean, detect_scale, data, false); + + // Detection Model Run + predictor->Run(); + + // Get Output Tensor + std::unique_ptr output_tensor0( + std::move(predictor->GetOutput(0))); + auto* outptr = output_tensor0->data(); + auto shape_out = output_tensor0->shape(); + int64_t out_len = ShapeProduction(shape_out); + + // Filter Out Detection Box + float detect_threshold = 0.3; + std::vector detect_result; + for (int i = 0; i < out_len / 6; ++i) { + if (outptr[1] >= detect_threshold) { + Object obj; + int xmin = static_cast(width * outptr[2]); + int ymin = static_cast(height * outptr[3]); + int xmax = static_cast(width * outptr[4]); + int ymax = static_cast(height * outptr[5]); + int w = xmax - xmin; + int h = ymax - ymin; + cv::Rect rec_clip = + cv::Rect(xmin, ymin, w, h) & cv::Rect(0, 0, width, height); + obj.rec = rec_clip; + detect_result.push_back(obj); + } + outptr += 6; + } + + // Classification + config.set_model_dir(class_model_dir); + + // Create Predictor For Classification Model + predictor = CreatePaddlePredictor(config); + + // Get Input Tensor + std::unique_ptr input_tensor1(std::move(predictor->GetInput(0))); + int classify_w = 128; + int classify_h = 128; + input_tensor1->Resize({1, 3, classify_h, classify_w}); + auto* input_data = input_tensor1->mutable_data(); + int detect_num = detect_result.size(); + std::vector classify_mean = {0.5f, 0.5f, 0.5f}; + std::vector classify_scale = {1.f, 1.f, 1.f}; + float classify_threshold = 0.5; + for (int i = 0; i < detect_num; ++i) { + cv::Rect rec_clip = detect_result[i].rec; + cv::Mat roi = img(rec_clip); + + // Do PreProcess + pre_process(roi, + classify_w, + classify_h, + classify_mean, + classify_scale, + input_data, + true); + + // Classification Model Run + predictor->Run(); + + // Get Output Tensor + std::unique_ptr output_tensor1( + std::move(predictor->GetOutput(1))); + auto* outptr = output_tensor1->data(); + + // Draw Detection and Classification Results + cv::rectangle(img, rec_clip, cv::Scalar(0, 0, 255), 2, cv::LINE_AA); + std::string text = outptr[1] > classify_threshold ? "wear mask" : "no mask"; + int font_face = cv::FONT_HERSHEY_COMPLEX_SMALL; + double font_scale = 1.f; + int thickness = 1; + cv::Size text_size = + cv::getTextSize(text, font_face, font_scale, thickness, nullptr); + float new_font_scale = rec_clip.width * 0.7 * font_scale / text_size.width; + text_size = + cv::getTextSize(text, font_face, new_font_scale, thickness, nullptr); + cv::Point origin; + origin.x = rec_clip.x + 5; + origin.y = rec_clip.y + text_size.height + 5; + cv::putText(img, + text, + origin, + font_face, + new_font_scale, + cv::Scalar(0, 255, 255), + thickness, + cv::LINE_AA); + + std::cout << "detect face, location: x=" << rec_clip.x + << ", y=" << rec_clip.y << ", width=" << rec_clip.width + << ", height=" << rec_clip.height + << ", wear mask: " << (outptr[1] > classify_threshold) + << std::endl; + } + + // Write Result to Image File + int start = img_path.find_last_of("/"); + int end = img_path.find_last_of("."); + std::string img_name = img_path.substr(start + 1, end - start - 1); + std::string result_name = img_name + "_mask_detection_result.jpg"; + cv::imwrite(result_name, img); +} + +int main(int argc, char** argv) { + if (argc < 3) { + std::cerr << "[ERROR] usage: " << argv[0] + << " detction_model_dir classification_model_dir image_path\n"; + exit(1); + } + std::string detect_model_dir = argv[1]; + std::string classify_model_dir = argv[2]; + std::string img_path = argv[3]; + RunModel(detect_model_dir, classify_model_dir, img_path); + return 0; +} diff --git a/lite/demo/cxx/mobile_classify/mobile_classify.cc b/lite/demo/cxx/mobile_classify/mobile_classify.cc new file mode 100644 index 0000000000000000000000000000000000000000..d0cf59e185e1330b7d8487d562afa0af29236007 --- /dev/null +++ b/lite/demo/cxx/mobile_classify/mobile_classify.cc @@ -0,0 +1,195 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "opencv2/core.hpp" +#include "opencv2/imgcodecs.hpp" +#include "opencv2/imgproc.hpp" +#include "paddle_api.h" // NOLINT + +using namespace paddle::lite_api; // NOLINT + +void load_labels(std::string path, std::vector* labels) { + FILE* fp = fopen(path.c_str(), "r"); + if (fp == nullptr) { + printf("load label file failed \n"); + return; + } + while (!feof(fp)) { + char str[1024]; + fgets(str, 1024, fp); + std::string str_s(str); + + if (str_s.length() > 0) { + for (int i = 0; i < str_s.length(); i++) { + if (str_s[i] == ' ') { + std::string strr = str_s.substr(i, str_s.length() - i - 1); + labels->push_back(strr); + i = str_s.length(); + } + } + } + } + fclose(fp); +} + +void print_topk(const float* scores, + const int size, + const int topk, + const std::vector& labels) { + std::vector> vec; + vec.resize(size); + for (int i = 0; i < size; i++) { + vec[i] = std::make_pair(scores[i], i); + } + + std::partial_sort(vec.begin(), + vec.begin() + topk, + vec.end(), + std::greater>()); + + // print topk and score + for (int i = 0; i < topk; i++) { + float score = vec[i].first; + int index = vec[i].second; + printf("i: %d, index: %d, name: %s, score: %f \n", + i, + index, + labels[index].c_str(), + score); + } +} +// fill tensor with mean and scale and trans layout: nhwc -> nchw, neon speed up +void neon_mean_scale( + const float* din, float* dout, int size, float* mean, float* scale) { + float32x4_t vmean0 = vdupq_n_f32(mean[0]); + float32x4_t vmean1 = vdupq_n_f32(mean[1]); + float32x4_t vmean2 = vdupq_n_f32(mean[2]); + float32x4_t vscale0 = vdupq_n_f32(1.f / scale[0]); + float32x4_t vscale1 = vdupq_n_f32(1.f / scale[1]); + float32x4_t vscale2 = vdupq_n_f32(1.f / scale[2]); + + float* dout_c0 = dout; + float* dout_c1 = dout + size; + float* dout_c2 = dout + size * 2; + + int i = 0; + for (; i < size - 3; i += 4) { + float32x4x3_t vin3 = vld3q_f32(din); + float32x4_t vsub0 = vsubq_f32(vin3.val[0], vmean0); + float32x4_t vsub1 = vsubq_f32(vin3.val[1], vmean1); + float32x4_t vsub2 = vsubq_f32(vin3.val[2], vmean2); + float32x4_t vs0 = vmulq_f32(vsub0, vscale0); + float32x4_t vs1 = vmulq_f32(vsub1, vscale1); + float32x4_t vs2 = vmulq_f32(vsub2, vscale2); + vst1q_f32(dout_c0, vs0); + vst1q_f32(dout_c1, vs1); + vst1q_f32(dout_c2, vs2); + + din += 12; + dout_c0 += 4; + dout_c1 += 4; + dout_c2 += 4; + } + for (; i < size; i++) { + *(dout_c0++) = (*(din++) - mean[0]) * scale[0]; + *(dout_c0++) = (*(din++) - mean[1]) * scale[1]; + *(dout_c0++) = (*(din++) - mean[2]) * scale[2]; + } +} + +void pre_process(const cv::Mat& img, + int width, + int height, + Tensor dstTensor, + float* means, + float* scales) { + cv::Mat rgb_img; + cv::cvtColor(img, rgb_img, cv::COLOR_BGR2RGB); + cv::resize(rgb_img, rgb_img, cv::Size(width, height), 0.f, 0.f); + cv::Mat imgf; + rgb_img.convertTo(imgf, CV_32FC3, 1 / 255.f); + const float* dimg = reinterpret_cast(imgf.data); + float* data = dstTensor.mutable_data(); + neon_mean_scale(dimg, data, width * height, means, scales); +} + +void RunModel(std::string model_dir, + std::string img_path, + const std::vector& labels, + const int topk, + int width, + int height) { + // 1. Set MobileConfig + MobileConfig config; + config.set_model_dir(model_dir); + + // 2. Create PaddlePredictor by MobileConfig + std::shared_ptr predictor = + CreatePaddlePredictor(config); + + // 3. Prepare input data from image + std::unique_ptr input_tensor(std::move(predictor->GetInput(0))); + input_tensor->Resize({1, 3, height, width}); + auto* data = input_tensor->mutable_data(); + // read img and pre-process + cv::Mat img = imread(img_path, cv::IMREAD_COLOR); + // pre_process(img, width, height, data); + float means[3] = {0.485f, 0.456f, 0.406f}; + float scales[3] = {0.229f, 0.224f, 0.225f}; + pre_process(img, width, height, *input_tensor, means, scales); + + // 4. Run predictor + predictor->Run(); + + // 5. Get output and post process + std::unique_ptr output_tensor( + std::move(predictor->GetOutput(0))); + auto* outptr = output_tensor->data(); + auto shape_out = output_tensor->shape(); + int64_t cnt = 1; + for (auto& i : shape_out) { + cnt *= i; + } + print_topk(outptr, cnt, topk, labels); +} + +int main(int argc, char** argv) { + if (argc < 4) { + std::cerr << "[ERROR] usage: " << argv[0] + << " model_dir image_path label_file\n"; + exit(1); + } + printf("parameter: model_dir, image_path and label_file are necessary \n"); + printf("parameter: topk, input_width, input_height, are optional \n"); + std::string model_dir = argv[1]; + std::string img_path = argv[2]; + std::string label_file = argv[3]; + std::vector labels; + load_labels(label_file, &labels); + int topk = 5; + int height = 224; + int width = 224; + if (argc > 4) { + topk = atoi(argv[4]); + } + if (argc > 6) { + width = atoi(argv[5]); + height = atoi(argv[6]); + } + + RunModel(model_dir, img_path, labels, topk, width, height); + return 0; +} diff --git a/lite/demo/cxx/mobile_full/mobilenetv1_full_api.cc b/lite/demo/cxx/mobile_full/mobilenetv1_full_api.cc index aa084d1fef7871ef11ac4864b30b3786691de387..0c9da1a76422edae45dfeec5d38556a5e2322a85 100644 --- a/lite/demo/cxx/mobile_full/mobilenetv1_full_api.cc +++ b/lite/demo/cxx/mobile_full/mobilenetv1_full_api.cc @@ -13,12 +13,10 @@ // limitations under the License. #include -#include +#include #include -#include "paddle_api.h" // NOLINT -#include "paddle_use_kernels.h" // NOLINT -#include "paddle_use_ops.h" // NOLINT -#include "paddle_use_passes.h" // NOLINT +#include "paddle_api.h" // NOLINT +#include "paddle_use_passes.h" // NOLINT using namespace paddle::lite_api; // NOLINT @@ -78,14 +76,22 @@ void RunModel() { // 6. Get output std::unique_ptr output_tensor( std::move(predictor->GetOutput(0))); - printf("Output dim: %d\n", output_tensor->shape()[1]); + std::cout << "Output shape " << output_tensor->shape()[1] << std::endl; for (int i = 0; i < ShapeProduction(output_tensor->shape()); i += 100) { - printf("Output[%d]: %f\n", i, output_tensor->data()[i]); + std::cout << "Output[" << i << "]: " << output_tensor->data()[i] + << std::endl; } } int main(int argc, char** argv) { google::ParseCommandLineFlags(&argc, &argv, true); + if (FLAGS_model_dir == "" || FLAGS_optimized_model_dir == "") { + std::cerr << "[ERROR] usage: " << argv[0] + << " --model_dir=" + << " --optimized_model_dir= " + << " --prefer_int8_kernel=[true|false]\n"; + exit(1); + } RunModel(); return 0; } diff --git a/lite/demo/cxx/mobile_light/mobilenetv1_light_api.cc b/lite/demo/cxx/mobile_light/mobilenetv1_light_api.cc index e1833814cad17b2af182443874c69f4c91e542fc..9d923cb87da5244e4550be3fb6936a650ec9b53a 100644 --- a/lite/demo/cxx/mobile_light/mobilenetv1_light_api.cc +++ b/lite/demo/cxx/mobile_light/mobilenetv1_light_api.cc @@ -12,27 +12,25 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include -#include +#include #include -#include "paddle_api.h" // NOLINT -#include "paddle_use_kernels.h" // NOLINT -#include "paddle_use_ops.h" // NOLINT +#include "paddle_api.h" // NOLINT using namespace paddle::lite_api; // NOLINT -DEFINE_string(model_dir, "", "Model dir path."); - int64_t ShapeProduction(const shape_t& shape) { int64_t res = 1; for (auto i : shape) res *= i; return res; } -void RunModel() { +void RunModel(std::string model_dir) { // 1. Set MobileConfig MobileConfig config; - config.set_model_dir(FLAGS_model_dir); + config.set_model_dir(model_dir); + // To load model transformed by opt after release/v2.3.0, plese use + // `set_model_from_file` listed below. + // config.set_model_from_file(model_dir); // 2. Create PaddlePredictor by MobileConfig std::shared_ptr predictor = @@ -52,14 +50,19 @@ void RunModel() { // 5. Get output std::unique_ptr output_tensor( std::move(predictor->GetOutput(0))); - printf("Output dim: %d\n", output_tensor->shape()[1]); + std::cout << "Output shape " << output_tensor->shape()[1] << std::endl; for (int i = 0; i < ShapeProduction(output_tensor->shape()); i += 100) { - printf("Output[%d]: %f\n", i, output_tensor->data()[i]); + std::cout << "Output[" << i << "]: " << output_tensor->data()[i] + << std::endl; } } int main(int argc, char** argv) { - google::ParseCommandLineFlags(&argc, &argv, true); - RunModel(); + if (argc < 2) { + std::cerr << "[ERROR] usage: ./" << argv[0] << " naive_buffer_model_dir\n"; + exit(1); + } + std::string model_dir = argv[1]; + RunModel(model_dir); return 0; } diff --git a/lite/demo/cxx/ssd_detection/ssd_detection.cc b/lite/demo/cxx/ssd_detection/ssd_detection.cc new file mode 100644 index 0000000000000000000000000000000000000000..2408afcbf64a24924eca119a9d9481dc030250c9 --- /dev/null +++ b/lite/demo/cxx/ssd_detection/ssd_detection.cc @@ -0,0 +1,209 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "opencv2/core.hpp" +#include "opencv2/imgcodecs.hpp" +#include "opencv2/imgproc.hpp" +#include "paddle_api.h" // NOLINT + +using namespace paddle::lite_api; // NOLINT + +struct Object { + int batch_id; + cv::Rect rec; + int class_id; + float prob; +}; + +int64_t ShapeProduction(const shape_t& shape) { + int64_t res = 1; + for (auto i : shape) res *= i; + return res; +} + +const char* class_names[] = { + "background", "aeroplane", "bicycle", "bird", "boat", + "bottle", "bus", "car", "cat", "chair", + "cow", "diningtable", "dog", "horse", "motorbike", + "person", "pottedplant", "sheep", "sofa", "train", + "tvmonitor"}; + +// fill tensor with mean and scale and trans layout: nhwc -> nchw, neon speed up +void neon_mean_scale(const float* din, + float* dout, + int size, + const std::vector mean, + const std::vector scale) { + if (mean.size() != 3 || scale.size() != 3) { + std::cerr << "[ERROR] mean or scale size must equal to 3\n"; + exit(1); + } + float32x4_t vmean0 = vdupq_n_f32(mean[0]); + float32x4_t vmean1 = vdupq_n_f32(mean[1]); + float32x4_t vmean2 = vdupq_n_f32(mean[2]); + float32x4_t vscale0 = vdupq_n_f32(1.f / scale[0]); + float32x4_t vscale1 = vdupq_n_f32(1.f / scale[1]); + float32x4_t vscale2 = vdupq_n_f32(1.f / scale[2]); + + float* dout_c0 = dout; + float* dout_c1 = dout + size; + float* dout_c2 = dout + size * 2; + + int i = 0; + for (; i < size - 3; i += 4) { + float32x4x3_t vin3 = vld3q_f32(din); + float32x4_t vsub0 = vsubq_f32(vin3.val[0], vmean0); + float32x4_t vsub1 = vsubq_f32(vin3.val[1], vmean1); + float32x4_t vsub2 = vsubq_f32(vin3.val[2], vmean2); + float32x4_t vs0 = vmulq_f32(vsub0, vscale0); + float32x4_t vs1 = vmulq_f32(vsub1, vscale1); + float32x4_t vs2 = vmulq_f32(vsub2, vscale2); + vst1q_f32(dout_c0, vs0); + vst1q_f32(dout_c1, vs1); + vst1q_f32(dout_c2, vs2); + + din += 12; + dout_c0 += 4; + dout_c1 += 4; + dout_c2 += 4; + } + for (; i < size; i++) { + *(dout_c0++) = (*(din++) - mean[0]) * scale[0]; + *(dout_c1++) = (*(din++) - mean[1]) * scale[1]; + *(dout_c2++) = (*(din++) - mean[2]) * scale[2]; + } +} + +void pre_process(const cv::Mat& img, int width, int height, float* data) { + cv::Mat rgb_img; + cv::cvtColor(img, rgb_img, cv::COLOR_BGR2RGB); + cv::resize(rgb_img, rgb_img, cv::Size(width, height), 0.f, 0.f); + cv::Mat imgf; + rgb_img.convertTo(imgf, CV_32FC3, 1 / 255.f); + std::vector mean = {0.5f, 0.5f, 0.5f}; + std::vector scale = {0.5f, 0.5f, 0.5f}; + const float* dimg = reinterpret_cast(imgf.data); + neon_mean_scale(dimg, data, width * height, mean, scale); +} + +std::vector detect_object(const float* data, + int count, + float thresh, + cv::Mat& image) { // NOLINT + if (data == nullptr) { + std::cerr << "[ERROR] data can not be nullptr\n"; + exit(1); + } + std::vector rect_out; + for (int iw = 0; iw < count; iw++) { + int oriw = image.cols; + int orih = image.rows; + if (data[1] > thresh && static_cast(data[0]) > 0) { + Object obj; + int x = static_cast(data[2] * oriw); + int y = static_cast(data[3] * orih); + int w = static_cast(data[4] * oriw) - x; + int h = static_cast(data[5] * orih) - y; + cv::Rect rec_clip = + cv::Rect(x, y, w, h) & cv::Rect(0, 0, image.cols, image.rows); + obj.batch_id = 0; + obj.class_id = static_cast(data[0]); + obj.prob = data[1]; + obj.rec = rec_clip; + if (w > 0 && h > 0 && obj.prob <= 1) { + rect_out.push_back(obj); + cv::rectangle(image, rec_clip, cv::Scalar(0, 0, 255), 2, cv::LINE_AA); + std::string str_prob = std::to_string(obj.prob); + std::string text = std::string(class_names[obj.class_id]) + ": " + + str_prob.substr(0, str_prob.find(".") + 4); + int font_face = cv::FONT_HERSHEY_COMPLEX_SMALL; + double font_scale = 1.f; + int thickness = 2; + cv::Size text_size = + cv::getTextSize(text, font_face, font_scale, thickness, nullptr); + float new_font_scale = w * 0.35 * font_scale / text_size.width; + text_size = cv::getTextSize( + text, font_face, new_font_scale, thickness, nullptr); + cv::Point origin; + origin.x = x + 10; + origin.y = y + text_size.height + 10; + cv::putText(image, + text, + origin, + font_face, + new_font_scale, + cv::Scalar(0, 255, 255), + thickness, + cv::LINE_AA); + + std::cout << "detection, image size: " << image.cols << ", " + << image.rows + << ", detect object: " << class_names[obj.class_id] + << ", score: " << obj.prob << ", location: x=" << x + << ", y=" << y << ", width=" << w << ", height=" << h + << std::endl; + } + } + data += 6; + } + return rect_out; +} + +void RunModel(std::string model_dir, std::string img_path) { + // 1. Set MobileConfig + MobileConfig config; + config.set_model_dir(model_dir); + + // 2. Create PaddlePredictor by MobileConfig + std::shared_ptr predictor = + CreatePaddlePredictor(config); + + // 3. Prepare input data from image + std::unique_ptr input_tensor(std::move(predictor->GetInput(0))); + const int in_width = 300; + const int in_height = 300; + input_tensor->Resize({1, 3, in_height, in_width}); + auto* data = input_tensor->mutable_data(); + cv::Mat img = imread(img_path, cv::IMREAD_COLOR); + pre_process(img, in_width, in_height, data); + + // 4. Run predictor + predictor->Run(); + + // 5. Get output and post process + std::unique_ptr output_tensor( + std::move(predictor->GetOutput(0))); + auto* outptr = output_tensor->data(); + auto shape_out = output_tensor->shape(); + int64_t cnt = ShapeProduction(shape_out); + auto rec_out = detect_object(outptr, static_cast(cnt / 6), 0.6f, img); + int start = img_path.find_last_of("/"); + int end = img_path.find_last_of("."); + std::string img_name = img_path.substr(start + 1, end - start - 1); + std::string result_name = img_name + "_ssd_detection_result.jpg"; + cv::imwrite(result_name, img); +} + +int main(int argc, char** argv) { + if (argc < 3) { + std::cerr << "[ERROR] usage: " << argv[0] << " model_dir image_path\n"; + exit(1); + } + std::string model_dir = argv[1]; + std::string img_path = argv[2]; + RunModel(model_dir, img_path); + return 0; +} diff --git a/lite/demo/cxx/test_cv/README.md b/lite/demo/cxx/test_cv/README.md new file mode 100644 index 0000000000000000000000000000000000000000..36d2985a4fd4f243027f8caab9b6c5a8beb94cad --- /dev/null +++ b/lite/demo/cxx/test_cv/README.md @@ -0,0 +1,131 @@ +# 图像预测库的使用 +1. 下载源码(https://github.com/PaddlePaddle/Paddle-Lite),打开LITE_WITH_CV=ON,编译full_publish模式 +example: +```shell +set BUILD_WITH_CV=ON or LITE_WITH_CV=ON +./lite/tools/build.sh +--arm_os=android +--arm_abi=armv8 +--arm_lang=gcc +--android_stl=c++_static +full_publish +``` + +2. 准备模型和优化模型 +example: +```shell +wget http://paddle-inference-dist.bj.bcebos.com/mobilenet_v1.tar.gz +tar zxvf mobilenet_v1.tar.gz +./lite/tools/build.sh build_optimize_tool +./build.model_optimize_tool/lite/api/model_optimize_tool +--optimize_out_type=naive_buffer +--optimize_out=model_dir +--model_dir=model_dir +--prefer_int8_kernel=false +``` + +3. 编译并运行完整test_model_cv demo +example: +```shell +cd inference_lite_lib.android.armv8/demo/cxx/test_cv +``` + +- 修改MakeFile, 注释编译test_img_propress 语句 + ```shell + test_model_cv: fetch_opencv test_model_cv.o + $(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) test_model_cv.o -o test_model_cv $(CXX_LIBS) $(LDFLAGS) + + test_model_cv.o: test_model_cv.cc + $(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o test_model_cv.o -c test_model_cv.cc + + #test_img_propress: fetch_opencv test_img_propress.o + # $(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) test_img_propress.o -o test_img_propress $(CXX_LIBS) $(LDFLAGS) + + #test_img_propress.o: test_img_propress.cc + # $(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o test_img_propress.o -c test_img_propress.cc + + .PHONY: clean + clean: + rm -f test_model_cv.o + rm -f test_model_cv + #rm -f test_img_propress.o + #rm -f test_img_propress + ``` +- 修改../../..//cxx/include/paddle_image_preprocess.h, 修改paddle_api.h头文件的路径 + ```shell + origin: + #include "lite/api/paddle_api.h" + #include "lite/api/paddle_place.h" + now: + #include "paddle_api.h" + #include "paddle_place.h" + ``` +- 测试模型必须是优化后的模型 + +```shell +make + +adb -s device_id push mobilenet_v1 /data/local/tmp/ +adb -s device_id push test_model_cv /data/local/tmp/ +adb -s device_id push test.jpg /data/local/tmp/ +adb -s device_id push ../../../cxx/lib/libpaddle_full_api_shared.so /data/local/tmp/ +adb -s device_id shell chmod +x /data/local/tmp/test_model_cv +adb -s device_id shell "export LD_LIBRARY_PATH=/data/local/tmp/:$LD_LIBRARY_PATH && +/data/local/tmp/test_model_cv /data/local/tmp/mobilenet_v1 /data/local/tmp/test.jpg 1 3 224 224 " +``` +运行成功将在控制台输出部分预测结果 + +4. 编译并运行完整test_img_preprocess demo +example: +```shell +cd inference_lite_lib.android.armv8/demo/cxx/test_cv +``` + +- 修改MakeFile, 注释编译test_model_cv 语句 + ```shell + #test_model_cv: fetch_opencv test_model_cv.o + # $(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) test_model_cv.o -o test_model_cv $(CXX_LIBS) $(LDFLAGS) + + #test_model_cv.o: test_model_cv.cc + # $(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o test_model_cv.o -c test_model_cv.cc + + test_img_propress: fetch_opencv test_img_propress.o + $(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) test_img_propress.o -o test_img_propress $(CXX_LIBS) $(LDFLAGS) + + test_img_propress.o: test_img_propress.cc + $(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o test_img_propress.o -c test_img_propress.cc + + .PHONY: clean + clean: + #rm -f test_model_cv.o + #rm -f test_model_cv + rm -f test_img_propress.o + rm -f test_img_propress + ``` +- 修改../../..//cxx/include/paddle_image_preprocess.h, 修改paddle_api.h头文件的路径 + ```shell + origin: + #include "lite/api/paddle_api.h" + #include "lite/api/paddle_place.h" + now: + #include "paddle_api.h" + #include "paddle_place.h" + ``` +- 测试模型必须是优化后的模型 + +```shell +make + +adb -s device_id push mobilenet_v1 /data/local/tmp/ +adb -s device_id push test_img_propress /data/local/tmp/ +adb -s device_id push test.jpg /data/local/tmp/ +adb -s device_id push ../../../cxx/lib/libpaddle_full_api_shared.so /data/local/tmp/ +adb -s device_id shell chmod +x /data/local/tmp/test_model_cv +adb -s device_id shell "export LD_LIBRARY_PATH=/data/local/tmp/:$LD_LIBRARY_PATH && +/data/local/tmp/test_img_propress /data/local/tmp/test.jpg /data/local/tmp/ 3 3 1 3 224 224 /data/local/tmp/mobilenet_v1 " +adb -s device_id pull /data/local/tmp/resize.jpg ./ +adb -s device_id pull /data/local/tmp/convert.jpg ./ +adb -s device_id pull /data/local/tmp/flip.jpg ./ +adb -s device_id pull /data/local/tmp/rotate.jpg ./ +``` +运行成功将在控制台输出OpenCV 和 Padlle-lite的耗时;同时,将在test_cv目录下看到生成的图像预处理结果图: 如:resize.jpg、convert.jpg等 diff --git a/lite/demo/cxx/test_cv/test_img_prepross.cc b/lite/demo/cxx/test_cv/test_img_prepross.cc new file mode 100644 index 0000000000000000000000000000000000000000..c2cbd66cc0a15a1032141641d83fbf8db85d20bf --- /dev/null +++ b/lite/demo/cxx/test_cv/test_img_prepross.cc @@ -0,0 +1,389 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "opencv2/core.hpp" +#include "opencv2/imgcodecs.hpp" +#include "opencv2/imgproc.hpp" +#include "paddle_api.h" // NOLINT +#include "paddle_image_preprocess.h" // NOLINT +#include "time.h" // NOLINT +typedef paddle::lite_api::Tensor Tensor; +typedef paddle::lite::utils::cv::ImageFormat ImageFormat; +typedef paddle::lite::utils::cv::FlipParam FlipParam; +typedef paddle::lite::utils::cv::TransParam TransParam; +typedef paddle::lite::utils::cv::ImagePreprocess ImagePreprocess; +typedef paddle::lite_api::DataLayoutType LayoutType; +using namespace paddle::lite_api; // NOLINT + +void fill_with_mat(cv::Mat& mat, uint8_t* src) { // NOLINT + for (int i = 0; i < mat.rows; i++) { + for (int j = 0; j < mat.cols; j++) { + int tmp = (i * mat.cols + j) * 3; + cv::Vec3b& rgb = mat.at(i, j); + rgb[0] = src[tmp]; + rgb[1] = src[tmp + 1]; + rgb[2] = src[tmp + 2]; + } + } +} +void test_img(std::vector cluster_id, + std::vector thread_num, + std::string img_path, + std::string dst_path, + ImageFormat srcFormat, + ImageFormat dstFormat, + int width, + int height, + float rotate, + FlipParam flip, + LayoutType layout, + std::string model_dir, + int test_iter = 1) { + // init + // paddle::lite::DeviceInfo::Init(); + // read img and pre-process + cv::Mat img = imread(img_path, cv::IMREAD_COLOR); + float means[3] = {0.485f, 0.456f, 0.406f}; + float scales[3] = {0.229f, 0.224f, 0.225f}; + int srch = img.rows; + int srcw = img.cols; + for (auto& cls : cluster_id) { + for (auto& th : thread_num) { + std::cout << "cluster: " << cls << ", threads: " << th << std::endl; + // 1. Set MobileConfig + MobileConfig config; + config.set_model_dir(model_dir); + config.set_power_mode((PowerMode)cls); + config.set_threads(th); + std::cout << "model: " << model_dir; + + // 2. Create PaddlePredictor by MobileConfig + std::shared_ptr predictor = + CreatePaddlePredictor(config); + + // 3. Prepare input data from image + std::unique_ptr input_tensor(predictor->GetInput(0)); + + /* + imread(img_path, param) + IMREAD_UNCHANGED(<0) 表示加载原图,不做任何改变 + IMREAD_GRAYSCALE ( 0)表示把原图作为灰度图像加载进来 + IMREAD_COLOR (>0) 表示把原图作为RGB图像加载进来 + */ + cv::Mat img; + if (srcFormat == ImageFormat::BGR || srcFormat == ImageFormat::RGB) { + img = imread(img_path, cv::IMREAD_COLOR); + } else if (srcFormat == ImageFormat::GRAY) { + img = imread(img_path, cv::IMREAD_GRAYSCALE); + } else { + printf("this format %d does not support \n", srcFormat); + return; + } + if (img.empty()) { + std::cout << "opencv read image " << img_path.c_str() << " failed" + << std::endl; + return; + } + int srch = img.rows; + int srcw = img.cols; + int dsth = height; + int dstw = width; + + std::cout << " input tensor size, num= " << 1 << ", channel= " << 1 + << ", height= " << srch << ", width= " << srcw + << ", srcFormat= " << (ImageFormat)srcFormat << std::endl; + // RGBA = 0, BGRA, RGB, BGR, GRAY, NV21 = 11, NV12, + if (srcFormat == ImageFormat::GRAY) { + std::cout << "srcFormat: GRAY" << std::endl; + } + if (srcFormat == ImageFormat::BGR) { + std::cout << "srcFormat: BGR" << std::endl; + } + if (srcFormat == ImageFormat::RGB) { + std::cout << "srcFormat: RGB" << std::endl; + } + std::cout << " output tensor size, num=" << 1 << ", channel=" << 1 + << ", height=" << dsth << ", width=" << dstw + << ", dstFormat= " << (ImageFormat)dstFormat << std::endl; + + if (dstFormat == ImageFormat::GRAY) { + std::cout << "dstFormat: GRAY" << std::endl; + } + if (dstFormat == ImageFormat::BGR) { + std::cout << "dstFormat: BGR" << std::endl; + } + if (dstFormat == ImageFormat::RGB) { + std::cout << "dstFormat: RGB" << std::endl; + } + + std::cout << "Rotate = " << rotate << ", Flip = " << flip + << ", Layout = " << static_cast(layout) << std::endl; + if (static_cast(layout) != 1 && static_cast(layout) != 3) { + std::cout << "this layout" << static_cast(layout) + << " is no support" << std::endl; + } + int size = 3 * srch * srcw; + if (srcFormat == ImageFormat::BGR || srcFormat == ImageFormat::RGB) { + size = 3 * srch * srcw; + } else if (srcFormat == ImageFormat::GRAY) { + size = srch * srcw; + } + uint8_t* src = img.data; + + int out_size = srch * srcw; + int resize = dstw * dsth; + if (dstFormat == ImageFormat::BGR || dstFormat == ImageFormat::RGB) { + out_size = 3 * srch * srcw; + resize = 3 * dsth * dstw; + } else if (dstFormat == ImageFormat::GRAY) { + out_size = srch * srcw; + resize = dsth * dstw; + } + // out + uint8_t* lite_dst = new uint8_t[out_size]; + uint8_t* resize_tmp = new uint8_t[resize]; + uint8_t* tv_out_ratote = new uint8_t[out_size]; + uint8_t* tv_out_flip = new uint8_t[out_size]; + std::vector shape_out = {1, 3, srch, srcw}; + + input_tensor->Resize(shape_out); + Tensor dst_tensor = *input_tensor; + std::cout << "opencv compute" << std::endl; + cv::Mat im_convert; + cv::Mat im_resize; + cv::Mat im_rotate; + cv::Mat im_flip; + double to_1 = 0; + double to_2 = 0; + double to_3 = 0; + double to_4 = 0; + double to1 = 0; + for (int i = 0; i < test_iter; i++) { + clock_t start = clock(); + clock_t begin = clock(); + // convert bgr-gray + if (dstFormat == srcFormat) { + im_convert = img; + } else if (dstFormat == ImageFormat::BGR && + srcFormat == ImageFormat::GRAY) { + cv::cvtColor(img, im_convert, cv::COLOR_GRAY2BGR); + } else if (srcFormat == ImageFormat::BGR && + dstFormat == ImageFormat::GRAY) { + cv::cvtColor(img, im_convert, cv::COLOR_BGR2GRAY); + } else if (dstFormat == srcFormat) { + printf("convert format error \n"); + return; + } + clock_t end = clock(); + to_1 += (end - begin); + + begin = clock(); + // resize default linear + cv::resize(im_convert, im_resize, cv::Size(dstw, dsth), 0.f, 0.f); + end = clock(); + to_2 += (end - begin); + + begin = clock(); + // rotate 90 + if (rotate == 90) { + cv::flip(im_convert.t(), im_rotate, 1); + } else if (rotate == 180) { + cv::flip(im_convert, im_rotate, -1); + } else if (rotate == 270) { + cv::flip(im_convert.t(), im_rotate, 0); + } + end = clock(); + to_3 += (end - begin); + + begin = clock(); + // flip + cv::flip(im_convert, im_flip, flip); + end = clock(); + to_4 += (end - begin); + clock_t ovet = clock(); + to1 += (ovet - start); + } + + std::cout << "Paddle-lite compute" << std::endl; + double lite_to = 0; + double lite_to_1 = 0; + double lite_to_2 = 0; + double lite_to_3 = 0; + double lite_to_4 = 0; + double lite_to_5 = 0; + TransParam tparam; + tparam.ih = srch; + tparam.iw = srcw; + tparam.oh = dsth; + tparam.ow = dstw; + tparam.flip_param = flip; + tparam.rotate_param = rotate; + + ImagePreprocess image_preprocess(srcFormat, dstFormat, tparam); + + for (int i = 0; i < test_iter; ++i) { + clock_t start = clock(); + clock_t begin = clock(); + image_preprocess.imageConvert(src, lite_dst); + clock_t end = clock(); + lite_to_1 += (end - begin); + + begin = clock(); + image_preprocess.imageResize(lite_dst, resize_tmp); + end = clock(); + lite_to_2 += (end - begin); + + begin = clock(); + image_preprocess.imageRotate( + lite_dst, tv_out_ratote, (ImageFormat)dstFormat, srcw, srch, 90); + end = clock(); + lite_to_3 += (end - begin); + + begin = clock(); + image_preprocess.imageFlip( + lite_dst, tv_out_flip, (ImageFormat)dstFormat, srcw, srch, flip); + end = clock(); + lite_to_4 += (end - begin); + + clock_t over = clock(); + lite_to += (over - start); + + begin = clock(); + image_preprocess.image2Tensor(lite_dst, + &dst_tensor, + (ImageFormat)dstFormat, + srcw, + srch, + layout, + means, + scales); + end = clock(); + lite_to_5 += (end - begin); + } + to_1 = 1000 * to_1 / CLOCKS_PER_SEC; + to_2 = 1000 * to_2 / CLOCKS_PER_SEC; + to_3 = 1000 * to_3 / CLOCKS_PER_SEC; + to_4 = 1000 * to_4 / CLOCKS_PER_SEC; + to1 = 1000 * to1 / CLOCKS_PER_SEC; + std::cout << "opencv convert run time: " << to_1 + << "ms, avg: " << to_1 / test_iter << std::endl; + std::cout << "opencv resize run time: " << to_2 + << "ms, avg: " << to_2 / test_iter << std::endl; + std::cout << "opencv rotate run time: " << to_3 + << "ms, avg: " << to_3 / test_iter << std::endl; + std::cout << "opencv flip time: " << to_4 + << "ms, avg: " << to_4 / test_iter << std::endl; + std::cout << "opencv total run time: " << to1 + << "ms, avg: " << to1 / test_iter << std::endl; + std::cout << "------" << std::endl; + + lite_to_1 = 1000 * lite_to_1 / CLOCKS_PER_SEC; + lite_to_2 = 1000 * lite_to_2 / CLOCKS_PER_SEC; + lite_to_3 = 1000 * lite_to_3 / CLOCKS_PER_SEC; + lite_to_4 = 1000 * lite_to_4 / CLOCKS_PER_SEC; + lite_to_5 = 1000 * lite_to_5 / CLOCKS_PER_SEC; + lite_to = 1000 * lite_to / CLOCKS_PER_SEC; + std::cout << "lite convert run time: " << lite_to_1 + << "ms, avg: " << lite_to_1 / test_iter << std::endl; + std::cout << "lite resize run time: " << lite_to_2 + << "ms, avg: " << lite_to_2 / test_iter << std::endl; + std::cout << "lite rotate run time: " << lite_to_3 + << "ms, avg: " << lite_to_3 / test_iter << std::endl; + std::cout << "lite flip time: " << lite_to_4 + << "ms, avg: " << lite_to_4 / test_iter << std::endl; + std::cout << "lite total run time: " << lite_to + << "ms, avg: " << lite_to / test_iter << std::endl; + std::cout << "lite img2tensor time: " << lite_to_5 + << "ms, avg: " << lite_to_5 / test_iter << std::endl; + std::cout << "------" << std::endl; + + double max_ratio = 0; + double max_diff = 0; + const double eps = 1e-6f; + // save_img + std::cout << "write image: " << std::endl; + std::string resize_name = dst_path + "/resize.jpg"; + std::string convert_name = dst_path + "/convert.jpg"; + std::string rotate_name = dst_path + "/rotate.jpg"; + std::string flip_name = dst_path + "/flip.jpg"; + cv::Mat resize_mat(dsth, dstw, CV_8UC3); + cv::Mat convert_mat(srch, srcw, CV_8UC3); + cv::Mat rotate_mat; + if (rotate == 90 || rotate == 270) { + rotate_mat = cv::Mat(srcw, srch, CV_8UC3); + } else { + rotate_mat = cv::Mat(srch, srcw, CV_8UC3); + } + cv::Mat flip_mat(srch, srcw, CV_8UC3); + fill_with_mat(resize_mat, resize_tmp); + fill_with_mat(convert_mat, lite_dst); + fill_with_mat(rotate_mat, tv_out_ratote); + fill_with_mat(flip_mat, tv_out_flip); + cv::imwrite(convert_name, convert_mat); + cv::imwrite(resize_name, resize_mat); + cv::imwrite(rotate_name, rotate_mat); + cv::imwrite(flip_name, flip_mat); + delete[] lite_dst; + delete[] resize_tmp; + delete[] tv_out_ratote; + delete[] tv_out_flip; + } + } +} + +int main(int argc, char** argv) { + if (argc < 7) { + std::cerr << "[ERROR] usage: " << argv[0] + << " image_path dst_apth srcFormat dstFormat width height\n"; + exit(1); + } + std::string image_path = argv[1]; + std::string dst_path = argv[2]; + int srcFormat = atoi(argv[3]); + int dstFormat = atoi(argv[4]); + int width = atoi(argv[5]); + int height = atoi(argv[6]); + int flip = -1; + float rotate = 90; + int layout = 1; + std::string model_dir = "mobilenet_v1"; + if (argc > 7) { + model_dir = argv[7]; + } + if (argc > 8) { + flip = atoi(argv[8]); + } + if (argc > 9) { + rotate = atoi(argv[9]); + } + if (argc > 10) { + layout = atoi(argv[10]); + } + test_img({3}, + {1, 2, 4}, + image_path, + dst_path, + (ImageFormat)srcFormat, + (ImageFormat)dstFormat, + width, + height, + rotate, + (FlipParam)flip, + (LayoutType)layout, + model_dir, + 20); + return 0; +} diff --git a/lite/demo/cxx/test_cv/test_model_cv.cc b/lite/demo/cxx/test_cv/test_model_cv.cc new file mode 100644 index 0000000000000000000000000000000000000000..24f408bf4a55ea2d499e39902201597c0e8c6e4e --- /dev/null +++ b/lite/demo/cxx/test_cv/test_model_cv.cc @@ -0,0 +1,224 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "opencv2/core.hpp" +#include "opencv2/imgcodecs.hpp" +#include "opencv2/imgproc.hpp" +#include "paddle_api.h" // NOLINT +#include "paddle_image_preprocess.h" // NOLINT +#include "time.h" // NOLINT + +using namespace paddle::lite_api; // NOLINT + +int64_t ShapeProduction(const shape_t& shape) { + int64_t res = 1; + for (auto i : shape) res *= i; + return res; +} +// fill tensor with mean and scale and trans layout: nhwc -> nchw, neon speed up +void neon_mean_scale( + const float* din, float* dout, int size, float* mean, float* scale) { + float32x4_t vmean0 = vdupq_n_f32(mean[0]); + float32x4_t vmean1 = vdupq_n_f32(mean[1]); + float32x4_t vmean2 = vdupq_n_f32(mean[2]); + float32x4_t vscale0 = vdupq_n_f32(1.f / scale[0]); + float32x4_t vscale1 = vdupq_n_f32(1.f / scale[1]); + float32x4_t vscale2 = vdupq_n_f32(1.f / scale[2]); + + float* dout_c0 = dout; + float* dout_c1 = dout + size; + float* dout_c2 = dout + size * 2; + + int i = 0; + for (; i < size - 3; i += 4) { + float32x4x3_t vin3 = vld3q_f32(din); + float32x4_t vsub0 = vsubq_f32(vin3.val[0], vmean0); + float32x4_t vsub1 = vsubq_f32(vin3.val[1], vmean1); + float32x4_t vsub2 = vsubq_f32(vin3.val[2], vmean2); + float32x4_t vs0 = vmulq_f32(vsub0, vscale0); + float32x4_t vs1 = vmulq_f32(vsub1, vscale1); + float32x4_t vs2 = vmulq_f32(vsub2, vscale2); + vst1q_f32(dout_c0, vs0); + vst1q_f32(dout_c1, vs1); + vst1q_f32(dout_c2, vs2); + + din += 12; + dout_c0 += 4; + dout_c1 += 4; + dout_c2 += 4; + } + for (; i < size; i++) { + *(dout_c0++) = (*(din++) - mean[0]) * scale[0]; + *(dout_c0++) = (*(din++) - mean[1]) * scale[1]; + *(dout_c0++) = (*(din++) - mean[2]) * scale[2]; + } +} +void pre_process(const cv::Mat& img, int width, int height, Tensor dstTensor) { +#ifdef LITE_WITH_CV + typedef paddle::lite::utils::cv::ImageFormat ImageFormat; + typedef paddle::lite::utils::cv::FlipParam FlipParam; + typedef paddle::lite::utils::cv::TransParam TransParam; + typedef paddle::lite::utils::cv::ImagePreprocess ImagePreprocess; + typedef paddle::lite_api::DataLayoutType LayoutType; + // init TransParam + TransParam tp; + tp.iw = img.cols; + tp.ih = img.rows; + tp.ow = width; + tp.oh = height; + ImageFormat srcFormat = ImageFormat::BGR; + ImageFormat dstFormat = ImageFormat::RGB; + // init ImagePreprocess + ImagePreprocess img_process(srcFormat, dstFormat, tp); + // init temp var + const uint8_t* img_ptr = reinterpret_cast(img.data); + uint8_t* rgb_ptr = new uint8_t[img.cols * img.rows * 3]; + uint8_t* resize_ptr = new uint8_t[width * height * 3]; + // do convert bgr--rgb + img_process.imageConvert(img_ptr, rgb_ptr); + // do resize + img_process.imageResize(rgb_ptr, resize_ptr); + // data--tensor and normalize + float means[3] = {103.94f, 116.78f, 123.68f}; + float scales[3] = {0.017f, 0.017f, 0.017f}; + img_process.image2Tensor( + resize_ptr, &dstTensor, LayoutType::kNCHW, means, scales); + float* data = dstTensor.mutable_data(); +#else + cv::Mat rgb_img; + cv::cvtColor(img, rgb_img, cv::COLOR_BGR2RGB); + cv::resize(rgb_img, rgb_img, cv::Size(width, height), 0.f, 0.f); + cv::Mat imgf; + rgb_img.convertTo(imgf, CV_32FC3, 1 / 255.f); + float means[3] = {0.485f, 0.456f, 0.406f}; + float scales[3] = {0.229f, 0.224f, 0.225f}; + const float* dimg = reinterpret_cast(imgf.data); + float* data = dstTensor.mutable_data(); + neon_mean_scale(dimg, data, width * height, means, scales); +#endif +} + +void RunModel(std::string model_dir, + std::string img_path, + std::vector input_shape, + PowerMode power_mode, + int thread_num, + int test_iter, + int warmup = 0) { + // 1. Set MobileConfig + MobileConfig config; + config.set_model_dir(model_dir); + config.set_power_mode(power_mode); + config.set_threads(thread_num); + + // 2. Create PaddlePredictor by MobileConfig + std::shared_ptr predictor = + CreatePaddlePredictor(config); + // 3. Prepare input data from image + std::unique_ptr input_tensor(std::move(predictor->GetInput(0))); + input_tensor->Resize( + {input_shape[0], input_shape[1], input_shape[2], input_shape[3]}); + auto* data = input_tensor->mutable_data(); + // read img and pre-process + cv::Mat img = imread(img_path, cv::IMREAD_COLOR); + + pre_process(img, input_shape[3], input_shape[2], *input_tensor); + + // 4. Run predictor + for (int i = 0; i < warmup; ++i) { + predictor->Run(); + } + double lps = 0.f; + double min_time = 1000000.f; + double max_time = 0.f; + for (int i = 0; i < test_iter; ++i) { + clock_t begin = clock(); + predictor->Run(); + clock_t end = clock(); + double t = (end - begin) * 1000; + t = t / CLOCKS_PER_SEC; + lps += t; + if (t < min_time) { + min_time = t; + } + if (t > max_time) { + max_time = t; + } + std::cout << "iter: " << i << ", time: " << t << " ms" << std::endl; + } + std::cout << "================== Speed Report ===================" + << std::endl; + std::cout << "Model: " << model_dir + << ", power_mode: " << static_cast(power_mode) + << ", threads num " << thread_num << ", warmup: " << warmup + << ", repeats: " << test_iter << ", avg time: " << lps / test_iter + << " ms" + << ", min time: " << min_time << " ms" + << ", max time: " << max_time << " ms." << std::endl; + + // 5. Get output and post process + std::unique_ptr output_tensor( + std::move(predictor->GetOutput(0))); + auto* outptr = output_tensor->data(); + auto shape_out = output_tensor->shape(); + int output_num = 1; + for (int i = 0; i < shape_out.size(); ++i) { + output_num *= shape_out[i]; + } + std::cout << "output_num: " << output_num << std::endl; + for (int i = 0; i < output_num; i += 100) { + std::cout << "i: " << i << ", out: " << outptr[i] << std::endl; + } +} + +int main(int argc, char** argv) { + if (argc < 7) { + std::cerr << "[ERROR] usage: " << argv[0] + << " model_dir image_path input_shape\n"; + exit(1); + } + std::string model_dir = argv[1]; + std::string img_path = argv[2]; + std::vector input_shape; + input_shape.push_back(atoi(argv[3])); + input_shape.push_back(atoi(argv[4])); + input_shape.push_back(atoi(argv[5])); + input_shape.push_back(atoi(argv[6])); + int power_mode = 3; + int threads = 1; + int test_iter = 100; + int warmup = 10; + if (argc > 7) { + power_mode = atoi(argv[7]); + } + if (argc > 8) { + threads = atoi(argv[8]); + } + if (argc > 9) { + test_iter = atoi(argv[9]); + } + if (argc > 10) { + warmup = atoi(argv[10]); + } + RunModel(model_dir, + img_path, + input_shape, + (PowerMode)power_mode, + threads, + test_iter, + warmup); + return 0; +} diff --git a/lite/demo/cxx/yolov3_detection/yolov3_detection.cc b/lite/demo/cxx/yolov3_detection/yolov3_detection.cc new file mode 100644 index 0000000000000000000000000000000000000000..a9beb1ed28de1f3c28eb5c03b3b660d518ee10c5 --- /dev/null +++ b/lite/demo/cxx/yolov3_detection/yolov3_detection.cc @@ -0,0 +1,238 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "opencv2/core.hpp" +#include "opencv2/imgcodecs.hpp" +#include "opencv2/imgproc.hpp" +#include "paddle_api.h" // NOLINT + +using namespace paddle::lite_api; // NOLINT + +struct Object { + cv::Rect rec; + int class_id; + float prob; +}; + +int64_t ShapeProduction(const shape_t& shape) { + int64_t res = 1; + for (auto i : shape) res *= i; + return res; +} + +const char* class_names[] = {"person", "bicycle", "car", + "motorcycle", "airplane", "bus", + "train", "truck", "boat", + "traffic light", "fire hydrant", "stop sign", + "parking meter", "bench", "bird", + "cat", "dog", "horse", + "sheep", "cow", "elephant", + "bear", "zebra", "giraffe", + "backpack", "umbrella", "handbag", + "tie", "suitcase", "frisbee", + "skis", "snowboard", "sports ball", + "kite", "baseball bat", "baseball glove", + "skateboard", "surfboard", "tennis racket", + "bottle", "wine glass", "cup", + "fork", "knife", "spoon", + "bowl", "banana", "apple", + "sandwich", "orange", "broccoli", + "carrot", "hot dog", "pizza", + "donut", "cake", "chair", + "couch", "potted plant", "bed", + "dining table", "toilet", "tv", + "laptop", "mouse", "remote", + "keyboard", "cell phone", "microwave", + "oven", "toaster", "sink", + "refrigerator", "book", "clock", + "vase", "scissors", "teddy bear", + "hair drier", "toothbrush"}; + +// fill tensor with mean and scale and trans layout: nhwc -> nchw, neon speed up +void neon_mean_scale(const float* din, + float* dout, + int size, + const std::vector mean, + const std::vector scale) { + if (mean.size() != 3 || scale.size() != 3) { + std::cerr << "[ERROR] mean or scale size must equal to 3\n"; + exit(1); + } + float32x4_t vmean0 = vdupq_n_f32(mean[0]); + float32x4_t vmean1 = vdupq_n_f32(mean[1]); + float32x4_t vmean2 = vdupq_n_f32(mean[2]); + float32x4_t vscale0 = vdupq_n_f32(1.f / scale[0]); + float32x4_t vscale1 = vdupq_n_f32(1.f / scale[1]); + float32x4_t vscale2 = vdupq_n_f32(1.f / scale[2]); + + float* dout_c0 = dout; + float* dout_c1 = dout + size; + float* dout_c2 = dout + size * 2; + + int i = 0; + for (; i < size - 3; i += 4) { + float32x4x3_t vin3 = vld3q_f32(din); + float32x4_t vsub0 = vsubq_f32(vin3.val[0], vmean0); + float32x4_t vsub1 = vsubq_f32(vin3.val[1], vmean1); + float32x4_t vsub2 = vsubq_f32(vin3.val[2], vmean2); + float32x4_t vs0 = vmulq_f32(vsub0, vscale0); + float32x4_t vs1 = vmulq_f32(vsub1, vscale1); + float32x4_t vs2 = vmulq_f32(vsub2, vscale2); + vst1q_f32(dout_c0, vs0); + vst1q_f32(dout_c1, vs1); + vst1q_f32(dout_c2, vs2); + + din += 12; + dout_c0 += 4; + dout_c1 += 4; + dout_c2 += 4; + } + for (; i < size; i++) { + *(dout_c0++) = (*(din++) - mean[0]) * scale[0]; + *(dout_c0++) = (*(din++) - mean[1]) * scale[1]; + *(dout_c0++) = (*(din++) - mean[2]) * scale[2]; + } +} + +void pre_process(const cv::Mat& img, int width, int height, float* data) { + cv::Mat rgb_img; + cv::cvtColor(img, rgb_img, cv::COLOR_BGR2RGB); + cv::resize( + rgb_img, rgb_img, cv::Size(width, height), 0.f, 0.f, cv::INTER_CUBIC); + cv::Mat imgf; + rgb_img.convertTo(imgf, CV_32FC3, 1 / 255.f); + std::vector mean = {0.485f, 0.456f, 0.406f}; + std::vector scale = {0.229f, 0.224f, 0.225f}; + const float* dimg = reinterpret_cast(imgf.data); + neon_mean_scale(dimg, data, width * height, mean, scale); +} + +std::vector detect_object(const float* data, + int count, + float thresh, + cv::Mat& image) { // NOLINT + if (data == nullptr) { + std::cerr << "[ERROR] data can not be nullptr\n"; + exit(1); + } + std::vector rect_out; + for (int iw = 0; iw < count; iw++) { + int oriw = image.cols; + int orih = image.rows; + if (data[1] > thresh) { + Object obj; + int x = static_cast(data[2]); + int y = static_cast(data[3]); + int w = static_cast(data[4] - data[2] + 1); + int h = static_cast(data[5] - data[3] + 1); + cv::Rect rec_clip = + cv::Rect(x, y, w, h) & cv::Rect(0, 0, image.cols, image.rows); + obj.class_id = static_cast(data[0]); + obj.prob = data[1]; + obj.rec = rec_clip; + if (w > 0 && h > 0 && obj.prob <= 1) { + rect_out.push_back(obj); + cv::rectangle(image, rec_clip, cv::Scalar(0, 0, 255), 1, cv::LINE_AA); + std::string str_prob = std::to_string(obj.prob); + std::string text = std::string(class_names[obj.class_id]) + ": " + + str_prob.substr(0, str_prob.find(".") + 4); + int font_face = cv::FONT_HERSHEY_COMPLEX_SMALL; + double font_scale = 1.f; + int thickness = 1; + cv::Size text_size = + cv::getTextSize(text, font_face, font_scale, thickness, nullptr); + float new_font_scale = w * 0.5 * font_scale / text_size.width; + text_size = cv::getTextSize( + text, font_face, new_font_scale, thickness, nullptr); + cv::Point origin; + origin.x = x + 3; + origin.y = y + text_size.height + 3; + cv::putText(image, + text, + origin, + font_face, + new_font_scale, + cv::Scalar(0, 255, 255), + thickness, + cv::LINE_AA); + + std::cout << "detection, image size: " << image.cols << ", " + << image.rows + << ", detect object: " << class_names[obj.class_id] + << ", score: " << obj.prob << ", location: x=" << x + << ", y=" << y << ", width=" << w << ", height=" << h + << std::endl; + } + } + data += 6; + } + return rect_out; +} + +void RunModel(std::string model_dir, std::string img_path) { + // 1. Set MobileConfig + MobileConfig config; + config.set_model_dir(model_dir); + + // 2. Create PaddlePredictor by MobileConfig + std::shared_ptr predictor = + CreatePaddlePredictor(config); + + const int in_width = 608; + const int in_height = 608; + + // 3. Prepare input data from image + // input 0 + std::unique_ptr input_tensor0(std::move(predictor->GetInput(0))); + input_tensor0->Resize({1, 3, in_height, in_width}); + auto* data0 = input_tensor0->mutable_data(); + cv::Mat img = imread(img_path, cv::IMREAD_COLOR); + pre_process(img, in_width, in_height, data0); + // input1 + std::unique_ptr input_tensor1(std::move(predictor->GetInput(1))); + input_tensor1->Resize({1, 2}); + auto* data1 = input_tensor1->mutable_data(); + data1[0] = img.rows; + data1[1] = img.cols; + + // 4. Run predictor + predictor->Run(); + + // 5. Get output and post process + std::unique_ptr output_tensor( + std::move(predictor->GetOutput(0))); + auto* outptr = output_tensor->data(); + auto shape_out = output_tensor->shape(); + int64_t cnt = 1; + for (auto& i : shape_out) { + cnt *= i; + } + auto rec_out = detect_object(outptr, static_cast(cnt / 6), 0.5f, img); + std::string result_name = + img_path.substr(0, img_path.find(".")) + "_yolov3_detection_result.jpg"; + cv::imwrite(result_name, img); +} + +int main(int argc, char** argv) { + if (argc < 3) { + std::cerr << "[ERROR] usage: " << argv[0] << " model_dir image_path\n"; + exit(1); + } + std::string model_dir = argv[1]; + std::string img_path = argv[2]; + RunModel(model_dir, img_path); + return 0; +} diff --git a/lite/fluid/eigen.h b/lite/fluid/eigen.h index eac5332b53c857b05aacbfa95ee2e4b9fcd98a93..c3af7e9f6c3588f404c614430bf01f7ab5e099e5 100644 --- a/lite/fluid/eigen.h +++ b/lite/fluid/eigen.h @@ -30,13 +30,20 @@ struct EigenDim { using Type = Eigen::DSizes; static Type From(const lite::DDim& dims) { - PADDLE_ENFORCE(dims.size() == D, "D must match DDim::size"); + PADDLE_ENFORCE_EQ(dims.size(), D, "D must match DDim::size"); Type ret; for (size_t d = 0; d < dims.size(); d++) { ret[d] = dims[d]; } return ret; } + + static Type From(const DDim::value_type length) { + PADDLE_ENFORCE_EQ(D, 1, "D must be 1."); + Type ret; + ret[0] = length; + return ret; + } }; // Interpret paddle::platform::Tensor as EigenTensor and EigenConstTensor. @@ -52,7 +59,7 @@ struct EigenTensor { using ConstType = Eigen::TensorMap>; - static Type From(Tensor& tensor, lite::DDim dims) { // NOLINT + static Type From(Tensor& tensor, const lite::DDim& dims) { // NOLINT return Type(const_cast(tensor.data()), EigenDim::From(dims)); // NOLINT } @@ -61,7 +68,7 @@ struct EigenTensor { return From(tensor, tensor.dims()); } // NOLINT - static ConstType From(const Tensor& tensor, lite::DDim dims) { + static ConstType From(const Tensor& tensor, const lite::DDim& dims) { return ConstType(tensor.data(), EigenDim::From(dims)); } @@ -97,14 +104,15 @@ template { // Flatten reshapes a Tensor into an EigenVector. static typename EigenVector::Type Flatten(Tensor& tensor) { // NOLINT - return EigenVector::From( - tensor, lite::DDim(std::vector({tensor.dims().production()}))); + return typename EigenVector::Type( + const_cast(tensor.data()), + EigenDim<1>::From(tensor.dims().production())); } static typename EigenVector::ConstType Flatten( const Tensor& tensor) { // NOLINT - return EigenVector::From( - tensor, lite::DDim(std::vector({tensor.dims().production()}))); + return typename EigenVector::ConstType( + tensor.data(), EigenDim<1>::From(tensor.dims().production())); } }; diff --git a/lite/kernels/CMakeLists.txt b/lite/kernels/CMakeLists.txt index 0bfd39ae9a0bdf6e8af606711fd4dcc6011994b5..4e0092b392eb31ce81f2a410ea86002b343f0aec 100644 --- a/lite/kernels/CMakeLists.txt +++ b/lite/kernels/CMakeLists.txt @@ -10,3 +10,4 @@ add_subdirectory(opencl) add_subdirectory(fpga) add_subdirectory(npu) add_subdirectory(xpu) +add_subdirectory(bm) diff --git a/lite/kernels/arm/CMakeLists.txt b/lite/kernels/arm/CMakeLists.txt index 731df6e6629826016cafc386284a17f754f83ece..60d5e3b5e234ef19cd144100d07441eb4acf48de 100644 --- a/lite/kernels/arm/CMakeLists.txt +++ b/lite/kernels/arm/CMakeLists.txt @@ -1,3 +1,12 @@ +# NOTE we leave the add_kernel not protected by LITE_WITH_LIGHT_WEIGHT_FRAMEWORK so that all the kernels will be registered +# to the model_optimize_tool. +if((NOT LITE_ON_MODEL_OPTIMIZE_TOOL) AND (NOT (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM))) + return() +endif() + +message(STATUS "compile with lite ARM kernels") + +# 1. basic kernels for basic models # for conv op add_kernel(conv_depthwise ARM basic SRCS conv_depthwise.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(conv_direct ARM basic SRCS conv_direct.cc DEPS ${lite_kernel_deps} math_arm) @@ -14,50 +23,65 @@ add_kernel(scale_compute_arm ARM basic SRCS scale_compute.cc DEPS ${lite_kernel_ add_kernel(softmax_compute_arm ARM basic SRCS softmax_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(batch_norm_compute_arm ARM basic SRCS batch_norm_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(elementwise_compute_arm ARM basic SRCS elementwise_compute.cc DEPS ${lite_kernel_deps} math_arm) -add_kernel(lrn_compute_arm ARM basic SRCS lrn_compute.cc DEPS ${lite_kernel_deps} math_arm) -add_kernel(decode_bboxes_compute_arm ARM basic SRCS decode_bboxes_compute.cc DEPS ${lite_kernel_deps} math_arm) + add_kernel(pool_compute_arm ARM basic SRCS pool_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(split_compute_arm ARM basic SRCS split_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(concat_compute_arm ARM basic SRCS concat_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(pad2d_compute_arm ARM basic SRCS pad2d_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(prior_box_compute_arm ARM basic SRCS prior_box_compute.cc DEPS ${lite_kernel_deps} math_arm) -add_kernel(density_prior_box_compute_arm ARM basic SRCS density_prior_box_compute.cc DEPS ${lite_kernel_deps} math_arm) -add_kernel(negative_compute_arm ARM basic SRCS negative_compute.cc DEPS ${lite_kernel_deps} math_arm) -add_kernel(crop_compute_arm ARM basic SRCS crop_compute.cc DEPS ${lite_kernel_deps} math_arm) -add_kernel(dropout_compute_arm ARM basic SRCS dropout_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(calib_compute_arm ARM basic SRCS calib_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(transpose_compute_arm ARM basic SRCS transpose_compute.cc DEPS ${lite_kernel_deps} math_arm) -add_kernel(power_compute_arm ARM basic SRCS power_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(yolo_box_compute_arm ARM basic SRCS yolo_box_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(shuffle_channel_compute_arm ARM basic SRCS shuffle_channel_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(argmax_compute_arm ARM basic SRCS argmax_compute.cc DEPS ${lite_kernel_deps} math_arm) -add_kernel(axpy_compute_arm ARM basic SRCS axpy_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(conv_transpose_compute_arm ARM basic SRCS conv_transpose_compute.cc DEPS ${lite_kernel_deps} math_arm) -add_kernel(norm_compute_arm ARM basic SRCS norm_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(interpolate_compute_arm ARM basic SRCS interpolate_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(box_coder_compute_arm ARM basic SRCS box_coder_compute.cc DEPS ${lite_kernel_deps} math_arm) -add_kernel(shape_compute_arm ARM basic SRCS shape_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(slice_compute_arm ARM basic SRCS slice_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(cast_compute_arm ARM basic SRCS cast_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(squeeze_compute_arm ARM basic SRCS squeeze_compute.cc DEPS ${lite_kernel_deps} math_arm) -add_kernel(unsqueeze_compute_arm ARM extra SRCS unsqueeze_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(unsqueeze_compute_arm ARM basic SRCS unsqueeze_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(expand_compute_arm ARM basic SRCS expand_compute.cc DEPS ${lite_kernel_deps} math_arm) -add_kernel(reduce_max_compute_arm ARM basic SRCS reduce_max_compute.cc DEPS ${lite_kernel_deps} math_arm) -add_kernel(sequence_expand_compute_arm ARM basic SRCS sequence_expand_compute.cc DEPS ${lite_kernel_deps} math_arm) -add_kernel(im2sequence_compute_arm ARM basic SRCS im2sequence_compute.cc DEPS ${lite_kernel_deps} math_arm) -add_kernel(sequence_pool_compute_arm ARM basic SRCS sequence_pool_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(reduce_mean_compute_arm ARM basic SRCS reduce_mean_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(stack_compute_arm ARM basic SRCS stack_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(affine_channel_compute_arm ARM basic SRCS affine_channel_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(range_compute_arm ARM basic SRCS range_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(dropout_compute_arm ARM basic SRCS dropout_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(layout_compute_arm ARM basic SRCS layout_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(instance_norm_compute_arm ARM basic SRCS instance_norm_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(grid_sampler_compute_arm ARM basic SRCS grid_sampler_compute.cc DEPS ${lite_kernel_deps} math_arm) + +## 2.other basic kernels: basic kernels that not used in basic models +add_kernel(negative_compute_arm ARM extra SRCS negative_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(crop_compute_arm ARM extra SRCS crop_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(power_compute_arm ARM extra SRCS power_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(norm_compute_arm ARM extra SRCS norm_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(assign_compute_arm ARM extra SRCS assign_compute.cc DEPS ${lite_kernel_deps} math_arm) + +## 3. extra kernels +add_kernel(lrn_compute_arm ARM extra SRCS lrn_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(decode_bboxes_compute_arm ARM extra SRCS decode_bboxes_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(density_prior_box_compute_arm ARM basic SRCS density_prior_box_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(axpy_compute_arm ARM extra SRCS axpy_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(shape_compute_arm ARM extra SRCS shape_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(reduce_max_compute_arm ARM extra SRCS reduce_max_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(sequence_expand_compute_arm ARM extra SRCS sequence_expand_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(im2sequence_compute_arm ARM extra SRCS im2sequence_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(sequence_pool_compute_arm ARM extra SRCS sequence_pool_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(layer_norm_compute_arm ARM extra SRCS layer_norm_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(gather_compute_arm ARM extra SRCS gather_compute.cc DEPS ${lite_kernel_deps} math_arm) -add_kernel(reduce_mean_compute_arm ARM extra SRCS reduce_mean_compute.cc DEPS ${lite_kernel_deps} math_arm) -add_kernel(stack_compute_arm ARM extra SRCS stack_compute.cc DEPS ${lite_kernel_deps} math_arm) -add_kernel(assign_compute_arm ARM extra SRCS assign_compute.cc DEPS ${lite_kernel_deps} math_arm) -add_kernel(affine_channel_compute_arm ARM extra SRCS affine_channel_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(reduce_prod_compute_arm ARM extra SRCS reduce_prod_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(split_lod_tensor_compute_arm ARM extra SRCS split_lod_tensor_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(merge_lod_tensor_compute_arm ARM extra SRCS merge_lod_tensor_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(anchor_generator_compute_arm ARM extra SRCS anchor_generator_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(generate_proposals_compute_arm ARM extra SRCS generate_proposals_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(roi_align_compute_arm ARM extra SRCS roi_align_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(box_clip_compute_arm ARM extra SRCS box_clip_compute.cc DEPS ${lite_kernel_deps} math_arm) -add_kernel(range_compute_arm ARM extra SRCS range_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(assign_value_compute_arm ARM extra SRCS assign_value_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(conditional_block_compute_arm ARM extra SRCS conditional_block_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(collect_fpn_proposals_compute_arm ARM extra SRCS collect_fpn_proposals_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(distribute_fpn_proposals_compute_arm ARM extra SRCS distribute_fpn_proposals_compute.cc DEPS ${lite_kernel_deps} math_arm) + # for OCR specific add_kernel(gru_unit_compute_arm ARM extra SRCS gru_unit_compute.cc DEPS ${lite_kernel_deps} math_arm) @@ -74,35 +98,28 @@ add_kernel(increment_compute_arm ARM extra SRCS increment_compute.cc DEPS ${lite add_kernel(write_to_array_compute_arm ARM extra SRCS write_to_array_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(read_from_array_compute_arm ARM extra SRCS read_from_array_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(beam_search_compute_arm ARM extra SRCS beam_search_compute.cc DEPS ${lite_kernel_deps} math_arm) -add_kernel(fill_constant_compute_arm ARM extra SRCS fill_constant_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(fill_constant_compute_arm ARM basic SRCS fill_constant_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(lod_reset_compute_arm ARM extra SRCS lod_reset_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(is_empty_compute_arm ARM extra SRCS is_empty_compute.cc DEPS ${lite_kernel_deps} math_arm) -# NOTE we leave the add_kernel not protected by LITE_WITH_LIGHT_WEIGHT_FRAMEWORK so that all the kernels will be registered -# to the model_optimize_tool. -if(NOT (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM)) - return() -endif() - -message(STATUS "compile with lite ARM kernels") lite_cc_test(test_scale_compute_arm SRCS scale_compute_test.cc DEPS scale_compute_arm) lite_cc_test(test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_compute_arm) lite_cc_test(test_batch_norm_compute_arm SRCS batch_norm_compute_test.cc DEPS batch_norm_compute_arm) lite_cc_test(test_elementwise_compute_arm SRCS elementwise_compute_test.cc DEPS elementwise_compute_arm) -lite_cc_test(test_lrn_compute_arm SRCS lrn_compute_test.cc DEPS lrn_compute_arm) -lite_cc_test(test_decode_bboxes_compute_arm SRCS decode_bboxes_compute_test.cc DEPS decode_bboxes_compute_arm) lite_cc_test(test_pool_compute_arm SRCS pool_compute_test.cc DEPS pool_compute_arm) lite_cc_test(test_mul_compute_arm SRCS mul_compute_test.cc DEPS mul_compute_arm) lite_cc_test(test_split_compute_arm SRCS split_compute_test.cc DEPS split_compute_arm) lite_cc_test(test_concat_compute_arm SRCS concat_compute_test.cc DEPS concat_compute_arm) -lite_cc_test(test_dropout_compute_arm SRCS dropout_compute_test.cc DEPS dropout_compute_arm) lite_cc_test(test_transpose_compute_arm SRCS transpose_compute_test.cc DEPS transpose_compute_arm COMPILE_LEVEL extra) lite_cc_test(test_argmax_compute_arm SRCS argmax_compute_test.cc DEPS argmax_compute_arm) -lite_cc_test(test_axpy_compute_arm SRCS axpy_compute_test.cc DEPS axpy_compute_arm) -lite_cc_test(test_conv_transpose_compute_arm SRCS conv_transpose_compute_test.cc DEPS conv_transpose_compute_arm) - +lite_cc_test(test_dropout_compute_arm SRCS dropout_compute_test.cc DEPS dropout_compute_arm) if(LITE_BUILD_EXTRA) + lite_cc_test(test_split_lod_tensor_compute_arm SRCS split_lod_tensor_compute_test.cc DEPS split_lod_tensor_compute_arm) + lite_cc_test(test_merge_lod_tensor_compute_arm SRCS merge_lod_tensor_compute_test.cc DEPS merge_lod_tensor_compute_arm) + lite_cc_test(test_lrn_compute_arm SRCS lrn_compute_test.cc DEPS lrn_compute_arm) + lite_cc_test(test_decode_bboxes_compute_arm SRCS decode_bboxes_compute_test.cc DEPS decode_bboxes_compute_arm) + lite_cc_test(test_axpy_compute_arm SRCS axpy_compute_test.cc DEPS axpy_compute_arm) lite_cc_test(test_layer_norm_compute_arm SRCS layer_norm_compute_test.cc DEPS layer_norm_compute_arm) lite_cc_test(test_lookup_table_compute_arm SRCS lookup_table_compute_test.cc DEPS lookup_table_compute_arm) endif() diff --git a/lite/kernels/arm/cast_compute.cc b/lite/kernels/arm/cast_compute.cc index bc274ea22485e84a1cc9145e62fc967f2847c5dd..266ae1fc916af4303aca274c39b9b4923fdbb154 100644 --- a/lite/kernels/arm/cast_compute.cc +++ b/lite/kernels/arm/cast_compute.cc @@ -56,6 +56,12 @@ void CastCompute::Run() { float* out_data = param.Out->mutable_data(); std::transform( x_data_begin, x_data_end, out_data, TransOp); + } else if (param.in_dtype == 3 && param.out_dtype == 2) { + const int64_t* x_data_begin = param.X->data(); + const int64_t* x_data_end = x_data_begin + param.X->numel(); + int32_t* out_data = param.Out->mutable_data(); + std::transform( + x_data_begin, x_data_end, out_data, TransOp); } else { LOG(FATAL) << "other has not been implemented"; } @@ -68,6 +74,6 @@ void CastCompute::Run() { REGISTER_LITE_KERNEL( cast, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::CastCompute, def) - .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) .Finalize(); diff --git a/lite/kernels/arm/collect_fpn_proposals_compute.cc b/lite/kernels/arm/collect_fpn_proposals_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..d54b96348e866bbe16898ddd6fdbd45beb62afa0 --- /dev/null +++ b/lite/kernels/arm/collect_fpn_proposals_compute.cc @@ -0,0 +1,147 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/collect_fpn_proposals_compute.h" +#include +#include +#include "lite/backends/arm/math/funcs.h" +#include "lite/core/op_registry.h" +#include "lite/core/tensor.h" +#include "lite/core/type_system.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +struct ScoreWithID { + float score; + int batch_id; + int index; + int level; + ScoreWithID() { + batch_id = -1; + index = -1; + level = -1; + } + ScoreWithID(float score_, int batch_id_, int index_, int level_) { + score = score_; + batch_id = batch_id_; + index = index_; + level = level_; + } +}; + +static inline bool CompareByScore(ScoreWithID a, ScoreWithID b) { + return a.score >= b.score; +} + +static inline bool CompareByBatchid(ScoreWithID a, ScoreWithID b) { + return a.batch_id < b.batch_id; +} + +void CollectFpnProposalsCompute::Run() { + auto& param = Param(); + auto multi_layer_rois = param.multi_level_rois; + auto multi_layer_scores = param.multi_level_scores; + auto* fpn_rois = param.fpn_rois; + int post_nms_topN = param.post_nms_topN; + + if (multi_layer_rois.size() != multi_layer_scores.size()) { + LOG(FATAL) << "multi_layer_rois.size() should be equan to " + "multi_layer_scores.size()"; + } + + size_t num_fpn_level = multi_layer_rois.size(); + std::vector integral_of_all_rois(num_fpn_level + 1, 0); + for (size_t i = 0; i < num_fpn_level; ++i) { + auto cur_rois_lod = multi_layer_rois[i]->lod().back(); + integral_of_all_rois[i + 1] = static_cast( + integral_of_all_rois[i] + cur_rois_lod[cur_rois_lod.size() - 1]); + } + + std::vector scores_of_all_rois( + integral_of_all_rois[num_fpn_level], ScoreWithID()); + for (int i = 0; i < num_fpn_level; ++i) { + const float* cur_level_scores = multi_layer_scores[i]->data(); + int cur_level_num = integral_of_all_rois[i + 1] - integral_of_all_rois[i]; + auto cur_scores_lod = multi_layer_scores[i]->lod().back(); + int cur_batch_id = 0; + for (int j = 0; j < cur_level_num; ++j) { + if (j >= cur_scores_lod[cur_batch_id + 1]) { + cur_batch_id++; + } + int cur_index = j + integral_of_all_rois[i]; + scores_of_all_rois[cur_index].score = cur_level_scores[j]; + scores_of_all_rois[cur_index].index = j; + scores_of_all_rois[cur_index].level = i; + scores_of_all_rois[cur_index].batch_id = cur_batch_id; + } + } + + // keep top post_nms_topN rois, sort the rois by the score + if (post_nms_topN > integral_of_all_rois[num_fpn_level]) { + post_nms_topN = integral_of_all_rois[num_fpn_level]; + } + std::stable_sort( + scores_of_all_rois.begin(), scores_of_all_rois.end(), CompareByScore); + scores_of_all_rois.resize(post_nms_topN); + // sort by batch id + std::stable_sort( + scores_of_all_rois.begin(), scores_of_all_rois.end(), CompareByBatchid); + // create a pointer array + std::vector multi_fpn_rois_data(num_fpn_level); + for (int i = 0; i < num_fpn_level; ++i) { + multi_fpn_rois_data[i] = multi_layer_rois[i]->data(); + } + + // initialize the outputs + const int kBoxDim = 4; + auto fpn_rois_data = fpn_rois->mutable_data(); + std::vector lod0(1, 0); + int cur_batch_id = 0; + for (int i = 0; i < post_nms_topN; ++i) { + int cur_fpn_level = scores_of_all_rois[i].level; + int cur_level_index = scores_of_all_rois[i].index; + std::memcpy(fpn_rois_data, + multi_fpn_rois_data[cur_fpn_level] + cur_level_index * kBoxDim, + kBoxDim * sizeof(float)); + fpn_rois_data += kBoxDim; + if (scores_of_all_rois[i].batch_id != cur_batch_id) { + cur_batch_id = scores_of_all_rois[i].batch_id; + lod0.emplace_back(i); + } + } + lod0.emplace_back(post_nms_topN); + lite::LoD lod; + lod.emplace_back(lod0); + fpn_rois->set_lod(lod); + return; +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(collect_fpn_proposals, + kARM, + kFloat, + kNCHW, + paddle::lite::kernels::arm::CollectFpnProposalsCompute, + def) + .BindInput("MultiLevelRois", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("MultiLevelScores", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("FpnRois", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/lite/kernels/xpu/bridges/registry.cc b/lite/kernels/arm/collect_fpn_proposals_compute.h similarity index 62% rename from lite/kernels/xpu/bridges/registry.cc rename to lite/kernels/arm/collect_fpn_proposals_compute.h index 4ab1b69a25a29aeb1c1ceaff25525459ef2e94cd..f1e7448a07aee4f9c2b57a1c6d2223f4262c59b4 100644 --- a/lite/kernels/xpu/bridges/registry.cc +++ b/lite/kernels/arm/collect_fpn_proposals_compute.h @@ -12,30 +12,27 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "lite/kernels/xpu/bridges/registry.h" -#include +#pragma once +#include +#include "lite/core/kernel.h" +#include "lite/operators/axpy_op.h" namespace paddle { namespace lite { namespace kernels { -namespace xpu { -namespace bridges { +namespace arm { -Factory& Factory::Instance() { - static Factory g_xpu_bridge; - return g_xpu_bridge; -} +class CollectFpnProposalsCompute + : public KernelLite { + public: + using param_t = operators::CollectFpnProposalsParam; -bool Factory::HasType(const std::string& op_type) const { - return map_.count(op_type); -} + void Run() override; -void Factory::Insert(const std::string& op_type, const func_type& func_name) { - map_.insert(std::make_pair(op_type, func_name)); -} + virtual ~CollectFpnProposalsCompute() = default; +}; -} // namespace bridges -} // namespace xpu +} // namespace arm } // namespace kernels } // namespace lite } // namespace paddle diff --git a/lite/kernels/arm/compare_compute.cc b/lite/kernels/arm/compare_compute.cc index 95014b4ccd427e152dfe919643afa5ff5eb3011d..6118cbc6e403645cada84d2434497b084636a4a3 100644 --- a/lite/kernels/arm/compare_compute.cc +++ b/lite/kernels/arm/compare_compute.cc @@ -112,6 +112,42 @@ void CompareCompute::Run() { } } +template